container/stub: export stub helpers
All checks were successful
Test / Create distribution (push) Successful in 33s
Test / Sandbox (push) Successful in 1m53s
Test / Hakurei (push) Successful in 3m18s
Test / Sandbox (race detector) (push) Successful in 3m40s
Test / Hpkg (push) Successful in 3m35s
Test / Hakurei (race detector) (push) Successful in 5m19s
Test / Flake checks (push) Successful in 1m39s

These are very useful in many packages containing relatively large amount of code making calls to difficult or impossible to stub functions.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2025-08-31 23:11:25 +09:00
parent b489a3bba1
commit 49600a6f46
23 changed files with 3808 additions and 3247 deletions

37
container/stub/call.go Normal file
View File

@@ -0,0 +1,37 @@
package stub
import (
"slices"
)
// ExpectArgs is an array primarily for storing expected function arguments.
// Its actual use is defined by the implementation.
type ExpectArgs = [5]any
// An Expect stores expected calls of a goroutine.
type Expect struct {
Calls []Call
// Tracks are handed out to descendant goroutines in order.
Tracks []Expect
}
// A Call holds expected arguments of a function call and its outcome.
type Call struct {
// Name is the function Name of this call. Must be unique.
Name string
// Args are the expected arguments of this Call.
Args ExpectArgs
// Ret is the return value of this Call.
Ret any
// Err is the returned error of this Call.
Err error
}
// Error returns [Call.Err] if all arguments are true, or [ErrCheck] otherwise.
func (k *Call) Error(ok ...bool) error {
if !slices.Contains(ok, false) {
return k.Err
}
return ErrCheck
}

View File

@@ -0,0 +1,23 @@
package stub_test
import (
"reflect"
"testing"
"hakurei.app/container/stub"
)
func TestCallError(t *testing.T) {
t.Run("contains false", func(t *testing.T) {
if err := new(stub.Call).Error(true, false, true); !reflect.DeepEqual(err, stub.ErrCheck) {
t.Errorf("Error: %#v, want %#v", err, stub.ErrCheck)
}
})
t.Run("passthrough", func(t *testing.T) {
wantErr := stub.UniqueError(0xbabe)
if err := (&stub.Call{Err: wantErr}).Error(true); !reflect.DeepEqual(err, wantErr) {
t.Errorf("Error: %#v, want %#v", err, wantErr)
}
})
}

25
container/stub/errors.go Normal file
View File

@@ -0,0 +1,25 @@
package stub
import (
"errors"
"strconv"
)
var (
ErrCheck = errors.New("one or more arguments did not match")
)
// UniqueError is an error that only equivalates to other [UniqueError] with the same magic value.
type UniqueError uintptr
func (e UniqueError) Error() string {
return "unique error " + strconv.Itoa(int(e)) + " injected by the test suite"
}
func (e UniqueError) Is(target error) bool {
var u UniqueError
if !errors.As(target, &u) {
return false
}
return e == u
}

View File

@@ -0,0 +1,35 @@
package stub_test
import (
"errors"
"syscall"
"testing"
"hakurei.app/container/stub"
)
func TestUniqueError(t *testing.T) {
t.Run("format", func(t *testing.T) {
want := "unique error 2989 injected by the test suite"
if got := stub.UniqueError(0xbad).Error(); got != want {
t.Errorf("Error: %q, want %q", got, want)
}
})
t.Run("is", func(t *testing.T) {
t.Run("type", func(t *testing.T) {
if errors.Is(stub.UniqueError(0), syscall.ENOTRECOVERABLE) {
t.Error("Is: unexpected true")
}
})
t.Run("val", func(t *testing.T) {
if errors.Is(stub.UniqueError(0), stub.UniqueError(1)) {
t.Error("Is: unexpected true")
}
if !errors.Is(stub.UniqueError(0xbad), stub.UniqueError(0xbad)) {
t.Error("Is: unexpected false")
}
})
})
}

15
container/stub/exit.go Normal file
View File

@@ -0,0 +1,15 @@
package stub
// PanicExit is a magic panic value treated as a simulated exit.
const PanicExit = 0xdeadbeef
// HandleExit must be deferred before calling with the stub.
func HandleExit() {
r := recover()
if r == PanicExit {
return
}
if r != nil {
panic(r)
}
}

View File

