internal/pkg: fine-grained cancellation

This enables a specific artifact to be targeted for cancellation.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-04-17 18:56:43 +09:00
parent 9036986156
commit 3942272c30
3 changed files with 121 additions and 46 deletions

View File

@@ -8,7 +8,6 @@ import (
"testing" "testing"
"testing/fstest" "testing/fstest"
"unique" "unique"
"unsafe"
"hakurei.app/check" "hakurei.app/check"
"hakurei.app/internal/pkg" "hakurei.app/internal/pkg"
@@ -33,20 +32,14 @@ func TestHTTPGet(t *testing.T) {
checkWithCache(t, []cacheTestCase{ checkWithCache(t, []cacheTestCase{
{"direct", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { {"direct", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) {
var r pkg.RContext r := newRContext(t, c)
rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache")
reflect.NewAt(
rCacheVal.Type(),
unsafe.Pointer(rCacheVal.UnsafeAddr()),
).Elem().Set(reflect.ValueOf(c))
f := pkg.NewHTTPGet( f := pkg.NewHTTPGet(
&client, &client,
"file:///testdata", "file:///testdata",
testdataChecksum.Value(), testdataChecksum.Value(),
) )
var got []byte var got []byte
if rc, err := f.Cure(&r); err != nil { if rc, err := f.Cure(r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(rc); err != nil { } else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err) t.Fatalf("ReadAll: error = %v", err)
@@ -65,7 +58,7 @@ func TestHTTPGet(t *testing.T) {
wantErrMismatch := &pkg.ChecksumMismatchError{ wantErrMismatch := &pkg.ChecksumMismatchError{
Got: testdataChecksum.Value(), Got: testdataChecksum.Value(),
} }
if rc, err := f.Cure(&r); err != nil { if rc, err := f.Cure(r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(rc); err != nil { } else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err) t.Fatalf("ReadAll: error = %v", err)
@@ -76,7 +69,7 @@ func TestHTTPGet(t *testing.T) {
} }
// check fallback validation // check fallback validation
if rc, err := f.Cure(&r); err != nil { if rc, err := f.Cure(r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if err = rc.Close(); !reflect.DeepEqual(err, wantErrMismatch) { } else if err = rc.Close(); !reflect.DeepEqual(err, wantErrMismatch) {
t.Fatalf("Close: error = %#v, want %#v", err, wantErrMismatch) t.Fatalf("Close: error = %#v, want %#v", err, wantErrMismatch)
@@ -89,18 +82,13 @@ func TestHTTPGet(t *testing.T) {
pkg.Checksum{}, pkg.Checksum{},
) )
wantErrNotFound := pkg.ResponseStatusError(http.StatusNotFound) wantErrNotFound := pkg.ResponseStatusError(http.StatusNotFound)
if _, err := f.Cure(&r); !reflect.DeepEqual(err, wantErrNotFound) { if _, err := f.Cure(r); !reflect.DeepEqual(err, wantErrNotFound) {
t.Fatalf("Cure: error = %#v, want %#v", err, wantErrNotFound) t.Fatalf("Cure: error = %#v, want %#v", err, wantErrNotFound)
} }
}, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")},
{"cure", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { {"cure", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) {
var r pkg.RContext r := newRContext(t, c)
rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache")
reflect.NewAt(
rCacheVal.Type(),
unsafe.Pointer(rCacheVal.UnsafeAddr()),
).Elem().Set(reflect.ValueOf(c))
f := pkg.NewHTTPGet( f := pkg.NewHTTPGet(
&client, &client,
@@ -120,7 +108,7 @@ func TestHTTPGet(t *testing.T) {
} }
var got []byte var got []byte
if rc, err := f.Cure(&r); err != nil { if rc, err := f.Cure(r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(rc); err != nil { } else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err) t.Fatalf("ReadAll: error = %v", err)
@@ -136,7 +124,7 @@ func TestHTTPGet(t *testing.T) {
"file:///testdata", "file:///testdata",
testdataChecksum.Value(), testdataChecksum.Value(),
) )
if rc, err := f.Cure(&r); err != nil { if rc, err := f.Cure(r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(rc); err != nil { } else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err) t.Fatalf("ReadAll: error = %v", err)

View File

@@ -72,6 +72,10 @@ func MustDecode(s string) (checksum Checksum) {
// common holds elements and receives methods shared between different contexts. // common holds elements and receives methods shared between different contexts.
type common struct { type common struct {
// Context specific to this [Artifact]. The toplevel context in [Cache] must
// not be exposed directly.
ctx context.Context
// Address of underlying [Cache], should be zeroed or made unusable after // Address of underlying [Cache], should be zeroed or made unusable after
// Cure returns and must not be exposed directly. // Cure returns and must not be exposed directly.
cache *Cache cache *Cache
@@ -183,7 +187,7 @@ func (t *TContext) destroy(errP *error) {
} }
// Unwrap returns the underlying [context.Context]. // Unwrap returns the underlying [context.Context].
func (c *common) Unwrap() context.Context { return c.cache.ctx } func (c *common) Unwrap() context.Context { return c.ctx }
// GetMessage returns [message.Msg] held by the underlying [Cache]. // GetMessage returns [message.Msg] held by the underlying [Cache].
func (c *common) GetMessage() message.Msg { return c.cache.msg } func (c *common) GetMessage() message.Msg { return c.cache.msg }
@@ -211,7 +215,7 @@ func (t *TContext) GetTempDir() *check.Absolute { return t.temp }
// [ChecksumMismatchError], or the underlying implementation may block on Close. // [ChecksumMismatchError], or the underlying implementation may block on Close.
func (c *common) Open(a Artifact) (r io.ReadCloser, err error) { func (c *common) Open(a Artifact) (r io.ReadCloser, err error) {
if f, ok := a.(FileArtifact); ok { if f, ok := a.(FileArtifact); ok {
return c.cache.openFile(f) return c.cache.openFile(c.ctx, f)
} }
var pathname *check.Absolute var pathname *check.Absolute
@@ -376,6 +380,9 @@ type KnownChecksum interface {
} }
// FileArtifact refers to an [Artifact] backed by a single file. // FileArtifact refers to an [Artifact] backed by a single file.
//
// FileArtifact does not support fine-grained cancellation. Its context is
// inherited from the first [TrivialArtifact] or [FloodArtifact] that opens it.
type FileArtifact interface { type FileArtifact interface {
// Cure returns [io.ReadCloser] of the full contents of [FileArtifact]. If // Cure returns [io.ReadCloser] of the full contents of [FileArtifact]. If
// [FileArtifact] implements [KnownChecksum], Cure is responsible for // [FileArtifact] implements [KnownChecksum], Cure is responsible for
@@ -535,6 +542,14 @@ const (
CHostAbstract CHostAbstract
) )
// pendingCure provides synchronisation and cancellation for pending cures.
type pendingCure struct {
// Closed on cure completion.
done <-chan struct{}
// Cancels the corresponding cure.
cancel context.CancelFunc
}
// Cache is a support layer that implementations of [Artifact] can use to store // Cache is a support layer that implementations of [Artifact] can use to store
// cured [Artifact] data in a content addressed fashion. // cured [Artifact] data in a content addressed fashion.
type Cache struct { type Cache struct {
@@ -570,7 +585,7 @@ type Cache struct {
// Identifier to error pair for unrecoverably faulted [Artifact]. // Identifier to error pair for unrecoverably faulted [Artifact].
identErr map[unique.Handle[ID]]error identErr map[unique.Handle[ID]]error
// Pending identifiers, accessed through Cure for entries not in ident. // Pending identifiers, accessed through Cure for entries not in ident.
identPending map[unique.Handle[ID]]<-chan struct{} identPending map[unique.Handle[ID]]*pendingCure
// Synchronises access to ident and corresponding filesystem entries. // Synchronises access to ident and corresponding filesystem entries.
identMu sync.RWMutex identMu sync.RWMutex
@@ -1007,6 +1022,7 @@ func (c *Cache) Scrub(checks int) error {
// wait for a pending [Artifact] to cure. If neither is possible, the current // wait for a pending [Artifact] to cure. If neither is possible, the current
// identifier is stored in identPending and a non-nil channel is returned. // identifier is stored in identPending and a non-nil channel is returned.
func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) ( func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) (
ctx context.Context,
done chan<- struct{}, done chan<- struct{},
checksum unique.Handle[Checksum], checksum unique.Handle[Checksum],
err error, err error,
@@ -1023,10 +1039,10 @@ func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) (
return return
} }
var notify <-chan struct{} var pending *pendingCure
if notify, ok = c.identPending[id]; ok { if pending, ok = c.identPending[id]; ok {
c.identMu.Unlock() c.identMu.Unlock()
<-notify <-pending.done
c.identMu.RLock() c.identMu.RLock()
if checksum, ok = c.ident[id]; !ok { if checksum, ok = c.ident[id]; !ok {
err = c.identErr[id] err = c.identErr[id]
@@ -1036,7 +1052,9 @@ func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) (
} }
d := make(chan struct{}) d := make(chan struct{})
c.identPending[id] = d pending = &pendingCure{done: d}
ctx, pending.cancel = context.WithCancel(c.ctx)
c.identPending[id] = pending
c.identMu.Unlock() c.identMu.Unlock()
done = d done = d
return return
@@ -1062,11 +1080,43 @@ func (c *Cache) finaliseIdent(
close(done) close(done)
} }
// Done returns a channel that is closed when the ongoing cure of an [Artifact]
// referred to by the specified identifier completes. Done may return nil if
// no ongoing cure of the specified identifier exists.
func (c *Cache) Done(id unique.Handle[ID]) <-chan struct{} {
c.identMu.RLock()
pending, ok := c.identPending[id]
c.identMu.RUnlock()
if !ok || pending == nil {
return nil
}
return pending.done
}
// Cancel cancels the ongoing cure of an [Artifact] referred to by the specified
// identifier. Cancel returns whether the [context.CancelFunc] has been killed.
// Cancel does not wait for the cure to complete.
func (c *Cache) Cancel(id unique.Handle[ID]) bool {
c.identMu.RLock()
pending, ok := c.identPending[id]
c.identMu.RUnlock()
if !ok || pending == nil || pending.cancel == nil {
return false
}
pending.cancel()
return true
}
// openFile tries to load [FileArtifact] from [Cache], and if that fails, // openFile tries to load [FileArtifact] from [Cache], and if that fails,
// obtains it via [FileArtifact.Cure] instead. Notably, it does not cure // obtains it via [FileArtifact.Cure] instead. Notably, it does not cure
// [FileArtifact] to the filesystem. If err is nil, the caller is responsible // [FileArtifact] to the filesystem. If err is nil, the caller is responsible
// for closing the resulting [io.ReadCloser]. // for closing the resulting [io.ReadCloser].
func (c *Cache) openFile(f FileArtifact) (r io.ReadCloser, err error) { //
// The context must originate from loadOrStoreIdent to enable cancellation.
func (c *Cache) openFile(
ctx context.Context,
f FileArtifact,
) (r io.ReadCloser, err error) {
if kc, ok := f.(KnownChecksum); c.flags&CAssumeChecksum != 0 && ok { if kc, ok := f.(KnownChecksum); c.flags&CAssumeChecksum != 0 && ok {
c.checksumMu.RLock() c.checksumMu.RLock()
r, err = os.Open(c.base.Append( r, err = os.Open(c.base.Append(
@@ -1097,7 +1147,7 @@ func (c *Cache) openFile(f FileArtifact) (r io.ReadCloser, err error) {
} }
}() }()
} }
return f.Cure(&RContext{common{c}}) return f.Cure(&RContext{common{ctx, c}})
} }
return return
} }
@@ -1245,12 +1295,8 @@ func (c *Cache) Cure(a Artifact) (
checksum unique.Handle[Checksum], checksum unique.Handle[Checksum],
err error, err error,
) { ) {
select { if err = c.ctx.Err(); err != nil {
case <-c.ctx.Done():
err = c.ctx.Err()
return return
default:
} }
return c.cure(a, true) return c.cure(a, true)
@@ -1461,8 +1507,11 @@ func (c *Cache) cure(a Artifact, curesExempt bool) (
} }
}() }()
var done chan<- struct{} var (
done, checksum, err = c.loadOrStoreIdent(id) ctx context.Context
done chan<- struct{}
)
ctx, done, checksum, err = c.loadOrStoreIdent(id)
if done == nil { if done == nil {
return return
} else { } else {
@@ -1575,7 +1624,7 @@ func (c *Cache) cure(a Artifact, curesExempt bool) (
if err = c.enterCure(a, curesExempt); err != nil { if err = c.enterCure(a, curesExempt); err != nil {
return return
} }
r, err = f.Cure(&RContext{common{c}}) r, err = f.Cure(&RContext{common{ctx, c}})
if err == nil { if err == nil {
if checksumPathname == nil || c.flags&CValidateKnown != 0 { if checksumPathname == nil || c.flags&CValidateKnown != 0 {
h := sha512.New384() h := sha512.New384()
@@ -1655,7 +1704,7 @@ func (c *Cache) cure(a Artifact, curesExempt bool) (
c.base.Append(dirWork, ids), c.base.Append(dirWork, ids),
c.base.Append(dirTemp, ids), c.base.Append(dirTemp, ids),
ids, nil, nil, nil, ids, nil, nil, nil,
common{c}, common{ctx, c},
} }
switch ca := a.(type) { switch ca := a.(type) {
case TrivialArtifact: case TrivialArtifact:
@@ -1878,7 +1927,7 @@ func open(
ident: make(map[unique.Handle[ID]]unique.Handle[Checksum]), ident: make(map[unique.Handle[ID]]unique.Handle[Checksum]),
identErr: make(map[unique.Handle[ID]]error), identErr: make(map[unique.Handle[ID]]error),
identPending: make(map[unique.Handle[ID]]<-chan struct{}), identPending: make(map[unique.Handle[ID]]*pendingCure),
brPool: sync.Pool{New: func() any { return new(bufio.Reader) }}, brPool: sync.Pool{New: func() any { return new(bufio.Reader) }},
bwPool: sync.Pool{New: func() any { return new(bufio.Writer) }}, bwPool: sync.Pool{New: func() any { return new(bufio.Writer) }},

View File

@@ -40,6 +40,23 @@ func unsafeOpen(
lock bool, lock bool,
) (*pkg.Cache, error) ) (*pkg.Cache, error)
// newRContext returns the address of a new [pkg.RContext] unsafely created for
// the specified [testing.TB].
func newRContext(tb testing.TB, c *pkg.Cache) *pkg.RContext {
var r pkg.RContext
rContextVal := reflect.ValueOf(&r).Elem().FieldByName("ctx")
reflect.NewAt(
rContextVal.Type(),
unsafe.Pointer(rContextVal.UnsafeAddr()),
).Elem().Set(reflect.ValueOf(tb.Context()))
rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache")
reflect.NewAt(
rCacheVal.Type(),
unsafe.Pointer(rCacheVal.UnsafeAddr()),
).Elem().Set(reflect.ValueOf(c))
return &r
}
func TestMain(m *testing.M) { container.TryArgv0(nil); os.Exit(m.Run()) } func TestMain(m *testing.M) { container.TryArgv0(nil); os.Exit(m.Run()) }
// overrideIdent overrides the ID method of [Artifact]. // overrideIdent overrides the ID method of [Artifact].
@@ -876,17 +893,38 @@ func TestCache(t *testing.T) {
t.Fatalf("Scrub: error = %#v, want %#v", err, wantErrScrub) t.Fatalf("Scrub: error = %#v, want %#v", err, wantErrScrub)
} }
identPendingVal := reflect.ValueOf(c).Elem().FieldByName("identPending") notify := c.Done(unique.Make(pkg.ID{0xff}))
identPending := reflect.NewAt(
identPendingVal.Type(),
unsafe.Pointer(identPendingVal.UnsafeAddr()),
).Elem().Interface().(map[unique.Handle[pkg.ID]]<-chan struct{})
notify := identPending[unique.Make(pkg.ID{0xff})]
go close(n) go close(n)
<-notify if notify != nil {
<-notify
}
for c.Done(unique.Make(pkg.ID{0xff})) != nil {
}
<-wCureDone <-wCureDone
}, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")},
{"cancel hanging", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) {
started := make(chan struct{})
go func() {
<-started
if !c.Cancel(unique.Make(pkg.ID{0xff})) {
panic("missed cancellation")
}
}()
if _, _, err := c.Cure(overrideIdent{pkg.ID{0xff}, &stubArtifact{
kind: pkg.KindTar,
cure: func(t *pkg.TContext) error {
close(started)
<-t.Unwrap().Done()
return stub.UniqueError(0xbad)
},
}}); !reflect.DeepEqual(err, stub.UniqueError(0xbad)) {
t.Fatalf("Cure: error = %v", err)
}
for c.Cancel(unique.Make(pkg.ID{0xff})) {
}
}, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")},
{"no assume checksum", 0, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { {"no assume checksum", 0, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) {
makeGarbage := func(work *check.Absolute, wantErr error) error { makeGarbage := func(work *check.Absolute, wantErr error) error {
if err := os.Mkdir(work.String(), 0700); err != nil { if err := os.Mkdir(work.String(), 0700); err != nil {