package main import ( "context" "errors" "io" "log" "net" "os" "sync" "testing" "time" "hakurei.app/check" "hakurei.app/internal/pkg" ) // daemonTimeout is the maximum amount of time cureFromIR will wait on I/O. const daemonTimeout = 30 * time.Second // daemonDeadline returns the deadline corresponding to daemonTimeout, or the // zero value when running in a test. func daemonDeadline() time.Time { if testing.Testing() { return time.Time{} } return time.Now().Add(daemonTimeout) } // cureFromIR services an IR curing request. func cureFromIR( ctx context.Context, cache *pkg.Cache, conn net.Conn, ) (pkg.Artifact, error) { go func() { <-ctx.Done() _ = conn.SetDeadline(time.Now()) }() if err := conn.SetReadDeadline(daemonDeadline()); err != nil { return nil, errors.Join(err, conn.Close()) } a, decodeErr := cache.NewDecoder(conn).Decode() if decodeErr != nil { _, err := conn.Write([]byte("\x00" + decodeErr.Error())) return nil, errors.Join(decodeErr, err, conn.Close()) } pathname, _, cureErr := cache.Cure(a) if err := conn.SetWriteDeadline(daemonDeadline()); err != nil { if !testing.Testing() || !errors.Is(err, io.ErrClosedPipe) { return a, errors.Join(err, conn.Close()) } } if cureErr != nil { _, err := conn.Write([]byte("\x00" + cureErr.Error())) return a, errors.Join(cureErr, err, conn.Close()) } _, err := conn.Write([]byte(pathname.String())) if testing.Testing() && errors.Is(err, io.ErrClosedPipe) { return a, nil } return a, errors.Join(err, conn.Close()) } // serve services connections from a [net.UnixListener]. func serve(ctx context.Context, log *log.Logger, cm *cache, ul *net.UnixListener) error { ul.SetUnlinkOnClose(true) if cm.c == nil { if err := cm.open(); err != nil { return errors.Join(err, ul.Close()) } } var wg sync.WaitGroup defer wg.Wait() wg.Go(func() { for { if ctx.Err() != nil { break } conn, err := ul.AcceptUnix() if err != nil { if !errors.Is(err, os.ErrDeadlineExceeded) { log.Println(err) } continue } wg.Go(func() { if a, _err := cureFromIR(ctx, cm.c, conn); _err != nil { log.Println(_err) } else { log.Printf( "fulfilled artifact %s", pkg.Encode(cm.c.Ident(a).Value()), ) } }) } }) <-ctx.Done() if err := ul.SetDeadline(time.Now()); err != nil { return errors.Join(err, ul.Close()) } wg.Wait() return ul.Close() } // cureRemote cures a [pkg.Artifact] on a daemon. func cureRemote( ctx context.Context, addr *net.UnixAddr, a pkg.Artifact, ) (*check.Absolute, error) { conn, err := net.DialUnix("unix", nil, addr) if err != nil { return nil, err } go func() { <-ctx.Done() _ = conn.SetDeadline(time.Now()) }() if err = pkg.NewIR().EncodeAll(conn, a); err != nil { return nil, errors.Join(err, conn.Close()) } else if err = conn.CloseWrite(); err != nil { return nil, errors.Join(err, conn.Close()) } payload, recvErr := io.ReadAll(conn) if err = errors.Join(recvErr, conn.Close()); err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { err = ctx.Err() } return nil, err } if len(payload) > 0 && payload[0] == 0 { return nil, errors.New(string(payload[1:])) } var p *check.Absolute p, err = check.NewAbs(string(payload)) return p, err }