cmd/mbf: cancel remote cure
All checks were successful
Test / Create distribution (push) Successful in 1m3s
Test / Sandbox (push) Successful in 2m43s
Test / Hakurei (push) Successful in 3m49s
Test / ShareFS (push) Successful in 3m50s
Test / Sandbox (race detector) (push) Successful in 5m20s
Test / Hakurei (race detector) (push) Successful in 6m19s
Test / Flake checks (push) Successful in 1m22s

This exposes the new fine-grained cancel API in cmd/mbf.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-04-17 21:30:00 +09:00
parent ae9b9adfd2
commit 8d657b6fdf
3 changed files with 177 additions and 12 deletions

View File

@@ -6,11 +6,14 @@ import (
"errors" "errors"
"io" "io"
"log" "log"
"math"
"net" "net"
"os" "os"
"sync" "sync"
"syscall"
"testing" "testing"
"time" "time"
"unique"
"hakurei.app/check" "hakurei.app/check"
"hakurei.app/internal/pkg" "hakurei.app/internal/pkg"
@@ -63,6 +66,42 @@ func cureFromIR(
return a, errors.Join(err, conn.Close()) return a, errors.Join(err, conn.Close())
} }
const (
// specialCancel is a message consisting of a single identifier referring
// to a curing artifact to be cancelled.
specialCancel = iota
// remoteSpecial denotes a special message with custom layout.
remoteSpecial = math.MaxUint64
)
// writeSpecialHeader writes the header of a remoteSpecial message.
func writeSpecialHeader(conn net.Conn, kind uint64) error {
var sh [16]byte
binary.LittleEndian.PutUint64(sh[:], remoteSpecial)
binary.LittleEndian.PutUint64(sh[8:], kind)
if n, err := conn.Write(sh[:]); err != nil {
return err
} else if n != len(sh) {
return io.ErrShortWrite
}
return nil
}
// cancelIdent reads an identifier from conn and cancels the corresponding cure.
func cancelIdent(
cache *pkg.Cache,
conn net.Conn,
) (*pkg.ID, bool, error) {
var ident pkg.ID
if _, err := io.ReadFull(conn, ident[:]); err != nil {
return nil, false, errors.Join(err, conn.Close())
} else if err = conn.Close(); err != nil {
return nil, false, err
}
return &ident, cache.Cancel(unique.Make(ident)), nil
}
// serve services connections from a [net.UnixListener]. // serve services connections from a [net.UnixListener].
func serve( func serve(
ctx context.Context, ctx context.Context,
@@ -94,7 +133,18 @@ func serve(
continue continue
} }
wg.Go(func() { wg.Go(func() {
go func() { <-ctx.Done(); _ = conn.SetDeadline(time.Now()) }() done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
_ = conn.SetDeadline(time.Now())
case <-done:
return
}
}()
if _err := conn.SetReadDeadline(daemonDeadline()); _err != nil { if _err := conn.SetReadDeadline(daemonDeadline()); _err != nil {
log.Println(_err) log.Println(_err)
if _err = conn.Close(); _err != nil { if _err = conn.Close(); _err != nil {
@@ -103,15 +153,46 @@ func serve(
return return
} }
var flagsWire [8]byte var word [8]byte
if _, _err := io.ReadFull(conn, flagsWire[:]); _err != nil { if _, _err := io.ReadFull(conn, word[:]); _err != nil {
log.Println(_err) log.Println(_err)
if _err = conn.Close(); _err != nil { if _err = conn.Close(); _err != nil {
log.Println(_err) log.Println(_err)
} }
return return
} }
flags := binary.LittleEndian.Uint64(flagsWire[:]) flags := binary.LittleEndian.Uint64(word[:])
if flags == remoteSpecial {
if _, _err := io.ReadFull(conn, word[:]); _err != nil {
log.Println(_err)
if _err = conn.Close(); _err != nil {
log.Println(_err)
}
return
}
switch special := binary.LittleEndian.Uint64(word[:]); special {
default:
log.Printf("invalid special %d", special)
case specialCancel:
if id, ok, _err := cancelIdent(cm.c, conn); _err != nil {
log.Println(_err)
} else if !ok {
log.Println(
"attempting to cancel invalid artifact",
pkg.Encode(*id),
)
} else {
log.Println(
"cancelled artifact",
pkg.Encode(*id),
)
}
}
return
}
if a, _err := cureFromIR(cm.c, conn, flags); _err != nil { if a, _err := cureFromIR(cm.c, conn, flags); _err != nil {
log.Println(_err) log.Println(_err)
@@ -133,6 +214,31 @@ func serve(
return ul.Close() return ul.Close()
} }
// dial wraps [net.DialUnix] with a context.
func dial(ctx context.Context, addr *net.UnixAddr) (
done chan<- struct{},
conn *net.UnixConn,
err error,
) {
conn, err = net.DialUnix("unix", nil, addr)
if err != nil {
return
}
d := make(chan struct{})
done = d
go func() {
select {
case <-ctx.Done():
_ = conn.SetDeadline(time.Now())
case <-d:
return
}
}()
return
}
// cureRemote cures a [pkg.Artifact] on a daemon. // cureRemote cures a [pkg.Artifact] on a daemon.
func cureRemote( func cureRemote(
ctx context.Context, ctx context.Context,
@@ -140,11 +246,15 @@ func cureRemote(
a pkg.Artifact, a pkg.Artifact,
flags uint64, flags uint64,
) (*check.Absolute, error) { ) (*check.Absolute, error) {
conn, err := net.DialUnix("unix", nil, addr) if flags == remoteSpecial {
return nil, syscall.EINVAL
}
done, conn, err := dial(ctx, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go func() { <-ctx.Done(); _ = conn.SetDeadline(time.Now()) }() defer close(done)
if n, flagErr := conn.Write(binary.LittleEndian.AppendUint64(nil, flags)); flagErr != nil { if n, flagErr := conn.Write(binary.LittleEndian.AppendUint64(nil, flags)); flagErr != nil {
return nil, errors.Join(flagErr, conn.Close()) return nil, errors.Join(flagErr, conn.Close())
@@ -165,7 +275,9 @@ func cureRemote(
payload, recvErr := io.ReadAll(conn) payload, recvErr := io.ReadAll(conn)
if err = errors.Join(recvErr, conn.Close()); err != nil { if err = errors.Join(recvErr, conn.Close()); err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) { if errors.Is(err, os.ErrDeadlineExceeded) {
err = ctx.Err() if cancelErr := ctx.Err(); cancelErr != nil {
err = cancelErr
}
} }
return nil, err return nil, err
} }
@@ -178,3 +290,29 @@ func cureRemote(
p, err = check.NewAbs(string(payload)) p, err = check.NewAbs(string(payload))
return p, err return p, err
} }
// cureRemote cancels a [pkg.Artifact] curing on a daemon.
func cancelRemote(
ctx context.Context,
addr *net.UnixAddr,
a pkg.Artifact,
) error {
done, conn, err := dial(ctx, addr)
if err != nil {
return err
}
defer close(done)
if err = writeSpecialHeader(conn, specialCancel); err != nil {
return errors.Join(err, conn.Close())
}
var n int
id := pkg.NewIR().Ident(a).Value()
if n, err = conn.Write(id[:]); err != nil {
return errors.Join(err, conn.Close())
} else if n != len(id) {
return errors.Join(io.ErrShortWrite, conn.Close())
}
return conn.Close()
}

View File

@@ -9,6 +9,8 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"strings"
"testing" "testing"
"time" "time"
@@ -104,11 +106,17 @@ func TestDaemon(t *testing.T) {
} }
}() }()
if err = cancelRemote(ctx, &addr, pkg.NewFile("nonexistent", nil)); err != nil {
t.Fatalf("cancelRemote: error = %v", err)
}
// keep this last for synchronisation
var p *check.Absolute var p *check.Absolute
p, err = cureRemote(ctx, &addr, pkg.NewFile("check", []byte{0}), 0) p, err = cureRemote(ctx, &addr, pkg.NewFile("check", []byte{0}), 0)
if err != nil { if err != nil {
t.Fatalf("cureRemote: error = %v", err) t.Fatalf("cureRemote: error = %v", err)
} }
cancel() cancel()
<-done <-done
@@ -117,9 +125,17 @@ func TestDaemon(t *testing.T) {
t.Errorf("cureRemote: %s, want %s", got, want) t.Errorf("cureRemote: %s, want %s", got, want)
} }
const wantLog = `daemon: fulfilled artifact fiZf-ZY_Yq6qxJNrHbMiIPYCsGkUiKCRsZrcSELXTqZWtCnESlHmzV5ThhWWGGYG wantLog := []string{
` "",
if gotLog := buf.String(); gotLog != wantLog { "daemon: attempting to cancel invalid artifact kQm9fmnCmXST1-MMmxzcau2oKZCXXrlZydo4PkeV5hO_2PKfeC8t98hrbV_ZZx_j",
t.Errorf("serve: logged\n%s\nwant\n%s", gotLog, wantLog) "daemon: fulfilled artifact fiZf-ZY_Yq6qxJNrHbMiIPYCsGkUiKCRsZrcSELXTqZWtCnESlHmzV5ThhWWGGYG",
}
gotLog := strings.Split(buf.String(), "\n")
slices.Sort(gotLog)
if !slices.Equal(gotLog, wantLog) {
t.Errorf(
"serve: logged\n%s\nwant\n%s",
strings.Join(gotLog, "\n"), strings.Join(wantLog, "\n"),
)
} }
} }

View File

@@ -479,10 +479,21 @@ func main() {
if flagNoReply { if flagNoReply {
flags |= remoteNoReply flags |= remoteNoReply
} }
pathname, err := cureRemote(ctx, &addr, rosa.Std.Load(p), flags) a := rosa.Std.Load(p)
pathname, err := cureRemote(ctx, &addr, a, flags)
if !flagNoReply && err == nil { if !flagNoReply && err == nil {
log.Println(pathname) log.Println(pathname)
} }
if errors.Is(err, context.Canceled) {
cc, cancel := context.WithDeadline(context.Background(), daemonDeadline())
defer cancel()
if _err := cancelRemote(cc, &addr, a); _err != nil {
log.Println(err)
}
}
return err return err
} }
}, },