1
0
forked from rosa/hakurei

cmd/mbf: cancel remote cure

This exposes the new fine-grained cancel API in cmd/mbf.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-04-17 21:30:00 +09:00
parent ae9b9adfd2
commit 8d657b6fdf
3 changed files with 177 additions and 12 deletions

View File

@@ -6,11 +6,14 @@ import (
"errors"
"io"
"log"
"math"
"net"
"os"
"sync"
"syscall"
"testing"
"time"
"unique"
"hakurei.app/check"
"hakurei.app/internal/pkg"
@@ -63,6 +66,42 @@ func cureFromIR(
return a, errors.Join(err, conn.Close())
}
const (
// specialCancel is a message consisting of a single identifier referring
// to a curing artifact to be cancelled.
specialCancel = iota
// remoteSpecial denotes a special message with custom layout.
remoteSpecial = math.MaxUint64
)
// writeSpecialHeader writes the header of a remoteSpecial message.
func writeSpecialHeader(conn net.Conn, kind uint64) error {
var sh [16]byte
binary.LittleEndian.PutUint64(sh[:], remoteSpecial)
binary.LittleEndian.PutUint64(sh[8:], kind)
if n, err := conn.Write(sh[:]); err != nil {
return err
} else if n != len(sh) {
return io.ErrShortWrite
}
return nil
}
// cancelIdent reads an identifier from conn and cancels the corresponding cure.
func cancelIdent(
cache *pkg.Cache,
conn net.Conn,
) (*pkg.ID, bool, error) {
var ident pkg.ID
if _, err := io.ReadFull(conn, ident[:]); err != nil {
return nil, false, errors.Join(err, conn.Close())
} else if err = conn.Close(); err != nil {
return nil, false, err
}
return &ident, cache.Cancel(unique.Make(ident)), nil
}
// serve services connections from a [net.UnixListener].
func serve(
ctx context.Context,
@@ -94,7 +133,18 @@ func serve(
continue
}
wg.Go(func() {
go func() { <-ctx.Done(); _ = conn.SetDeadline(time.Now()) }()
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
_ = conn.SetDeadline(time.Now())
case <-done:
return
}
}()
if _err := conn.SetReadDeadline(daemonDeadline()); _err != nil {
log.Println(_err)
if _err = conn.Close(); _err != nil {
@@ -103,15 +153,46 @@ func serve(
return
}
var flagsWire [8]byte
if _, _err := io.ReadFull(conn, flagsWire[:]); _err != nil {
var word [8]byte
if _, _err := io.ReadFull(conn, word[:]); _err != nil {
log.Println(_err)
if _err = conn.Close(); _err != nil {
log.Println(_err)
}
return
}
flags := binary.LittleEndian.Uint64(flagsWire[:])
flags := binary.LittleEndian.Uint64(word[:])
if flags == remoteSpecial {
if _, _err := io.ReadFull(conn, word[:]); _err != nil {
log.Println(_err)
if _err = conn.Close(); _err != nil {
log.Println(_err)
}
return
}
switch special := binary.LittleEndian.Uint64(word[:]); special {
default:
log.Printf("invalid special %d", special)
case specialCancel:
if id, ok, _err := cancelIdent(cm.c, conn); _err != nil {
log.Println(_err)
} else if !ok {
log.Println(
"attempting to cancel invalid artifact",
pkg.Encode(*id),
)
} else {
log.Println(
"cancelled artifact",
pkg.Encode(*id),
)
}
}
return
}
if a, _err := cureFromIR(cm.c, conn, flags); _err != nil {
log.Println(_err)
@@ -133,6 +214,31 @@ func serve(
return ul.Close()
}
// dial wraps [net.DialUnix] with a context.
func dial(ctx context.Context, addr *net.UnixAddr) (
done chan<- struct{},
conn *net.UnixConn,
err error,
) {
conn, err = net.DialUnix("unix", nil, addr)
if err != nil {
return
}
d := make(chan struct{})
done = d
go func() {
select {
case <-ctx.Done():
_ = conn.SetDeadline(time.Now())
case <-d:
return
}
}()
return
}
// cureRemote cures a [pkg.Artifact] on a daemon.
func cureRemote(
ctx context.Context,
@@ -140,11 +246,15 @@ func cureRemote(
a pkg.Artifact,
flags uint64,
) (*check.Absolute, error) {
conn, err := net.DialUnix("unix", nil, addr)
if flags == remoteSpecial {
return nil, syscall.EINVAL
}
done, conn, err := dial(ctx, addr)
if err != nil {
return nil, err
}
go func() { <-ctx.Done(); _ = conn.SetDeadline(time.Now()) }()
defer close(done)
if n, flagErr := conn.Write(binary.LittleEndian.AppendUint64(nil, flags)); flagErr != nil {
return nil, errors.Join(flagErr, conn.Close())
@@ -165,7 +275,9 @@ func cureRemote(
payload, recvErr := io.ReadAll(conn)
if err = errors.Join(recvErr, conn.Close()); err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ctx.Err()
if cancelErr := ctx.Err(); cancelErr != nil {
err = cancelErr
}
}
return nil, err
}
@@ -178,3 +290,29 @@ func cureRemote(
p, err = check.NewAbs(string(payload))
return p, err
}
// cureRemote cancels a [pkg.Artifact] curing on a daemon.
func cancelRemote(
ctx context.Context,
addr *net.UnixAddr,
a pkg.Artifact,
) error {
done, conn, err := dial(ctx, addr)
if err != nil {
return err
}
defer close(done)
if err = writeSpecialHeader(conn, specialCancel); err != nil {
return errors.Join(err, conn.Close())
}
var n int
id := pkg.NewIR().Ident(a).Value()
if n, err = conn.Write(id[:]); err != nil {
return errors.Join(err, conn.Close())
} else if n != len(id) {
return errors.Join(io.ErrShortWrite, conn.Close())
}
return conn.Close()
}

View File

@@ -9,6 +9,8 @@ import (
"net"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
@@ -104,11 +106,17 @@ func TestDaemon(t *testing.T) {
}
}()
if err = cancelRemote(ctx, &addr, pkg.NewFile("nonexistent", nil)); err != nil {
t.Fatalf("cancelRemote: error = %v", err)
}
// keep this last for synchronisation
var p *check.Absolute
p, err = cureRemote(ctx, &addr, pkg.NewFile("check", []byte{0}), 0)
if err != nil {
t.Fatalf("cureRemote: error = %v", err)
}
cancel()
<-done
@@ -117,9 +125,17 @@ func TestDaemon(t *testing.T) {
t.Errorf("cureRemote: %s, want %s", got, want)
}
const wantLog = `daemon: fulfilled artifact fiZf-ZY_Yq6qxJNrHbMiIPYCsGkUiKCRsZrcSELXTqZWtCnESlHmzV5ThhWWGGYG
`
if gotLog := buf.String(); gotLog != wantLog {
t.Errorf("serve: logged\n%s\nwant\n%s", gotLog, wantLog)
wantLog := []string{
"",
"daemon: attempting to cancel invalid artifact kQm9fmnCmXST1-MMmxzcau2oKZCXXrlZydo4PkeV5hO_2PKfeC8t98hrbV_ZZx_j",
"daemon: fulfilled artifact fiZf-ZY_Yq6qxJNrHbMiIPYCsGkUiKCRsZrcSELXTqZWtCnESlHmzV5ThhWWGGYG",
}
gotLog := strings.Split(buf.String(), "\n")
slices.Sort(gotLog)
if !slices.Equal(gotLog, wantLog) {
t.Errorf(
"serve: logged\n%s\nwant\n%s",
strings.Join(gotLog, "\n"), strings.Join(wantLog, "\n"),
)
}
}

View File

@@ -479,10 +479,21 @@ func main() {
if flagNoReply {
flags |= remoteNoReply
}
pathname, err := cureRemote(ctx, &addr, rosa.Std.Load(p), flags)
a := rosa.Std.Load(p)
pathname, err := cureRemote(ctx, &addr, a, flags)
if !flagNoReply && err == nil {
log.Println(pathname)
}
if errors.Is(err, context.Canceled) {
cc, cancel := context.WithDeadline(context.Background(), daemonDeadline())
defer cancel()
if _err := cancelRemote(cc, &addr, a); _err != nil {
log.Println(err)
}
}
return err
}
},