package main import ( "context" "encoding/binary" "errors" "io" "log" "math" "net" "os" "sync" "syscall" "testing" "time" "unique" "hakurei.app/check" "hakurei.app/internal/pkg" ) // daemonTimeout is the maximum amount of time cureFromIR will wait on I/O. const daemonTimeout = 30 * time.Second // daemonDeadline returns the deadline corresponding to daemonTimeout, or the // zero value when running in a test. func daemonDeadline() time.Time { if testing.Testing() { return time.Time{} } return time.Now().Add(daemonTimeout) } const ( // remoteNoReply notifies that the client will not receive a cure reply. remoteNoReply = 1 << iota ) // cureFromIR services an IR curing request. func cureFromIR( cache *pkg.Cache, conn net.Conn, flags uint64, ) (pkg.Artifact, error) { a, decodeErr := cache.NewDecoder(conn).Decode() if decodeErr != nil { _, err := conn.Write([]byte("\x00" + decodeErr.Error())) return nil, errors.Join(decodeErr, err, conn.Close()) } pathname, _, cureErr := cache.Cure(a) if flags&remoteNoReply != 0 { return a, errors.Join(cureErr, conn.Close()) } if err := conn.SetWriteDeadline(daemonDeadline()); err != nil { return a, errors.Join(cureErr, err, conn.Close()) } if cureErr != nil { _, err := conn.Write([]byte("\x00" + cureErr.Error())) return a, errors.Join(cureErr, err, conn.Close()) } _, err := conn.Write([]byte(pathname.String())) if testing.Testing() && errors.Is(err, io.ErrClosedPipe) { return a, nil } return a, errors.Join(err, conn.Close()) } const ( // specialCancel is a message consisting of a single identifier referring // to a curing artifact to be cancelled. specialCancel = iota // specialAbort requests for all pending cures to be aborted. It has no // message body. specialAbort // remoteSpecial denotes a special message with custom layout. remoteSpecial = math.MaxUint64 ) // writeSpecialHeader writes the header of a remoteSpecial message. func writeSpecialHeader(conn net.Conn, kind uint64) error { var sh [16]byte binary.LittleEndian.PutUint64(sh[:], remoteSpecial) binary.LittleEndian.PutUint64(sh[8:], kind) if n, err := conn.Write(sh[:]); err != nil { return err } else if n != len(sh) { return io.ErrShortWrite } return nil } // cancelIdent reads an identifier from conn and cancels the corresponding cure. func cancelIdent( cache *pkg.Cache, conn net.Conn, ) (*pkg.ID, bool, error) { var ident pkg.ID if _, err := io.ReadFull(conn, ident[:]); err != nil { return nil, false, errors.Join(err, conn.Close()) } else if err = conn.Close(); err != nil { return nil, false, err } return &ident, cache.Cancel(unique.Make(ident)), nil } // serve services connections from a [net.UnixListener]. func serve( ctx context.Context, log *log.Logger, cm *cache, ul *net.UnixListener, ) error { ul.SetUnlinkOnClose(true) if cm.c == nil { if err := cm.open(); err != nil { return errors.Join(err, ul.Close()) } } var wg sync.WaitGroup defer wg.Wait() wg.Go(func() { for { if ctx.Err() != nil { break } conn, err := ul.AcceptUnix() if err != nil { if !errors.Is(err, os.ErrDeadlineExceeded) { log.Println(err) } continue } wg.Go(func() { done := make(chan struct{}) defer close(done) go func() { select { case <-ctx.Done(): _ = conn.SetDeadline(time.Now()) case <-done: return } }() if _err := conn.SetReadDeadline(daemonDeadline()); _err != nil { log.Println(_err) if _err = conn.Close(); _err != nil { log.Println(_err) } return } var word [8]byte if _, _err := io.ReadFull(conn, word[:]); _err != nil { log.Println(_err) if _err = conn.Close(); _err != nil { log.Println(_err) } return } flags := binary.LittleEndian.Uint64(word[:]) if flags == remoteSpecial { if _, _err := io.ReadFull(conn, word[:]); _err != nil { log.Println(_err) if _err = conn.Close(); _err != nil { log.Println(_err) } return } switch special := binary.LittleEndian.Uint64(word[:]); special { default: log.Printf("invalid special %d", special) case specialCancel: if id, ok, _err := cancelIdent(cm.c, conn); _err != nil { log.Println(_err) } else if !ok { log.Println( "attempting to cancel invalid artifact", pkg.Encode(*id), ) } else { log.Println( "cancelled artifact", pkg.Encode(*id), ) } case specialAbort: if _err := conn.Close(); _err != nil { log.Println(_err) } log.Println("aborting all pending cures") cm.c.Abort() } return } if a, _err := cureFromIR(cm.c, conn, flags); _err != nil { log.Println(_err) } else { log.Printf( "fulfilled artifact %s", pkg.Encode(cm.c.Ident(a).Value()), ) } }) } }) <-ctx.Done() if err := ul.SetDeadline(time.Now()); err != nil { return errors.Join(err, ul.Close()) } wg.Wait() return ul.Close() } // dial wraps [net.DialUnix] with a context. func dial(ctx context.Context, addr *net.UnixAddr) ( done chan<- struct{}, conn *net.UnixConn, err error, ) { conn, err = net.DialUnix("unix", nil, addr) if err != nil { return } d := make(chan struct{}) done = d go func() { select { case <-ctx.Done(): _ = conn.SetDeadline(time.Now()) case <-d: return } }() return } // cureRemote cures a [pkg.Artifact] on a daemon. func cureRemote( ctx context.Context, addr *net.UnixAddr, a pkg.Artifact, flags uint64, ) (*check.Absolute, error) { if flags == remoteSpecial { return nil, syscall.EINVAL } done, conn, err := dial(ctx, addr) if err != nil { return nil, err } defer close(done) if n, flagErr := conn.Write(binary.LittleEndian.AppendUint64(nil, flags)); flagErr != nil { return nil, errors.Join(flagErr, conn.Close()) } else if n != 8 { return nil, errors.Join(io.ErrShortWrite, conn.Close()) } if err = pkg.NewIR().EncodeAll(conn, a); err != nil { return nil, errors.Join(err, conn.Close()) } else if err = conn.CloseWrite(); err != nil { return nil, errors.Join(err, conn.Close()) } if flags&remoteNoReply != 0 { return nil, conn.Close() } payload, recvErr := io.ReadAll(conn) if err = errors.Join(recvErr, conn.Close()); err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { if cancelErr := ctx.Err(); cancelErr != nil { err = cancelErr } } return nil, err } if len(payload) > 0 && payload[0] == 0 { return nil, errors.New(string(payload[1:])) } var p *check.Absolute p, err = check.NewAbs(string(payload)) return p, err } // cancelRemote cancels a [pkg.Artifact] curing on a daemon. func cancelRemote( ctx context.Context, addr *net.UnixAddr, a pkg.Artifact, ) error { done, conn, err := dial(ctx, addr) if err != nil { return err } defer close(done) if err = writeSpecialHeader(conn, specialCancel); err != nil { return errors.Join(err, conn.Close()) } var n int id := pkg.NewIR().Ident(a).Value() if n, err = conn.Write(id[:]); err != nil { return errors.Join(err, conn.Close()) } else if n != len(id) { return errors.Join(io.ErrShortWrite, conn.Close()) } return conn.Close() } // abortRemote aborts all [pkg.Artifact] curing on a daemon. func abortRemote( ctx context.Context, addr *net.UnixAddr, ) error { done, conn, err := dial(ctx, addr) if err != nil { return err } defer close(done) err = writeSpecialHeader(conn, specialAbort) return errors.Join(err, conn.Close()) }