Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* @vitest-environment node
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('@/lib/auth/auth-client', () => ({
client: { oauth2: { link: vi.fn() } },
useSession: vi.fn(() => ({ data: null, isPending: false, error: null })),
}))

vi.mock('@/lib/credentials/client-state', () => ({
writeOAuthReturnContext: vi.fn(),
}))

vi.mock('@/hooks/queries/credentials', () => ({
useCreateCredentialDraft: vi.fn(() => ({ mutateAsync: vi.fn(), isPending: false })),
}))

vi.mock('@/lib/oauth', () => ({
getCanonicalScopesForProvider: vi.fn(() => []),
getProviderIdFromServiceId: vi.fn((id: string) => id),
OAUTH_PROVIDERS: {},
parseProvider: vi.fn((p: string) => ({ baseProvider: p, variant: null })),
}))

vi.mock('@/lib/oauth/utils', () => ({
getScopeDescription: vi.fn((s: string) => s),
}))

import { getDefaultCredentialName } from '@/app/workspace/[workspaceId]/components/oauth-modal'

describe('getDefaultCredentialName', () => {
beforeEach(() => {
vi.clearAllMocks()
})

it('uses the user name when available', () => {
expect(getDefaultCredentialName('Waleed', 'Google Drive', 0)).toBe("Waleed's Google Drive 1")
})

it('increments the number based on existing credential count', () => {
expect(getDefaultCredentialName('Waleed', 'Google Drive', 2)).toBe("Waleed's Google Drive 3")
})

it('falls back to "My" when user name is null', () => {
expect(getDefaultCredentialName(null, 'Slack', 0)).toBe('My Slack 1')
})

it('falls back to "My" when user name is undefined', () => {
expect(getDefaultCredentialName(undefined, 'Gmail', 1)).toBe('My Gmail 2')
})

it('falls back to "My" when user name is empty string', () => {
expect(getDefaultCredentialName('', 'GitHub', 0)).toBe('My GitHub 1')
})

it('falls back to "My" when user name is whitespace-only', () => {
expect(getDefaultCredentialName(' ', 'Notion', 0)).toBe('My Notion 1')
})

it('trims whitespace from user name', () => {
expect(getDefaultCredentialName(' Waleed ', 'Linear', 0)).toBe("Waleed's Linear 1")
})

it('works with zero existing credentials', () => {
expect(getDefaultCredentialName('Alice', 'Jira', 0)).toBe("Alice's Jira 1")
})

it('works with many existing credentials', () => {
expect(getDefaultCredentialName('Bob', 'Slack', 9)).toBe("Bob's Slack 10")
})
})
304 changes: 304 additions & 0 deletions apps/sim/app/workspace/[workspaceId]/components/oauth-modal.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
'use client'

import { useMemo, useState } from 'react'
import { createLogger } from '@sim/logger'
import { Check } from 'lucide-react'
import {
Badge,
Button,
Input,
Label,
Modal,
ModalBody,
ModalContent,
ModalFooter,
ModalHeader,
} from '@/components/emcn'
import { client, useSession } from '@/lib/auth/auth-client'
import type { OAuthReturnContext } from '@/lib/credentials/client-state'
import { writeOAuthReturnContext } from '@/lib/credentials/client-state'
import {
getCanonicalScopesForProvider,
getProviderIdFromServiceId,
OAUTH_PROVIDERS,
type OAuthProvider,
parseProvider,
} from '@/lib/oauth'
import { getScopeDescription } from '@/lib/oauth/utils'
import { useCreateCredentialDraft } from '@/hooks/queries/credentials'

const logger = createLogger('OAuthModal')
const EMPTY_SCOPES: string[] = []

/**
* Generates a default credential display name.
* Format: "{User}'s {Provider} {N}" or "My {Provider} {N}" when no user name is available.
*/
export function getDefaultCredentialName(
userName: string | null | undefined,
providerName: string,
credentialCount: number
): string {
const trimmed = userName?.trim()
const num = credentialCount + 1
if (trimmed) {
return `${trimmed}'s ${providerName} ${num}`
}
return `My ${providerName} ${num}`
}

interface OAuthModalBaseProps {
isOpen: boolean
onClose: () => void
provider: OAuthProvider
serviceId: string
}

type OAuthModalConnectProps = OAuthModalBaseProps & {
mode: 'connect'
workspaceId: string
credentialCount: number
} & (
| { workflowId: string; knowledgeBaseId?: never }
| { workflowId?: never; knowledgeBaseId: string }
)

