A vibe coded tangled fork which supports pijul.
at master 489 lines 14 kB view raw
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "log/slog" 10 "net/http" 11 "slices" 12 "strings" 13 "time" 14 15 comatproto "github.com/bluesky-social/indigo/api/atproto" 16 "github.com/bluesky-social/indigo/atproto/auth/oauth" 17 lexutil "github.com/bluesky-social/indigo/lex/util" 18 xrpc "github.com/bluesky-social/indigo/xrpc" 19 "github.com/go-chi/chi/v5" 20 "github.com/posthog/posthog-go" 21 "tangled.org/core/api/tangled" 22 "tangled.org/core/appview/db" 23 "tangled.org/core/appview/models" 24 "tangled.org/core/consts" 25 "tangled.org/core/idresolver" 26 "tangled.org/core/orm" 27 "tangled.org/core/tid" 28) 29 30func (o *OAuth) Router() http.Handler { 31 r := chi.NewRouter() 32 33 r.Get("/oauth/client-metadata.json", o.clientMetadata) 34 r.Get("/oauth/jwks.json", o.jwks) 35 r.Get("/oauth/callback", o.callback) 36 return r 37} 38 39func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) { 40 doc := o.ClientApp.Config.ClientMetadata() 41 doc.JWKSURI = &o.JwksUri 42 doc.ClientName = &o.ClientName 43 doc.ClientURI = &o.ClientUri 44 doc.Scope = doc.Scope + " identity:handle" 45 46 w.Header().Set("Content-Type", "application/json") 47 if err := json.NewEncoder(w).Encode(doc); err != nil { 48 http.Error(w, err.Error(), http.StatusInternalServerError) 49 return 50 } 51} 52 53func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) { 54 w.Header().Set("Content-Type", "application/json") 55 body := o.ClientApp.Config.PublicJWKS() 56 if err := json.NewEncoder(w).Encode(body); err != nil { 57 http.Error(w, err.Error(), http.StatusInternalServerError) 58 return 59 } 60} 61 62func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { 63 ctx := r.Context() 64 l := o.Logger.With("query", r.URL.Query()) 65 66 authReturn := o.GetAuthReturn(r) 67 _ = o.ClearAuthReturn(w, r) 68 69 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) 70 if err != nil { 71 var callbackErr *oauth.AuthRequestCallbackError 72 if errors.As(err, &callbackErr) { 73 l.Debug("callback error", "err", callbackErr) 74 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound) 75 return 76 } 77 l.Error("failed to process callback", "err", err) 78 http.Redirect(w, r, "/login?error=oauth", http.StatusFound) 79 return 80 } 81 82 if err := o.SaveSession(w, r, sessData); err != nil { 83 l.Error("failed to save session", "data", sessData, "err", err) 84 errorCode := "session" 85 if errors.Is(err, ErrMaxAccountsReached) { 86 errorCode = "max_accounts" 87 } 88 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound) 89 return 90 } 91 92 o.Logger.Debug("session saved successfully") 93 94 go o.addToDefaultKnot(sessData.AccountDID.String()) 95 go o.addToDefaultSpindle(sessData.AccountDID.String()) 96 go o.ensureTangledProfile(sessData) 97 go o.autoClaimTnglShDomain(sessData.AccountDID.String()) 98 99 if !o.Config.Core.Dev { 100 err = o.Posthog.Enqueue(posthog.Capture{ 101 DistinctId: sessData.AccountDID.String(), 102 Event: "signin", 103 }) 104 if err != nil { 105 o.Logger.Error("failed to enqueue posthog event", "err", err) 106 } 107 } 108 109 redirectURL := "/" 110 if authReturn.ReturnURL != "" { 111 redirectURL = authReturn.ReturnURL 112 } 113 114 if o.isAccountDeactivated(sessData) { 115 redirectURL = "/settings/profile" 116 } 117 118 http.Redirect(w, r, redirectURL, http.StatusFound) 119} 120 121func (o *OAuth) isAccountDeactivated(sessData *oauth.ClientSessionData) bool { 122 pdsClient := &xrpc.Client{ 123 Host: sessData.HostURL, 124 Client: &http.Client{Timeout: 5 * time.Second}, 125 } 126 127 _, err := comatproto.RepoDescribeRepo( 128 context.Background(), 129 pdsClient, 130 sessData.AccountDID.String(), 131 ) 132 if err == nil { 133 return false 134 } 135 136 var xrpcErr *xrpc.Error 137 var xrpcBody *xrpc.XRPCError 138 return errors.As(err, &xrpcErr) && 139 errors.As(xrpcErr.Wrapped, &xrpcBody) && 140 xrpcBody.ErrStr == "RepoDeactivated" 141} 142 143func (o *OAuth) addToDefaultSpindle(did string) { 144 l := o.Logger.With("subject", did) 145 146 // use the tangled.sh app password to get an accessJwt 147 // and create an sh.tangled.spindle.member record with that 148 spindleMembers, err := db.GetSpindleMembers( 149 o.Db, 150 orm.FilterEq("instance", "spindle.tangled.sh"), 151 orm.FilterEq("subject", did), 152 ) 153 if err != nil { 154 l.Error("failed to get spindle members", "err", err) 155 return 156 } 157 158 if len(spindleMembers) != 0 { 159 l.Warn("already a member of the default spindle") 160 return 161 } 162 163 l.Debug("adding to default spindle") 164 session, err := o.getAppPasswordSession() 165 if err != nil { 166 l.Error("failed to create session", "err", err) 167 return 168 } 169 170 record := tangled.SpindleMember{ 171 LexiconTypeID: tangled.SpindleMemberNSID, 172 Subject: did, 173 Instance: consts.DefaultSpindle, 174 CreatedAt: time.Now().Format(time.RFC3339), 175 } 176 177 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 178 l.Error("failed to add to default spindle", "err", err) 179 return 180 } 181 182 l.Debug("successfully added to default spindle", "did", did) 183} 184 185func (o *OAuth) addToDefaultKnot(did string) { 186 l := o.Logger.With("subject", did) 187 188 // use the tangled.sh app password to get an accessJwt 189 // and create an sh.tangled.spindle.member record with that 190 191 allKnots, err := o.Enforcer.GetKnotsForUser(did) 192 if err != nil { 193 l.Error("failed to get knot members for did", "err", err) 194 return 195 } 196 197 if slices.Contains(allKnots, consts.DefaultKnot) { 198 l.Warn("already a member of the default knot") 199 return 200 } 201 202 l.Debug("adding to default knot") 203 session, err := o.getAppPasswordSession() 204 if err != nil { 205 l.Error("failed to create session", "err", err) 206 return 207 } 208 209 record := tangled.KnotMember{ 210 LexiconTypeID: tangled.KnotMemberNSID, 211 Subject: did, 212 Domain: consts.DefaultKnot, 213 CreatedAt: time.Now().Format(time.RFC3339), 214 } 215 216 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 217 l.Error("failed to add to default knot", "err", err) 218 return 219 } 220 221 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil { 222 l.Error("failed to set up enforcer rules", "err", err) 223 return 224 } 225 226 l.Debug("successfully added to default knot") 227} 228 229func (o *OAuth) ensureTangledProfile(sessData *oauth.ClientSessionData) { 230 ctx := context.Background() 231 did := sessData.AccountDID.String() 232 l := o.Logger.With("did", did) 233 234 profile, _ := db.GetProfile(o.Db, did) 235 if profile != nil { 236 l.Debug("profile already exists in DB") 237 return 238 } 239 240 l.Debug("creating empty Tangled profile") 241 242 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID) 243 if err != nil { 244 l.Error("failed to resume session for profile creation", "err", err) 245 return 246 } 247 client := sess.APIClient() 248 249 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{ 250 Collection: tangled.ActorProfileNSID, 251 Repo: did, 252 Rkey: "self", 253 Record: &lexutil.LexiconTypeDecoder{Val: &tangled.ActorProfile{}}, 254 }) 255 256 if err != nil { 257 l.Error("failed to create empty profile on PDS", "err", err) 258 return 259 } 260 261 tx, err := o.Db.BeginTx(ctx, nil) 262 if err != nil { 263 l.Error("failed to start transaction", "err", err) 264 return 265 } 266 267 emptyProfile := &models.Profile{Did: did} 268 if err := db.UpsertProfile(tx, emptyProfile); err != nil { 269 l.Error("failed to create empty profile in DB", "err", err) 270 return 271 } 272 273 l.Debug("successfully created empty Tangled profile on PDS and DB") 274} 275 276// create a AppPasswordSession using apppasswords 277type AppPasswordSession struct { 278 AccessJwt string `json:"accessJwt"` 279 RefreshJwt string `json:"refreshJwt"` 280 PdsEndpoint string 281 Did string 282 Logger *slog.Logger 283 ExpiresAt time.Time 284} 285 286func CreateAppPasswordSession(res *idresolver.Resolver, appPassword, did string, logger *slog.Logger) (*AppPasswordSession, error) { 287 if appPassword == "" { 288 return nil, fmt.Errorf("no app password configured") 289 } 290 291 resolved, err := res.ResolveIdent(context.Background(), did) 292 if err != nil { 293 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 294 } 295 296 pdsEndpoint := resolved.PDSEndpoint() 297 if pdsEndpoint == "" { 298 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 299 } 300 301 sessionPayload := map[string]string{ 302 "identifier": did, 303 "password": appPassword, 304 } 305 sessionBytes, err := json.Marshal(sessionPayload) 306 if err != nil { 307 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 308 } 309 310 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 311 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 312 if err != nil { 313 return nil, fmt.Errorf("failed to create session request: %v", err) 314 } 315 sessionReq.Header.Set("Content-Type", "application/json") 316 317 logger.Debug("creating app password session", "url", sessionURL, "headers", sessionReq.Header) 318 319 client := &http.Client{Timeout: 30 * time.Second} 320 sessionResp, err := client.Do(sessionReq) 321 if err != nil { 322 return nil, fmt.Errorf("failed to create session: %v", err) 323 } 324 defer sessionResp.Body.Close() 325 326 if sessionResp.StatusCode != http.StatusOK { 327 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 328 } 329 330 var session AppPasswordSession 331 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 332 return nil, fmt.Errorf("failed to decode session response: %v", err) 333 } 334 335 session.PdsEndpoint = pdsEndpoint 336 session.Did = did 337 session.Logger = logger 338 session.ExpiresAt = time.Now().Add(115 * time.Minute) 339 340 return &session, nil 341} 342 343func (s *AppPasswordSession) refreshSession() error { 344 refreshURL := s.PdsEndpoint + "/xrpc/com.atproto.server.refreshSession" 345 req, err := http.NewRequestWithContext(context.Background(), "POST", refreshURL, nil) 346 if err != nil { 347 return fmt.Errorf("failed to create refresh request: %w", err) 348 } 349 350 req.Header.Set("Authorization", "Bearer "+s.RefreshJwt) 351 352 s.Logger.Debug("refreshing app password session", "url", refreshURL) 353 354 client := &http.Client{Timeout: 30 * time.Second} 355 resp, err := client.Do(req) 356 if err != nil { 357 return fmt.Errorf("failed to refresh session: %w", err) 358 } 359 defer resp.Body.Close() 360 361 if resp.StatusCode != http.StatusOK { 362 var errorResponse map[string]any 363 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil { 364 return fmt.Errorf("failed to refresh session: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err) 365 } 366 errorBytes, _ := json.Marshal(errorResponse) 367 return fmt.Errorf("failed to refresh session: HTTP %d, response: %s", resp.StatusCode, string(errorBytes)) 368 } 369 370 var refreshResponse struct { 371 AccessJwt string `json:"accessJwt"` 372 RefreshJwt string `json:"refreshJwt"` 373 } 374 if err := json.NewDecoder(resp.Body).Decode(&refreshResponse); err != nil { 375 return fmt.Errorf("failed to decode refresh response: %w", err) 376 } 377 378 s.AccessJwt = refreshResponse.AccessJwt 379 s.RefreshJwt = refreshResponse.RefreshJwt 380 // Set new expiry time with 5 minute buffer 381 s.ExpiresAt = time.Now().Add(115 * time.Minute) 382 383 s.Logger.Debug("successfully refreshed app password session") 384 return nil 385} 386 387func (s *AppPasswordSession) isValid() bool { 388 return time.Now().Before(s.ExpiresAt) 389} 390 391func (s *AppPasswordSession) putRecord(record any, collection string) error { 392 if !s.isValid() { 393 s.Logger.Debug("access token expired, refreshing session") 394 if err := s.refreshSession(); err != nil { 395 return fmt.Errorf("failed to refresh session: %w", err) 396 } 397 s.Logger.Debug("session refreshed") 398 } 399 400 recordBytes, err := json.Marshal(record) 401 if err != nil { 402 return fmt.Errorf("failed to marshal knot member record: %w", err) 403 } 404 405 payload := map[string]any{ 406 "repo": s.Did, 407 "collection": collection, 408 "rkey": tid.TID(), 409 "record": json.RawMessage(recordBytes), 410 } 411 412 payloadBytes, err := json.Marshal(payload) 413 if err != nil { 414 return fmt.Errorf("failed to marshal request payload: %w", err) 415 } 416 417 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 418 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 419 if err != nil { 420 return fmt.Errorf("failed to create HTTP request: %w", err) 421 } 422 423 req.Header.Set("Content-Type", "application/json") 424 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 425 426 s.Logger.Debug("putting record", "url", url, "collection", collection) 427 428 client := &http.Client{Timeout: 30 * time.Second} 429 resp, err := client.Do(req) 430 if err != nil { 431 return fmt.Errorf("failed to add user to default service: %w", err) 432 } 433 defer resp.Body.Close() 434 435 if resp.StatusCode != http.StatusOK { 436 var errorResponse map[string]any 437 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil { 438 return fmt.Errorf("failed to add user to default service: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err) 439 } 440 return fmt.Errorf("failed to add user to default service: HTTP %d, response: %v", resp.StatusCode, errorResponse) 441 } 442 443 return nil 444} 445 446// autoClaimTnglShDomain checks if the user has a .tngl.sh handle and, if so, 447// ensures their corresponding sites domain is claimed. This is idempotent — 448// ClaimDomain is a no-op if the claim already exists. 449func (o *OAuth) autoClaimTnglShDomain(did string) { 450 l := o.Logger.With("did", did) 451 452 pdsDomain := strings.TrimPrefix(o.Config.Pds.Host, "https://") 453 pdsDomain = strings.TrimPrefix(pdsDomain, "http://") 454 455 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did) 456 if err != nil { 457 l.Error("autoClaimTnglShDomain: failed to resolve ident", "err", err) 458 return 459 } 460 461 handle := resolved.Handle.String() 462 if !strings.HasSuffix(handle, "."+pdsDomain) { 463 return 464 } 465 466 if err := db.ClaimDomain(o.Db, did, handle); err != nil { 467 l.Warn("autoClaimTnglShDomain: failed to claim domain", "domain", handle, "err", err) 468 } else { 469 l.Info("autoClaimTnglShDomain: claimed domain", "domain", handle) 470 } 471} 472 473// getAppPasswordSession returns a cached AppPasswordSession, creating one if needed. 474func (o *OAuth) getAppPasswordSession() (*AppPasswordSession, error) { 475 o.appPasswordSessionMu.Lock() 476 defer o.appPasswordSessionMu.Unlock() 477 478 if o.appPasswordSession != nil { 479 return o.appPasswordSession, nil 480 } 481 482 session, err := CreateAppPasswordSession(o.IdResolver, o.Config.Core.AppPassword, consts.TangledDid, o.Logger) 483 if err != nil { 484 return nil, err 485 } 486 487 o.appPasswordSession = session 488 return session, nil 489}