From 8d657b6fdf5570e3c0324ec5086ee8aad88171ac Mon Sep 17 00:00:00 2001 From: Ophestra Date: Fri, 17 Apr 2026 21:30:00 +0900 Subject: [PATCH] cmd/mbf: cancel remote cure This exposes the new fine-grained cancel API in cmd/mbf. Signed-off-by: Ophestra --- cmd/mbf/daemon.go | 152 +++++++++++++++++++++++++++++++++++++++-- cmd/mbf/daemon_test.go | 24 +++++-- cmd/mbf/main.go | 13 +++- 3 files changed, 177 insertions(+), 12 deletions(-) diff --git a/cmd/mbf/daemon.go b/cmd/mbf/daemon.go index 4a5f7d6d..cadbf57e 100644 --- a/cmd/mbf/daemon.go +++ b/cmd/mbf/daemon.go @@ -6,11 +6,14 @@ import ( "errors" "io" "log" + "math" "net" "os" "sync" + "syscall" "testing" "time" + "unique" "hakurei.app/check" "hakurei.app/internal/pkg" @@ -63,6 +66,42 @@ func cureFromIR( return a, errors.Join(err, conn.Close()) } +const ( + // specialCancel is a message consisting of a single identifier referring + // to a curing artifact to be cancelled. + specialCancel = iota + + // remoteSpecial denotes a special message with custom layout. + remoteSpecial = math.MaxUint64 +) + +// writeSpecialHeader writes the header of a remoteSpecial message. +func writeSpecialHeader(conn net.Conn, kind uint64) error { + var sh [16]byte + binary.LittleEndian.PutUint64(sh[:], remoteSpecial) + binary.LittleEndian.PutUint64(sh[8:], kind) + if n, err := conn.Write(sh[:]); err != nil { + return err + } else if n != len(sh) { + return io.ErrShortWrite + } + return nil +} + +// cancelIdent reads an identifier from conn and cancels the corresponding cure. +func cancelIdent( + cache *pkg.Cache, + conn net.Conn, +) (*pkg.ID, bool, error) { + var ident pkg.ID + if _, err := io.ReadFull(conn, ident[:]); err != nil { + return nil, false, errors.Join(err, conn.Close()) + } else if err = conn.Close(); err != nil { + return nil, false, err + } + return &ident, cache.Cancel(unique.Make(ident)), nil +} + // serve services connections from a [net.UnixListener]. func serve( ctx context.Context, @@ -94,7 +133,18 @@ func serve( continue } wg.Go(func() { - go func() { <-ctx.Done(); _ = conn.SetDeadline(time.Now()) }() + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + _ = conn.SetDeadline(time.Now()) + + case <-done: + return + } + }() + if _err := conn.SetReadDeadline(daemonDeadline()); _err != nil { log.Println(_err) if _err = conn.Close(); _err != nil { @@ -103,15 +153,46 @@ func serve( return } - var flagsWire [8]byte - if _, _err := io.ReadFull(conn, flagsWire[:]); _err != nil { + var word [8]byte + if _, _err := io.ReadFull(conn, word[:]); _err != nil { log.Println(_err) if _err = conn.Close(); _err != nil { log.Println(_err) } return } - flags := binary.LittleEndian.Uint64(flagsWire[:]) + flags := binary.LittleEndian.Uint64(word[:]) + + if flags == remoteSpecial { + if _, _err := io.ReadFull(conn, word[:]); _err != nil { + log.Println(_err) + if _err = conn.Close(); _err != nil { + log.Println(_err) + } + return + } + switch special := binary.LittleEndian.Uint64(word[:]); special { + default: + log.Printf("invalid special %d", special) + + case specialCancel: + if id, ok, _err := cancelIdent(cm.c, conn); _err != nil { + log.Println(_err) + } else if !ok { + log.Println( + "attempting to cancel invalid artifact", + pkg.Encode(*id), + ) + } else { + log.Println( + "cancelled artifact", + pkg.Encode(*id), + ) + } + } + + return + } if a, _err := cureFromIR(cm.c, conn, flags); _err != nil { log.Println(_err) @@ -133,6 +214,31 @@ func serve( return ul.Close() } +// dial wraps [net.DialUnix] with a context. +func dial(ctx context.Context, addr *net.UnixAddr) ( + done chan<- struct{}, + conn *net.UnixConn, + err error, +) { + conn, err = net.DialUnix("unix", nil, addr) + if err != nil { + return + } + + d := make(chan struct{}) + done = d + go func() { + select { + case <-ctx.Done(): + _ = conn.SetDeadline(time.Now()) + + case <-d: + return + } + }() + return +} + // cureRemote cures a [pkg.Artifact] on a daemon. func cureRemote( ctx context.Context, @@ -140,11 +246,15 @@ func cureRemote( a pkg.Artifact, flags uint64, ) (*check.Absolute, error) { - conn, err := net.DialUnix("unix", nil, addr) + if flags == remoteSpecial { + return nil, syscall.EINVAL + } + + done, conn, err := dial(ctx, addr) if err != nil { return nil, err } - go func() { <-ctx.Done(); _ = conn.SetDeadline(time.Now()) }() + defer close(done) if n, flagErr := conn.Write(binary.LittleEndian.AppendUint64(nil, flags)); flagErr != nil { return nil, errors.Join(flagErr, conn.Close()) @@ -165,7 +275,9 @@ func cureRemote( payload, recvErr := io.ReadAll(conn) if err = errors.Join(recvErr, conn.Close()); err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { - err = ctx.Err() + if cancelErr := ctx.Err(); cancelErr != nil { + err = cancelErr + } } return nil, err } @@ -178,3 +290,29 @@ func cureRemote( p, err = check.NewAbs(string(payload)) return p, err } + +// cureRemote cancels a [pkg.Artifact] curing on a daemon. +func cancelRemote( + ctx context.Context, + addr *net.UnixAddr, + a pkg.Artifact, +) error { + done, conn, err := dial(ctx, addr) + if err != nil { + return err + } + defer close(done) + + if err = writeSpecialHeader(conn, specialCancel); err != nil { + return errors.Join(err, conn.Close()) + } + + var n int + id := pkg.NewIR().Ident(a).Value() + if n, err = conn.Write(id[:]); err != nil { + return errors.Join(err, conn.Close()) + } else if n != len(id) { + return errors.Join(io.ErrShortWrite, conn.Close()) + } + return conn.Close() +} diff --git a/cmd/mbf/daemon_test.go b/cmd/mbf/daemon_test.go index 6fd95a96..3306eea3 100644 --- a/cmd/mbf/daemon_test.go +++ b/cmd/mbf/daemon_test.go @@ -9,6 +9,8 @@ import ( "net" "os" "path/filepath" + "slices" + "strings" "testing" "time" @@ -104,11 +106,17 @@ func TestDaemon(t *testing.T) { } }() + if err = cancelRemote(ctx, &addr, pkg.NewFile("nonexistent", nil)); err != nil { + t.Fatalf("cancelRemote: error = %v", err) + } + + // keep this last for synchronisation var p *check.Absolute p, err = cureRemote(ctx, &addr, pkg.NewFile("check", []byte{0}), 0) if err != nil { t.Fatalf("cureRemote: error = %v", err) } + cancel() <-done @@ -117,9 +125,17 @@ func TestDaemon(t *testing.T) { t.Errorf("cureRemote: %s, want %s", got, want) } - const wantLog = `daemon: fulfilled artifact fiZf-ZY_Yq6qxJNrHbMiIPYCsGkUiKCRsZrcSELXTqZWtCnESlHmzV5ThhWWGGYG -` - if gotLog := buf.String(); gotLog != wantLog { - t.Errorf("serve: logged\n%s\nwant\n%s", gotLog, wantLog) + wantLog := []string{ + "", + "daemon: attempting to cancel invalid artifact kQm9fmnCmXST1-MMmxzcau2oKZCXXrlZydo4PkeV5hO_2PKfeC8t98hrbV_ZZx_j", + "daemon: fulfilled artifact fiZf-ZY_Yq6qxJNrHbMiIPYCsGkUiKCRsZrcSELXTqZWtCnESlHmzV5ThhWWGGYG", + } + gotLog := strings.Split(buf.String(), "\n") + slices.Sort(gotLog) + if !slices.Equal(gotLog, wantLog) { + t.Errorf( + "serve: logged\n%s\nwant\n%s", + strings.Join(gotLog, "\n"), strings.Join(wantLog, "\n"), + ) } } diff --git a/cmd/mbf/main.go b/cmd/mbf/main.go index bee97b01..4369565a 100644 --- a/cmd/mbf/main.go +++ b/cmd/mbf/main.go @@ -479,10 +479,21 @@ func main() { if flagNoReply { flags |= remoteNoReply } - pathname, err := cureRemote(ctx, &addr, rosa.Std.Load(p), flags) + a := rosa.Std.Load(p) + pathname, err := cureRemote(ctx, &addr, a, flags) if !flagNoReply && err == nil { log.Println(pathname) } + + if errors.Is(err, context.Canceled) { + cc, cancel := context.WithDeadline(context.Background(), daemonDeadline()) + defer cancel() + + if _err := cancelRemote(cc, &addr, a); _err != nil { + log.Println(err) + } + } + return err } },