diff --git a/container/dispatcher_test.go b/container/dispatcher_test.go index 3a806ce..ad59503 100644 --- a/container/dispatcher_test.go +++ b/container/dispatcher_test.go @@ -148,7 +148,8 @@ func checkSimple(t *testing.T, fname string, testCases []simpleTestCase) { t.Run(tc.name, func(t *testing.T) { t.Helper() - k := &kstub{stub.New(t, func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{s} }, tc.want)} + wait4signal := make(chan struct{}) + k := &kstub{wait4signal, stub.New(t, func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{wait4signal, s} }, tc.want)} defer stub.HandleExit(t) if err := tc.f(k); !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%s: error = %v, want %v", fname, err, tc.wantErr) @@ -185,8 +186,8 @@ func checkOpBehaviour(t *testing.T, testCases []opBehaviourTestCase) { t.Helper() state := &setupState{Params: tc.params} - k := &kstub{stub.New(t, - func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{s} }, + k := &kstub{nil, stub.New(t, + func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{nil, s} }, stub.Expect{Calls: slices.Concat(tc.early, []stub.Call{{Name: stub.CallSeparator}}, tc.apply)}, )} defer stub.HandleExit(t) @@ -296,7 +297,18 @@ func (nameDentry) IsDir() bool { panic("unreachable") } func (nameDentry) Type() fs.FileMode { panic("unreachable") } func (nameDentry) Info() (fs.FileInfo, error) { panic("unreachable") } -type kstub struct{ *stub.Stub[syscallDispatcher] } +const ( + // magicWait4Signal must be used in a single pair of signal and wait4 calls across two goroutines + // originating from the same toplevel kstub. + // To enable this behaviour this value must be the last element of the args field in the wait4 call + // and the ret value of the signal call. + magicWait4Signal = 0xdef +) + +type kstub struct { + wait4signal chan struct{} + *stub.Stub[syscallDispatcher] +} func (k *kstub) new(f func(k syscallDispatcher)) { k.Helper(); k.New(f) } @@ -463,7 +475,14 @@ func (k *kstub) start(c *exec.Cmd) error { func (k *kstub) signal(c *exec.Cmd, sig os.Signal) error { k.Helper() - return k.Expects("signal").Error( + expect := k.Expects("signal") + if v, ok := expect.Ret.(int); ok && v == magicWait4Signal { + if k.wait4signal == nil { + panic("kstub not initialised for wait4 simulation") + } + defer func() { close(k.wait4signal) }() + } + return expect.Error( stub.CheckArg(k.Stub, "c.Path", c.Path, 0), stub.CheckArgReflect(k.Stub, "c.Args", c.Args, 1), stub.CheckArgReflect(k.Stub, "c.Env", c.Env, 2), @@ -648,9 +667,17 @@ func (k *kstub) unmount(target string, flags int) (err error) { func (k *kstub) wait4(pid int, wstatus *syscall.WaitStatus, options int, rusage *syscall.Rusage) (wpid int, err error) { k.Helper() expect := k.Expects("wait4") - // special case to prevent leaking the wait4 goroutine when testing initEntrypoint - if v, ok := expect.Args[4].(int); ok && v == stub.PanicExit { - panic(stub.PanicExit) + if v, ok := expect.Args[4].(int); ok { + switch v { + case stub.PanicExit: // special case to prevent leaking the wait4 goroutine while testing initEntrypoint + panic(stub.PanicExit) + + case magicWait4Signal: // block until corresponding signal call + if k.wait4signal == nil { + panic("kstub not initialised for wait4 simulation") + } + <-k.wait4signal + } } wpid = expect.Ret.(int) diff --git a/container/init_test.go b/container/init_test.go index 99e2976..cc04057 100644 --- a/container/init_test.go +++ b/container/init_test.go @@ -2046,7 +2046,8 @@ func TestInitEntrypoint(t *testing.T) { call("resume", stub.ExpectArgs{}, true, nil), call("verbosef", stub.ExpectArgs{"%s after process start", []any{"terminated"}}, nil, nil), call("verbose", stub.ExpectArgs{[]any{"forwarding context cancellation"}}, nil, nil), - call("signal", stub.ExpectArgs{"/run/current-system/sw/bin/bash", []string{"bash", "-c", "false"}, ([]string)(nil), "/.hakurei/nonexistent", os.Interrupt}, nil, stub.UniqueError(9)), + // magicWait4Signal as ret causes wait4 stub to unblock + call("signal", stub.ExpectArgs{"/run/current-system/sw/bin/bash", []string{"bash", "-c", "false"}, ([]string)(nil), "/.hakurei/nonexistent", os.Interrupt}, magicWait4Signal, stub.UniqueError(9)), call("printf", stub.ExpectArgs{"cannot forward cancellation: %v", []any{stub.UniqueError(9)}}, nil, nil), call("resume", stub.ExpectArgs{}, false, nil), call("verbosef", stub.ExpectArgs{"initial process exited with signal %s", []any{syscall.Signal(0x4e)}}, nil, nil), @@ -2057,9 +2058,10 @@ func TestInitEntrypoint(t *testing.T) { /* wait4 */ Tracks: []stub.Expect{{Calls: []stub.Call{ - call("wait4", stub.ExpectArgs{-1, syscall.WaitStatus(0xfade01ce), 0, nil}, 0xbad, nil), + // magicWait4Signal as args[4] causes this to block until simulated signal is delivered + call("wait4", stub.ExpectArgs{-1, syscall.WaitStatus(0xfade01ce), 0, nil, magicWait4Signal}, 0xbad, nil), // this terminates the goroutine at the call, preventing it from leaking while preserving behaviour - call("wait4", stub.ExpectArgs{-1, nil, 0, nil, 0xdeadbeef}, 0, syscall.ECHILD), + call("wait4", stub.ExpectArgs{-1, nil, 0, nil, stub.PanicExit}, 0, syscall.ECHILD), }}}, }, nil}, @@ -2149,7 +2151,7 @@ func TestInitEntrypoint(t *testing.T) { /* wait4 */ Tracks: []stub.Expect{{Calls: []stub.Call{ // this terminates the goroutine at the call, preventing it from leaking while preserving behaviour - call("wait4", stub.ExpectArgs{-1, nil, 0, nil, 0xdeadbeef}, 0, syscall.ECHILD), + call("wait4", stub.ExpectArgs{-1, nil, 0, nil, stub.PanicExit}, 0, syscall.ECHILD), }}}, }, nil},