message: relocate from container
All checks were successful
Test / Create distribution (push) Successful in 35s
Test / Sandbox (push) Successful in 2m22s
Test / Hpkg (push) Successful in 4m2s
Test / Sandbox (race detector) (push) Successful in 4m28s
Test / Hakurei (race detector) (push) Successful in 5m21s
Test / Hakurei (push) Successful in 2m9s
Test / Flake checks (push) Successful in 1m29s

This package is quite useful. This change allows it to be imported without importing container.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2025-10-09 05:04:08 +09:00
parent df9b77b077
commit 87b5c30ef6
47 changed files with 210 additions and 185 deletions

View File

@@ -6,6 +6,7 @@ import (
"hakurei.app/container/check"
"hakurei.app/container/fhs"
"hakurei.app/message"
)
func init() { gob.Register(new(AutoRootOp)) }
@@ -81,7 +82,7 @@ func (r *AutoRootOp) String() string {
}
// IsAutoRootBindable returns whether a dir entry name is selected for AutoRoot.
func IsAutoRootBindable(msg Msg, name string) bool {
func IsAutoRootBindable(msg message.Msg, name string) bool {
switch name {
case "proc", "dev", "tmp", "mnt", "etc":

View File

@@ -8,6 +8,7 @@ import (
"hakurei.app/container/bits"
"hakurei.app/container/check"
"hakurei.app/container/stub"
"hakurei.app/message"
)
func TestAutoRootOp(t *testing.T) {
@@ -195,7 +196,7 @@ func TestIsAutoRootBindable(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var msg Msg
var msg message.Msg
if tc.log {
msg = &kstub{nil, stub.New(t, func(s *stub.Stub[syscallDispatcher]) syscallDispatcher { panic("unreachable") }, stub.Expect{Calls: []stub.Call{
call("verbose", stub.ExpectArgs{[]any{"got unexpected root entry"}}, nil, nil),

View File

@@ -18,6 +18,7 @@ import (
"hakurei.app/container/check"
"hakurei.app/container/fhs"
"hakurei.app/container/seccomp"
"hakurei.app/message"
)
const (
@@ -52,7 +53,7 @@ type (
cmd *exec.Cmd
ctx context.Context
msg Msg
msg message.Msg
Params
}
@@ -396,9 +397,9 @@ func (p *Container) ProcessState() *os.ProcessState {
}
// New returns the address to a new instance of [Container] that requires further initialisation before use.
func New(ctx context.Context, msg Msg) *Container {
func New(ctx context.Context, msg message.Msg) *Container {
if msg == nil {
msg = NewMsg(nil)
msg = message.NewMsg(nil)
}
p := &Container{ctx: ctx, msg: msg, Params: Params{Ops: new(Ops)}}
@@ -409,7 +410,7 @@ func New(ctx context.Context, msg Msg) *Container {
}
// NewCommand calls [New] and initialises the [Params.Path] and [Params.Args] fields.
func NewCommand(ctx context.Context, msg Msg, pathname *check.Absolute, name string, args ...string) *Container {
func NewCommand(ctx context.Context, msg message.Msg, pathname *check.Absolute, name string, args ...string) *Container {
z := New(ctx, msg)
z.Path = pathname
z.Args = append([]string{name}, args...)

View File

@@ -26,6 +26,7 @@ import (
"hakurei.app/container/vfs"
"hakurei.app/hst"
"hakurei.app/ldd"
"hakurei.app/message"
)
func TestStartError(t *testing.T) {
@@ -152,13 +153,13 @@ func TestStartError(t *testing.T) {
})
t.Run("msg", func(t *testing.T) {
if got, ok := container.GetErrorMessage(tc.err); !ok {
if got, ok := message.GetMessage(tc.err); !ok {
if tc.msg != "" {
t.Errorf("GetErrorMessage: err does not implement MessageError")
t.Errorf("GetMessage: err does not implement MessageError")
}
return
} else if got != tc.msg {
t.Errorf("GetErrorMessage: %q, want %q", got, tc.msg)
t.Errorf("GetMessage: %q, want %q", got, tc.msg)
}
})
})
@@ -545,7 +546,7 @@ func testContainerCancel(
}
func TestContainerString(t *testing.T) {
msg := container.NewMsg(nil)
msg := message.NewMsg(nil)
c := container.NewCommand(t.Context(), msg, check.MustAbs("/run/current-system/sw/bin/ldd"), "ldd", "/usr/bin/env")
c.SeccompFlags |= seccomp.AllowMultiarch
c.SeccompRules = seccomp.Preset(
@@ -711,7 +712,7 @@ func TestMain(m *testing.M) {
}
func helperNewContainerLibPaths(ctx context.Context, libPaths *[]*check.Absolute, args ...string) (c *container.Container) {
msg := container.NewMsg(nil)
msg := message.NewMsg(nil)
c = container.NewCommand(ctx, msg, absHelperInnerPath, "helper", args...)
c.Env = append(c.Env, envDoCheck+"=1")
c.Bind(check.MustAbs(os.Args[0]), absHelperInnerPath, 0)

View File

@@ -11,6 +11,7 @@ import (
"syscall"
"hakurei.app/container/seccomp"
"hakurei.app/message"
)
type osFile interface {
@@ -37,7 +38,7 @@ type syscallDispatcher interface {
setNoNewPrivs() error
// lastcap provides [LastCap].
lastcap(msg Msg) uintptr
lastcap(msg message.Msg) uintptr
// capset provides capset.
capset(hdrp *capHeader, datap *[2]capData) error
// capBoundingSetDrop provides capBoundingSetDrop.
@@ -52,9 +53,9 @@ type syscallDispatcher interface {
receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error)
// bindMount provides procPaths.bindMount.
bindMount(msg Msg, source, target string, flags uintptr) error
bindMount(msg message.Msg, source, target string, flags uintptr) error
// remount provides procPaths.remount.
remount(msg Msg, target string, flags uintptr) error
remount(msg message.Msg, target string, flags uintptr) error
// mountTmpfs provides mountTmpfs.
mountTmpfs(fsname, target string, flags uintptr, size int, perm os.FileMode) error
// ensureFile provides ensureFile.
@@ -122,11 +123,11 @@ type syscallDispatcher interface {
wait4(pid int, wstatus *syscall.WaitStatus, options int, rusage *syscall.Rusage) (wpid int, err error)
// printf provides the Printf method of [log.Logger].
printf(msg Msg, format string, v ...any)
printf(msg message.Msg, format string, v ...any)
// fatal provides the Fatal method of [log.Logger]
fatal(msg Msg, v ...any)
fatal(msg message.Msg, v ...any)
// fatalf provides the Fatalf method of [log.Logger]
fatalf(msg Msg, format string, v ...any)
fatalf(msg message.Msg, format string, v ...any)
}
// direct implements syscallDispatcher on the current kernel.
@@ -140,7 +141,7 @@ func (direct) setPtracer(pid uintptr) error { return SetPtracer(pid) }
func (direct) setDumpable(dumpable uintptr) error { return SetDumpable(dumpable) }
func (direct) setNoNewPrivs() error { return SetNoNewPrivs() }
func (direct) lastcap(msg Msg) uintptr { return LastCap(msg) }
func (direct) lastcap(msg message.Msg) uintptr { return LastCap(msg) }
func (direct) capset(hdrp *capHeader, datap *[2]capData) error { return capset(hdrp, datap) }
func (direct) capBoundingSetDrop(cap uintptr) error { return capBoundingSetDrop(cap) }
func (direct) capAmbientClearAll() error { return capAmbientClearAll() }
@@ -150,10 +151,10 @@ func (direct) receive(key string, e any, fdp *uintptr) (func() error, error) {
return Receive(key, e, fdp)
}
func (direct) bindMount(msg Msg, source, target string, flags uintptr) error {
func (direct) bindMount(msg message.Msg, source, target string, flags uintptr) error {
return hostProc.bindMount(msg, source, target, flags)
}
func (direct) remount(msg Msg, target string, flags uintptr) error {
func (direct) remount(msg message.Msg, target string, flags uintptr) error {
return hostProc.remount(msg, target, flags)
}
func (k direct) mountTmpfs(fsname, target string, flags uintptr, size int, perm os.FileMode) error {
@@ -221,6 +222,6 @@ func (direct) wait4(pid int, wstatus *syscall.WaitStatus, options int, rusage *s
return syscall.Wait4(pid, wstatus, options, rusage)
}
func (direct) printf(msg Msg, format string, v ...any) { msg.GetLogger().Printf(format, v...) }
func (direct) fatal(msg Msg, v ...any) { msg.GetLogger().Fatal(v...) }
func (direct) fatalf(msg Msg, format string, v ...any) { msg.GetLogger().Fatalf(format, v...) }
func (direct) printf(msg message.Msg, format string, v ...any) { msg.GetLogger().Printf(format, v...) }
func (direct) fatal(msg message.Msg, v ...any) { msg.GetLogger().Fatal(v...) }
func (direct) fatalf(msg message.Msg, format string, v ...any) { msg.GetLogger().Fatalf(format, v...) }

View File

@@ -18,6 +18,7 @@ import (
"hakurei.app/container/seccomp"
"hakurei.app/container/stub"
"hakurei.app/message"
)
type opValidTestCase struct {
@@ -329,7 +330,7 @@ func (k *kstub) setDumpable(dumpable uintptr) error {
}
func (k *kstub) setNoNewPrivs() error { k.Helper(); return k.Expects("setNoNewPrivs").Err }
func (k *kstub) lastcap(msg Msg) uintptr {
func (k *kstub) lastcap(msg message.Msg) uintptr {
k.Helper()
k.checkMsg(msg)
return k.Expects("lastcap").Ret.(uintptr)
@@ -409,7 +410,7 @@ func (k *kstub) receive(key string, e any, fdp *uintptr) (closeFunc func() error
return
}
func (k *kstub) bindMount(msg Msg, source, target string, flags uintptr) error {
func (k *kstub) bindMount(msg message.Msg, source, target string, flags uintptr) error {
k.Helper()
k.checkMsg(msg)
return k.Expects("bindMount").Error(
@@ -418,7 +419,7 @@ func (k *kstub) bindMount(msg Msg, source, target string, flags uintptr) error {
stub.CheckArg(k.Stub, "flags", flags, 2))
}
func (k *kstub) remount(msg Msg, target string, flags uintptr) error {
func (k *kstub) remount(msg message.Msg, target string, flags uintptr) error {
k.Helper()
k.checkMsg(msg)
return k.Expects("remount").Error(
@@ -702,7 +703,7 @@ func (k *kstub) wait4(pid int, wstatus *syscall.WaitStatus, options int, rusage
return
}
func (k *kstub) printf(_ Msg, format string, v ...any) {
func (k *kstub) printf(_ message.Msg, format string, v ...any) {
k.Helper()
if k.Expects("printf").Error(
stub.CheckArg(k.Stub, "format", format, 0),
@@ -711,7 +712,7 @@ func (k *kstub) printf(_ Msg, format string, v ...any) {
}
}
func (k *kstub) fatal(_ Msg, v ...any) {
func (k *kstub) fatal(_ message.Msg, v ...any) {
k.Helper()
if k.Expects("fatal").Error(
stub.CheckArgReflect(k.Stub, "v", v, 0)) != nil {
@@ -720,7 +721,7 @@ func (k *kstub) fatal(_ Msg, v ...any) {
panic(stub.PanicExit)
}
func (k *kstub) fatalf(_ Msg, format string, v ...any) {
func (k *kstub) fatalf(_ message.Msg, format string, v ...any) {
k.Helper()
if k.Expects("fatalf").Error(
stub.CheckArg(k.Stub, "format", format, 0),
@@ -730,7 +731,7 @@ func (k *kstub) fatalf(_ Msg, format string, v ...any) {
panic(stub.PanicExit)
}
func (k *kstub) checkMsg(msg Msg) {
func (k *kstub) checkMsg(msg message.Msg) {
k.Helper()
var target *kstub

View File

@@ -3,6 +3,8 @@ package container
import (
"os"
"sync"
"hakurei.app/message"
)
var (
@@ -10,7 +12,7 @@ var (
executableOnce sync.Once
)
func copyExecutable(msg Msg) {
func copyExecutable(msg message.Msg) {
if name, err := os.Executable(); err != nil {
msg.BeforeExit()
msg.GetLogger().Fatalf("cannot read executable path: %v", err)
@@ -19,7 +21,7 @@ func copyExecutable(msg Msg) {
}
}
func MustExecutable(msg Msg) string {
func MustExecutable(msg message.Msg) string {
executableOnce.Do(func() { copyExecutable(msg) })
return executable
}

View File

@@ -5,11 +5,12 @@ import (
"testing"
"hakurei.app/container"
"hakurei.app/message"
)
func TestExecutable(t *testing.T) {
for i := 0; i < 16; i++ {
if got := container.MustExecutable(container.NewMsg(nil)); got != os.Args[0] {
if got := container.MustExecutable(message.NewMsg(nil)); got != os.Args[0] {
t.Errorf("MustExecutable: %q, want %q",
got, os.Args[0])
}

View File

@@ -14,6 +14,7 @@ import (
"hakurei.app/container/fhs"
"hakurei.app/container/seccomp"
"hakurei.app/message"
)
const (
@@ -61,7 +62,7 @@ type (
setupState struct {
nonrepeatable uintptr
*Params
Msg
message.Msg
}
)
@@ -95,14 +96,14 @@ type initParams struct {
}
// Init is called by [TryArgv0] if the current process is the container init.
func Init(msg Msg) {
func Init(msg message.Msg) {
if msg == nil {
panic("attempting to call initEntrypoint with nil msg")
}
initEntrypoint(direct{}, msg)
}
func initEntrypoint(k syscallDispatcher, msg Msg) {
func initEntrypoint(k syscallDispatcher, msg message.Msg) {
k.lockOSThread()
if k.getpid() != 1 {
@@ -125,7 +126,7 @@ func initEntrypoint(k syscallDispatcher, msg Msg) {
k.fatal(msg, "invalid setup descriptor")
}
if errors.Is(err, ErrReceiveEnv) {
k.fatal(msg, "HAKUREI_SETUP not set")
k.fatal(msg, setupEnv+" not set")
}
k.fatalf(msg, "cannot decode init setup payload: %v", err)
@@ -448,11 +449,11 @@ const initName = "init"
// TryArgv0 calls [Init] if the last element of argv0 is "init".
// If a nil msg is passed, the system logger is used instead.
func TryArgv0(msg Msg) {
func TryArgv0(msg message.Msg) {
if msg == nil {
log.SetPrefix(initName + ": ")
log.SetFlags(0)
msg = NewMsg(log.Default())
msg = message.NewMsg(log.Default())
}
if len(os.Args) > 0 && path.Base(os.Args[0]) == initName {

View File

@@ -7,6 +7,7 @@ import (
. "syscall"
"hakurei.app/container/vfs"
"hakurei.app/message"
)
/*
@@ -87,7 +88,7 @@ const (
)
// bindMount mounts source on target and recursively applies flags if MS_REC is set.
func (p *procPaths) bindMount(msg Msg, source, target string, flags uintptr) error {
func (p *procPaths) bindMount(msg message.Msg, source, target string, flags uintptr) error {
// syscallDispatcher.bindMount and procPaths.remount must not be called from this function
if err := p.k.mount(source, target, FstypeNULL, MS_SILENT|MS_BIND|flags&MS_REC, zeroString); err != nil {
@@ -97,7 +98,7 @@ func (p *procPaths) bindMount(msg Msg, source, target string, flags uintptr) err
}
// remount applies flags on target, recursively if MS_REC is set.
func (p *procPaths) remount(msg Msg, target string, flags uintptr) error {
func (p *procPaths) remount(msg message.Msg, target string, flags uintptr) error {
// syscallDispatcher methods bindMount, remount must not be called from this function
var targetFinal string
@@ -159,7 +160,7 @@ func (p *procPaths) remount(msg Msg, target string, flags uintptr) error {
}
// remountWithFlags remounts mount point described by [vfs.MountInfoNode].
func remountWithFlags(k syscallDispatcher, msg Msg, n *vfs.MountInfoNode, mf uintptr) error {
func remountWithFlags(k syscallDispatcher, msg message.Msg, n *vfs.MountInfoNode, mf uintptr) error {
// syscallDispatcher methods bindMount, remount must not be called from this function
kf, unmatched := n.Flags()

View File

@@ -1,110 +0,0 @@
package container
import (
"errors"
"log"
"sync/atomic"
)
// MessageError is an error with a user-facing message.
type MessageError interface {
// Message returns a user-facing error message.
Message() string
error
}
// GetErrorMessage returns whether an error implements [MessageError], and the message if it does.
func GetErrorMessage(err error) (string, bool) {
var e MessageError
if !errors.As(err, &e) || e == nil {
return zeroString, false
}
return e.Message(), true
}
// Msg is used for package-wide verbose logging.
type Msg interface {
// GetLogger returns the address of the underlying [log.Logger].
GetLogger() *log.Logger
// IsVerbose atomically loads and returns whether [Msg] has verbose logging enabled.
IsVerbose() bool
// SwapVerbose atomically stores a new verbose state and returns the previous value held by [Msg].
SwapVerbose(verbose bool) bool
// Verbose passes its argument to the Println method of the underlying [log.Logger] if IsVerbose returns true.
Verbose(v ...any)
// Verbosef passes its argument to the Printf method of the underlying [log.Logger] if IsVerbose returns true.
Verbosef(format string, v ...any)
// Suspend causes the embedded [Suspendable] to withhold writes to its downstream [io.Writer].
// Suspend returns false and is a noop if called between calls to Suspend and Resume.
Suspend() bool
// Resume dumps the entire buffer held by the embedded [Suspendable] and stops withholding future writes.
// Resume returns false and is a noop if a call to Suspend does not precede it.
Resume() bool
// BeforeExit runs implementation-specific cleanup code, and optionally prints warnings.
// BeforeExit is called before [os.Exit].
BeforeExit()
}
// defaultMsg is the default implementation of the [Msg] interface.
// The zero value is not safe for use. Callers should use the [NewMsg] function instead.
type defaultMsg struct {
verbose atomic.Bool
logger *log.Logger
Suspendable
}
// NewMsg initialises a downstream [log.Logger] for a new [Msg].
// The [log.Logger] should no longer be configured after NewMsg returns.
// If downstream is nil, a new logger is initialised in its place.
func NewMsg(downstream *log.Logger) Msg {
if downstream == nil {
downstream = log.New(log.Writer(), "container: ", 0)
}
m := defaultMsg{logger: downstream}
m.Suspendable.Downstream = downstream.Writer()
downstream.SetOutput(&m.Suspendable)
return &m
}
func (msg *defaultMsg) GetLogger() *log.Logger { return msg.logger }
func (msg *defaultMsg) IsVerbose() bool { return msg.verbose.Load() }
func (msg *defaultMsg) SwapVerbose(verbose bool) bool { return msg.verbose.Swap(verbose) }
func (msg *defaultMsg) Verbose(v ...any) {
if msg.verbose.Load() {
msg.logger.Println(v...)
}
}
func (msg *defaultMsg) Verbosef(format string, v ...any) {
if msg.verbose.Load() {
msg.logger.Printf(format, v...)
}
}
// Resume calls [Suspendable.Resume] and prints a message if buffer was filled
// between calls to [Suspendable.Suspend] and Resume.
func (msg *defaultMsg) Resume() bool {
resumed, dropped, _, err := msg.Suspendable.Resume()
if err != nil {
// probably going to result in an error as well, so this message is as good as unreachable
msg.logger.Printf("cannot dump buffer on resume: %v", err)
}
if resumed && dropped > 0 {
msg.logger.Printf("dropped %d bytes while output is suspended", dropped)
}
return resumed
}
// BeforeExit prints a message if called between calls to [Suspendable.Suspend] and Resume.
func (msg *defaultMsg) BeforeExit() {
if msg.Resume() {
msg.logger.Printf("beforeExit reached on suspended output")
}
}

View File

@@ -1,177 +0,0 @@
package container_test
import (
"bytes"
"errors"
"io"
"log"
"strings"
"syscall"
"testing"
"hakurei.app/container"
"hakurei.app/container/stub"
)
func TestMessageError(t *testing.T) {
testCases := []struct {
name string
err error
want string
wantOk bool
}{
{"nil", nil, "", false},
{"new", errors.New(":3"), "", false},
{"start", &container.StartError{
Step: "meow",
Err: syscall.ENOTRECOVERABLE,
}, "cannot meow: state not recoverable", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, ok := container.GetErrorMessage(tc.err)
if got != tc.want {
t.Errorf("GetErrorMessage: %q, want %q", got, tc.want)
}
if ok != tc.wantOk {
t.Errorf("GetErrorMessage: ok = %v, want %v", ok, tc.wantOk)
}
})
}
}
func TestDefaultMsg(t *testing.T) {
// copied from output.go
const suspendBufMax = 1 << 24
t.Run("logger", func(t *testing.T) {
t.Run("nil", func(t *testing.T) {
got := container.NewMsg(nil).GetLogger()
if out := got.Writer().(*container.Suspendable).Downstream; out != log.Writer() {
t.Errorf("GetLogger: Downstream = %#v", out)
}
if prefix := got.Prefix(); prefix != "container: " {
t.Errorf("GetLogger: prefix = %q", prefix)
}
})
t.Run("takeover", func(t *testing.T) {
l := log.New(io.Discard, "\x00", 0xdeadbeef)
got := container.NewMsg(l)
if logger := got.GetLogger(); logger != l {
t.Errorf("GetLogger: %#v, want %#v", logger, l)
}
if ds := l.Writer().(*container.Suspendable).Downstream; ds != io.Discard {
t.Errorf("GetLogger: Downstream = %#v", ds)
}
})
})
dw := expectWriter{t: t}
steps := []struct {
name string
pt, next []byte
err error
f func(t *testing.T, msg container.Msg)
}{
{"zero verbose", nil, nil, nil, func(t *testing.T, msg container.Msg) {
if msg.IsVerbose() {
t.Error("IsVerbose unexpected true")
}
}},
{"swap false", nil, nil, nil, func(t *testing.T, msg container.Msg) {
if msg.SwapVerbose(false) {
t.Error("SwapVerbose unexpected true")
}
}},
{"write discard", nil, nil, nil, func(_ *testing.T, msg container.Msg) {
msg.Verbose("\x00")
msg.Verbosef("\x00")
}},
{"verbose false", nil, nil, nil, func(t *testing.T, msg container.Msg) {
if msg.IsVerbose() {
t.Error("IsVerbose unexpected true")
}
}},
{"swap true", nil, nil, nil, func(t *testing.T, msg container.Msg) {
if msg.SwapVerbose(true) {
t.Error("SwapVerbose unexpected true")
}
}},
{"write verbose", []byte("test: \x00\n"), nil, nil, func(_ *testing.T, msg container.Msg) {
msg.Verbose("\x00")
}},
{"write verbosef", []byte(`test: "\x00"` + "\n"), nil, nil, func(_ *testing.T, msg container.Msg) {
msg.Verbosef("%q", "\x00")
}},
{"verbose true", nil, nil, nil, func(t *testing.T, msg container.Msg) {
if !msg.IsVerbose() {
t.Error("IsVerbose unexpected false")
}
}},
{"resume noop", nil, nil, nil, func(t *testing.T, msg container.Msg) {
if msg.Resume() {
t.Error("Resume unexpected success")
}
}},
{"beforeExit noop", nil, nil, nil, func(_ *testing.T, msg container.Msg) {
msg.BeforeExit()
}},
{"beforeExit suspend", nil, nil, nil, func(_ *testing.T, msg container.Msg) {
msg.Suspend()
}},
{"beforeExit message", []byte("test: beforeExit reached on suspended output\n"), nil, nil, func(_ *testing.T, msg container.Msg) {
msg.BeforeExit()
}},
{"post beforeExit resume noop", nil, nil, nil, func(t *testing.T, msg container.Msg) {
if msg.Resume() {
t.Error("Resume unexpected success")
}
}},
{"suspend", nil, nil, nil, func(_ *testing.T, msg container.Msg) {
msg.Suspend()
}},
{"suspend write", nil, nil, nil, func(_ *testing.T, msg container.Msg) {
msg.GetLogger().Print("\x00")
}},
{"resume error", []byte("test: \x00\n"), []byte("test: cannot dump buffer on resume: unique error 0 injected by the test suite\n"), stub.UniqueError(0), func(t *testing.T, msg container.Msg) {
if !msg.Resume() {
t.Error("Resume unexpected failure")
}
}},
{"suspend drop", nil, nil, nil, func(_ *testing.T, msg container.Msg) {
msg.Suspend()
}},
{"suspend write fill", nil, nil, nil, func(_ *testing.T, msg container.Msg) {
msg.GetLogger().Print(strings.Repeat("\x00", suspendBufMax))
}},
{"resume dropped", append([]byte("test: "), bytes.Repeat([]byte{0}, suspendBufMax-6)...), []byte("test: dropped 7 bytes while output is suspended\n"), nil, func(t *testing.T, msg container.Msg) {
if !msg.Resume() {
t.Error("Resume unexpected failure")
}
}},
}
msg := container.NewMsg(log.New(&dw, "test: ", 0))
for _, step := range steps {
// these share the same writer, so cannot be subtests
t.Logf("running step %q", step.name)
dw.expect, dw.next, dw.err = step.pt, step.next, step.err
step.f(t, msg)
if dw.expect != nil {
t.Errorf("expect: %q", string(dw.expect))
}
}
}

View File

@@ -1,77 +0,0 @@
package container
import (
"bytes"
"io"
"sync"
"sync/atomic"
"syscall"
)
const (
suspendBufInitial = 1 << 12
suspendBufMax = 1 << 24
)
// Suspendable proxies writes to a downstream [io.Writer] but optionally withholds writes
// between calls to Suspend and Resume.
type Suspendable struct {
Downstream io.Writer
s atomic.Bool
buf bytes.Buffer
// for growing buf
bufOnce sync.Once
// for synchronising all other buf operations
bufMu sync.Mutex
dropped int
}
func (s *Suspendable) Write(p []byte) (n int, err error) {
if !s.s.Load() {
return s.Downstream.Write(p)
}
s.bufOnce.Do(func() { s.buf.Grow(suspendBufInitial) })
s.bufMu.Lock()
defer s.bufMu.Unlock()
if free := suspendBufMax - s.buf.Len(); free < len(p) {
// fast path
if free <= 0 {
s.dropped += len(p)
return 0, syscall.ENOMEM
}
n, _ = s.buf.Write(p[:free])
err = syscall.ENOMEM
s.dropped += len(p) - n
return
}
return s.buf.Write(p)
}
// IsSuspended returns whether [Suspendable] is currently between a call to Suspend and Resume.
func (s *Suspendable) IsSuspended() bool { return s.s.Load() }
// Suspend causes [Suspendable] to start withholding output in its buffer.
func (s *Suspendable) Suspend() bool { return s.s.CompareAndSwap(false, true) }
// Resume undoes the effect of Suspend and dumps the buffered into the downstream [io.Writer].
func (s *Suspendable) Resume() (resumed bool, dropped uintptr, n int64, err error) {
if s.s.CompareAndSwap(true, false) {
s.bufMu.Lock()
defer s.bufMu.Unlock()
resumed = true
dropped = uintptr(s.dropped)
s.dropped = 0
n, err = io.Copy(s.Downstream, &s.buf)
s.buf.Reset()
}
return
}

View File

@@ -1,173 +0,0 @@
package container_test
import (
"bytes"
"errors"
"reflect"
"strconv"
"syscall"
"testing"
"hakurei.app/container"
"hakurei.app/container/stub"
)
func TestSuspendable(t *testing.T) {
// copied from output.go
const suspendBufMax = 1 << 24
const (
// equivalent to len(want.pt)
nSpecialPtEquiv = -iota - 1
// equivalent to len(want.w)
nSpecialWEquiv
// suspends writer before executing test case, implies nSpecialWEquiv
nSpecialSuspend
// offset: resume writer and measure against dump instead, implies nSpecialPtEquiv
nSpecialDump
)
// shares the same writer
steps := []struct {
name string
w, pt []byte
err error
wantErr error
n int
}{
{"simple", []byte{0xde, 0xad, 0xbe, 0xef}, []byte{0xde, 0xad, 0xbe, 0xef},
nil, nil, nSpecialPtEquiv},
{"error", []byte{0xb, 0xad}, []byte{0xb, 0xad},
stub.UniqueError(0), stub.UniqueError(0), nSpecialPtEquiv},
{"suspend short", []byte{0}, nil,
nil, nil, nSpecialSuspend},
{"sw short 0", []byte{0xca, 0xfe, 0xba, 0xbe}, nil,
nil, nil, nSpecialWEquiv},
{"sw short 1", []byte{0xff}, nil,
nil, nil, nSpecialWEquiv},
{"resume short", nil, []byte{0, 0xca, 0xfe, 0xba, 0xbe, 0xff}, nil, nil,
nSpecialDump},
{"long pt", bytes.Repeat([]byte{0xff}, suspendBufMax+1), bytes.Repeat([]byte{0xff}, suspendBufMax+1),
nil, nil, nSpecialPtEquiv},
{"suspend fill", bytes.Repeat([]byte{0xfe}, suspendBufMax), nil,
nil, nil, nSpecialSuspend},
{"drop", []byte{0}, nil,
nil, syscall.ENOMEM, 0},
{"drop error", []byte{0}, nil,
stub.UniqueError(1), syscall.ENOMEM, 0},
{"resume fill", nil, bytes.Repeat([]byte{0xfe}, suspendBufMax),
nil, nil, nSpecialDump - 2},
{"suspend fill partial", bytes.Repeat([]byte{0xfd}, suspendBufMax-0xf), nil,
nil, nil, nSpecialSuspend},
{"partial write", bytes.Repeat([]byte{0xad}, 0x1f), nil,
nil, syscall.ENOMEM, 0xf},
{"full drop", []byte{0}, nil,
nil, syscall.ENOMEM, 0},
{"resume fill partial", nil, append(bytes.Repeat([]byte{0xfd}, suspendBufMax-0xf), bytes.Repeat([]byte{0xad}, 0xf)...),
nil, nil, nSpecialDump - 0x10 - 1},
}
var dw expectWriter
w := container.Suspendable{Downstream: &dw}
for _, step := range steps {
// these share the same writer, so cannot be subtests
t.Logf("writing step %q", step.name)
dw.expect, dw.err = step.pt, step.err
var (
gotN int
gotErr error
)
wantN := step.n
switch wantN {
case nSpecialPtEquiv:
wantN = len(step.pt)
gotN, gotErr = w.Write(step.w)
case nSpecialWEquiv:
wantN = len(step.w)
gotN, gotErr = w.Write(step.w)
case nSpecialSuspend:
s := w.IsSuspended()
if ok := w.Suspend(); s && ok {
t.Fatal("Suspend: unexpected success")
}
wantN = len(step.w)
gotN, gotErr = w.Write(step.w)
default:
if wantN <= nSpecialDump {
if !w.IsSuspended() {
t.Fatal("IsSuspended unexpected false")
}
resumed, dropped, n, err := w.Resume()
if !resumed {
t.Fatal("Resume: resumed = false")
}
if wantDropped := nSpecialDump - wantN; int(dropped) != wantDropped {
t.Errorf("Resume: dropped = %d, want %d", dropped, wantDropped)
}
wantN = len(step.pt)
gotN, gotErr = int(n), err
} else {
gotN, gotErr = w.Write(step.w)
}
}
if gotN != wantN {
t.Errorf("Write: n = %d, want %d", gotN, wantN)
}
if !reflect.DeepEqual(gotErr, step.wantErr) {
t.Errorf("Write: %v", gotErr)
}
if dw.expect != nil {
t.Errorf("expect: %q", string(dw.expect))
}
}
}
// expectWriter compares Write calls to expect.
type expectWriter struct {
expect []byte
err error
// optional consecutive write
next []byte
// optional, calls Error on failure if not nil
t *testing.T
}
func (w *expectWriter) Write(p []byte) (n int, err error) {
defer func() { w.expect = w.next; w.next = nil }()
n, err = len(p), w.err
if w.expect == nil {
n, err = 0, errors.New("unexpected call to Write: "+strconv.Quote(string(p)))
if w.t != nil {
w.t.Error(err.Error())
}
return
}
if string(p) != string(w.expect) {
n, err = 0, errors.New("p = "+strconv.Quote(string(p))+", want "+strconv.Quote(string(w.expect)))
if w.t != nil {
w.t.Error(err.Error())
}
return
}
return
}

View File

@@ -7,6 +7,7 @@ import (
"sync"
"hakurei.app/container/fhs"
"hakurei.app/message"
)
var (
@@ -23,7 +24,7 @@ const (
kernelCapLastCapPath = fhs.ProcSys + "kernel/cap_last_cap"
)
func mustReadSysctl(msg Msg) {
func mustReadSysctl(msg message.Msg) {
sysctlOnce.Do(func() {
if v, err := os.ReadFile(kernelOverflowuidPath); err != nil {
msg.GetLogger().Fatalf("cannot read %q: %v", kernelOverflowuidPath, err)
@@ -45,6 +46,6 @@ func mustReadSysctl(msg Msg) {
})
}
func OverflowUid(msg Msg) int { mustReadSysctl(msg); return kernelOverflowuid }
func OverflowGid(msg Msg) int { mustReadSysctl(msg); return kernelOverflowgid }
func LastCap(msg Msg) uintptr { mustReadSysctl(msg); return uintptr(kernelCapLastCap) }
func OverflowUid(msg message.Msg) int { mustReadSysctl(msg); return kernelOverflowuid }
func OverflowGid(msg message.Msg) int { mustReadSysctl(msg); return kernelOverflowgid }
func LastCap(msg message.Msg) uintptr { mustReadSysctl(msg); return uintptr(kernelCapLastCap) }