internal/app: filter ops in implementation
All checks were successful
Test / Create distribution (push) Successful in 35s
Test / Sandbox (push) Successful in 2m18s
Test / Hpkg (push) Successful in 4m1s
Test / Sandbox (race detector) (push) Successful in 4m28s
Test / Hakurei (race detector) (push) Successful in 5m19s
Test / Hakurei (push) Successful in 2m14s
Test / Flake checks (push) Successful in 1m33s

This is cleaner and less error-prone, and should also result in negligibly less memory allocation.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
Ophestra 2025-10-10 02:23:34 +09:00
parent 4246256d78
commit 22ee5ae151
Signed by: cat
SSH Key Fingerprint: SHA256:gQ67O0enBZ7UdZypgtspB2FDM1g3GVw8nX0XSdcFw8Q
9 changed files with 57 additions and 84 deletions

View File

@ -458,18 +458,15 @@ func TestApp(t *testing.T) {
Container: tc.config.Container, Container: tc.config.Container,
} }
sPriv.populateEarly(tc.k, msg, tc.config) sPriv.populateEarly(tc.k, msg)
if err := sPriv.populateLocal(tc.k, msg); err != nil { if err := sPriv.populateLocal(tc.k, msg); err != nil {
t.Fatalf("populateLocal: error = %#v", err) t.Fatalf("populateLocal: error = %#v", err)
} }
gotSys = system.New(t.Context(), msg, sPriv.uid.unwrap()) gotSys = system.New(t.Context(), msg, sPriv.uid.unwrap())
stateSys := outcomeStateSys{config: tc.config, sys: gotSys, outcomeState: &sPriv} if err := (&outcomeStateSys{config: tc.config, sys: gotSys, outcomeState: &sPriv}).toSystem(); err != nil {
for _, op := range sPriv.Shim.Ops {
if err := op.toSystem(&stateSys); err != nil {
t.Fatalf("toSystem: error = %#v", err) t.Fatalf("toSystem: error = %#v", err)
} }
}
go func() { go func() {
e := gob.NewEncoder(gw) e := gob.NewEncoder(gw)

View File

@ -75,18 +75,15 @@ func (k *outcome) finalise(ctx context.Context, msg message.Msg, id *state.ID, c
EnvPaths: copyPaths(k.syscallDispatcher), EnvPaths: copyPaths(k.syscallDispatcher),
Container: config.Container, Container: config.Container,
} }
s.populateEarly(k.syscallDispatcher, msg, config) s.populateEarly(k.syscallDispatcher, msg)
if err := s.populateLocal(k.syscallDispatcher, msg); err != nil { if err := s.populateLocal(k.syscallDispatcher, msg); err != nil {
return err return err
} }
sys := system.New(k.ctx, msg, s.uid.unwrap()) sys := system.New(k.ctx, msg, s.uid.unwrap())
stateSys := outcomeStateSys{config: config, sys: sys, outcomeState: &s} if err := (&outcomeStateSys{config: config, sys: sys, outcomeState: &s}).toSystem(); err != nil {
for _, op := range s.Shim.Ops {
if err := op.toSystem(&stateSys); err != nil {
return err return err
} }
}
k.sys = sys k.sys = sys
k.supp = supp k.supp = supp

View File

@ -1,6 +1,7 @@
package app package app
import ( import (
"errors"
"os" "os"
"strconv" "strconv"
@ -76,8 +77,8 @@ func (s *outcomeState) valid() bool {
// populateEarly populates exported fields via syscallDispatcher. // populateEarly populates exported fields via syscallDispatcher.
// This must only be called from the priv side. // This must only be called from the priv side.
func (s *outcomeState) populateEarly(k syscallDispatcher, msg message.Msg, config *hst.Config) { func (s *outcomeState) populateEarly(k syscallDispatcher, msg message.Msg) {
s.Shim = &shimParams{PrivPID: os.Getpid(), Verbose: msg.IsVerbose(), Ops: fromConfig(config)} s.Shim = &shimParams{PrivPID: os.Getpid(), Verbose: msg.IsVerbose()}
// enforce bounds and default early // enforce bounds and default early
if s.Container.WaitDelay <= 0 { if s.Container.WaitDelay <= 0 {
@ -203,6 +204,9 @@ type outcomeStateParams struct {
*outcomeState *outcomeState
} }
// errNotEnabled is returned by outcomeOp.toSystem and used internally to exclude an outcomeOp from transmission.
var errNotEnabled = errors.New("op not enabled in the configuration")
// An outcomeOp inflicts an outcome on [system.I] and contains enough information to // An outcomeOp inflicts an outcome on [system.I] and contains enough information to
// inflict it on [container.Params] in a separate process. // inflict it on [container.Params] in a separate process.
// An implementation of outcomeOp must store cross-process states in exported fields only. // An implementation of outcomeOp must store cross-process states in exported fields only.
@ -216,11 +220,15 @@ type outcomeOp interface {
toContainer(state *outcomeStateParams) error toContainer(state *outcomeStateParams) error
} }
// fromConfig returns a corresponding slice of outcomeOp for [hst.Config]. // toSystem calls the outcomeOp.toSystem method on all outcomeOp implementations and populates shimParams.Ops.
// This function assumes the caller has already called the Validate method on [hst.Config] // This function assumes the caller has already called the Validate method on [hst.Config]
// and checked that it returns nil. // and checked that it returns nil.
func fromConfig(config *hst.Config) (ops []outcomeOp) { func (state *outcomeStateSys) toSystem() error {
ops = []outcomeOp{ if state.Shim == nil || state.Shim.Ops != nil {
return newWithMessage("invalid ops state reached")
}
ops := [...]outcomeOp{
// must run first // must run first
&spParamsOp{}, &spParamsOp{},
@ -230,22 +238,27 @@ func fromConfig(config *hst.Config) (ops []outcomeOp) {
spRuntimeOp{}, spRuntimeOp{},
spTmpdirOp{}, spTmpdirOp{},
spAccountOp{}, spAccountOp{},
// optional via enablements
&spWaylandOp{},
&spX11Op{},
&spPulseOp{},
&spDBusOp{},
spFinal{},
} }
et := config.Enablements.Unwrap() state.Shim.Ops = make([]outcomeOp, 0, len(ops))
if et&hst.EWayland != 0 { for _, op := range ops {
ops = append(ops, &spWaylandOp{}) if err := op.toSystem(state); err != nil {
} // this error is used internally to exclude this outcomeOp from transmission
if et&hst.EX11 != 0 { if errors.Is(err, errNotEnabled) {
ops = append(ops, &spX11Op{}) continue
}
if et&hst.EPulse != 0 {
ops = append(ops, &spPulseOp{})
}
if et&hst.EDBus != 0 {
ops = append(ops, &spDBusOp{})
} }
ops = append(ops, spFinal{}) return err
return }
state.Shim.Ops = append(state.Shim.Ops, op)
}
return nil
} }

View File

@ -1,7 +1,6 @@
package app package app
import ( import (
"reflect"
"testing" "testing"
"hakurei.app/hst" "hakurei.app/hst"
@ -30,49 +29,3 @@ func TestOutcomeStateValid(t *testing.T) {
}) })
} }
} }
func TestFromConfig(t *testing.T) {
testCases := []struct {
name string
config *hst.Config
want []outcomeOp
}{
{"ne", new(hst.Config), []outcomeOp{
&spParamsOp{},
spFilesystemOp{},
spRuntimeOp{},
spTmpdirOp{},
spAccountOp{},
spFinal{},
}},
{"wayland pulse", &hst.Config{Enablements: hst.NewEnablements(hst.EWayland | hst.EPulse)}, []outcomeOp{
&spParamsOp{},
spFilesystemOp{},
spRuntimeOp{},
spTmpdirOp{},
spAccountOp{},
&spWaylandOp{},
&spPulseOp{},
spFinal{},
}},
{"all", &hst.Config{Enablements: hst.NewEnablements(0xff)}, []outcomeOp{
&spParamsOp{},
spFilesystemOp{},
spRuntimeOp{},
spTmpdirOp{},
spAccountOp{},
&spWaylandOp{},
&spX11Op{},
&spPulseOp{},
&spDBusOp{},
spFinal{},
}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if got := fromConfig(tc.config); !reflect.DeepEqual(got, tc.want) {
t.Errorf("fromConfig: %#v, want %#v", got, tc.want)
}
})
}
}

View File

@ -48,11 +48,7 @@ type shimParams struct {
} }
// valid checks shimParams to be safe for use. // valid checks shimParams to be safe for use.
func (p *shimParams) valid() bool { func (p *shimParams) valid() bool { return p != nil && p.PrivPID > 0 }
return p != nil &&
p.Ops != nil &&
p.PrivPID > 0
}
// ShimMain is the main function of the shim process and runs as the unconstrained target user. // ShimMain is the main function of the shim process and runs as the unconstrained target user.
func ShimMain() { func ShimMain() {

View File

@ -4,6 +4,7 @@ import (
"encoding/gob" "encoding/gob"
"hakurei.app/container/fhs" "hakurei.app/container/fhs"
"hakurei.app/hst"
"hakurei.app/system/acl" "hakurei.app/system/acl"
"hakurei.app/system/dbus" "hakurei.app/system/dbus"
) )
@ -18,6 +19,10 @@ type spDBusOp struct {
} }
func (s *spDBusOp) toSystem(state *outcomeStateSys) error { func (s *spDBusOp) toSystem(state *outcomeStateSys) error {
if state.config.Enablements.Unwrap()&hst.EDBus == 0 {
return errNotEnabled
}
if state.config.SessionBus == nil { if state.config.SessionBus == nil {
state.config.SessionBus = dbus.NewConfig(state.config.ID, true, true) state.config.SessionBus = dbus.NewConfig(state.config.ID, true, true)
} }

View File

@ -24,6 +24,10 @@ type spPulseOp struct {
} }
func (s *spPulseOp) toSystem(state *outcomeStateSys) error { func (s *spPulseOp) toSystem(state *outcomeStateSys) error {
if state.config.Enablements.Unwrap()&hst.EPulse == 0 {
return errNotEnabled
}
pulseRuntimeDir, pulseSocket := s.commonPaths(state.outcomeState) pulseRuntimeDir, pulseSocket := s.commonPaths(state.outcomeState)
if _, err := state.k.stat(pulseRuntimeDir.String()); err != nil { if _, err := state.k.stat(pulseRuntimeDir.String()); err != nil {

View File

@ -18,6 +18,10 @@ type spWaylandOp struct {
} }
func (s *spWaylandOp) toSystem(state *outcomeStateSys) error { func (s *spWaylandOp) toSystem(state *outcomeStateSys) error {
if state.config.Enablements.Unwrap()&hst.EWayland == 0 {
return errNotEnabled
}
// outer wayland socket (usually `/run/user/%d/wayland-%d`) // outer wayland socket (usually `/run/user/%d/wayland-%d`)
var socketPath *check.Absolute var socketPath *check.Absolute
if name, ok := state.k.lookupEnv(wayland.WaylandDisplay); !ok { if name, ok := state.k.lookupEnv(wayland.WaylandDisplay); !ok {

View File

@ -25,6 +25,10 @@ type spX11Op struct {
} }
func (s *spX11Op) toSystem(state *outcomeStateSys) error { func (s *spX11Op) toSystem(state *outcomeStateSys) error {
if state.config.Enablements.Unwrap()&hst.EX11 == 0 {
return errNotEnabled
}
if d, ok := state.k.lookupEnv("DISPLAY"); !ok { if d, ok := state.k.lookupEnv("DISPLAY"); !ok {
return newWithMessage("DISPLAY is not set") return newWithMessage("DISPLAY is not set")
} else { } else {