@@ -0,0 +1,30 @@
package stub_test
import (
"testing"
"hakurei.app/container/stub"
)
func TestHandleExit(t *testing.T) {
t.Run("exit", func(t *testing.T) {
defer stub.HandleExit()
panic(stub.PanicExit)
})
t.Run("nil", func(t *testing.T) {
defer stub.HandleExit()
})
t.Run("passthrough", func(t *testing.T) {
defer func() {
want := 0xcafebabe
if r := recover(); r != want {
t.Errorf("recover: %v, want %v", r, want)
}
}()
defer stub.HandleExit()
panic(0xcafebabe)
})
}

141
container/stub/stub.go Normal file
View File

@@ -0,0 +1,141 @@
// Package stub provides function call level stubbing and validation
// for library functions that are impossible to check otherwise.
package stub
import (
"reflect"
"sync"
"testing"
)
// this should prevent stub from being inadvertently imported outside tests
var _ = func() {
if !testing.Testing() {
panic("stub imported while not in a test")
}
}
const (
// A CallSeparator denotes an injected separation between two groups of calls.
CallSeparator = "\x00"
)
// A Stub is a collection of tracks of expected calls.
type Stub[K any] struct {
testing.TB
// makeK creates a new K for a descendant [Stub].
// This function may be called concurrently.
makeK func(s *Stub[K]) K
// want is a hierarchy of expected calls.
want Expect
// pos is the current position in [Expect.Calls].
pos int
// goroutine counts the number of goroutines created by this [Stub].
goroutine int
// sub stores the addresses of descendant [Stub] created by New.
sub []*Stub[K]
// wg waits for all descendants to complete.
wg *sync.WaitGroup
}
// New creates a root [Stub].
func New[K any](tb testing.TB, makeK func(s *Stub[K]) K, want Expect) *Stub[K] {
return &Stub[K]{TB: tb, makeK: makeK, want: want, wg: new(sync.WaitGroup)}
}
// New calls f in a new goroutine
func (s *Stub[K]) New(f func(k K)) {
s.Helper()
s.Expects("New")
if len(s.want.Tracks) <= s.goroutine {
s.Fatal("New: track overrun")
}
ds := &Stub[K]{TB: s.TB, makeK: s.makeK, want: s.want.Tracks[s.goroutine], wg: s.wg}
s.goroutine++
s.sub = append(s.sub, ds)
s.wg.Add(1)
go func() {
s.Helper()
defer s.wg.Done()
defer HandleExit()
f(s.makeK(ds))
}()
}
// Pos returns the current position of [Stub] in its [Expect.Calls]
func (s *Stub[K]) Pos() int { return s.pos }
// Len returns the length of [Expect.Calls].
func (s *Stub[K]) Len() int { return len(s.want.Calls) }
// VisitIncomplete calls f on an incomplete s and all its descendants.
func (s *Stub[K]) VisitIncomplete(f func(s *Stub[K])) {
s.Helper()
s.wg.Wait()
if s.want.Calls != nil && len(s.want.Calls) != s.pos {
f(s)
}
for _, ds := range s.sub {
ds.VisitIncomplete(f)
}
}
// Expects checks the name of and returns the current [Call] and advances pos.
func (s *Stub[K]) Expects(name string) (expect *Call) {
s.Helper()
if len(s.want.Calls) == s.pos {
s.Fatal("Expects: advancing beyond expected calls")
}
expect = &s.want.Calls[s.pos]
if name != expect.Name {
if expect.Name == CallSeparator {
s.Fatalf("Expects: func = %s, separator overrun", name)
}
if name == CallSeparator {
s.Fatalf("Expects: separator, want %s", expect.Name)
}
s.Fatalf("Expects: func = %s, want %s", name, expect.Name)
}
s.pos++
return
}
// CheckArg checks an argument comparable with the == operator. Avoid using this with pointers.
func CheckArg[T comparable, K any](s *Stub[K], arg string, got T, n int) bool {
s.Helper()
pos := s.pos - 1
if pos < 0 || pos >= len(s.want.Calls) {
panic("invalid call to CheckArg")
}
expect := s.want.Calls[pos]
want, ok := expect.Args[n].(T)
if !ok || got != want {
s.Errorf("%s: %s = %#v, want %#v (%d)", expect.Name, arg, got, want, pos)
return false
}
return true
}
// CheckArgReflect checks an argument of any type.
func CheckArgReflect[K any](s *Stub[K], arg string, got any, n int) bool {
s.Helper()
pos := s.pos - 1
if pos < 0 || pos >= len(s.want.Calls) {
panic("invalid call to CheckArgReflect")
}
expect := s.want.Calls[pos]
want := expect.Args[n]
if !reflect.DeepEqual(got, want) {
s.Errorf("%s: %s = %#v, want %#v (%d)", expect.Name, arg, got, want, pos)
return false
}
return true
}

