diff --git a/container/dispatcher_test.go b/container/dispatcher_test.go index a4fa3c4..3a806ce 100644 --- a/container/dispatcher_test.go +++ b/container/dispatcher_test.go @@ -149,7 +149,7 @@ func checkSimple(t *testing.T, fname string, testCases []simpleTestCase) { t.Helper() k := &kstub{stub.New(t, func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{s} }, tc.want)} - defer k.HandleExit() + defer stub.HandleExit(t) if err := tc.f(k); !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%s: error = %v, want %v", fname, err, tc.wantErr) } @@ -189,7 +189,7 @@ func checkOpBehaviour(t *testing.T, testCases []opBehaviourTestCase) { func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{s} }, stub.Expect{Calls: slices.Concat(tc.early, []stub.Call{{Name: stub.CallSeparator}}, tc.apply)}, )} - defer k.HandleExit() + defer stub.HandleExit(t) errEarly := tc.op.early(state, k) k.Expects(stub.CallSeparator) if !reflect.DeepEqual(errEarly, tc.wantErrEarly) { diff --git a/container/stub/exit.go b/container/stub/exit.go index 0f44b5d..6470fe6 100644 --- a/container/stub/exit.go +++ b/container/stub/exit.go @@ -12,20 +12,13 @@ const ( ) // HandleExit must be deferred before calling with the stub. -func (s *Stub[K]) HandleExit() { handleExit(s.TB, true) } - -func handleExit(t testing.TB, root bool) { +func HandleExit(t testing.TB) { switch r := recover(); r { case PanicExit: break case panicFailNow: - if root { - t.FailNow() - } else { - t.Fail() - } - break + t.FailNow() case panicFatal, panicFatalf, nil: break @@ -34,3 +27,18 @@ func handleExit(t testing.TB, root bool) { panic(r) } } + +// handleExitNew handles exits from goroutines created by [Stub.New]. +func handleExitNew(t testing.TB) { + switch r := recover(); r { + case PanicExit, panicFatal, panicFatalf, nil: + break + + case panicFailNow: + t.Fail() + break + + default: + panic(r) + } +} diff --git a/container/stub/exit_test.go b/container/stub/exit_test.go index 02c4af8..4e935ff 100644 --- a/container/stub/exit_test.go +++ b/container/stub/exit_test.go @@ -7,8 +7,8 @@ import ( "hakurei.app/container/stub" ) -//go:linkname handleExit hakurei.app/container/stub.handleExit -func handleExit(_ testing.TB, _ bool) +//go:linkname handleExitNew hakurei.app/container/stub.handleExitNew +func handleExitNew(_ testing.TB) // overrideTFailNow overrides the Fail and FailNow method. type overrideTFailNow struct { @@ -33,7 +33,7 @@ func (o *overrideTFailNow) Fail() { func TestHandleExit(t *testing.T) { t.Run("exit", func(t *testing.T) { - defer handleExit(t, true) + defer stub.HandleExit(t) panic(stub.PanicExit) }) @@ -45,7 +45,7 @@ func TestHandleExit(t *testing.T) { t.Errorf("FailNow was never called") } }() - defer handleExit(ot, true) + defer stub.HandleExit(ot) panic(0xcafe0000) }) @@ -56,24 +56,38 @@ func TestHandleExit(t *testing.T) { t.Errorf("Fail was never called") } }() - defer handleExit(ot, false) + defer handleExitNew(ot) panic(0xcafe0000) }) }) t.Run("nil", func(t *testing.T) { - defer handleExit(t, true) + defer stub.HandleExit(t) }) t.Run("passthrough", func(t *testing.T) { - defer func() { - want := 0xcafebabe - if r := recover(); r != want { - t.Errorf("recover: %v, want %v", r, want) - } + t.Run("toplevel", func(t *testing.T) { + defer func() { + want := 0xcafebabe + if r := recover(); r != want { + t.Errorf("recover: %v, want %v", r, want) + } - }() - defer handleExit(t, true) - panic(0xcafebabe) + }() + defer stub.HandleExit(t) + panic(0xcafebabe) + }) + + t.Run("new", func(t *testing.T) { + defer func() { + want := 0xcafe + if r := recover(); r != want { + t.Errorf("recover: %v, want %v", r, want) + } + + }() + defer handleExitNew(t) + panic(0xcafe) + }) }) } diff --git a/container/stub/stub.go b/container/stub/stub.go index e0ade3b..a536892 100644 --- a/container/stub/stub.go +++ b/container/stub/stub.go @@ -68,7 +68,7 @@ func (s *Stub[K]) New(f func(k K)) { s.Helper() defer s.wg.Done() - defer handleExit(s.TB, false) + defer handleExitNew(s.TB) f(s.makeK(ds)) }() } diff --git a/container/stub/stub_test.go b/container/stub/stub_test.go index d5defbb..c3d79a4 100644 --- a/container/stub/stub_test.go +++ b/container/stub/stub_test.go @@ -118,7 +118,7 @@ func TestStub(t *testing.T) { {"New", ExpectArgs{}, nil, nil}, {"panic", ExpectArgs{"unreachable"}, nil, nil}, }}) - func() { defer s.HandleExit(); s.New(func(k stubHolder) { panic("unreachable") }) }() + func() { defer HandleExit(t); s.New(func(k stubHolder) { panic("unreachable") }) }() var visit int s.VisitIncomplete(func(s *Stub[stubHolder]) { @@ -139,7 +139,7 @@ func TestStub(t *testing.T) { ot := &overrideT{T: t} ot.error.Store(checkError(t, "Expects: advancing beyond expected calls")) s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{}) - func() { defer s.HandleExit(); s.Expects("unreachable") }() + func() { defer HandleExit(t); s.Expects("unreachable") }() }) t.Run("separator", func(t *testing.T) { @@ -149,7 +149,7 @@ func TestStub(t *testing.T) { s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ {CallSeparator, ExpectArgs{}, nil, nil}, }}) - func() { defer s.HandleExit(); s.Expects("meow") }() + func() { defer HandleExit(t); s.Expects("meow") }() }) t.Run("mismatch", func(t *testing.T) { @@ -158,7 +158,7 @@ func TestStub(t *testing.T) { s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ {"panic", ExpectArgs{}, nil, nil}, }}) - func() { defer s.HandleExit(); s.Expects(CallSeparator) }() + func() { defer HandleExit(t); s.Expects(CallSeparator) }() }) }) @@ -168,7 +168,7 @@ func TestStub(t *testing.T) { s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ {"nya", ExpectArgs{}, nil, nil}, }}) - func() { defer s.HandleExit(); s.Expects("meow") }() + func() { defer HandleExit(t); s.Expects("meow") }() }) }) }) @@ -198,7 +198,7 @@ func TestCheckArg(t *testing.T) { } }) t.Run("mismatch", func(t *testing.T) { - defer s.HandleExit() + defer HandleExit(t) s.Expects("meow") ot.errorf.Store(checkErrorf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1)) if CheckArg(s, "time", 0, 0) { @@ -241,7 +241,7 @@ func TestCheckArgReflect(t *testing.T) { } }) t.Run("mismatch", func(t *testing.T) { - defer s.HandleExit() + defer HandleExit(t) s.Expects("meow") ot.errorf.Store(checkErrorf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1)) if CheckArgReflect(s, "time", 0, 0) { diff --git a/system/dispatcher_test.go b/system/dispatcher_test.go index 7fbb75c..53d2b66 100644 --- a/system/dispatcher_test.go +++ b/system/dispatcher_test.go @@ -44,7 +44,7 @@ func checkOpBehaviour(t *testing.T, testCases []opBehaviourTestCase) { } sys, s := InternalNew(t, stub.Expect{Calls: slices.Concat(tc.apply, []stub.Call{{Name: stub.CallSeparator}}, tc.revert)}, tc.uid) - defer s.HandleExit() + defer stub.HandleExit(t) errApply := tc.op.apply(sys) s.Expects(stub.CallSeparator) if !reflect.DeepEqual(errApply, tc.wantErrApply) { @@ -91,7 +91,7 @@ func checkOpsBuilder(t *testing.T, fname string, testCases []opsBuilderTestCase) t.Helper() sys, s := InternalNew(t, tc.exp, tc.uid) - defer s.HandleExit() + defer stub.HandleExit(t) tc.f(sys) s.VisitIncomplete(func(s *stub.Stub[syscallDispatcher]) { t.Helper()