A vibe coded tangled fork which supports pijul.
1package oauth
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/url"
10 "strings"
11 "sync"
12 "time"
13
14 comatproto "github.com/bluesky-social/indigo/api/atproto"
15 "github.com/bluesky-social/indigo/atproto/auth/oauth"
16 atpclient "github.com/bluesky-social/indigo/atproto/client"
17 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
18 "github.com/bluesky-social/indigo/atproto/syntax"
19 xrpc "github.com/bluesky-social/indigo/xrpc"
20 "github.com/gorilla/sessions"
21 "github.com/posthog/posthog-go"
22 "tangled.org/core/appview/config"
23 "tangled.org/core/appview/db"
24 "tangled.org/core/idresolver"
25 "tangled.org/core/rbac"
26)
27
28type OAuth struct {
29 ClientApp *oauth.ClientApp
30 SessStore *sessions.CookieStore
31 Config *config.Config
32 JwksUri string
33 ClientName string
34 ClientUri string
35 Posthog posthog.Client
36 Db *db.DB
37 Enforcer *rbac.Enforcer
38 IdResolver *idresolver.Resolver
39 Logger *slog.Logger
40
41 appPasswordSession *AppPasswordSession
42 appPasswordSessionMu sync.Mutex
43}
44
45func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
46 var oauthConfig oauth.ClientConfig
47 var clientUri string
48 if config.Core.Dev {
49 clientUri = "http://127.0.0.1:3000"
50 callbackUri := clientUri + "/oauth/callback"
51 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes)
52 } else {
53 clientUri = "https://" + config.Core.AppviewHost
54 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
55 callbackUri := clientUri + "/oauth/callback"
56 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes)
57 }
58
59 // configure client secret
60 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
61 if err != nil {
62 return nil, err
63 }
64 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
65 return nil, err
66 }
67
68 jwksUri := clientUri + "/oauth/jwks.json"
69
70 authStore, err := NewRedisStore(&RedisStoreConfig{
71 RedisURL: config.Redis.ToURL(),
72 SessionExpiryDuration: time.Hour * 24 * 90,
73 SessionInactivityDuration: time.Hour * 24 * 14,
74 AuthRequestExpiryDuration: time.Minute * 30,
75 })
76 if err != nil {
77 return nil, err
78 }
79
80 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
81
82 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
83 clientApp.Dir = res.Directory()
84 // allow non-public transports in dev mode
85 if config.Core.Dev {
86 clientApp.Resolver.Client.Transport = http.DefaultTransport
87 }
88
89 clientName := config.Core.AppviewName
90
91 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
92 return &OAuth{
93 ClientApp: clientApp,
94 Config: config,
95 SessStore: sessStore,
96 JwksUri: jwksUri,
97 ClientName: clientName,
98 ClientUri: clientUri,
99 Posthog: ph,
100 Db: db,
101 Enforcer: enforcer,
102 IdResolver: res,
103 Logger: logger,
104 }, nil
105}
106
107func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
108 userSession, err := o.SessStore.Get(r, SessionName)
109 if err != nil {
110 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err)
111 }
112
113 userSession.Values[SessionDid] = sessData.AccountDID.String()
114 userSession.Values[SessionPds] = sessData.HostURL
115 userSession.Values[SessionId] = sessData.SessionID
116 userSession.Values[SessionAuthenticated] = true
117
118 if err := userSession.Save(r, w); err != nil {
119 return err
120 }
121
122 handle := ""
123 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
124 if err == nil && resolved.Handle.String() != "" {
125 handle = resolved.Handle.String()
126 }
127
128 registry := o.GetAccounts(r)
129 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
130 return err
131 }
132 return o.SaveAccounts(w, r, registry)
133}
134
135func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
136 userSession, err := o.SessStore.Get(r, SessionName)
137 if err != nil {
138 return nil, fmt.Errorf("error getting user session: %w", err)
139 }
140 if userSession.IsNew {
141 return nil, fmt.Errorf("no session available for user")
142 }
143
144 d := userSession.Values[SessionDid].(string)
145 sessDid, err := syntax.ParseDID(d)
146 if err != nil {
147 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
148 }
149
150 sessId := userSession.Values[SessionId].(string)
151
152 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
153 if err != nil {
154 return nil, fmt.Errorf("failed to resume session: %w", err)
155 }
156
157 return clientSess, nil
158}
159
160func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
161 userSession, err := o.SessStore.Get(r, SessionName)
162 if err != nil {
163 return fmt.Errorf("error getting user session: %w", err)
164 }
165 if userSession.IsNew {
166 return fmt.Errorf("no session available for user")
167 }
168
169 d := userSession.Values[SessionDid].(string)
170 sessDid, err := syntax.ParseDID(d)
171 if err != nil {
172 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
173 }
174
175 sessId := userSession.Values[SessionId].(string)
176
177 // delete the session
178 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
179 if err1 != nil {
180 err1 = fmt.Errorf("failed to logout: %w", err1)
181 }
182
183 // remove the cookie
184 userSession.Options.MaxAge = -1
185 err2 := o.SessStore.Save(r, w, userSession)
186 if err2 != nil {
187 err2 = fmt.Errorf("failed to save into session store: %w", err2)
188 }
189
190 return errors.Join(err1, err2)
191}
192
193func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
194 registry := o.GetAccounts(r)
195 account := registry.FindAccount(targetDid)
196 if account == nil {
197 return fmt.Errorf("account not found in registry: %s", targetDid)
198 }
199
200 did, err := syntax.ParseDID(targetDid)
201 if err != nil {
202 return fmt.Errorf("invalid DID: %w", err)
203 }
204
205 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId)
206 if err != nil {
207 registry.RemoveAccount(targetDid)
208 _ = o.SaveAccounts(w, r, registry)
209 return fmt.Errorf("session expired for account: %w", err)
210 }
211
212 userSession, err := o.SessStore.Get(r, SessionName)
213 if err != nil {
214 return err
215 }
216
217 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
218 userSession.Values[SessionPds] = sess.Data.HostURL
219 userSession.Values[SessionId] = sess.Data.SessionID
220 userSession.Values[SessionAuthenticated] = true
221
222 return userSession.Save(r, w)
223}
224
225func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
226 registry := o.GetAccounts(r)
227 account := registry.FindAccount(targetDid)
228 if account == nil {
229 return nil
230 }
231
232 did, err := syntax.ParseDID(targetDid)
233 if err == nil {
234 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
235 }
236
237 registry.RemoveAccount(targetDid)
238 return o.SaveAccounts(w, r, registry)
239}
240
241type User struct {
242 Did string
243 Pds string
244}
245
246func (o *OAuth) GetUser(r *http.Request) *User {
247 sess, err := o.ResumeSession(r)
248 if err != nil {
249 return nil
250 }
251
252 return &User{
253 Did: sess.Data.AccountDID.String(),
254 Pds: sess.Data.HostURL,
255 }
256}
257
258func (o *OAuth) GetDid(r *http.Request) string {
259 if u := o.GetMultiAccountUser(r); u != nil {
260 return u.Did()
261 }
262
263 return ""
264}
265
266func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
267 session, err := o.ResumeSession(r)
268 if err != nil {
269 return nil, fmt.Errorf("error getting session: %w", err)
270 }
271 return session.APIClient(), nil
272}
273
274// this is a higher level abstraction on ServerGetServiceAuth
275type ServiceClientOpts struct {
276 service string
277 exp int64
278 lxm string
279 dev bool
280 timeout time.Duration
281}
282
283type ServiceClientOpt func(*ServiceClientOpts)
284
285func DefaultServiceClientOpts() ServiceClientOpts {
286 return ServiceClientOpts{
287 timeout: time.Second * 5,
288 }
289}
290
291func WithService(service string) ServiceClientOpt {
292 return func(s *ServiceClientOpts) {
293 s.service = service
294 }
295}
296
297// Specify the Duration in seconds for the expiry of this token
298//
299// The time of expiry is calculated as time.Now().Unix() + exp
300func WithExp(exp int64) ServiceClientOpt {
301 return func(s *ServiceClientOpts) {
302 s.exp = time.Now().Unix() + exp
303 }
304}
305
306func WithLxm(lxm string) ServiceClientOpt {
307 return func(s *ServiceClientOpts) {
308 s.lxm = lxm
309 }
310}
311
312func WithDev(dev bool) ServiceClientOpt {
313 return func(s *ServiceClientOpts) {
314 s.dev = dev
315 }
316}
317
318func WithTimeout(timeout time.Duration) ServiceClientOpt {
319 return func(s *ServiceClientOpts) {
320 s.timeout = timeout
321 }
322}
323
324func (s *ServiceClientOpts) Audience() string {
325 // did:web spec requires colons to be encoded as %3A
326 encoded := strings.ReplaceAll(s.service, ":", "%3A")
327 return fmt.Sprintf("did:web:%s", encoded)
328}
329
330func (s *ServiceClientOpts) Host() string {
331 scheme := "https://"
332 if s.dev {
333 scheme = "http://"
334 }
335
336 return scheme + s.service
337}
338
339func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
340 opts := DefaultServiceClientOpts()
341 for _, o := range os {
342 o(&opts)
343 }
344
345 client, err := o.AuthorizedClient(r)
346 if err != nil {
347 return nil, err
348 }
349
350 // force expiry to atleast 60 seconds in the future
351 sixty := time.Now().Unix() + 60
352 if opts.exp < sixty {
353 opts.exp = sixty
354 }
355
356 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
357 if err != nil {
358 return nil, err
359 }
360
361 return &xrpc.Client{
362 Auth: &xrpc.AuthInfo{
363 AccessJwt: resp.Token,
364 },
365 Host: opts.Host(),
366 Client: &http.Client{
367 Timeout: opts.timeout,
368 },
369 }, nil
370}
371
372func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) {
373 parsedDid, err := syntax.ParseDID(did)
374 if err != nil {
375 return "", fmt.Errorf("invalid DID: %w", err)
376 }
377
378 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier())
379 if err != nil {
380 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err)
381 }
382
383 host := ident.PDSEndpoint()
384 if host == "" {
385 return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
386 }
387
388 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host)
389 if err != nil {
390 return "", fmt.Errorf("resolving auth server: %w", err)
391 }
392
393 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL)
394 if err != nil {
395 return "", fmt.Errorf("fetching auth server metadata: %w", err)
396 }
397
398 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes))
399 scopes = append(scopes, TangledScopes...)
400 scopes = append(scopes, extraScopes...)
401
402 loginHint := did
403 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() {
404 loginHint = ident.Handle.String()
405 }
406
407 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint)
408 if err != nil {
409 return "", fmt.Errorf("auth request failed: %w", err)
410 }
411
412 info.AccountDID = &parsedDid
413 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
414
415 if err := o.SetAuthReturn(w, r, returnURL, false); err != nil {
416 return "", fmt.Errorf("failed to set auth return: %w", err)
417 }
418
419 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s",
420 authserverMeta.AuthorizationEndpoint,
421 url.QueryEscape(o.ClientApp.Config.ClientID),
422 url.QueryEscape(info.RequestURI),
423 )
424
425 return redirectURL, nil
426}