265
container/stub/stub_test.go Normal file
View File

@@ -0,0 +1,265 @@
package stub
import (
"reflect"
"sync/atomic"
"testing"
)
// stubHolder embeds [Stub].
type stubHolder struct{ *Stub[stubHolder] }
// overrideT allows some methods of [testing.T] to be overridden.
type overrideT struct {
*testing.T
fatal atomic.Pointer[func(args ...any)]
fatalf atomic.Pointer[func(format string, args ...any)]
errorf atomic.Pointer[func(format string, args ...any)]
}
func (t *overrideT) Fatal(args ...any) {
fp := t.fatal.Load()
if fp == nil || *fp == nil {
t.T.Fatal(args...)
return
}
(*fp)(args...)
}
func (t *overrideT) Fatalf(format string, args ...any) {
fp := t.fatalf.Load()
if fp == nil || *fp == nil {
t.T.Fatalf(format, args...)
return
}
(*fp)(format, args...)
}
func (t *overrideT) Errorf(format string, args ...any) {
fp := t.errorf.Load()
if fp == nil || *fp == nil {
t.T.Errorf(format, args...)
return
}
(*fp)(format, args...)
}
func TestStub(t *testing.T) {
t.Run("new", func(t *testing.T) {
t.Run("success", func(t *testing.T) {
s := New(t, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{
{"New", ExpectArgs{}, nil, nil},
}, Tracks: []Expect{{Calls: []Call{
{"done", ExpectArgs{0xbabe}, nil, nil},
}}}})
s.New(func(k stubHolder) {
expect := k.Expects("done")
if expect.Name != "done" {
t.Errorf("New: Name = %s, want done", expect.Name)
}
if expect.Args != (ExpectArgs{0xbabe}) {
t.Errorf("New: Args = %#v", expect.Args)
}
if expect.Ret != nil {
t.Errorf("New: Ret = %#v", expect.Ret)
}
if expect.Err != nil {
t.Errorf("New: Err = %#v", expect.Err)
}
})
if pos := s.Pos(); pos != 1 {
t.Errorf("Pos: %d, want 1", pos)
}
if l := s.Len(); l != 1 {
t.Errorf("Len: %d, want 1", l)
}
s.VisitIncomplete(func(s *Stub[stubHolder]) { panic("unreachable") })
})
t.Run("overrun", func(t *testing.T) {
ot := &overrideT{T: t}
ot.fatal.Store(checkFatal(t, "New: track overrun"))
s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{
{"New", ExpectArgs{}, nil, nil},
{"panic", ExpectArgs{"unreachable"}, nil, nil},
}})
func() { defer HandleExit(); s.New(func(k stubHolder) { panic("unreachable") }) }()
var visit int
s.VisitIncomplete(func(s *Stub[stubHolder]) {
visit++
if visit > 1 {
panic("unexpected visit count")
}
want := Call{"panic", ExpectArgs{"unreachable"}, nil, nil}
if got := s.want.Calls[s.pos]; !reflect.DeepEqual(got, want) {
t.Errorf("VisitIncomplete: %#v, want %#v", got, want)
}
})
})
t.Run("expects", func(t *testing.T) {
t.Run("overrun", func(t *testing.T) {
ot := &overrideT{T: t}
ot.fatal.Store(checkFatal(t, "Expects: advancing beyond expected calls"))
s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{})
func() { defer HandleExit(); s.Expects("unreachable") }()
})
t.Run("separator", func(t *testing.T) {
t.Run("overrun", func(t *testing.T) {
ot := &overrideT{T: t}
ot.fatalf.Store(checkFatalf(t, "Expects: func = %s, separator overrun", "meow"))
s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{
{CallSeparator, ExpectArgs{}, nil, nil},
}})
func() { defer HandleExit(); s.Expects("meow") }()
})
t.Run("mismatch", func(t *testing.T) {
ot := &overrideT{T: t}
ot.fatalf.Store(checkFatalf(t, "Expects: separator, want %s", "panic"))
s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{
{"panic", ExpectArgs{}, nil, nil},
}})
func() { defer HandleExit(); s.Expects(CallSeparator) }()
})
})
t.Run("mismatch", func(t *testing.T) {
ot := &overrideT{T: t}
ot.fatalf.Store(checkFatalf(t, "Expects: func = %s, want %s", "meow", "nya"))
s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{
{"nya", ExpectArgs{}, nil, nil},
}})
func() { defer HandleExit(); s.Expects("meow") }()
})
})
})
}
func TestCheckArg(t *testing.T) {
t.Run("oob negative", func(t *testing.T) {
defer func() {
want := "invalid call to CheckArg"
if r := recover(); r != want {
t.Errorf("recover: %v, want %v", r, want)
}
}()
s := New(t, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{})
CheckArg(s, "unreachable", struct{}{}, 0)
})
ot := &overrideT{T: t}
s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{
{"panic", ExpectArgs{PanicExit}, nil, nil},
{"meow", ExpectArgs{-1}, nil, nil},
}})
t.Run("match", func(t *testing.T) {
s.Expects("panic")
if !CheckArg(s, "v", PanicExit, 0) {
t.Errorf("CheckArg: unexpected false")
}
})
t.Run("mismatch", func(t *testing.T) {
defer HandleExit()
s.Expects("meow")
ot.errorf.Store(checkFatalf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1))
if CheckArg(s, "time", 0, 0) {
t.Errorf("CheckArg: unexpected true")
}
})
t.Run("oob", func(t *testing.T) {
s.pos++
defer func() {
want := "invalid call to CheckArg"
if r := recover(); r != want {
t.Errorf("recover: %v, want %v", r, want)
}
}()
CheckArg(s, "unreachable", struct{}{}, 0)
})
}
func TestCheckArgReflect(t *testing.T) {
t.Run("oob lower", func(t *testing.T) {
defer func() {
want := "invalid call to CheckArgReflect"
if r := recover(); r != want {
t.Errorf("recover: %v, want %v", r, want)
}
}()
s := New(t, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{})
CheckArgReflect(s, "unreachable", struct{}{}, 0)
})
ot := &overrideT{T: t}
s := New(ot, func(s *Stub[stubHolder]) stubHolder { return stubHolder{s} }, Expect{Calls: []Call{
{"panic", ExpectArgs{PanicExit}, nil, nil},
{"meow", ExpectArgs{-1}, nil, nil},
}})
t.Run("match", func(t *testing.T) {
s.Expects("panic")
if !CheckArgReflect(s, "v", PanicExit, 0) {
t.Errorf("CheckArgReflect: unexpected false")
}
})
t.Run("mismatch", func(t *testing.T) {
defer HandleExit()
s.Expects("meow")
ot.errorf.Store(checkFatalf(t, "%s: %s = %#v, want %#v (%d)", "meow", "time", 0, -1, 1))
if CheckArgReflect(s, "time", 0, 0) {
t.Errorf("CheckArgReflect: unexpected true")
}
})
t.Run("oob", func(t *testing.T) {
s.pos++
defer func() {
want := "invalid call to CheckArgReflect"
if r := recover(); r != want {
t.Errorf("recover: %v, want %v", r, want)
}
}()
CheckArgReflect(s, "unreachable", struct{}{}, 0)
})
}
func checkFatal(t *testing.T, wantArgs ...any) *func(args ...any) {
var called bool
f := func(args ...any) {
if called {
panic("invalid call to fatal")
}
called = true
if !reflect.DeepEqual(args, wantArgs) {
t.Errorf("Fatal: %#v, want %#v", args, wantArgs)
}
panic(PanicExit)
}
return &f
}
func checkFatalf(t *testing.T, wantFormat string, wantArgs ...any) *func(format string, args ...any) {
var called bool
f := func(format string, args ...any) {
if called {
panic("invalid call to fatalf")
}
called = true
if format != wantFormat {
t.Errorf("Fatalf: format = %q, want %q", format, wantFormat)
}
if !reflect.DeepEqual(args, wantArgs) {
t.Errorf("Fatalf: args = %#v, want %#v", args, wantArgs)
}
panic(PanicExit)
}
return &f
}