diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 96b3670..5edc9b8 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -450,15 +450,7 @@ func TestApp(t *testing.T) { var gotSys *system.I { - sPriv := outcomeState{ - ID: &tc.id, - Identity: tc.config.Identity, - UserID: (&Hsu{k: tc.k}).MustIDMsg(msg), - EnvPaths: copyPaths(tc.k), - Container: tc.config.Container, - } - - sPriv.populateEarly(tc.k, msg) + sPriv := newOutcomeState(tc.k, msg, &tc.id, tc.config, &Hsu{k: tc.k}) if err := sPriv.populateLocal(tc.k, msg); err != nil { t.Fatalf("populateLocal: error = %#v", err) } @@ -574,6 +566,7 @@ type stubNixOS struct { func (k *stubNixOS) new(func(k syscallDispatcher)) { panic("not implemented") } +func (k *stubNixOS) getpid() int { return 0xdeadbeef } func (k *stubNixOS) getuid() int { return 1971 } func (k *stubNixOS) getgid() int { return 100 } diff --git a/internal/app/dispatcher.go b/internal/app/dispatcher.go index 6e00a9b..d56278a 100644 --- a/internal/app/dispatcher.go +++ b/internal/app/dispatcher.go @@ -29,6 +29,8 @@ type syscallDispatcher interface { // just synchronising access is not enough, as this is for test instrumentation. new(f func(k syscallDispatcher)) + // getpid provides [os.Getpid]. + getpid() int // getuid provides [os.Getuid]. getuid() int // getgid provides [os.Getgid]. @@ -70,6 +72,7 @@ type direct struct{} func (k direct) new(f func(k syscallDispatcher)) { go f(k) } +func (direct) getpid() int { return os.Getpid() } func (direct) getuid() int { return os.Getuid() } func (direct) getgid() int { return os.Getgid() } func (direct) lookupEnv(key string) (string, bool) { return os.LookupEnv(key) } diff --git a/internal/app/dispatcher_test.go b/internal/app/dispatcher_test.go index f6d10ba..23a5e52 100644 --- a/internal/app/dispatcher_test.go +++ b/internal/app/dispatcher_test.go @@ -11,6 +11,7 @@ import ( type panicDispatcher struct{} func (panicDispatcher) new(func(k syscallDispatcher)) { panic("unreachable") } +func (panicDispatcher) getpid() int { panic("unreachable") } func (panicDispatcher) getuid() int { panic("unreachable") } func (panicDispatcher) getgid() int { panic("unreachable") } func (panicDispatcher) lookupEnv(string) (string, bool) { panic("unreachable") } diff --git a/internal/app/finalise.go b/internal/app/finalise.go index 0adfcb4..52a38fa 100644 --- a/internal/app/finalise.go +++ b/internal/app/finalise.go @@ -68,14 +68,7 @@ func (k *outcome) finalise(ctx context.Context, msg message.Msg, id *state.ID, c } // early validation complete at this point - s := outcomeState{ - ID: id, - Identity: config.Identity, - UserID: (&Hsu{k: k}).MustIDMsg(msg), - EnvPaths: copyPaths(k.syscallDispatcher), - Container: config.Container, - } - s.populateEarly(k.syscallDispatcher, msg) + s := newOutcomeState(k.syscallDispatcher, msg, id, config, &Hsu{k: k}) if err := s.populateLocal(k.syscallDispatcher, msg); err != nil { return err } @@ -87,7 +80,7 @@ func (k *outcome) finalise(ctx context.Context, msg message.Msg, id *state.ID, c k.sys = sys k.supp = supp - k.state = &s + k.state = s k.config = config return nil } diff --git a/internal/app/outcome.go b/internal/app/outcome.go index 4d2cd71..b79b075 100644 --- a/internal/app/outcome.go +++ b/internal/app/outcome.go @@ -2,7 +2,6 @@ package app import ( "errors" - "os" "strconv" "hakurei.app/container" @@ -72,10 +71,16 @@ func (s *outcomeState) valid() bool { s.EnvPaths != nil } -// populateEarly populates exported fields via syscallDispatcher. -// This must only be called from the priv side. -func (s *outcomeState) populateEarly(k syscallDispatcher, msg message.Msg) { - s.Shim = &shimParams{PrivPID: os.Getpid(), Verbose: msg.IsVerbose()} +// newOutcomeState returns the address of a new outcomeState with its exported fields populated via syscallDispatcher. +func newOutcomeState(k syscallDispatcher, msg message.Msg, id *state.ID, config *hst.Config, hsu *Hsu) *outcomeState { + s := outcomeState{ + ID: id, + Identity: config.Identity, + UserID: hsu.MustIDMsg(msg), + EnvPaths: copyPaths(k), + Container: config.Container, + } + s.Shim = &shimParams{PrivPID: k.getpid(), Verbose: msg.IsVerbose()} // enforce bounds and default early if s.Container.WaitDelay < 0 { @@ -94,7 +99,7 @@ func (s *outcomeState) populateEarly(k syscallDispatcher, msg message.Msg) { s.Mapuid, s.Mapgid = k.overflowUid(msg), k.overflowGid(msg) } - return + return &s } // populateLocal populates unexported fields from transmitted exported fields.