interface OAuthModalReauthorizeProps extends OAuthModalBaseProps {
mode: 'reauthorize'
toolName: string
requiredScopes?: string[]
newScopes?: string[]
onConnect?: () => Promise<void> | void
}

export type OAuthModalProps = OAuthModalConnectProps | OAuthModalReauthorizeProps

export function OAuthModal(props: OAuthModalProps) {
const { isOpen, onClose, provider, serviceId, mode } = props

const isConnect = mode === 'connect'
const credentialCount = isConnect ? props.credentialCount : 0
const workspaceId = isConnect ? props.workspaceId : ''
const workflowId = isConnect ? props.workflowId : undefined
const knowledgeBaseId = isConnect ? props.knowledgeBaseId : undefined
const toolName = !isConnect ? props.toolName : ''
const requiredScopes = !isConnect ? (props.requiredScopes ?? EMPTY_SCOPES) : EMPTY_SCOPES
const newScopes = !isConnect ? (props.newScopes ?? EMPTY_SCOPES) : EMPTY_SCOPES
const onConnectOverride = !isConnect ? props.onConnect : undefined

const { data: session } = useSession()
const [error, setError] = useState<string | null>(null)
const createDraft = useCreateCredentialDraft()

const { providerName, ProviderIcon } = useMemo(() => {
const { baseProvider } = parseProvider(provider)
const baseProviderConfig = OAUTH_PROVIDERS[baseProvider]
let name = baseProviderConfig?.name || provider
let Icon = baseProviderConfig?.icon || (() => null)
if (baseProviderConfig) {
for (const [key, service] of Object.entries(baseProviderConfig.services)) {
if (key === serviceId || service.providerId === provider) {
name = service.name
Icon = service.icon
break
}
}
}
return { providerName: name, ProviderIcon: Icon }
}, [provider, serviceId])

const providerId = getProviderIdFromServiceId(serviceId)

const [displayName, setDisplayName] = useState(() =>
isConnect ? getDefaultCredentialName(session?.user?.name, providerName, credentialCount) : ''
)

const newScopesSet = useMemo(
() =>
new Set(
newScopes.filter(
(scope) => !scope.includes('userinfo.email') && !scope.includes('userinfo.profile')
)
),
[newScopes]
)

const displayScopes = useMemo(() => {
if (isConnect) {
return getCanonicalScopesForProvider(providerId).filter(
(scope) => !scope.includes('userinfo.email') && !scope.includes('userinfo.profile')
)
}
const filtered = requiredScopes.filter(
(scope) => !scope.includes('userinfo.email') && !scope.includes('userinfo.profile')
)
return filtered.sort((a, b) => {
const aIsNew = newScopesSet.has(a)
const bIsNew = newScopesSet.has(b)
if (aIsNew && !bIsNew) return -1
if (!aIsNew && bIsNew) return 1
return 0
})
}, [isConnect, providerId, requiredScopes, newScopesSet])

const handleClose = () => {
setError(null)
onClose()
}

const handleConnect = async () => {
setError(null)

try {
if (isConnect) {
const trimmedName = displayName.trim()
if (!trimmedName) {
setError('Display name is required.')
return
}

await createDraft.mutateAsync({
workspaceId,
providerId,
displayName: trimmedName,
})

const baseContext = {
displayName: trimmedName,
providerId,
preCount: credentialCount,
workspaceId,
requestedAt: Date.now(),
}

const returnContext: OAuthReturnContext = knowledgeBaseId
? { ...baseContext, origin: 'kb-connectors' as const, knowledgeBaseId }
: { ...baseContext, origin: 'workflow' as const, workflowId: workflowId! }

writeOAuthReturnContext(returnContext)
}

if (!isConnect && onConnectOverride) {
await onConnectOverride()
onClose()
return
}

if (!isConnect) {
logger.info('Linking OAuth2:', {
providerId,
requiredScopes,
hasNewScopes: newScopes.length > 0,
})
}

if (providerId === 'trello') {
if (!isConnect) onClose()
window.location.href = '/api/auth/trello/authorize'
return
}

if (providerId === 'shopify') {
if (!isConnect) onClose()
const returnUrl = encodeURIComponent(window.location.href)
window.location.href = `/api/auth/shopify/authorize?returnUrl=${returnUrl}`
return
}

await client.oauth2.link({ providerId, callbackURL: window.location.href })
handleClose()
} catch (err) {
logger.error('Failed to initiate OAuth connection', { error: err })
setError('Failed to connect. Please try again.')
}
}

const isPending = isConnect && createDraft.isPending
const isConnectDisabled = isConnect ? !displayName.trim() || Boolean(isPending) : false

const subtitle = isConnect
? `Grant access to use ${providerName} in your ${knowledgeBaseId ? 'knowledge base' : 'workflow'}`
: `The "${toolName}" tool requires access to your account`

return (
<Modal open={isOpen} onOpenChange={(open) => !open && handleClose()}>
<ModalContent size='md'>
<ModalHeader>Connect {providerName}</ModalHeader>
<ModalBody>
<div className='flex flex-col gap-4'>
<div className='flex items-center gap-3'>
<div className='flex h-[40px] w-[40px] flex-shrink-0 items-center justify-center rounded-lg bg-[var(--surface-5)]'>
<ProviderIcon className='h-[18px] w-[18px]' />
</div>
<div className='flex-1'>
<p className='font-medium text-[var(--text-primary)] text-small'>
Connect your {providerName} account
</p>
<p className='text-[var(--text-tertiary)] text-caption'>{subtitle}</p>
</div>
</div>

{displayScopes.length > 0 && (
<div className='rounded-lg border border-[var(--border-1)] bg-[var(--surface-5)]'>
<div className='border-[var(--border-1)] border-b px-3.5 py-2.5'>
<h4 className='font-medium text-[var(--text-primary)] text-caption'>
Permissions requested
</h4>
</div>
<ul className='max-h-[200px] space-y-2.5 overflow-y-auto px-3.5 py-3'>
{displayScopes.map((scope) => (
<li key={scope} className='flex items-start gap-2.5'>
<div className='mt-0.5 flex h-[16px] w-[16px] flex-shrink-0 items-center justify-center'>
<Check className='h-[10px] w-[10px] text-[var(--text-primary)]' />
</div>
<div className='flex flex-1 items-center gap-2 text-[var(--text-primary)] text-caption'>
<span>{getScopeDescription(scope)}</span>
{!isConnect && newScopesSet.has(scope) && (
<Badge variant='amber' size='sm'>
New
</Badge>
)}
</div>
</li>
))}
</ul>
</div>
)}

{isConnect && (
<div>
<Label>
Display name <span className='text-[var(--text-muted)]'>*</span>
</Label>
<Input
value={displayName}
onChange={(e) => {
setDisplayName(e.target.value)
setError(null)
}}
onKeyDown={(e) => {
if (e.key === 'Enter' && !isPending) void handleConnect()
}}
placeholder={`My ${providerName} account`}
autoComplete='off'
data-lpignore='true'
className='mt-1.5'
/>
</div>
)}

{error && <p className='text-[var(--text-error)] text-caption'>{error}</p>}
</div>
</ModalBody>
<ModalFooter>
<Button variant='default' onClick={handleClose} disabled={Boolean(isPending)}>
Cancel
</Button>
<Button variant='primary' onClick={handleConnect} disabled={isConnectDisabled}>
{isPending ? 'Connecting...' : 'Connect'}
</Button>
</ModalFooter>
</ModalContent>
</Modal>
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import {
} from '@/components/emcn'
import { consumeOAuthReturnContext } from '@/lib/credentials/client-state'
import { getProviderIdFromServiceId, type OAuthProvider } from '@/lib/oauth'
import { OAuthModal } from '@/app/workspace/[workspaceId]/components/oauth-modal'
import { ConnectorSelectorField } from '@/app/workspace/[workspaceId]/knowledge/[id]/components/add-connector-modal/components/connector-selector-field'
import { ConnectCredentialModal } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/editor/components/sub-block/components/credential-selector/components/connect-credential-modal'
import { getDependsOnFields } from '@/blocks/utils'
import { CONNECTOR_REGISTRY } from '@/connectors/registry'
import type { ConnectorConfig, ConnectorConfigField } from '@/connectors/types'
Expand Down Expand Up @@ -553,7 +553,8 @@ export function AddConnectorModal({ open, onOpenChange, knowledgeBaseId }: AddCo
connectorConfig &&
connectorConfig.auth.mode === 'oauth' &&
connectorProviderId && (
<ConnectCredentialModal
<OAuthModal
mode='connect'
isOpen={showOAuthModal}
onClose={() => {
consumeOAuthReturnContext()
Expand Down
Loading
Loading