diff --git a/container/dispatcher_test.go b/container/dispatcher_test.go index 1880046..20a100b 100644 --- a/container/dispatcher_test.go +++ b/container/dispatcher_test.go @@ -148,8 +148,8 @@ func checkSimple(t *testing.T, fname string, testCases []simpleTestCase) { t.Run(tc.name, func(t *testing.T) { t.Helper() - defer stub.HandleExit() k := &kstub{stub.New(t, func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{s} }, tc.want)} + defer k.HandleExit() if err := tc.f(k); !reflect.DeepEqual(err, tc.wantErr) { t.Errorf("%s: error = %v, want %v", fname, err, tc.wantErr) } @@ -184,12 +184,12 @@ func checkOpBehaviour(t *testing.T, testCases []opBehaviourTestCase) { t.Run(tc.name, func(t *testing.T) { t.Helper() - defer stub.HandleExit() state := &setupState{Params: tc.params} k := &kstub{stub.New(t, func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { return &kstub{s} }, stub.Expect{Calls: slices.Concat(tc.early, []stub.Call{{Name: stub.CallSeparator}}, tc.apply)}, )} + defer k.HandleExit() errEarly := tc.op.early(state, k) k.Expects(stub.CallSeparator) if !reflect.DeepEqual(errEarly, tc.wantErrEarly) { diff --git a/container/stub/exit.go b/container/stub/exit.go index 8d85946..0f44b5d 100644 --- a/container/stub/exit.go +++ b/container/stub/exit.go @@ -1,15 +1,36 @@ package stub +import "testing" + // PanicExit is a magic panic value treated as a simulated exit. const PanicExit = 0xdeadbeef +const ( + panicFailNow = 0xcafe0000 + iota + panicFatal + panicFatalf +) + // HandleExit must be deferred before calling with the stub. -func HandleExit() { - r := recover() - if r == PanicExit { - return - } - if r != nil { +func (s *Stub[K]) HandleExit() { handleExit(s.TB, true) } + +func handleExit(t testing.TB, root bool) { + switch r := recover(); r { + case PanicExit: + break + + case panicFailNow: + if root { + t.FailNow() + } else { + t.Fail() + } + break + + case panicFatal, panicFatalf, nil: + break + + default: panic(r) } } diff --git a/container/stub/exit_test.go b/container/stub/exit_test.go index 5021971..02c4af8 100644 --- a/container/stub/exit_test.go +++ b/container/stub/exit_test.go @@ -2,18 +2,67 @@ package stub_test import ( "testing" + _ "unsafe" "hakurei.app/container/stub" ) +//go:linkname handleExit hakurei.app/container/stub.handleExit +func handleExit(_ testing.TB, _ bool) + +// overrideTFailNow overrides the Fail and FailNow method. +type overrideTFailNow struct { + *testing.T + failNow bool + fail bool +} + +func (o *overrideTFailNow) FailNow() { + if o.failNow { + o.Errorf("attempted to FailNow twice") + } + o.failNow = true +} + +func (o *overrideTFailNow) Fail() { + if o.fail { + o.Errorf("attempted to Fail twice") + } + o.fail = true +} + func TestHandleExit(t *testing.T) { t.Run("exit", func(t *testing.T) { - defer stub.HandleExit() + defer handleExit(t, true) panic(stub.PanicExit) }) + t.Run("goexit", func(t *testing.T) { + t.Run("FailNow", func(t *testing.T) { + ot := &overrideTFailNow{T: t} + defer func() { + if !ot.failNow { + t.Errorf("FailNow was never called") + } + }() + defer handleExit(ot, true) + panic(0xcafe0000) + }) + + t.Run("Fail", func(t *testing.T) { + ot := &overrideTFailNow{T: t} + defer func() { + if !ot.fail { + t.Errorf("Fail was never called") + } + }() + defer handleExit(ot, false) + panic(0xcafe0000) + }) + }) + t.Run("nil", func(t *testing.T) { - defer stub.HandleExit() + defer handleExit(t, true) }) t.Run("passthrough", func(t *testing.T) { @@ -24,7 +73,7 @@ func TestHandleExit(t *testing.T) { } }() - defer stub.HandleExit() + defer handleExit(t, true) panic(0xcafebabe) }) } diff --git a/container/stub/stub.go b/container/stub/stub.go index 942a339..e0ade3b 100644 --- a/container/stub/stub.go +++ b/container/stub/stub.go @@ -45,6 +45,13 @@ func New[K any](tb testing.TB, makeK func(s *Stub[K]) K, want Expect) *Stub[K] { return &Stub[K]{TB: tb, makeK: makeK, want: want, wg: new(sync.WaitGroup)} } +func (s *Stub[K]) FailNow() { panic(panicFailNow) } +func (s *Stub[K]) Fatal(args ...any) { s.Error(args...); panic(panicFatal) } +func (s *Stub[K]) Fatalf(format string, args ...any) { s.Errorf(format, args...); panic(panicFatalf) } +func (s *Stub[K]) SkipNow() { panic("invalid call to SkipNow") } +func (s *Stub[K]) Skip(...any) { panic("invalid call to Skip") } +func (s *Stub[K]) Skipf(string, ...any) { panic("invalid call to Skipf") } + // New calls f in a new goroutine func (s *Stub[K]) New(f func(k K)) { s.Helper() @@ -61,7 +68,7 @@ func (s *Stub[K]) New(f func(k K)) { s.Helper() defer s.wg.Done() - defer HandleExit() + defer handleExit(s.TB, false) f(s.makeK(ds)) }() } diff --git a/container/stub/stub_test.go b/container/stub/stub_test.go index 3333ce7..d5defbb 100644 --- a/container/stub/stub_test.go +++ b/container/stub/stub_test.go @@ -13,29 +13,19 @@ type stubHolder struct{ *Stub[stubHolder] } type overrideT struct { *testing.T - fatal atomic.Pointer[func(args ...any)] - fatalf atomic.Pointer[func(format string, args ...any)] + error atomic.Pointer[func(args ...any)] errorf atomic.Pointer[func(format string, args ...any)] } -func (t *overrideT) Fatal(args ...any) { - fp := t.fatal.Load() +func (t *overrideT) Error(args ...any) { + fp := t.error.Load() if fp == nil || *fp == nil { - t.T.Fatal(args...) + t.T.Error(args...) return } (*fp)(args...) } -func (t *overrideT) Fatalf(format string, args ...any) { - fp := t.fatalf.Load() - if fp == nil || *fp == nil { - t.T.Fatalf(format, args...) - return - } - (*fp)(format, args...) -} - func (t *overrideT) Errorf(format string, args ...any) { fp := t.errorf.Load() if fp == nil || *fp == nil { @@ -46,6 +36,47 @@ func (t *overrideT) Errorf(format string, args ...any) { } func TestStub(t *testing.T) { + t.Run("goexit", func(t *testing.T) { + t.Run("FailNow", func(t *testing.T) { + defer func() { + if r := recover(); r != panicFailNow { + t.Errorf("recover: %v", r) + } + }() + new(stubHolder).FailNow() + }) + + t.Run("SkipNow", func(t *testing.T) { + defer func() { + want := "invalid call to SkipNow" + if r := recover(); r != want { + t.Errorf("recover: %v, want %v", r, want) + } + }() + new(stubHolder).SkipNow() + }) + + t.Run("Skip", func(t *testing.T) { + defer func() { + want := "invalid call to Skip" + if r := recover(); r != want { + t.Errorf("recover: %v, want %v", r, want) + } + }() + new(stubHolder).Skip() + }) + + t.Run("Skipf", func(t *testing.T) { + defer func() { + want := "invalid call to Skipf" + if r := recover(); r != want { + t.Errorf("recover: %v, want %v", r, want) + } + }() + new(stubHolder).Skipf("") + }) + }) + t.Run("new", func(t *testing.T) { t.Run("success", func(t *testing.T) { s := New(t, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ @@ -82,12 +113,12 @@ func TestStub(t *testing.T) { t.Run("overrun", func(t *testing.T) { ot := &overrideT{T: t} - ot.fatal.Store(checkFatal(t, "New: track overrun")) + ot.error.Store(checkError(t, "New: track overrun")) s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ {"New", ExpectArgs{}, nil, nil}, {"panic", ExpectArgs{"unreachable"}, nil, nil}, }}) - func() { defer HandleExit(); s.New(func(k stubHolder) { panic("unreachable") }) }() + func() { defer s.HandleExit(); s.New(func(k stubHolder) { panic("unreachable") }) }() var visit int s.VisitIncomplete(func(s *Stub[stubHolder]) { @@ -106,38 +137,38 @@ func TestStub(t *testing.T) { t.Run("expects", func(t *testing.T) { t.Run("overrun", func(t *testing.T) { ot := &overrideT{T: t} - ot.fatal.Store(checkFatal(t, "Expects: advancing beyond expected calls")) + ot.error.Store(checkError(t, "Expects: advancing beyond expected calls")) s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{}) - func() { defer HandleExit(); s.Expects("unreachable") }() + func() { defer s.HandleExit(); s.Expects("unreachable") }() }) t.Run("separator", func(t *testing.T) { t.Run("overrun", func(t *testing.T) { ot := &overrideT{T: t} - ot.fatalf.Store(checkFatalf(t, "Expects: func = %s, separator overrun", "meow")) + ot.errorf.Store(checkErrorf(t, "Expects: func = %s, separator overrun", "meow")) s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ {CallSeparator, ExpectArgs{}, nil, nil}, }}) - func() { defer HandleExit(); s.Expects("meow") }() + func() { defer s.HandleExit(); s.Expects("meow") }() }) t.Run("mismatch", func(t *testing.T) { ot := &overrideT{T: t} - ot.fatalf.Store(checkFatalf(t, "Expects: separator, want %s", "panic")) + ot.errorf.Store(checkErrorf(t, "Expects: separator, want %s", "panic")) s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ {"panic", ExpectArgs{}, nil, nil}, }}) - func() { defer HandleExit(); s.Expects(CallSeparator) }() + func() { defer s.HandleExit(); s.Expects(CallSeparator) }() }) }) t.Run("mismatch", func(t *testing.T) { ot := &overrideT{T: t} - ot.fatalf.Store(checkFatalf(t, "Expects: func = %s, want %s", "meow", "nya")) + ot.errorf.Store(checkErrorf(t, "Expects: func = %s, want %s", "meow", "nya")) s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{ {"nya", ExpectArgs{}, nil, nil}, }}) - func() { defer HandleExit(); s.Expects("meow") }() + func() { defer s.HandleExit(); s.Expects("meow") }() }) }) }) @@ -167,9 +198,9 @@ func TestCheckArg(t *testing.T) { } }) t.Run("mismatch", func(t *testing.T) { - defer HandleExit() + defer s.HandleExit() s.Expects("meow") - ot.errorf.Store(checkFatalf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1)) + ot.errorf.Store(checkErrorf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1)) if CheckArg(s, "time", 0, 0) { t.Errorf("CheckArg: unexpected true") } @@ -210,9 +241,9 @@ func TestCheckArgReflect(t *testing.T) { } }) t.Run("mismatch", func(t *testing.T) { - defer HandleExit() + defer s.HandleExit() s.Expects("meow") - ot.errorf.Store(checkFatalf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1)) + ot.errorf.Store(checkErrorf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1)) if CheckArgReflect(s, "time", 0, 0) { t.Errorf("CheckArgReflect: unexpected true") } @@ -229,35 +260,35 @@ func TestCheckArgReflect(t *testing.T) { }) } -func checkFatal(t *testing.T, wantArgs ...any) *func(args ...any) { +func checkError(t *testing.T, wantArgs ...any) *func(args ...any) { var called bool f := func(args ...any) { if called { - panic("invalid call to fatal") + panic("invalid call to error") } called = true if !reflect.DeepEqual(args, wantArgs) { - t.Errorf("Fatal: %#v, want %#v", args, wantArgs) + t.Errorf("Error: %#v, want %#v", args, wantArgs) } panic(PanicExit) } return &f } -func checkFatalf(t *testing.T, wantFormat string, wantArgs ...any) *func(format string, args ...any) { +func checkErrorf(t *testing.T, wantFormat string, wantArgs ...any) *func(format string, args ...any) { var called bool f := func(format string, args ...any) { if called { - panic("invalid call to fatalf") + panic("invalid call to errorf") } called = true if format != wantFormat { - t.Errorf("Fatalf: format = %q, want %q", format, wantFormat) + t.Errorf("Errorf: format = %q, want %q", format, wantFormat) } if !reflect.DeepEqual(args, wantArgs) { - t.Errorf("Fatalf: args = %#v, want %#v", args, wantArgs) + t.Errorf("Errorf: args = %#v, want %#v", args, wantArgs) } panic(PanicExit) } diff --git a/system/dispatcher_test.go b/system/dispatcher_test.go index e5b703f..7fbb75c 100644 --- a/system/dispatcher_test.go +++ b/system/dispatcher_test.go @@ -43,8 +43,8 @@ func checkOpBehaviour(t *testing.T, testCases []opBehaviourTestCase) { ec = (*Criteria)(&tc.ec) } - defer stub.HandleExit() sys, s := InternalNew(t, stub.Expect{Calls: slices.Concat(tc.apply, []stub.Call{{Name: stub.CallSeparator}}, tc.revert)}, tc.uid) + defer s.HandleExit() errApply := tc.op.apply(sys) s.Expects(stub.CallSeparator) if !reflect.DeepEqual(errApply, tc.wantErrApply) { @@ -90,8 +90,8 @@ func checkOpsBuilder(t *testing.T, fname string, testCases []opsBuilderTestCase) t.Run(tc.name, func(t *testing.T) { t.Helper() - defer stub.HandleExit() sys, s := InternalNew(t, tc.exp, tc.uid) + defer s.HandleExit() tc.f(sys) s.VisitIncomplete(func(s *stub.Stub[syscallDispatcher]) { t.Helper()