A vibe coded tangled fork which supports pijul.
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 lexutil "github.com/bluesky-social/indigo/lex/util"
16 "github.com/ipfs/go-cid"
17 "tangled.org/core/appview/models"
18 "tangled.org/core/appview/pagination"
19 "tangled.org/core/orm"
20 "tangled.org/core/sets"
21)
22
23func comparePullSource(existing, new *models.PullSource) bool {
24 if existing == nil && new == nil {
25 return true
26 }
27 if existing == nil || new == nil {
28 return false
29 }
30 if existing.Branch != new.Branch {
31 return false
32 }
33 if existing.RepoAt == nil && new.RepoAt == nil {
34 return true
35 }
36 if existing.RepoAt == nil || new.RepoAt == nil {
37 return false
38 }
39 return *existing.RepoAt == *new.RepoAt
40}
41
42func compareSubmissions(existing, new []*models.PullSubmission) bool {
43 if len(existing) != len(new) {
44 return false
45 }
46 for i := range existing {
47 if existing[i].Blob.Ref.String() != new[i].Blob.Ref.String() {
48 return false
49 }
50 if existing[i].Blob.MimeType != new[i].Blob.MimeType {
51 return false
52 }
53 if existing[i].Blob.Size != new[i].Blob.Size {
54 return false
55 }
56 }
57 return true
58}
59
60func PutPull(tx *sql.Tx, pull *models.Pull) error {
61 // ensure sequence exists
62 _, err := tx.Exec(`
63 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
64 values (?, 1)
65 `, pull.RepoAt)
66 if err != nil {
67 return err
68 }
69
70 pulls, err := GetPulls(
71 tx,
72 orm.FilterEq("owner_did", pull.OwnerDid),
73 orm.FilterEq("rkey", pull.Rkey),
74 )
75 switch {
76 case err != nil:
77 return err
78 case len(pulls) == 0:
79 return createNewPull(tx, pull)
80 case len(pulls) != 1: // should be unreachable
81 return fmt.Errorf("invalid number of pulls returned: %d", len(pulls))
82 default:
83 existingPull := pulls[0]
84 if existingPull.State == models.PullMerged {
85 return nil
86 }
87
88 dependentOnEqual := (existingPull.DependentOn == nil && pull.DependentOn == nil) ||
89 (existingPull.DependentOn != nil && pull.DependentOn != nil && *existingPull.DependentOn == *pull.DependentOn)
90
91 pullSourceEqual := comparePullSource(existingPull.PullSource, pull.PullSource)
92 submissionsEqual := compareSubmissions(existingPull.Submissions, pull.Submissions)
93
94 if existingPull.Title == pull.Title &&
95 existingPull.Body == pull.Body &&
96 existingPull.TargetBranch == pull.TargetBranch &&
97 existingPull.RepoAt == pull.RepoAt &&
98 dependentOnEqual &&
99 pullSourceEqual &&
100 submissionsEqual {
101 return nil
102 }
103
104 isLonger := len(existingPull.Submissions) < len(pull.Submissions)
105 if isLonger {
106 isAppendOnly := compareSubmissions(existingPull.Submissions, pull.Submissions[:len(existingPull.Submissions)])
107 if !isAppendOnly {
108 return fmt.Errorf("the new pull does not treat submissions as append-only")
109 }
110 } else if !submissionsEqual {
111 return fmt.Errorf("the new pull does not treat submissions as append-only")
112 }
113
114 pull.ID = existingPull.ID
115 pull.PullId = existingPull.PullId
116 return updatePull(tx, pull, existingPull)
117 }
118}
119
120func createNewPull(tx *sql.Tx, pull *models.Pull) error {
121 _, err := tx.Exec(`
122 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
123 values (?, 1)
124 `, pull.RepoAt)
125 if err != nil {
126 return err
127 }
128
129 var nextId int
130 err = tx.QueryRow(`
131 update repo_pull_seqs
132 set next_pull_id = next_pull_id + 1
133 where repo_at = ?
134 returning next_pull_id - 1
135 `, pull.RepoAt).Scan(&nextId)
136 if err != nil {
137 return err
138 }
139
140 pull.PullId = nextId
141 pull.State = models.PullOpen
142
143 var sourceBranch, sourceRepoAt *string
144 if pull.PullSource != nil {
145 sourceBranch = &pull.PullSource.Branch
146 if pull.PullSource.RepoAt != nil {
147 x := pull.PullSource.RepoAt.String()
148 sourceRepoAt = &x
149 }
150 }
151
152 // var stackId, changeId, parentChangeId *string
153 // if pull.StackId != "" {
154 // stackId = &pull.StackId
155 // }
156 // if pull.ChangeId != "" {
157 // changeId = &pull.ChangeId
158 // }
159 // if pull.ParentChangeId != "" {
160 // parentChangeId = &pull.ParentChangeId
161 // }
162
163 result, err := tx.Exec(
164 `
165 insert into pulls (
166 repo_at,
167 owner_did,
168 pull_id,
169 title,
170 target_branch,
171 body,
172 rkey,
173 state,
174 dependent_on,
175 source_branch,
176 source_repo_at
177 )
178 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
179 pull.RepoAt,
180 pull.OwnerDid,
181 pull.PullId,
182 pull.Title,
183 pull.TargetBranch,
184 pull.Body,
185 pull.Rkey,
186 pull.State,
187 pull.DependentOn,
188 sourceBranch,
189 sourceRepoAt,
190 )
191 if err != nil {
192 return err
193 }
194
195 // Set the database primary key ID
196 id, err := result.LastInsertId()
197 if err != nil {
198 return err
199 }
200 pull.ID = int(id)
201
202 for i, s := range pull.Submissions {
203 _, err = tx.Exec(`
204 insert into pull_submissions (
205 pull_at,
206 round_number,
207 patch,
208 combined,
209 source_rev,
210 patch_blob_ref,
211 patch_blob_mime,
212 patch_blob_size
213 )
214 values (?, ?, ?, ?, ?, ?, ?, ?)
215 `,
216 pull.AtUri(),
217 i,
218 s.Patch,
219 s.Combined,
220 s.SourceRev,
221 s.Blob.Ref.String(),
222 s.Blob.MimeType,
223 s.Blob.Size,
224 )
225 if err != nil {
226 return err
227 }
228 }
229
230 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
231 return fmt.Errorf("put reference_links: %w", err)
232 }
233
234 return nil
235}
236
237func updatePull(tx *sql.Tx, pull *models.Pull, existingPull *models.Pull) error {
238 var sourceBranch, sourceRepoAt *string
239 if pull.PullSource != nil {
240 sourceBranch = &pull.PullSource.Branch
241 if pull.PullSource.RepoAt != nil {
242 x := pull.PullSource.RepoAt.String()
243 sourceRepoAt = &x
244 }
245 }
246
247 _, err := tx.Exec(`
248 update pulls set
249 title = ?,
250 body = ?,
251 target_branch = ?,
252 dependent_on = ?,
253 source_branch = ?,
254 source_repo_at = ?
255 where owner_did = ? and rkey = ?
256 `, pull.Title, pull.Body, pull.TargetBranch, pull.DependentOn, sourceBranch, sourceRepoAt, pull.OwnerDid, pull.Rkey)
257 if err != nil {
258 return err
259 }
260
261 // insert new submissions (append-only)
262 for i := len(existingPull.Submissions); i < len(pull.Submissions); i++ {
263 s := pull.Submissions[i]
264 _, err = tx.Exec(`
265 insert into pull_submissions (
266 pull_at,
267 round_number,
268 patch,
269 combined,
270 source_rev,
271 patch_blob_ref,
272 patch_blob_mime,
273 patch_blob_size
274 )
275 values (?, ?, ?, ?, ?, ?, ?, ?)
276 `,
277 pull.AtUri(),
278 i,
279 s.Patch,
280 s.Combined,
281 s.SourceRev,
282 s.Blob.Ref.String(),
283 s.Blob.MimeType,
284 s.Blob.Size,
285 )
286 if err != nil {
287 return err
288 }
289 }
290
291 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
292 return fmt.Errorf("put reference_links: %w", err)
293 }
294 return nil
295}
296
297// func NewPull(tx *sql.Tx, pull *models.Pull) error {
298// _, err := tx.Exec(`
299// insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
300// values (?, 1)
301// `, pull.RepoAt)
302// if err != nil {
303// return err
304// }
305//
306// var nextId int
307// err = tx.QueryRow(`
308// update repo_pull_seqs
309// set next_pull_id = next_pull_id + 1
310// where repo_at = ?
311// returning next_pull_id - 1
312// `, pull.RepoAt).Scan(&nextId)
313// if err != nil {
314// return err
315// }
316//
317// pull.PullId = nextId
318// pull.State = models.PullOpen
319//
320// var sourceBranch, sourceRepoAt *string
321// if pull.PullSource != nil {
322// sourceBranch = &pull.PullSource.Branch
323// if pull.PullSource.RepoAt != nil {
324// x := pull.PullSource.RepoAt.String()
325// sourceRepoAt = &x
326// }
327// }
328//
329// // var stackId, changeId, parentChangeId *string
330// // if pull.StackId != "" {
331// // stackId = &pull.StackId
332// // }
333// // if pull.ChangeId != "" {
334// // changeId = &pull.ChangeId
335// // }
336// // if pull.ParentChangeId != "" {
337// // parentChangeId = &pull.ParentChangeId
338// // }
339//
340// result, err := tx.Exec(
341// `
342// insert into pulls (
343// repo_at,
344// owner_did,
345// pull_id,
346// title,
347// target_branch,
348// body,
349// rkey,
350// state,
351// dependent_on,
352// source_branch,
353// source_repo_at
354// )
355// values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
356// pull.RepoAt,
357// pull.OwnerDid,
358// pull.PullId,
359// pull.Title,
360// pull.TargetBranch,
361// pull.Body,
362// pull.Rkey,
363// pull.State,
364// pull.DependentOn,
365// sourceBranch,
366// sourceRepoAt,
367// )
368// if err != nil {
369// return err
370// }
371//
372// // Set the database primary key ID
373// id, err := result.LastInsertId()
374// if err != nil {
375// return err
376// }
377// pull.ID = int(id)
378//
379// _, err = tx.Exec(`
380// insert into pull_submissions (
381// pull_at,
382// round_number,
383// patch,
384// combined,
385// source_rev,
386// patch_blob_ref,
387// patch_blob_mime,
388// patch_blob_size
389// )
390// values (?, ?, ?, ?, ?, ?, ?, ?)
391// `,
392// pull.AtUri(),
393// 0,
394// pull.Submissions[0].Patch,
395// pull.Submissions[0].Combined,
396// pull.Submissions[0].SourceRev,
397// pull.Submissions[0].Blob.Ref.String(),
398// pull.Submissions[0].Blob.MimeType,
399// pull.Submissions[0].Blob.Size,
400// )
401// if err != nil {
402// return err
403// }
404//
405// if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
406// return fmt.Errorf("put reference_links: %w", err)
407// }
408//
409// return nil
410// }
411
412func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
413 var pullId int
414 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
415 return pullId - 1, err
416}
417
418func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) {
419 pulls := make(map[syntax.ATURI]*models.Pull)
420
421 var conditions []string
422 var args []any
423 for _, filter := range filters {
424 conditions = append(conditions, filter.Condition())
425 args = append(args, filter.Arg()...)
426 }
427
428 whereClause := ""
429 if conditions != nil {
430 whereClause = " where " + strings.Join(conditions, " and ")
431 }
432 pageClause := ""
433 if page.Limit != 0 {
434 pageClause = fmt.Sprintf(
435 " limit %d offset %d ",
436 page.Limit,
437 page.Offset,
438 )
439 }
440
441 query := fmt.Sprintf(`
442 select
443 id,
444 owner_did,
445 repo_at,
446 pull_id,
447 created,
448 title,
449 state,
450 target_branch,
451 body,
452 rkey,
453 source_branch,
454 source_repo_at,
455 dependent_on
456 from
457 pulls
458 %s
459 order by
460 created desc
461 %s
462 `, whereClause, pageClause)
463
464 rows, err := e.Query(query, args...)
465 if err != nil {
466 return nil, err
467 }
468 defer rows.Close()
469
470 for rows.Next() {
471 var pull models.Pull
472 var createdAt string
473 var sourceBranch, sourceRepoAt, dependentOn sql.NullString
474 err := rows.Scan(
475 &pull.ID,
476 &pull.OwnerDid,
477 &pull.RepoAt,
478 &pull.PullId,
479 &createdAt,
480 &pull.Title,
481 &pull.State,
482 &pull.TargetBranch,
483 &pull.Body,
484 &pull.Rkey,
485 &sourceBranch,
486 &sourceRepoAt,
487 &dependentOn,
488 )
489 if err != nil {
490 return nil, err
491 }
492
493 createdTime, err := time.Parse(time.RFC3339, createdAt)
494 if err != nil {
495 return nil, err
496 }
497 pull.Created = createdTime
498
499 if sourceBranch.Valid {
500 pull.PullSource = &models.PullSource{
501 Branch: sourceBranch.String,
502 }
503 if sourceRepoAt.Valid {
504 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
505 if err != nil {
506 return nil, err
507 }
508 pull.PullSource.RepoAt = &sourceRepoAtParsed
509 }
510 }
511
512 if dependentOn.Valid {
513 x := syntax.ATURI(dependentOn.String)
514 pull.DependentOn = &x
515 }
516
517 pulls[pull.AtUri()] = &pull
518 }
519
520 var pullAts []syntax.ATURI
521 for _, p := range pulls {
522 pullAts = append(pullAts, p.AtUri())
523 }
524 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
525 if err != nil {
526 return nil, fmt.Errorf("failed to get submissions: %w", err)
527 }
528
529 for pullAt, submissions := range submissionsMap {
530 if p, ok := pulls[pullAt]; ok {
531 p.Submissions = submissions
532 }
533 }
534
535 // collect allLabels for each issue
536 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
537 if err != nil {
538 return nil, fmt.Errorf("failed to query labels: %w", err)
539 }
540 for pullAt, labels := range allLabels {
541 if p, ok := pulls[pullAt]; ok {
542 p.Labels = labels
543 }
544 }
545
546 // collect pull source for all pulls that need it
547 var sourceAts []syntax.ATURI
548 for _, p := range pulls {
549 if p.PullSource != nil && p.PullSource.RepoAt != nil {
550 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
551 }
552 }
553 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts))
554 if err != nil && !errors.Is(err, sql.ErrNoRows) {
555 return nil, fmt.Errorf("failed to get source repos: %w", err)
556 }
557 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
558 for _, r := range sourceRepos {
559 sourceRepoMap[r.RepoAt()] = &r
560 }
561 for _, p := range pulls {
562 if p.PullSource != nil && p.PullSource.RepoAt != nil {
563 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
564 p.PullSource.Repo = sourceRepo
565 }
566 }
567 }
568
569 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
570 if err != nil {
571 return nil, fmt.Errorf("failed to query reference_links: %w", err)
572 }
573 for pullAt, references := range allReferences {
574 if pull, ok := pulls[pullAt]; ok {
575 pull.References = references
576 }
577 }
578
579 orderedByPullId := []*models.Pull{}
580 for _, p := range pulls {
581 orderedByPullId = append(orderedByPullId, p)
582 }
583 sort.Slice(orderedByPullId, func(i, j int) bool {
584 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
585 })
586
587 return orderedByPullId, nil
588}
589
590func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
591 return GetPullsPaginated(e, pagination.Page{}, filters...)
592}
593
594func GetPull(e Execer, filters ...orm.Filter) (*models.Pull, error) {
595 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, filters...)
596 if err != nil {
597 return nil, err
598 }
599 if len(pulls) == 0 {
600 return nil, sql.ErrNoRows
601 }
602
603 return pulls[0], nil
604}
605
606// mapping from pull -> pull submissions
607func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
608 var conditions []string
609 var args []any
610 for _, filter := range filters {
611 conditions = append(conditions, filter.Condition())
612 args = append(args, filter.Arg()...)
613 }
614
615 whereClause := ""
616 if conditions != nil {
617 whereClause = " where " + strings.Join(conditions, " and ")
618 }
619
620 query := fmt.Sprintf(`
621 select
622 id,
623 pull_at,
624 round_number,
625 patch,
626 combined,
627 created,
628 source_rev,
629 patch_blob_ref,
630 patch_blob_mime,
631 patch_blob_size
632 from
633 pull_submissions
634 %s
635 order by
636 round_number asc
637 `, whereClause)
638
639 rows, err := e.Query(query, args...)
640 if err != nil {
641 return nil, err
642 }
643 defer rows.Close()
644
645 submissionMap := make(map[int]*models.PullSubmission)
646
647 for rows.Next() {
648 var submission models.PullSubmission
649 var submissionCreatedStr string
650 var submissionSourceRev, submissionCombined sql.Null[string]
651 var patchBlobRef, patchBlobMime sql.Null[string]
652 var patchBlobSize sql.Null[int64]
653 err := rows.Scan(
654 &submission.ID,
655 &submission.PullAt,
656 &submission.RoundNumber,
657 &submission.Patch,
658 &submissionCombined,
659 &submissionCreatedStr,
660 &submissionSourceRev,
661 &patchBlobRef,
662 &patchBlobMime,
663 &patchBlobSize,
664 )
665 if err != nil {
666 return nil, err
667 }
668
669 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
670 submission.Created = t
671 }
672
673 if submissionSourceRev.Valid {
674 submission.SourceRev = submissionSourceRev.V
675 }
676
677 if submissionCombined.Valid {
678 submission.Combined = submissionCombined.V
679 }
680
681 if patchBlobRef.Valid {
682 submission.Blob.Ref = lexutil.LexLink(cid.MustParse(patchBlobRef.V))
683 }
684
685 if patchBlobMime.Valid {
686 submission.Blob.MimeType = patchBlobMime.V
687 }
688
689 if patchBlobSize.Valid {
690 submission.Blob.Size = patchBlobSize.V
691 }
692
693 submissionMap[submission.ID] = &submission
694 }
695
696 if err := rows.Err(); err != nil {
697 return nil, err
698 }
699
700 // Get comments for all submissions using GetPullComments
701 submissionIds := slices.Collect(maps.Keys(submissionMap))
702 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds))
703 if err != nil {
704 return nil, fmt.Errorf("failed to get pull comments: %w", err)
705 }
706 for _, comment := range comments {
707 if submission, ok := submissionMap[comment.SubmissionId]; ok {
708 submission.Comments = append(submission.Comments, comment)
709 }
710 }
711
712 // group the submissions by pull_at
713 m := make(map[syntax.ATURI][]*models.PullSubmission)
714 for _, s := range submissionMap {
715 m[s.PullAt] = append(m[s.PullAt], s)
716 }
717
718 // sort each one by round number
719 for _, s := range m {
720 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
721 return cmp.Compare(a.RoundNumber, b.RoundNumber)
722 })
723 }
724
725 return m, nil
726}
727
728func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) {
729 var conditions []string
730 var args []any
731 for _, filter := range filters {
732 conditions = append(conditions, filter.Condition())
733 args = append(args, filter.Arg()...)
734 }
735
736 whereClause := ""
737 if conditions != nil {
738 whereClause = " where " + strings.Join(conditions, " and ")
739 }
740
741 query := fmt.Sprintf(`
742 select
743 id,
744 pull_id,
745 submission_id,
746 repo_at,
747 owner_did,
748 comment_at,
749 body,
750 created
751 from
752 pull_comments
753 %s
754 order by
755 created asc
756 `, whereClause)
757
758 rows, err := e.Query(query, args...)
759 if err != nil {
760 return nil, err
761 }
762 defer rows.Close()
763
764 commentMap := make(map[string]*models.PullComment)
765 for rows.Next() {
766 var comment models.PullComment
767 var createdAt string
768 err := rows.Scan(
769 &comment.ID,
770 &comment.PullId,
771 &comment.SubmissionId,
772 &comment.RepoAt,
773 &comment.OwnerDid,
774 &comment.CommentAt,
775 &comment.Body,
776 &createdAt,
777 )
778 if err != nil {
779 return nil, err
780 }
781
782 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
783 comment.Created = t
784 }
785
786 atUri := comment.AtUri().String()
787 commentMap[atUri] = &comment
788 }
789
790 if err := rows.Err(); err != nil {
791 return nil, err
792 }
793
794 // collect references for each comments
795 commentAts := slices.Collect(maps.Keys(commentMap))
796 allReferencs, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts))
797 if err != nil {
798 return nil, fmt.Errorf("failed to query reference_links: %w", err)
799 }
800 for commentAt, references := range allReferencs {
801 if comment, ok := commentMap[commentAt.String()]; ok {
802 comment.References = references
803 }
804 }
805
806 var comments []models.PullComment
807 for _, c := range commentMap {
808 comments = append(comments, *c)
809 }
810
811 sort.Slice(comments, func(i, j int) bool {
812 return comments[i].Created.Before(comments[j].Created)
813 })
814
815 return comments, nil
816}
817
818// timeframe here is directly passed into the sql query filter, and any
819// timeframe in the past should be negative; e.g.: "-3 months"
820func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
821 var pulls []models.Pull
822
823 rows, err := e.Query(`
824 select
825 p.owner_did,
826 p.repo_at,
827 p.pull_id,
828 p.created,
829 p.title,
830 p.state,
831 r.did,
832 r.name,
833 r.knot,
834 r.rkey,
835 r.created
836 from
837 pulls p
838 join
839 repos r on p.repo_at = r.at_uri
840 where
841 p.owner_did = ? and p.created >= date ('now', ?)
842 order by
843 p.created desc`, did, timeframe)
844 if err != nil {
845 return nil, err
846 }
847 defer rows.Close()
848
849 for rows.Next() {
850 var pull models.Pull
851 var repo models.Repo
852 var pullCreatedAt, repoCreatedAt string
853 err := rows.Scan(
854 &pull.OwnerDid,
855 &pull.RepoAt,
856 &pull.PullId,
857 &pullCreatedAt,
858 &pull.Title,
859 &pull.State,
860 &repo.Did,
861 &repo.Name,
862 &repo.Knot,
863 &repo.Rkey,
864 &repoCreatedAt,
865 )
866 if err != nil {
867 return nil, err
868 }
869
870 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
871 if err != nil {
872 return nil, err
873 }
874 pull.Created = pullCreatedTime
875
876 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
877 if err != nil {
878 return nil, err
879 }
880 repo.Created = repoCreatedTime
881
882 pull.Repo = &repo
883
884 pulls = append(pulls, pull)
885 }
886
887 if err := rows.Err(); err != nil {
888 return nil, err
889 }
890
891 return pulls, nil
892}
893
894func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
895 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
896 res, err := tx.Exec(
897 query,
898 comment.OwnerDid,
899 comment.RepoAt,
900 comment.SubmissionId,
901 comment.CommentAt,
902 comment.PullId,
903 comment.Body,
904 )
905 if err != nil {
906 return 0, err
907 }
908
909 i, err := res.LastInsertId()
910 if err != nil {
911 return 0, err
912 }
913
914 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
915 return 0, fmt.Errorf("put reference_links: %w", err)
916 }
917
918 return i, nil
919}
920
921// use with transaction
922func SetPullsState(e Execer, pullState models.PullState, filters ...orm.Filter) error {
923 var conditions []string
924 var args []any
925
926 args = append(args, pullState)
927 for _, filter := range filters {
928 conditions = append(conditions, filter.Condition())
929 args = append(args, filter.Arg()...)
930 }
931 args = append(args, models.PullAbandoned) // only update state of non-deleted pulls
932 args = append(args, models.PullMerged) // only update state of non-merged pulls
933
934 whereClause := ""
935 if conditions != nil {
936 whereClause = " where " + strings.Join(conditions, " and ")
937 }
938
939 query := fmt.Sprintf("update pulls set state = ? %s and state <> ? and state <> ?", whereClause)
940
941 _, err := e.Exec(query, args...)
942 return err
943}
944
945func ClosePulls(e Execer, filters ...orm.Filter) error {
946 return SetPullsState(e, models.PullClosed, filters...)
947}
948
949func ReopenPulls(e Execer, filters ...orm.Filter) error {
950 return SetPullsState(e, models.PullOpen, filters...)
951}
952
953func MergePulls(e Execer, filters ...orm.Filter) error {
954 return SetPullsState(e, models.PullMerged, filters...)
955}
956
957func AbandonPulls(e Execer, filters ...orm.Filter) error {
958 return SetPullsState(e, models.PullAbandoned, filters...)
959}
960
961func ResubmitPull(
962 e Execer,
963 pullAt syntax.ATURI,
964 newRoundNumber int,
965 newPatch string,
966 combinedPatch string,
967 newSourceRev string,
968 blob *lexutil.LexBlob,
969) error {
970 _, err := e.Exec(`
971 insert into pull_submissions (
972 pull_at,
973 round_number,
974 patch,
975 combined,
976 source_rev,
977 patch_blob_ref,
978 patch_blob_mime,
979 patch_blob_size
980 )
981 values (?, ?, ?, ?, ?, ?, ?, ?)
982 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev, blob.Ref.String(), blob.MimeType, blob.Size)
983
984 return err
985}
986
987func SetDependentOn(e Execer, dependentOn syntax.ATURI, filters ...orm.Filter) error {
988 var conditions []string
989 var args []any
990
991 args = append(args, dependentOn)
992
993 for _, filter := range filters {
994 conditions = append(conditions, filter.Condition())
995 args = append(args, filter.Arg()...)
996 }
997
998 whereClause := ""
999 if conditions != nil {
1000 whereClause = " where " + strings.Join(conditions, " and ")
1001 }
1002
1003 query := fmt.Sprintf("update pulls set dependent_on = ? %s", whereClause)
1004 _, err := e.Exec(query, args...)
1005
1006 return err
1007}
1008
1009func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
1010 row := e.QueryRow(`
1011 select
1012 count(case when state = ? then 1 end) as open_count,
1013 count(case when state = ? then 1 end) as merged_count,
1014 count(case when state = ? then 1 end) as closed_count,
1015 count(case when state = ? then 1 end) as deleted_count
1016 from pulls
1017 where repo_at = ?`,
1018 models.PullOpen,
1019 models.PullMerged,
1020 models.PullClosed,
1021 models.PullAbandoned,
1022 repoAt,
1023 )
1024
1025 var count models.PullCount
1026 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
1027 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
1028 }
1029
1030 return count, nil
1031}
1032
1033// change-id dependent_on
1034//
1035// 4 w ,-------- at_uri(z) (TOP)
1036// 3 z <----',------- at_uri(y)
1037// 2 y <-----',------ at_uri(x)
1038// 1 x <------' nil (BOT)
1039//
1040// `w` has no dependents, so it is the top of the stack
1041//
1042// this unfortunately does a db query for *each* pull of the stack,
1043// ideally this would be a recursive query, but in the interest of implementation simplicity,
1044// we took the less performant route
1045//
1046// TODO: make this less bad
1047func GetStack(e Execer, atUri syntax.ATURI) (models.Stack, error) {
1048 // first get the pull for the given at-uri
1049 pull, err := GetPull(e, orm.FilterEq("at_uri", atUri))
1050 if err != nil {
1051 return nil, err
1052 }
1053
1054 // Collect all pulls in the stack by traversing up and down
1055 allPulls := []*models.Pull{pull}
1056 visited := sets.New[syntax.ATURI]()
1057
1058 // Traverse up to find all dependents
1059 current := pull
1060 for {
1061 dependent, err := GetPull(e,
1062 orm.FilterEq("dependent_on", current.AtUri()),
1063 orm.FilterNotEq("state", models.PullAbandoned),
1064 )
1065 if err != nil || dependent == nil {
1066 break
1067 }
1068 if visited.Contains(dependent.AtUri()) {
1069 return allPulls, fmt.Errorf("circular dependency detected in stack")
1070 }
1071 allPulls = append(allPulls, dependent)
1072 visited.Insert(dependent.AtUri())
1073 current = dependent
1074 }
1075
1076 // Traverse down to find all dependencies
1077 current = pull
1078 for current.DependentOn != nil {
1079 dependency, err := GetPull(
1080 e,
1081 orm.FilterEq("at_uri", current.DependentOn),
1082 orm.FilterNotEq("state", models.PullAbandoned),
1083 )
1084
1085 if err != nil {
1086 return allPulls, fmt.Errorf("failed to find parent pull request, stack is malformed, missing PR: %s", current.DependentOn)
1087 }
1088 if visited.Contains(dependency.AtUri()) {
1089 return allPulls, fmt.Errorf("circular dependency detected in stack")
1090 }
1091 allPulls = append(allPulls, dependency)
1092 visited.Insert(dependency.AtUri())
1093 current = dependency
1094 }
1095
1096 // sort the list: find the top and build ordered list
1097 atUriMap := make(map[syntax.ATURI]*models.Pull, len(allPulls))
1098 dependentMap := make(map[syntax.ATURI]*models.Pull, len(allPulls))
1099
1100 for _, p := range allPulls {
1101 atUriMap[p.AtUri()] = p
1102 if p.DependentOn != nil {
1103 dependentMap[*p.DependentOn] = p
1104 }
1105 }
1106
1107 // the top of the stack is the pull that no other pull depends on
1108 var topPull *models.Pull
1109 for _, maybeTop := range allPulls {
1110 if _, ok := dependentMap[maybeTop.AtUri()]; !ok {
1111 topPull = maybeTop
1112 break
1113 }
1114 }
1115
1116 pulls := []*models.Pull{}
1117 for {
1118 pulls = append(pulls, topPull)
1119 if topPull.DependentOn != nil {
1120 if next, ok := atUriMap[*topPull.DependentOn]; ok {
1121 topPull = next
1122 } else {
1123 return pulls, fmt.Errorf("failed to find parent pull request, stack is malformed")
1124 }
1125 } else {
1126 break
1127 }
1128 }
1129
1130 return pulls, nil
1131}
1132
1133func GetAbandonedPulls(e Execer, atUri syntax.ATURI) ([]*models.Pull, error) {
1134 stack, err := GetStack(e, atUri)
1135 if err != nil {
1136 return nil, err
1137 }
1138
1139 var abandoned []*models.Pull
1140 for _, p := range stack {
1141 if p.State == models.PullAbandoned {
1142 abandoned = append(abandoned, p)
1143 }
1144 }
1145
1146 return abandoned, nil
1147}