A vibe coded tangled fork which supports pijul.
1package oauth
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/url"
10 "sync"
11 "time"
12
13 comatproto "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/auth/oauth"
15 atpclient "github.com/bluesky-social/indigo/atproto/client"
16 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
17 "github.com/bluesky-social/indigo/atproto/syntax"
18 xrpc "github.com/bluesky-social/indigo/xrpc"
19 "github.com/gorilla/sessions"
20 "github.com/posthog/posthog-go"
21 "tangled.org/core/appview/config"
22 "tangled.org/core/appview/db"
23 "tangled.org/core/idresolver"
24 "tangled.org/core/rbac"
25)
26
27type OAuth struct {
28 ClientApp *oauth.ClientApp
29 SessStore *sessions.CookieStore
30 Config *config.Config
31 JwksUri string
32 ClientName string
33 ClientUri string
34 Posthog posthog.Client
35 Db *db.DB
36 Enforcer *rbac.Enforcer
37 IdResolver *idresolver.Resolver
38 Logger *slog.Logger
39
40 appPasswordSession *AppPasswordSession
41 appPasswordSessionMu sync.Mutex
42}
43
44func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
45 var oauthConfig oauth.ClientConfig
46 var clientUri string
47 if config.Core.Dev {
48 clientUri = "http://127.0.0.1:3000"
49 callbackUri := clientUri + "/oauth/callback"
50 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes)
51 } else {
52 clientUri = "https://" + config.Core.AppviewHost
53 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
54 callbackUri := clientUri + "/oauth/callback"
55 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes)
56 }
57
58 // configure client secret
59 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
60 if err != nil {
61 return nil, err
62 }
63 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
64 return nil, err
65 }
66
67 jwksUri := clientUri + "/oauth/jwks.json"
68
69 authStore, err := NewRedisStore(&RedisStoreConfig{
70 RedisURL: config.Redis.ToURL(),
71 SessionExpiryDuration: time.Hour * 24 * 90,
72 SessionInactivityDuration: time.Hour * 24 * 14,
73 AuthRequestExpiryDuration: time.Minute * 30,
74 })
75 if err != nil {
76 return nil, err
77 }
78
79 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
80
81 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
82 clientApp.Dir = res.Directory()
83 // allow non-public transports in dev mode
84 if config.Core.Dev {
85 clientApp.Resolver.Client.Transport = http.DefaultTransport
86 }
87
88 clientName := config.Core.AppviewName
89
90 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
91 return &OAuth{
92 ClientApp: clientApp,
93 Config: config,
94 SessStore: sessStore,
95 JwksUri: jwksUri,
96 ClientName: clientName,
97 ClientUri: clientUri,
98 Posthog: ph,
99 Db: db,
100 Enforcer: enforcer,
101 IdResolver: res,
102 Logger: logger,
103 }, nil
104}
105
106func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
107 userSession, err := o.SessStore.Get(r, SessionName)
108 if err != nil {
109 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err)
110 }
111
112 userSession.Values[SessionDid] = sessData.AccountDID.String()
113 userSession.Values[SessionPds] = sessData.HostURL
114 userSession.Values[SessionId] = sessData.SessionID
115 userSession.Values[SessionAuthenticated] = true
116
117 if err := userSession.Save(r, w); err != nil {
118 return err
119 }
120
121 handle := ""
122 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
123 if err == nil && resolved.Handle.String() != "" {
124 handle = resolved.Handle.String()
125 }
126
127 registry := o.GetAccounts(r)
128 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
129 return err
130 }
131 return o.SaveAccounts(w, r, registry)
132}
133
134func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
135 userSession, err := o.SessStore.Get(r, SessionName)
136 if err != nil {
137 return nil, fmt.Errorf("error getting user session: %w", err)
138 }
139 if userSession.IsNew {
140 return nil, fmt.Errorf("no session available for user")
141 }
142
143 d := userSession.Values[SessionDid].(string)
144 sessDid, err := syntax.ParseDID(d)
145 if err != nil {
146 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
147 }
148
149 sessId := userSession.Values[SessionId].(string)
150
151 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
152 if err != nil {
153 return nil, fmt.Errorf("failed to resume session: %w", err)
154 }
155
156 return clientSess, nil
157}
158
159func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
160 userSession, err := o.SessStore.Get(r, SessionName)
161 if err != nil {
162 return fmt.Errorf("error getting user session: %w", err)
163 }
164 if userSession.IsNew {
165 return fmt.Errorf("no session available for user")
166 }
167
168 d := userSession.Values[SessionDid].(string)
169 sessDid, err := syntax.ParseDID(d)
170 if err != nil {
171 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
172 }
173
174 sessId := userSession.Values[SessionId].(string)
175
176 // delete the session
177 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
178 if err1 != nil {
179 err1 = fmt.Errorf("failed to logout: %w", err1)
180 }
181
182 // remove the cookie
183 userSession.Options.MaxAge = -1
184 err2 := o.SessStore.Save(r, w, userSession)
185 if err2 != nil {
186 err2 = fmt.Errorf("failed to save into session store: %w", err2)
187 }
188
189 return errors.Join(err1, err2)
190}
191
192func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
193 registry := o.GetAccounts(r)
194 account := registry.FindAccount(targetDid)
195 if account == nil {
196 return fmt.Errorf("account not found in registry: %s", targetDid)
197 }
198
199 did, err := syntax.ParseDID(targetDid)
200 if err != nil {
201 return fmt.Errorf("invalid DID: %w", err)
202 }
203
204 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId)
205 if err != nil {
206 registry.RemoveAccount(targetDid)
207 _ = o.SaveAccounts(w, r, registry)
208 return fmt.Errorf("session expired for account: %w", err)
209 }
210
211 userSession, err := o.SessStore.Get(r, SessionName)
212 if err != nil {
213 return err
214 }
215
216 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
217 userSession.Values[SessionPds] = sess.Data.HostURL
218 userSession.Values[SessionId] = sess.Data.SessionID
219 userSession.Values[SessionAuthenticated] = true
220
221 return userSession.Save(r, w)
222}
223
224func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
225 registry := o.GetAccounts(r)
226 account := registry.FindAccount(targetDid)
227 if account == nil {
228 return nil
229 }
230
231 did, err := syntax.ParseDID(targetDid)
232 if err == nil {
233 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
234 }
235
236 registry.RemoveAccount(targetDid)
237 return o.SaveAccounts(w, r, registry)
238}
239
240type User struct {
241 Did string
242 Pds string
243}
244
245func (o *OAuth) GetUser(r *http.Request) *User {
246 sess, err := o.ResumeSession(r)
247 if err != nil {
248 return nil
249 }
250
251 return &User{
252 Did: sess.Data.AccountDID.String(),
253 Pds: sess.Data.HostURL,
254 }
255}
256
257func (o *OAuth) GetDid(r *http.Request) string {
258 if u := o.GetMultiAccountUser(r); u != nil {
259 return u.Did()
260 }
261
262 return ""
263}
264
265func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
266 session, err := o.ResumeSession(r)
267 if err != nil {
268 return nil, fmt.Errorf("error getting session: %w", err)
269 }
270 return session.APIClient(), nil
271}
272
273// this is a higher level abstraction on ServerGetServiceAuth
274type ServiceClientOpts struct {
275 service string
276 exp int64
277 lxm string
278 dev bool
279 timeout time.Duration
280}
281
282type ServiceClientOpt func(*ServiceClientOpts)
283
284func DefaultServiceClientOpts() ServiceClientOpts {
285 return ServiceClientOpts{
286 timeout: time.Second * 5,
287 }
288}
289
290func WithService(service string) ServiceClientOpt {
291 return func(s *ServiceClientOpts) {
292 s.service = service
293 }
294}
295
296// Specify the Duration in seconds for the expiry of this token
297//
298// The time of expiry is calculated as time.Now().Unix() + exp
299func WithExp(exp int64) ServiceClientOpt {
300 return func(s *ServiceClientOpts) {
301 s.exp = time.Now().Unix() + exp
302 }
303}
304
305func WithLxm(lxm string) ServiceClientOpt {
306 return func(s *ServiceClientOpts) {
307 s.lxm = lxm
308 }
309}
310
311func WithDev(dev bool) ServiceClientOpt {
312 return func(s *ServiceClientOpts) {
313 s.dev = dev
314 }
315}
316
317func WithTimeout(timeout time.Duration) ServiceClientOpt {
318 return func(s *ServiceClientOpts) {
319 s.timeout = timeout
320 }
321}
322
323func (s *ServiceClientOpts) Audience() string {
324 return fmt.Sprintf("did:web:%s", s.service)
325}
326
327func (s *ServiceClientOpts) Host() string {
328 scheme := "https://"
329 if s.dev {
330 scheme = "http://"
331 }
332
333 return scheme + s.service
334}
335
336func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
337 opts := DefaultServiceClientOpts()
338 for _, o := range os {
339 o(&opts)
340 }
341
342 client, err := o.AuthorizedClient(r)
343 if err != nil {
344 return nil, err
345 }
346
347 // force expiry to atleast 60 seconds in the future
348 sixty := time.Now().Unix() + 60
349 if opts.exp < sixty {
350 opts.exp = sixty
351 }
352
353 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
354 if err != nil {
355 return nil, err
356 }
357
358 return &xrpc.Client{
359 Auth: &xrpc.AuthInfo{
360 AccessJwt: resp.Token,
361 },
362 Host: opts.Host(),
363 Client: &http.Client{
364 Timeout: opts.timeout,
365 },
366 }, nil
367}
368
369func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) {
370 parsedDid, err := syntax.ParseDID(did)
371 if err != nil {
372 return "", fmt.Errorf("invalid DID: %w", err)
373 }
374
375 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier())
376 if err != nil {
377 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err)
378 }
379
380 host := ident.PDSEndpoint()
381 if host == "" {
382 return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
383 }
384
385 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host)
386 if err != nil {
387 return "", fmt.Errorf("resolving auth server: %w", err)
388 }
389
390 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL)
391 if err != nil {
392 return "", fmt.Errorf("fetching auth server metadata: %w", err)
393 }
394
395 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes))
396 scopes = append(scopes, TangledScopes...)
397 scopes = append(scopes, extraScopes...)
398
399 loginHint := did
400 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() {
401 loginHint = ident.Handle.String()
402 }
403
404 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint)
405 if err != nil {
406 return "", fmt.Errorf("auth request failed: %w", err)
407 }
408
409 info.AccountDID = &parsedDid
410 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
411
412 if err := o.SetAuthReturn(w, r, returnURL, false); err != nil {
413 return "", fmt.Errorf("failed to set auth return: %w", err)
414 }
415
416 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s",
417 authserverMeta.AuthorizationEndpoint,
418 url.QueryEscape(o.ClientApp.Config.ClientID),
419 url.QueryEscape(info.RequestURI),
420 )
421
422 return redirectURL, nil
423}