container/stub: export stub helpers
All checks were successful
Test / Create distribution (push) Successful in 33s
Test / Sandbox (push) Successful in 1m53s
Test / Hakurei (push) Successful in 3m18s
Test / Sandbox (race detector) (push) Successful in 3m40s
Test / Hpkg (push) Successful in 3m35s
Test / Hakurei (race detector) (push) Successful in 5m19s
Test / Flake checks (push) Successful in 1m39s
All checks were successful
Test / Create distribution (push) Successful in 33s
Test / Sandbox (push) Successful in 1m53s
Test / Hakurei (push) Successful in 3m18s
Test / Sandbox (race detector) (push) Successful in 3m40s
Test / Hpkg (push) Successful in 3m35s
Test / Hakurei (race detector) (push) Successful in 5m19s
Test / Flake checks (push) Successful in 1m39s
These are very useful in many packages containing relatively large amount of code making calls to difficult or impossible to stub functions. Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
141
container/stub/stub.go
Normal file
141
container/stub/stub.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Package stub provides function call level stubbing and validation
|
||||
// for library functions that are impossible to check otherwise.
|
||||
package stub
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// this should prevent stub from being inadvertently imported outside tests
|
||||
var _ = func() {
|
||||
if !testing.Testing() {
|
||||
panic("stub imported while not in a test")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// A CallSeparator denotes an injected separation between two groups of calls.
|
||||
CallSeparator = "\x00"
|
||||
)
|
||||
|
||||
// A Stub is a collection of tracks of expected calls.
|
||||
type Stub[K any] struct {
|
||||
testing.TB
|
||||
|
||||
// makeK creates a new K for a descendant [Stub].
|
||||
// This function may be called concurrently.
|
||||
makeK func(s *Stub[K]) K
|
||||
|
||||
// want is a hierarchy of expected calls.
|
||||
want Expect
|
||||
// pos is the current position in [Expect.Calls].
|
||||
pos int
|
||||
// goroutine counts the number of goroutines created by this [Stub].
|
||||
goroutine int
|
||||
// sub stores the addresses of descendant [Stub] created by New.
|
||||
sub []*Stub[K]
|
||||
// wg waits for all descendants to complete.
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
// New creates a root [Stub].
|
||||
func New[K any](tb testing.TB, makeK func(s *Stub[K]) K, want Expect) *Stub[K] {
|
||||
return &Stub[K]{TB: tb, makeK: makeK, want: want, wg: new(sync.WaitGroup)}
|
||||
}
|
||||
|
||||
// New calls f in a new goroutine
|
||||
func (s *Stub[K]) New(f func(k K)) {
|
||||
s.Helper()
|
||||
|
||||
s.Expects("New")
|
||||
if len(s.want.Tracks) <= s.goroutine {
|
||||
s.Fatal("New: track overrun")
|
||||
}
|
||||
ds := &Stub[K]{TB: s.TB, makeK: s.makeK, want: s.want.Tracks[s.goroutine], wg: s.wg}
|
||||
s.goroutine++
|
||||
s.sub = append(s.sub, ds)
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
s.Helper()
|
||||
|
||||
defer s.wg.Done()
|
||||
defer HandleExit()
|
||||
f(s.makeK(ds))
|
||||
}()
|
||||
}
|
||||
|
||||
// Pos returns the current position of [Stub] in its [Expect.Calls]
|
||||
func (s *Stub[K]) Pos() int { return s.pos }
|
||||
|
||||
// Len returns the length of [Expect.Calls].
|
||||
func (s *Stub[K]) Len() int { return len(s.want.Calls) }
|
||||
|
||||
// VisitIncomplete calls f on an incomplete s and all its descendants.
|
||||
func (s *Stub[K]) VisitIncomplete(f func(s *Stub[K])) {
|
||||
s.Helper()
|
||||
s.wg.Wait()
|
||||
|
||||
if s.want.Calls != nil && len(s.want.Calls) != s.pos {
|
||||
f(s)
|
||||
}
|
||||
for _, ds := range s.sub {
|
||||
ds.VisitIncomplete(f)
|
||||
}
|
||||
}
|
||||
|
||||
// Expects checks the name of and returns the current [Call] and advances pos.
|
||||
func (s *Stub[K]) Expects(name string) (expect *Call) {
|
||||
s.Helper()
|
||||
|
||||
if len(s.want.Calls) == s.pos {
|
||||
s.Fatal("Expects: advancing beyond expected calls")
|
||||
}
|
||||
expect = &s.want.Calls[s.pos]
|
||||
if name != expect.Name {
|
||||
if expect.Name == CallSeparator {
|
||||
s.Fatalf("Expects: func = %s, separator overrun", name)
|
||||
}
|
||||
if name == CallSeparator {
|
||||
s.Fatalf("Expects: separator, want %s", expect.Name)
|
||||
}
|
||||
s.Fatalf("Expects: func = %s, want %s", name, expect.Name)
|
||||
}
|
||||
s.pos++
|
||||
return
|
||||
}
|
||||
|
||||
// CheckArg checks an argument comparable with the == operator. Avoid using this with pointers.
|
||||
func CheckArg[T comparable, K any](s *Stub[K], arg string, got T, n int) bool {
|
||||
s.Helper()
|
||||
|
||||
pos := s.pos - 1
|
||||
if pos < 0 || pos >= len(s.want.Calls) {
|
||||
panic("invalid call to CheckArg")
|
||||
}
|
||||
expect := s.want.Calls[pos]
|
||||
want, ok := expect.Args[n].(T)
|
||||
if !ok || got != want {
|
||||
s.Errorf("%s: %s = %#v, want %#v (%d)", expect.Name, arg, got, want, pos)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckArgReflect checks an argument of any type.
|
||||
func CheckArgReflect[K any](s *Stub[K], arg string, got any, n int) bool {
|
||||
s.Helper()
|
||||
|
||||
pos := s.pos - 1
|
||||
if pos < 0 || pos >= len(s.want.Calls) {
|
||||
panic("invalid call to CheckArgReflect")
|
||||
}
|
||||
expect := s.want.Calls[pos]
|
||||
want := expect.Args[n]
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
s.Errorf("%s: %s = %#v, want %#v (%d)", expect.Name, arg, got, want, pos)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user