From 0166833431633f7acc8b0d35c4451494eff8710a Mon Sep 17 00:00:00 2001 From: Ophestra Date: Sat, 23 Aug 2025 21:47:06 +0900 Subject: [PATCH] container/dispatcher: start goroutine in dispatcher This allows instrumentation of calls from goroutine without relying on finalizers. Signed-off-by: Ophestra --- container/dispatcher.go | 6 +++--- container/dispatcher_test.go | 17 ++++++++++++----- container/init.go | 12 ++++++++++-- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/container/dispatcher.go b/container/dispatcher.go index c7a8bad..ae697ce 100644 --- a/container/dispatcher.go +++ b/container/dispatcher.go @@ -22,10 +22,10 @@ type osFile interface { // syscallDispatcher provides methods that make state-dependent system calls as part of their behaviour. type syscallDispatcher interface { - // new returns a new instance of syscallDispatcher for use in another goroutine. + // new starts a goroutine with a new instance of syscallDispatcher. // A syscallDispatcher must never be used in any goroutine other than the one owning it, // just synchronising access is not enough, as this is for test instrumentation. - new() syscallDispatcher + new(f func(k syscallDispatcher)) // lockOSThread provides [runtime.LockOSThread]. lockOSThread() @@ -145,7 +145,7 @@ type syscallDispatcher interface { // direct implements syscallDispatcher on the current kernel. type direct struct{} -func (k direct) new() syscallDispatcher { return k } +func (k direct) new(f func(k syscallDispatcher)) { go f(k) } func (direct) lockOSThread() { runtime.LockOSThread() } diff --git a/container/dispatcher_test.go b/container/dispatcher_test.go index bfb38cd..e369d0f 100644 --- a/container/dispatcher_test.go +++ b/container/dispatcher_test.go @@ -12,6 +12,7 @@ import ( "runtime" "slices" "strings" + "sync" "syscall" "testing" "time" @@ -113,7 +114,7 @@ type simpleTestCase struct { func checkSimple(t *testing.T, fname string, testCases []simpleTestCase) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - k := &kstub{t: t, want: tc.want} + k := &kstub{t: t, want: tc.want, wg: new(sync.WaitGroup)} if err := tc.f(k); !errors.Is(err, tc.wantErr) { t.Errorf("%s: error = %v, want %v", fname, err, tc.wantErr) } @@ -141,7 +142,7 @@ func checkOpBehaviour(t *testing.T, testCases []opBehaviourTestCase) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { state := &setupState{Params: tc.params} - k := &kstub{t: t, want: [][]kexpect{slices.Concat(tc.early, []kexpect{{name: "\x00"}}, tc.apply)}} + k := &kstub{t: t, want: [][]kexpect{slices.Concat(tc.early, []kexpect{{name: "\x00"}}, tc.apply)}, wg: new(sync.WaitGroup)} errEarly := tc.op.early(state, k) k.expect("\x00") if !errors.Is(errEarly, tc.wantErrEarly) { @@ -272,10 +273,14 @@ type kstub struct { track int // sub stores addresses of kstub created by new. sub []*kstub + // wg waits for all descendants to complete. + wg *sync.WaitGroup } // handleIncomplete calls f on an incomplete k and all its descendants. func (k *kstub) handleIncomplete(f func(k *kstub)) { + k.wg.Wait() + if k.want != nil && len(k.want[k.track]) != k.pos { f(k) } @@ -331,13 +336,15 @@ func checkArgReflect(k *kstub, arg string, got any, n int) bool { return true } -func (k *kstub) new() syscallDispatcher { +func (k *kstub) new(f func(k syscallDispatcher)) { k.expect("new") if len(k.want) <= k.track+1 { k.t.Fatalf("new: track overrun") } - k.sub = append(k.sub, &kstub{t: k.t, want: k.want, track: k.track + 1}) - return k.sub[len(k.sub)-1] + sk := &kstub{t: k.t, want: k.want, track: len(k.sub) + 1, wg: k.wg} + k.sub = append(k.sub, sk) + k.wg.Add(1) + go func() { defer k.wg.Done(); f(sk) }() } func (k *kstub) lockOSThread() { k.expect("lockOSThread") } diff --git a/container/init.go b/container/init.go index 6bb5a9e..44e89c7 100644 --- a/container/init.go +++ b/container/init.go @@ -178,6 +178,7 @@ func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setV fmt.Sprintf("cannot prepare op at index %d:", i)) k.beforeExit() k.exit(1) + return } } @@ -218,6 +219,7 @@ func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setV fmt.Sprintf("cannot apply op at index %d:", i)) k.beforeExit() k.exit(1) + return } } @@ -333,7 +335,7 @@ func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setV info := make(chan winfo, 1) done := make(chan struct{}) - go func(k syscallDispatcher) { + k.new(func(k syscallDispatcher) { var ( err error wpid = -2 @@ -360,7 +362,7 @@ func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setV } close(done) - }(k.new()) + }) // handle signals to dump withheld messages sig := make(chan os.Signal, 2) @@ -385,7 +387,9 @@ func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setV } continue } + k.beforeExit() k.exit(0) + return case w := <-info: if w.wpid == cmd.Process.Pid { @@ -396,9 +400,11 @@ func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setV case w.wstatus.Exited(): r = w.wstatus.ExitStatus() k.verbosef("initial process exited with code %d", w.wstatus.ExitStatus()) + case w.wstatus.Signaled(): r = 128 + int(w.wstatus.Signal()) k.verbosef("initial process exited with signal %s", w.wstatus.Signal()) + default: r = 255 k.verbosef("initial process exited with status %#x", w.wstatus) @@ -410,11 +416,13 @@ func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setV case <-done: k.beforeExit() k.exit(r) + return case <-timeout: k.printf("timeout exceeded waiting for lingering processes") k.beforeExit() k.exit(r) + return } } }