From 3942272c3039b16e4731855641887f2442a2dc79 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Fri, 17 Apr 2026 18:56:43 +0900 Subject: [PATCH] internal/pkg: fine-grained cancellation This enables a specific artifact to be targeted for cancellation. Signed-off-by: Ophestra --- internal/pkg/net_test.go | 28 ++++--------- internal/pkg/pkg.go | 87 +++++++++++++++++++++++++++++++--------- internal/pkg/pkg_test.go | 52 ++++++++++++++++++++---- 3 files changed, 121 insertions(+), 46 deletions(-) diff --git a/internal/pkg/net_test.go b/internal/pkg/net_test.go index f525efbd..23c39356 100644 --- a/internal/pkg/net_test.go +++ b/internal/pkg/net_test.go @@ -8,7 +8,6 @@ import ( "testing" "testing/fstest" "unique" - "unsafe" "hakurei.app/check" "hakurei.app/internal/pkg" @@ -33,20 +32,14 @@ func TestHTTPGet(t *testing.T) { checkWithCache(t, []cacheTestCase{ {"direct", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { - var r pkg.RContext - rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache") - reflect.NewAt( - rCacheVal.Type(), - unsafe.Pointer(rCacheVal.UnsafeAddr()), - ).Elem().Set(reflect.ValueOf(c)) - + r := newRContext(t, c) f := pkg.NewHTTPGet( &client, "file:///testdata", testdataChecksum.Value(), ) var got []byte - if rc, err := f.Cure(&r); err != nil { + if rc, err := f.Cure(r); err != nil { t.Fatalf("Cure: error = %v", err) } else if got, err = io.ReadAll(rc); err != nil { t.Fatalf("ReadAll: error = %v", err) @@ -65,7 +58,7 @@ func TestHTTPGet(t *testing.T) { wantErrMismatch := &pkg.ChecksumMismatchError{ Got: testdataChecksum.Value(), } - if rc, err := f.Cure(&r); err != nil { + if rc, err := f.Cure(r); err != nil { t.Fatalf("Cure: error = %v", err) } else if got, err = io.ReadAll(rc); err != nil { t.Fatalf("ReadAll: error = %v", err) @@ -76,7 +69,7 @@ func TestHTTPGet(t *testing.T) { } // check fallback validation - if rc, err := f.Cure(&r); err != nil { + if rc, err := f.Cure(r); err != nil { t.Fatalf("Cure: error = %v", err) } else if err = rc.Close(); !reflect.DeepEqual(err, wantErrMismatch) { t.Fatalf("Close: error = %#v, want %#v", err, wantErrMismatch) @@ -89,18 +82,13 @@ func TestHTTPGet(t *testing.T) { pkg.Checksum{}, ) wantErrNotFound := pkg.ResponseStatusError(http.StatusNotFound) - if _, err := f.Cure(&r); !reflect.DeepEqual(err, wantErrNotFound) { + if _, err := f.Cure(r); !reflect.DeepEqual(err, wantErrNotFound) { t.Fatalf("Cure: error = %#v, want %#v", err, wantErrNotFound) } }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, {"cure", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { - var r pkg.RContext - rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache") - reflect.NewAt( - rCacheVal.Type(), - unsafe.Pointer(rCacheVal.UnsafeAddr()), - ).Elem().Set(reflect.ValueOf(c)) + r := newRContext(t, c) f := pkg.NewHTTPGet( &client, @@ -120,7 +108,7 @@ func TestHTTPGet(t *testing.T) { } var got []byte - if rc, err := f.Cure(&r); err != nil { + if rc, err := f.Cure(r); err != nil { t.Fatalf("Cure: error = %v", err) } else if got, err = io.ReadAll(rc); err != nil { t.Fatalf("ReadAll: error = %v", err) @@ -136,7 +124,7 @@ func TestHTTPGet(t *testing.T) { "file:///testdata", testdataChecksum.Value(), ) - if rc, err := f.Cure(&r); err != nil { + if rc, err := f.Cure(r); err != nil { t.Fatalf("Cure: error = %v", err) } else if got, err = io.ReadAll(rc); err != nil { t.Fatalf("ReadAll: error = %v", err) diff --git a/internal/pkg/pkg.go b/internal/pkg/pkg.go index fea3147f..4fc2d3b2 100644 --- a/internal/pkg/pkg.go +++ b/internal/pkg/pkg.go @@ -72,6 +72,10 @@ func MustDecode(s string) (checksum Checksum) { // common holds elements and receives methods shared between different contexts. type common struct { + // Context specific to this [Artifact]. The toplevel context in [Cache] must + // not be exposed directly. + ctx context.Context + // Address of underlying [Cache], should be zeroed or made unusable after // Cure returns and must not be exposed directly. cache *Cache @@ -183,7 +187,7 @@ func (t *TContext) destroy(errP *error) { } // Unwrap returns the underlying [context.Context]. -func (c *common) Unwrap() context.Context { return c.cache.ctx } +func (c *common) Unwrap() context.Context { return c.ctx } // GetMessage returns [message.Msg] held by the underlying [Cache]. func (c *common) GetMessage() message.Msg { return c.cache.msg } @@ -211,7 +215,7 @@ func (t *TContext) GetTempDir() *check.Absolute { return t.temp } // [ChecksumMismatchError], or the underlying implementation may block on Close. func (c *common) Open(a Artifact) (r io.ReadCloser, err error) { if f, ok := a.(FileArtifact); ok { - return c.cache.openFile(f) + return c.cache.openFile(c.ctx, f) } var pathname *check.Absolute @@ -376,6 +380,9 @@ type KnownChecksum interface { } // FileArtifact refers to an [Artifact] backed by a single file. +// +// FileArtifact does not support fine-grained cancellation. Its context is +// inherited from the first [TrivialArtifact] or [FloodArtifact] that opens it. type FileArtifact interface { // Cure returns [io.ReadCloser] of the full contents of [FileArtifact]. If // [FileArtifact] implements [KnownChecksum], Cure is responsible for @@ -535,6 +542,14 @@ const ( CHostAbstract ) +// pendingCure provides synchronisation and cancellation for pending cures. +type pendingCure struct { + // Closed on cure completion. + done <-chan struct{} + // Cancels the corresponding cure. + cancel context.CancelFunc +} + // Cache is a support layer that implementations of [Artifact] can use to store // cured [Artifact] data in a content addressed fashion. type Cache struct { @@ -570,7 +585,7 @@ type Cache struct { // Identifier to error pair for unrecoverably faulted [Artifact]. identErr map[unique.Handle[ID]]error // Pending identifiers, accessed through Cure for entries not in ident. - identPending map[unique.Handle[ID]]<-chan struct{} + identPending map[unique.Handle[ID]]*pendingCure // Synchronises access to ident and corresponding filesystem entries. identMu sync.RWMutex @@ -1007,6 +1022,7 @@ func (c *Cache) Scrub(checks int) error { // wait for a pending [Artifact] to cure. If neither is possible, the current // identifier is stored in identPending and a non-nil channel is returned. func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) ( + ctx context.Context, done chan<- struct{}, checksum unique.Handle[Checksum], err error, @@ -1023,10 +1039,10 @@ func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) ( return } - var notify <-chan struct{} - if notify, ok = c.identPending[id]; ok { + var pending *pendingCure + if pending, ok = c.identPending[id]; ok { c.identMu.Unlock() - <-notify + <-pending.done c.identMu.RLock() if checksum, ok = c.ident[id]; !ok { err = c.identErr[id] @@ -1036,7 +1052,9 @@ func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) ( } d := make(chan struct{}) - c.identPending[id] = d + pending = &pendingCure{done: d} + ctx, pending.cancel = context.WithCancel(c.ctx) + c.identPending[id] = pending c.identMu.Unlock() done = d return @@ -1062,11 +1080,43 @@ func (c *Cache) finaliseIdent( close(done) } +// Done returns a channel that is closed when the ongoing cure of an [Artifact] +// referred to by the specified identifier completes. Done may return nil if +// no ongoing cure of the specified identifier exists. +func (c *Cache) Done(id unique.Handle[ID]) <-chan struct{} { + c.identMu.RLock() + pending, ok := c.identPending[id] + c.identMu.RUnlock() + if !ok || pending == nil { + return nil + } + return pending.done +} + +// Cancel cancels the ongoing cure of an [Artifact] referred to by the specified +// identifier. Cancel returns whether the [context.CancelFunc] has been killed. +// Cancel does not wait for the cure to complete. +func (c *Cache) Cancel(id unique.Handle[ID]) bool { + c.identMu.RLock() + pending, ok := c.identPending[id] + c.identMu.RUnlock() + if !ok || pending == nil || pending.cancel == nil { + return false + } + pending.cancel() + return true +} + // openFile tries to load [FileArtifact] from [Cache], and if that fails, // obtains it via [FileArtifact.Cure] instead. Notably, it does not cure // [FileArtifact] to the filesystem. If err is nil, the caller is responsible // for closing the resulting [io.ReadCloser]. -func (c *Cache) openFile(f FileArtifact) (r io.ReadCloser, err error) { +// +// The context must originate from loadOrStoreIdent to enable cancellation. +func (c *Cache) openFile( + ctx context.Context, + f FileArtifact, +) (r io.ReadCloser, err error) { if kc, ok := f.(KnownChecksum); c.flags&CAssumeChecksum != 0 && ok { c.checksumMu.RLock() r, err = os.Open(c.base.Append( @@ -1097,7 +1147,7 @@ func (c *Cache) openFile(f FileArtifact) (r io.ReadCloser, err error) { } }() } - return f.Cure(&RContext{common{c}}) + return f.Cure(&RContext{common{ctx, c}}) } return } @@ -1245,12 +1295,8 @@ func (c *Cache) Cure(a Artifact) ( checksum unique.Handle[Checksum], err error, ) { - select { - case <-c.ctx.Done(): - err = c.ctx.Err() + if err = c.ctx.Err(); err != nil { return - - default: } return c.cure(a, true) @@ -1461,8 +1507,11 @@ func (c *Cache) cure(a Artifact, curesExempt bool) ( } }() - var done chan<- struct{} - done, checksum, err = c.loadOrStoreIdent(id) + var ( + ctx context.Context + done chan<- struct{} + ) + ctx, done, checksum, err = c.loadOrStoreIdent(id) if done == nil { return } else { @@ -1575,7 +1624,7 @@ func (c *Cache) cure(a Artifact, curesExempt bool) ( if err = c.enterCure(a, curesExempt); err != nil { return } - r, err = f.Cure(&RContext{common{c}}) + r, err = f.Cure(&RContext{common{ctx, c}}) if err == nil { if checksumPathname == nil || c.flags&CValidateKnown != 0 { h := sha512.New384() @@ -1655,7 +1704,7 @@ func (c *Cache) cure(a Artifact, curesExempt bool) ( c.base.Append(dirWork, ids), c.base.Append(dirTemp, ids), ids, nil, nil, nil, - common{c}, + common{ctx, c}, } switch ca := a.(type) { case TrivialArtifact: @@ -1878,7 +1927,7 @@ func open( ident: make(map[unique.Handle[ID]]unique.Handle[Checksum]), identErr: make(map[unique.Handle[ID]]error), - identPending: make(map[unique.Handle[ID]]<-chan struct{}), + identPending: make(map[unique.Handle[ID]]*pendingCure), brPool: sync.Pool{New: func() any { return new(bufio.Reader) }}, bwPool: sync.Pool{New: func() any { return new(bufio.Writer) }}, diff --git a/internal/pkg/pkg_test.go b/internal/pkg/pkg_test.go index 387f7b31..dbaa0801 100644 --- a/internal/pkg/pkg_test.go +++ b/internal/pkg/pkg_test.go @@ -40,6 +40,23 @@ func unsafeOpen( lock bool, ) (*pkg.Cache, error) +// newRContext returns the address of a new [pkg.RContext] unsafely created for +// the specified [testing.TB]. +func newRContext(tb testing.TB, c *pkg.Cache) *pkg.RContext { + var r pkg.RContext + rContextVal := reflect.ValueOf(&r).Elem().FieldByName("ctx") + reflect.NewAt( + rContextVal.Type(), + unsafe.Pointer(rContextVal.UnsafeAddr()), + ).Elem().Set(reflect.ValueOf(tb.Context())) + rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache") + reflect.NewAt( + rCacheVal.Type(), + unsafe.Pointer(rCacheVal.UnsafeAddr()), + ).Elem().Set(reflect.ValueOf(c)) + return &r +} + func TestMain(m *testing.M) { container.TryArgv0(nil); os.Exit(m.Run()) } // overrideIdent overrides the ID method of [Artifact]. @@ -876,17 +893,38 @@ func TestCache(t *testing.T) { t.Fatalf("Scrub: error = %#v, want %#v", err, wantErrScrub) } - identPendingVal := reflect.ValueOf(c).Elem().FieldByName("identPending") - identPending := reflect.NewAt( - identPendingVal.Type(), - unsafe.Pointer(identPendingVal.UnsafeAddr()), - ).Elem().Interface().(map[unique.Handle[pkg.ID]]<-chan struct{}) - notify := identPending[unique.Make(pkg.ID{0xff})] + notify := c.Done(unique.Make(pkg.ID{0xff})) go close(n) - <-notify + if notify != nil { + <-notify + } + for c.Done(unique.Make(pkg.ID{0xff})) != nil { + } <-wCureDone }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, + {"cancel hanging", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { + started := make(chan struct{}) + go func() { + <-started + if !c.Cancel(unique.Make(pkg.ID{0xff})) { + panic("missed cancellation") + } + }() + if _, _, err := c.Cure(overrideIdent{pkg.ID{0xff}, &stubArtifact{ + kind: pkg.KindTar, + cure: func(t *pkg.TContext) error { + close(started) + <-t.Unwrap().Done() + return stub.UniqueError(0xbad) + }, + }}); !reflect.DeepEqual(err, stub.UniqueError(0xbad)) { + t.Fatalf("Cure: error = %v", err) + } + for c.Cancel(unique.Make(pkg.ID{0xff})) { + } + }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, + {"no assume checksum", 0, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { makeGarbage := func(work *check.Absolute, wantErr error) error { if err := os.Mkdir(work.String(), 0700); err != nil {