diff --git a/container/container.go b/container/container.go index 941b663..09360df 100644 --- a/container/container.go +++ b/container/container.go @@ -21,6 +21,10 @@ const ( // Nonexistent is a path that cannot exist. // /proc is chosen because a system with covered /proc is unsupported by this package. Nonexistent = "/proc/nonexistent" + + // CancelSignal is the signal expected by container init on context cancel. + // A custom [Container.Cancel] function must eventually deliver this signal. + CancelSignal = SIGTERM ) type ( @@ -62,6 +66,8 @@ type ( Path string // Initial process argv. Args []string + // Deliver SIGINT to the initial process on context cancellation. + ForwardCancel bool // Mapped Uid in user namespace. Uid int @@ -129,7 +135,7 @@ func (p *Container) Start() error { if p.Cancel != nil { p.cmd.Cancel = func() error { return p.Cancel(p.cmd) } } else { - p.cmd.Cancel = func() error { return p.cmd.Process.Signal(SIGTERM) } + p.cmd.Cancel = func() error { return p.cmd.Process.Signal(CancelSignal) } } p.cmd.Dir = "/" p.cmd.SysProcAttr = &SysProcAttr{ @@ -226,6 +232,14 @@ func (p *Container) String() string { p.Args, !p.SeccompDisable, len(p.SeccompRules), int(p.SeccompFlags), int(p.SeccompPresets)) } +// ProcessState returns the address to os.ProcessState held by the underlying [exec.Cmd]. +func (p *Container) ProcessState() *os.ProcessState { + if p.cmd == nil { + return nil + } + return p.cmd.ProcessState +} + func New(ctx context.Context, name string, args ...string) *Container { return &Container{name: name, ctx: ctx, Params: Params{Args: append([]string{name}, args...), Dir: "/", Ops: new(Ops)}, diff --git a/container/container_test.go b/container/container_test.go index 8e59e4e..190fa5b 100644 --- a/container/container_test.go +++ b/container/container_test.go @@ -8,6 +8,8 @@ import ( "fmt" "log" "os" + "os/exec" + "os/signal" "strconv" "strings" "syscall" @@ -90,41 +92,32 @@ func TestContainer(t *testing.T) { t.Cleanup(func() { container.SetOutput(oldOutput) }) } - t.Run("cancel", func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), helperDefaultTimeout) - - c := helperNewContainer(ctx, "block") - c.Stdout, c.Stderr = os.Stdout, os.Stderr - c.WaitDelay = helperDefaultTimeout - - ready := make(chan struct{}) - if r, w, err := os.Pipe(); err != nil { - t.Fatalf("cannot pipe: %v", err) - } else { - c.ExtraFiles = append(c.ExtraFiles, w) - go func() { - defer close(ready) - if _, err = r.Read(make([]byte, 1)); err != nil { - panic(err.Error()) - } - }() - } - - if err := c.Start(); err != nil { - hlog.PrintBaseError(err, "start:") - t.Fatalf("cannot start container: %v", err) - } else if err = c.Serve(); err != nil { - hlog.PrintBaseError(err, "serve:") - t.Errorf("cannot serve setup params: %v", err) - } - <-ready - cancel() + t.Run("cancel", testContainerCancel(nil, func(t *testing.T, c *container.Container) { wantErr := context.Canceled + wantExitCode := 0 if err := c.Wait(); !errors.Is(err, wantErr) { hlog.PrintBaseError(err, "wait:") - t.Fatalf("Wait: error = %v, want %v", err, wantErr) + t.Errorf("Wait: error = %v, want %v", err, wantErr) } - }) + if ps := c.ProcessState(); ps == nil { + t.Errorf("ProcessState unexpectedly returned nil") + } else if code := ps.ExitCode(); code != wantExitCode { + t.Errorf("ExitCode: %d, want %d", code, wantExitCode) + } + })) + + t.Run("forward", testContainerCancel(func(c *container.Container) { + c.ForwardCancel = true + }, func(t *testing.T, c *container.Container) { + var exitError *exec.ExitError + if err := c.Wait(); !errors.As(err, &exitError) { + hlog.PrintBaseError(err, "wait:") + t.Errorf("Wait: error = %v", err) + } + if code := exitError.ExitCode(); code != blockExitCodeInterrupt { + t.Errorf("ExitCode: %d, want %d", code, blockExitCodeInterrupt) + } + })) for i, tc := range containerTestCases { t.Run(tc.name, func(t *testing.T) { @@ -214,6 +207,46 @@ func hostnameFromTestCase(name string) string { return "test-" + strings.Join(strings.Fields(name), "-") } +func testContainerCancel( + containerExtra func(c *container.Container), + waitCheck func(t *testing.T, c *container.Container), +) func(t *testing.T) { + return func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), helperDefaultTimeout) + + c := helperNewContainer(ctx, "block") + c.Stdout, c.Stderr = os.Stdout, os.Stderr + c.WaitDelay = helperDefaultTimeout + if containerExtra != nil { + containerExtra(c) + } + + ready := make(chan struct{}) + if r, w, err := os.Pipe(); err != nil { + t.Fatalf("cannot pipe: %v", err) + } else { + c.ExtraFiles = append(c.ExtraFiles, w) + go func() { + defer close(ready) + if _, err = r.Read(make([]byte, 1)); err != nil { + panic(err.Error()) + } + }() + } + + if err := c.Start(); err != nil { + hlog.PrintBaseError(err, "start:") + t.Fatalf("cannot start container: %v", err) + } else if err = c.Serve(); err != nil { + hlog.PrintBaseError(err, "serve:") + t.Errorf("cannot serve setup params: %v", err) + } + <-ready + cancel() + waitCheck(t, c) + } +} + func TestContainerString(t *testing.T) { c := container.New(t.Context(), "ldd", "/usr/bin/env") c.SeccompFlags |= seccomp.AllowMultiarch @@ -227,12 +260,21 @@ func TestContainerString(t *testing.T) { } } +const ( + blockExitCodeInterrupt = 2 +) + func init() { helperCommands = append(helperCommands, func(c command.Command) { c.Command("block", command.UsageInternal, func(args []string) error { if _, err := os.NewFile(3, "sync").Write([]byte{0}); err != nil { return fmt.Errorf("write to sync pipe: %v", err) } + { + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + go func() { <-sig; os.Exit(blockExitCodeInterrupt) }() + } select {} }) diff --git a/container/init.go b/container/init.go index 757af58..9c4f810 100644 --- a/container/init.go +++ b/container/init.go @@ -277,7 +277,7 @@ func Init(prepare func(prefix string), setVerbose func(verbose bool)) { msg.Suspend() if err := closeSetup(); err != nil { - log.Println("cannot close setup pipe:", err) + log.Printf("cannot close setup pipe: %v", err) // not fatal } @@ -311,7 +311,7 @@ func Init(prepare func(prefix string), setVerbose func(verbose bool)) { } } if !errors.Is(err, ECHILD) { - log.Println("unexpected wait4 response:", err) + log.Printf("unexpected wait4 response: %v", err) } close(done) @@ -319,7 +319,7 @@ func Init(prepare func(prefix string), setVerbose func(verbose bool)) { // handle signals to dump withheld messages sig := make(chan os.Signal, 2) - signal.Notify(sig, SIGINT, SIGTERM) + signal.Notify(sig, os.Interrupt, CancelSignal) // closed after residualProcessTimeout has elapsed after initial process death timeout := make(chan struct{}) @@ -329,9 +329,16 @@ func Init(prepare func(prefix string), setVerbose func(verbose bool)) { select { case s := <-sig: if msg.Resume() { - msg.Verbosef("terminating on %s after process start", s.String()) + msg.Verbosef("%s after process start", s.String()) } else { - msg.Verbosef("terminating on %s", s.String()) + msg.Verbosef("got %s", s.String()) + } + if s == CancelSignal && params.ForwardCancel && cmd.Process != nil { + msg.Verbose("forwarding context cancellation") + if err := cmd.Process.Signal(os.Interrupt); err != nil { + log.Printf("cannot forward cancellation: %v", err) + } + continue } os.Exit(0) case w := <-info: @@ -351,10 +358,7 @@ func Init(prepare func(prefix string), setVerbose func(verbose bool)) { msg.Verbosef("initial process exited with status %#x", w.wstatus) } - go func() { - time.Sleep(residualProcessTimeout) - close(timeout) - }() + go func() { time.Sleep(residualProcessTimeout); close(timeout) }() } case <-done: msg.BeforeExit()