internal/app/shim: use syscall dispatcher
	
		
			
	
		
	
	
		
	
		
			All checks were successful
		
		
	
	
		
			
				
	
				Test / Create distribution (push) Successful in 33s
				
			
		
			
				
	
				Test / Sandbox (push) Successful in 2m14s
				
			
		
			
				
	
				Test / Hakurei (push) Successful in 3m9s
				
			
		
			
				
	
				Test / Sandbox (race detector) (push) Successful in 3m58s
				
			
		
			
				
	
				Test / Hpkg (push) Successful in 4m5s
				
			
		
			
				
	
				Test / Hakurei (race detector) (push) Successful in 4m46s
				
			
		
			
				
	
				Test / Flake checks (push) Successful in 1m28s
				
			
		
		
	
	
				
					
				
			
		
			All checks were successful
		
		
	
	Test / Create distribution (push) Successful in 33s
				
			Test / Sandbox (push) Successful in 2m14s
				
			Test / Hakurei (push) Successful in 3m9s
				
			Test / Sandbox (race detector) (push) Successful in 3m58s
				
			Test / Hpkg (push) Successful in 4m5s
				
			Test / Hakurei (race detector) (push) Successful in 4m46s
				
			Test / Flake checks (push) Successful in 1m28s
				
			This enables instrumented testing of the shim. Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
		
							parent
							
								
									c5f59c5488
								
							
						
					
					
						commit
						8accd3b219
					
				| @ -50,7 +50,7 @@ func buildCommand(ctx context.Context, msg message.Msg, early *earlyHardeningErr | |||||||
| 		Flag(&flagVerbose, "v", command.BoolFlag(false), "Increase log verbosity"). | 		Flag(&flagVerbose, "v", command.BoolFlag(false), "Increase log verbosity"). | ||||||
| 		Flag(&flagJSON, "json", command.BoolFlag(false), "Serialise output in JSON when applicable") | 		Flag(&flagJSON, "json", command.BoolFlag(false), "Serialise output in JSON when applicable") | ||||||
| 
 | 
 | ||||||
| 	c.Command("shim", command.UsageInternal, func([]string) error { app.ShimMain(); return errSuccess }) | 	c.Command("shim", command.UsageInternal, func([]string) error { app.Shim(msg); return errSuccess }) | ||||||
| 
 | 
 | ||||||
| 	c.Command("app", "Load and start container from configuration file", func(args []string) error { | 	c.Command("app", "Load and start container from configuration file", func(args []string) error { | ||||||
| 		if len(args) < 1 { | 		if len(args) < 1 { | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ func Main(ctx context.Context, msg message.Msg, config *hst.Config) { | |||||||
| 		log.Fatal(err) | 		log.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	seal := outcome{syscallDispatcher: direct{}} | 	seal := outcome{syscallDispatcher: direct{msg}} | ||||||
| 	if err := seal.finalise(ctx, msg, &id, config); err != nil { | 	if err := seal.finalise(ctx, msg, &id, config); err != nil { | ||||||
| 		printMessageError("cannot seal app:", err) | 		printMessageError("cannot seal app:", err) | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
|  | |||||||
| @ -574,10 +574,9 @@ func (s stubOsFileReadCloser) Stat() (fs.FileInfo, error) { panic("attempting to | |||||||
| 
 | 
 | ||||||
| type stubNixOS struct { | type stubNixOS struct { | ||||||
| 	usernameErr map[string]error | 	usernameErr map[string]error | ||||||
|  | 	panicDispatcher | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (k *stubNixOS) new(func(k syscallDispatcher)) { panic("not implemented") } |  | ||||||
| 
 |  | ||||||
| func (k *stubNixOS) getpid() int { return 0xdeadbeef } | func (k *stubNixOS) getpid() int { return 0xdeadbeef } | ||||||
| func (k *stubNixOS) getuid() int { return 1971 } | func (k *stubNixOS) getuid() int { return 1971 } | ||||||
| func (k *stubNixOS) getgid() int { return 100 } | func (k *stubNixOS) getgid() int { return 100 } | ||||||
|  | |||||||
| @ -1,16 +1,18 @@ | |||||||
| package app | package app | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/fs" | 	"io/fs" | ||||||
| 	"log" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/exec" | 	"os/exec" | ||||||
|  | 	"os/signal" | ||||||
| 	"os/user" | 	"os/user" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 
 | 
 | ||||||
| 	"hakurei.app/container" | 	"hakurei.app/container" | ||||||
| 	"hakurei.app/container/check" | 	"hakurei.app/container/check" | ||||||
|  | 	"hakurei.app/container/seccomp" | ||||||
| 	"hakurei.app/internal" | 	"hakurei.app/internal" | ||||||
| 	"hakurei.app/message" | 	"hakurei.app/message" | ||||||
| 	"hakurei.app/system/dbus" | 	"hakurei.app/system/dbus" | ||||||
| @ -28,7 +30,7 @@ type syscallDispatcher interface { | |||||||
| 	// new starts a goroutine with a new instance of syscallDispatcher. | 	// new starts a goroutine with a new instance of syscallDispatcher. | ||||||
| 	// A syscallDispatcher must never be used in any goroutine other than the one owning it, | 	// A syscallDispatcher must never be used in any goroutine other than the one owning it, | ||||||
| 	// just synchronising access is not enough, as this is for test instrumentation. | 	// just synchronising access is not enough, as this is for test instrumentation. | ||||||
| 	new(f func(k syscallDispatcher)) | 	new(f func(k syscallDispatcher, msg message.Msg)) | ||||||
| 
 | 
 | ||||||
| 	// getpid provides [os.Getpid]. | 	// getpid provides [os.Getpid]. | ||||||
| 	getpid() int | 	getpid() int | ||||||
| @ -38,6 +40,8 @@ type syscallDispatcher interface { | |||||||
| 	getgid() int | 	getgid() int | ||||||
| 	// lookupEnv provides [os.LookupEnv]. | 	// lookupEnv provides [os.LookupEnv]. | ||||||
| 	lookupEnv(key string) (string, bool) | 	lookupEnv(key string) (string, bool) | ||||||
|  | 	// pipe provides os.Pipe. | ||||||
|  | 	pipe() (r, w *os.File, err error) | ||||||
| 	// stat provides [os.Stat]. | 	// stat provides [os.Stat]. | ||||||
| 	stat(name string) (os.FileInfo, error) | 	stat(name string) (os.FileInfo, error) | ||||||
| 	// open provides [os.Open]. | 	// open provides [os.Open]. | ||||||
| @ -46,6 +50,8 @@ type syscallDispatcher interface { | |||||||
| 	readdir(name string) ([]os.DirEntry, error) | 	readdir(name string) ([]os.DirEntry, error) | ||||||
| 	// tempdir provides [os.TempDir]. | 	// tempdir provides [os.TempDir]. | ||||||
| 	tempdir() string | 	tempdir() string | ||||||
|  | 	// exit provides [os.Exit]. | ||||||
|  | 	exit(code int) | ||||||
| 
 | 
 | ||||||
| 	// evalSymlinks provides [filepath.EvalSymlinks]. | 	// evalSymlinks provides [filepath.EvalSymlinks]. | ||||||
| 	evalSymlinks(path string) (string, error) | 	evalSymlinks(path string) (string, error) | ||||||
| @ -56,10 +62,29 @@ type syscallDispatcher interface { | |||||||
| 	// cmdOutput provides the Output method of [exec.Cmd]. | 	// cmdOutput provides the Output method of [exec.Cmd]. | ||||||
| 	cmdOutput(cmd *exec.Cmd) ([]byte, error) | 	cmdOutput(cmd *exec.Cmd) ([]byte, error) | ||||||
| 
 | 
 | ||||||
|  | 	// notifyContext provides [signal.NotifyContext]. | ||||||
|  | 	notifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) | ||||||
|  | 
 | ||||||
|  | 	// prctl provides [container.Prctl]. | ||||||
|  | 	prctl(op, arg2, arg3 uintptr) error | ||||||
| 	// overflowUid provides [container.OverflowUid]. | 	// overflowUid provides [container.OverflowUid]. | ||||||
| 	overflowUid(msg message.Msg) int | 	overflowUid(msg message.Msg) int | ||||||
| 	// overflowGid provides [container.OverflowGid]. | 	// overflowGid provides [container.OverflowGid]. | ||||||
| 	overflowGid(msg message.Msg) int | 	overflowGid(msg message.Msg) int | ||||||
|  | 	// setDumpable provides [container.SetDumpable]. | ||||||
|  | 	setDumpable(dumpable uintptr) error | ||||||
|  | 	// receive provides [container.Receive]. | ||||||
|  | 	receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error) | ||||||
|  | 
 | ||||||
|  | 	// containerStart provides the Start method of [container.Container]. | ||||||
|  | 	containerStart(z *container.Container) error | ||||||
|  | 	// containerStart provides the Serve method of [container.Container]. | ||||||
|  | 	containerServe(z *container.Container) error | ||||||
|  | 	// containerStart provides the Wait method of [container.Container]. | ||||||
|  | 	containerWait(z *container.Container) error | ||||||
|  | 
 | ||||||
|  | 	// seccompLoad provides [seccomp.Load]. | ||||||
|  | 	seccompLoad(rules []seccomp.NativeRule, flags seccomp.ExportFlag) error | ||||||
| 
 | 
 | ||||||
| 	// mustHsuPath provides [internal.MustHsuPath]. | 	// mustHsuPath provides [internal.MustHsuPath]. | ||||||
| 	mustHsuPath() *check.Absolute | 	mustHsuPath() *check.Absolute | ||||||
| @ -67,23 +92,32 @@ type syscallDispatcher interface { | |||||||
| 	// dbusAddress provides [dbus.Address]. | 	// dbusAddress provides [dbus.Address]. | ||||||
| 	dbusAddress() (session, system string) | 	dbusAddress() (session, system string) | ||||||
| 
 | 
 | ||||||
|  | 	// setupContSignal provides setupContSignal. | ||||||
|  | 	setupContSignal(pid int) (io.ReadCloser, func(), error) | ||||||
|  | 
 | ||||||
|  | 	// getMsg returns the [message.Msg] held by syscallDispatcher. | ||||||
|  | 	getMsg() message.Msg | ||||||
|  | 	// fatal provides [log.Fatal]. | ||||||
|  | 	fatal(v ...any) | ||||||
| 	// fatalf provides [log.Fatalf]. | 	// fatalf provides [log.Fatalf]. | ||||||
| 	fatalf(format string, v ...any) | 	fatalf(format string, v ...any) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // direct implements syscallDispatcher on the current kernel. | // direct implements syscallDispatcher on the current kernel. | ||||||
| type direct struct{} | type direct struct{ msg message.Msg } | ||||||
| 
 | 
 | ||||||
| func (k direct) new(f func(k syscallDispatcher)) { go f(k) } | func (k direct) new(f func(k syscallDispatcher, msg message.Msg)) { go f(k, k.msg) } | ||||||
| 
 | 
 | ||||||
| func (direct) getpid() int                                { return os.Getpid() } | func (direct) getpid() int                                { return os.Getpid() } | ||||||
| func (direct) getuid() int                                { return os.Getuid() } | func (direct) getuid() int                                { return os.Getuid() } | ||||||
| func (direct) getgid() int                                { return os.Getgid() } | func (direct) getgid() int                                { return os.Getgid() } | ||||||
| func (direct) lookupEnv(key string) (string, bool)        { return os.LookupEnv(key) } | func (direct) lookupEnv(key string) (string, bool)        { return os.LookupEnv(key) } | ||||||
|  | func (direct) pipe() (r, w *os.File, err error)           { return os.Pipe() } | ||||||
| func (direct) stat(name string) (os.FileInfo, error)      { return os.Stat(name) } | func (direct) stat(name string) (os.FileInfo, error)      { return os.Stat(name) } | ||||||
| func (direct) open(name string) (osFile, error)           { return os.Open(name) } | func (direct) open(name string) (osFile, error)           { return os.Open(name) } | ||||||
| func (direct) readdir(name string) ([]os.DirEntry, error) { return os.ReadDir(name) } | func (direct) readdir(name string) ([]os.DirEntry, error) { return os.ReadDir(name) } | ||||||
| func (direct) tempdir() string                            { return os.TempDir() } | func (direct) tempdir() string                            { return os.TempDir() } | ||||||
|  | func (direct) exit(code int)                              { os.Exit(code) } | ||||||
| 
 | 
 | ||||||
| func (direct) evalSymlinks(path string) (string, error) { return filepath.EvalSymlinks(path) } | func (direct) evalSymlinks(path string) (string, error) { return filepath.EvalSymlinks(path) } | ||||||
| 
 | 
 | ||||||
| @ -98,11 +132,32 @@ func (direct) lookupGroupId(name string) (gid string, err error) { | |||||||
| 
 | 
 | ||||||
| func (direct) cmdOutput(cmd *exec.Cmd) ([]byte, error) { return cmd.Output() } | func (direct) cmdOutput(cmd *exec.Cmd) ([]byte, error) { return cmd.Output() } | ||||||
| 
 | 
 | ||||||
| func (direct) overflowUid(msg message.Msg) int { return container.OverflowUid(msg) } | func (direct) notifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) { | ||||||
| func (direct) overflowGid(msg message.Msg) int { return container.OverflowGid(msg) } | 	return signal.NotifyContext(parent, signals...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (direct) prctl(op, arg2, arg3 uintptr) error { return container.Prctl(op, arg2, arg3) } | ||||||
|  | func (direct) overflowUid(msg message.Msg) int    { return container.OverflowUid(msg) } | ||||||
|  | func (direct) overflowGid(msg message.Msg) int    { return container.OverflowGid(msg) } | ||||||
|  | func (direct) setDumpable(dumpable uintptr) error { return container.SetDumpable(dumpable) } | ||||||
|  | func (direct) receive(key string, e any, fdp *uintptr) (func() error, error) { | ||||||
|  | 	return container.Receive(key, e, fdp) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (direct) containerStart(z *container.Container) error { return z.Start() } | ||||||
|  | func (direct) containerServe(z *container.Container) error { return z.Serve() } | ||||||
|  | func (direct) containerWait(z *container.Container) error  { return z.Wait() } | ||||||
|  | 
 | ||||||
|  | func (direct) seccompLoad(rules []seccomp.NativeRule, flags seccomp.ExportFlag) error { | ||||||
|  | 	return seccomp.Load(rules, flags) | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| func (direct) mustHsuPath() *check.Absolute { return internal.MustHsuPath() } | func (direct) mustHsuPath() *check.Absolute { return internal.MustHsuPath() } | ||||||
| 
 | 
 | ||||||
| func (k direct) dbusAddress() (session, system string) { return dbus.Address() } | func (direct) dbusAddress() (session, system string) { return dbus.Address() } | ||||||
| 
 | 
 | ||||||
| func (direct) fatalf(format string, v ...any) { log.Fatalf(format, v...) } | func (direct) setupContSignal(pid int) (io.ReadCloser, func(), error) { return setupContSignal(pid) } | ||||||
|  | 
 | ||||||
|  | func (k direct) getMsg() message.Msg            { return k.msg } | ||||||
|  | func (k direct) fatal(v ...any)                 { k.msg.GetLogger().Fatal(v...) } | ||||||
|  | func (k direct) fatalf(format string, v ...any) { k.msg.GetLogger().Fatalf(format, v...) } | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ package app | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/fs" | 	"io/fs" | ||||||
| 	"log" | 	"log" | ||||||
| @ -15,6 +16,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"hakurei.app/container" | 	"hakurei.app/container" | ||||||
| 	"hakurei.app/container/check" | 	"hakurei.app/container/check" | ||||||
|  | 	"hakurei.app/container/seccomp" | ||||||
| 	"hakurei.app/container/stub" | 	"hakurei.app/container/stub" | ||||||
| 	"hakurei.app/hst" | 	"hakurei.app/hst" | ||||||
| 	"hakurei.app/internal/app/state" | 	"hakurei.app/internal/app/state" | ||||||
| @ -492,20 +494,38 @@ func (panicMsgContext) Value(any) any               { panic("unreachable") } | |||||||
| // This type is meant to be embedded in partial syscallDispatcher implementations. | // This type is meant to be embedded in partial syscallDispatcher implementations. | ||||||
| type panicDispatcher struct{} | type panicDispatcher struct{} | ||||||
| 
 | 
 | ||||||
| func (panicDispatcher) new(func(k syscallDispatcher))         { panic("unreachable") } | func (panicDispatcher) new(func(k syscallDispatcher, msg message.Msg))      { panic("unreachable") } | ||||||
| func (panicDispatcher) getpid() int                           { panic("unreachable") } | func (panicDispatcher) getpid() int                                         { panic("unreachable") } | ||||||
| func (panicDispatcher) getuid() int                           { panic("unreachable") } | func (panicDispatcher) getuid() int                                         { panic("unreachable") } | ||||||
| func (panicDispatcher) getgid() int                           { panic("unreachable") } | func (panicDispatcher) getgid() int                                         { panic("unreachable") } | ||||||
| func (panicDispatcher) lookupEnv(string) (string, bool)       { panic("unreachable") } | func (panicDispatcher) lookupEnv(string) (string, bool)                     { panic("unreachable") } | ||||||
| func (panicDispatcher) stat(string) (os.FileInfo, error)      { panic("unreachable") } | func (panicDispatcher) pipe() (*os.File, *os.File, error)                   { panic("unreachable") } | ||||||
| func (panicDispatcher) open(string) (osFile, error)           { panic("unreachable") } | func (panicDispatcher) stat(string) (os.FileInfo, error)                    { panic("unreachable") } | ||||||
| func (panicDispatcher) readdir(string) ([]os.DirEntry, error) { panic("unreachable") } | func (panicDispatcher) open(string) (osFile, error)                         { panic("unreachable") } | ||||||
| func (panicDispatcher) tempdir() string                       { panic("unreachable") } | func (panicDispatcher) readdir(string) ([]os.DirEntry, error)               { panic("unreachable") } | ||||||
| func (panicDispatcher) evalSymlinks(string) (string, error)   { panic("unreachable") } | func (panicDispatcher) tempdir() string                                     { panic("unreachable") } | ||||||
| func (panicDispatcher) lookupGroupId(string) (string, error)  { panic("unreachable") } | func (panicDispatcher) exit(int)                                            { panic("unreachable") } | ||||||
| func (panicDispatcher) cmdOutput(*exec.Cmd) ([]byte, error)   { panic("unreachable") } | func (panicDispatcher) evalSymlinks(string) (string, error)                 { panic("unreachable") } | ||||||
| func (panicDispatcher) overflowUid(message.Msg) int           { panic("unreachable") } | func (panicDispatcher) prctl(uintptr, uintptr, uintptr) error               { panic("unreachable") } | ||||||
| func (panicDispatcher) overflowGid(message.Msg) int           { panic("unreachable") } | func (panicDispatcher) lookupGroupId(string) (string, error)                { panic("unreachable") } | ||||||
| func (panicDispatcher) mustHsuPath() *check.Absolute          { panic("unreachable") } | func (panicDispatcher) cmdOutput(*exec.Cmd) ([]byte, error)                 { panic("unreachable") } | ||||||
| func (panicDispatcher) dbusAddress() (string, string)         { panic("unreachable") } | func (panicDispatcher) overflowUid(message.Msg) int                         { panic("unreachable") } | ||||||
| func (panicDispatcher) fatalf(string, ...any)                 { panic("unreachable") } | func (panicDispatcher) overflowGid(message.Msg) int                         { panic("unreachable") } | ||||||
|  | func (panicDispatcher) setDumpable(uintptr) error                           { panic("unreachable") } | ||||||
|  | func (panicDispatcher) receive(string, any, *uintptr) (func() error, error) { panic("unreachable") } | ||||||
|  | func (panicDispatcher) containerStart(*container.Container) error           { panic("unreachable") } | ||||||
|  | func (panicDispatcher) containerServe(*container.Container) error           { panic("unreachable") } | ||||||
|  | func (panicDispatcher) containerWait(*container.Container) error            { panic("unreachable") } | ||||||
|  | func (panicDispatcher) mustHsuPath() *check.Absolute                        { panic("unreachable") } | ||||||
|  | func (panicDispatcher) dbusAddress() (string, string)                       { panic("unreachable") } | ||||||
|  | func (panicDispatcher) setupContSignal(int) (io.ReadCloser, func(), error)  { panic("unreachable") } | ||||||
|  | func (panicDispatcher) getMsg() message.Msg                                 { panic("unreachable") } | ||||||
|  | func (panicDispatcher) fatal(...any)                                        { panic("unreachable") } | ||||||
|  | func (panicDispatcher) fatalf(string, ...any)                               { panic("unreachable") } | ||||||
|  | 
 | ||||||
|  | func (panicDispatcher) notifyContext(context.Context, ...os.Signal) (context.Context, context.CancelFunc) { | ||||||
|  | 	panic("unreachable") | ||||||
|  | } | ||||||
|  | func (panicDispatcher) seccompLoad([]seccomp.NativeRule, seccomp.ExportFlag) error { | ||||||
|  | 	panic("unreachable") | ||||||
|  | } | ||||||
|  | |||||||
| @ -21,7 +21,9 @@ type Hsu struct { | |||||||
| 	id     int | 	id     int | ||||||
| 
 | 
 | ||||||
| 	kOnce sync.Once | 	kOnce sync.Once | ||||||
| 	k     syscallDispatcher | 
 | ||||||
|  | 	// msg is not populated | ||||||
|  | 	k syscallDispatcher | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| var ErrHsuAccess = errors.New("current user is not in the hsurc file") | var ErrHsuAccess = errors.New("current user is not in the hsurc file") | ||||||
|  | |||||||
| @ -7,7 +7,6 @@ import ( | |||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/exec" | 	"os/exec" | ||||||
| 	"os/signal" |  | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"syscall" | 	"syscall" | ||||||
| @ -23,6 +22,20 @@ import ( | |||||||
| //#include "shim-signal.h" | //#include "shim-signal.h" | ||||||
| import "C" | import "C" | ||||||
| 
 | 
 | ||||||
|  | // setupContSignal sets up the SIGCONT signal handler for the cross-uid shim exit hack. | ||||||
|  | // The signal handler is implemented in C, signals can be processed by reading from the returned reader. | ||||||
|  | // The returned function must be called after all signal processing concludes. | ||||||
|  | func setupContSignal(pid int) (io.ReadCloser, func(), error) { | ||||||
|  | 	if r, w, err := os.Pipe(); err != nil { | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} else if _, err = C.hakurei_shim_setup_cont_signal(C.pid_t(pid), C.int(w.Fd())); err != nil { | ||||||
|  | 		_, _ = r.Close(), w.Close() | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} else { | ||||||
|  | 		return r, func() { runtime.KeepAlive(w) }, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // shimEnv is the name of the environment variable storing decimal representation of | // shimEnv is the name of the environment variable storing decimal representation of | ||||||
| // setup pipe fd for [container.Receive]. | // setup pipe fd for [container.Receive]. | ||||||
| const shimEnv = "HAKUREI_SHIM" | const shimEnv = "HAKUREI_SHIM" | ||||||
| @ -46,76 +59,102 @@ type shimParams struct { | |||||||
| // valid checks shimParams to be safe for use. | // valid checks shimParams to be safe for use. | ||||||
| func (p *shimParams) valid() bool { return p != nil && p.PrivPID > 0 } | func (p *shimParams) valid() bool { return p != nil && p.PrivPID > 0 } | ||||||
| 
 | 
 | ||||||
| // ShimMain is the main function of the shim process and runs as the unconstrained target user. | // shimName is the prefix used by log.std in the shim process. | ||||||
| func ShimMain() { | const shimName = "shim" | ||||||
| 	log.SetPrefix("shim: ") |  | ||||||
| 	log.SetFlags(0) |  | ||||||
| 	msg := message.NewMsg(log.Default()) |  | ||||||
| 
 | 
 | ||||||
| 	if err := container.SetDumpable(container.SUID_DUMP_DISABLE); err != nil { | // Shim is called by the main function of the shim process and runs as the unconstrained target user. | ||||||
| 		log.Fatalf("cannot set SUID_DUMP_DISABLE: %s", err) | // Shim does not return. | ||||||
|  | func Shim(msg message.Msg) { | ||||||
|  | 	if msg == nil { | ||||||
|  | 		msg = message.NewMsg(log.Default()) | ||||||
|  | 	} | ||||||
|  | 	shimEntrypoint(direct{msg}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func shimEntrypoint(k syscallDispatcher) { | ||||||
|  | 	msg := k.getMsg() | ||||||
|  | 	if msg == nil { | ||||||
|  | 		panic("attempting to call shimEntrypoint with nil msg") | ||||||
|  | 	} else if logger := msg.GetLogger(); logger != nil { | ||||||
|  | 		logger.SetPrefix(shimName + ": ") | ||||||
|  | 		logger.SetFlags(0) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := k.setDumpable(container.SUID_DUMP_DISABLE); err != nil { | ||||||
|  | 		k.fatalf("cannot set SUID_DUMP_DISABLE: %s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var ( | 	var ( | ||||||
| 		state      outcomeState | 		state      outcomeState | ||||||
| 		closeSetup func() error | 		closeSetup func() error | ||||||
| 	) | 	) | ||||||
| 	if f, err := container.Receive(shimEnv, &state, nil); err != nil { | 	if f, err := k.receive(shimEnv, &state, nil); err != nil { | ||||||
| 		if errors.Is(err, syscall.EBADF) { | 		if errors.Is(err, syscall.EBADF) { | ||||||
| 			log.Fatal("invalid config descriptor") | 			k.fatal("invalid config descriptor") | ||||||
| 		} | 		} | ||||||
| 		if errors.Is(err, container.ErrReceiveEnv) { | 		if errors.Is(err, container.ErrReceiveEnv) { | ||||||
| 			log.Fatal(shimEnv + " not set") | 			k.fatal(shimEnv + " not set") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		log.Fatalf("cannot receive shim setup params: %v", err) | 		k.fatalf("cannot receive shim setup params: %v", err) | ||||||
| 	} else { | 	} else { | ||||||
| 		msg.SwapVerbose(state.Shim.Verbose) | 		msg.SwapVerbose(state.Shim.Verbose) | ||||||
| 		closeSetup = f | 		closeSetup = f | ||||||
| 
 | 
 | ||||||
| 		if err = state.populateLocal(direct{}, msg); err != nil { | 		if err = state.populateLocal(k, msg); err != nil { | ||||||
| 			if m, ok := message.GetMessage(err); ok { | 			if m, ok := message.GetMessage(err); ok { | ||||||
| 				log.Fatal(m) | 				k.fatal(m) | ||||||
| 			} else { | 			} else { | ||||||
| 				log.Fatalf("cannot populate local state: %v", err) | 				k.fatalf("cannot populate local state: %v", err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// the Go runtime does not expose siginfo_t so SIGCONT is handled in C to check si_pid | 	// the Go runtime does not expose siginfo_t so SIGCONT is handled in C to check si_pid | ||||||
| 	var signalPipe io.ReadCloser | 	var signalPipe io.ReadCloser | ||||||
| 	if r, w, err := os.Pipe(); err != nil { | 	if r, wKeepAlive, err := k.setupContSignal(state.Shim.PrivPID); err != nil { | ||||||
| 		log.Fatalf("cannot pipe: %v", err) | 		switch { | ||||||
| 	} else if _, err = C.hakurei_shim_setup_cont_signal(C.pid_t(state.Shim.PrivPID), C.int(w.Fd())); err != nil { | 		case errors.As(err, new(*os.SyscallError)): // returned by os.Pipe | ||||||
| 		log.Fatalf("cannot install SIGCONT handler: %v", err) | 			k.fatal(err.Error()) | ||||||
|  | 			return | ||||||
|  | 
 | ||||||
|  | 		case errors.As(err, new(syscall.Errno)): // returned by hakurei_shim_setup_cont_signal | ||||||
|  | 			k.fatalf("cannot install SIGCONT handler: %v", err) | ||||||
|  | 			return | ||||||
|  | 
 | ||||||
|  | 		default: // unreachable | ||||||
|  | 			k.fatalf("cannot set up exit request: %v", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 	} else { | 	} else { | ||||||
| 		defer runtime.KeepAlive(w) | 		defer wKeepAlive() | ||||||
| 		signalPipe = r | 		signalPipe = r | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// pdeath_signal delivery is checked as if the dying process called kill(2), see kernel/exit.c | 	// pdeath_signal delivery is checked as if the dying process called kill(2), see kernel/exit.c | ||||||
| 	if _, _, errno := syscall.Syscall(syscall.SYS_PRCTL, syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGCONT), 0); errno != 0 { | 	if err := k.prctl(syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGCONT), 0); err != nil { | ||||||
| 		log.Fatalf("cannot set parent-death signal: %v", errno) | 		k.fatalf("cannot set parent-death signal: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	stateParams := state.newParams() | 	stateParams := state.newParams() | ||||||
| 	for _, op := range state.Shim.Ops { | 	for _, op := range state.Shim.Ops { | ||||||
| 		if err := op.toContainer(stateParams); err != nil { | 		if err := op.toContainer(stateParams); err != nil { | ||||||
| 			if m, ok := message.GetMessage(err); ok { | 			if m, ok := message.GetMessage(err); ok { | ||||||
| 				log.Fatal(m) | 				k.fatal(m) | ||||||
| 			} else { | 			} else { | ||||||
| 				log.Fatalf("cannot create container state: %v", err) | 				k.fatalf("cannot create container state: %v", err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// shim exit outcomes | 	// shim exit outcomes | ||||||
| 	var cancelContainer atomic.Pointer[context.CancelFunc] | 	var cancelContainer atomic.Pointer[context.CancelFunc] | ||||||
| 	go func() { | 	k.new(func(k syscallDispatcher, msg message.Msg) { | ||||||
| 		buf := make([]byte, 1) | 		buf := make([]byte, 1) | ||||||
| 		for { | 		for { | ||||||
| 			if _, err := signalPipe.Read(buf); err != nil { | 			if _, err := signalPipe.Read(buf); err != nil { | ||||||
| 				log.Fatalf("cannot read from signal pipe: %v", err) | 				k.fatalf("cannot read from signal pipe: %v", err) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			switch buf[0] { | 			switch buf[0] { | ||||||
| @ -128,37 +167,37 @@ func ShimMain() { | |||||||
| 
 | 
 | ||||||
| 				// setup has not completed, terminate immediately | 				// setup has not completed, terminate immediately | ||||||
| 				msg.Resume() | 				msg.Resume() | ||||||
| 				os.Exit(hst.ExitRequest) | 				k.exit(hst.ExitRequest) | ||||||
| 				return | 				return | ||||||
| 
 | 
 | ||||||
| 			case 1: // got SIGCONT after adoption: monitor died before delivering signal | 			case 1: // got SIGCONT after adoption: monitor died before delivering signal | ||||||
| 				msg.BeforeExit() | 				msg.BeforeExit() | ||||||
| 				os.Exit(hst.ExitOrphan) | 				k.exit(hst.ExitOrphan) | ||||||
| 				return | 				return | ||||||
| 
 | 
 | ||||||
| 			case 2: // unreachable | 			case 2: // unreachable | ||||||
| 				log.Println("sa_sigaction got invalid siginfo") | 				msg.Verbose("sa_sigaction got invalid siginfo") | ||||||
| 
 | 
 | ||||||
| 			case 3: // got SIGCONT from unexpected process: hopefully the terminal driver | 			case 3: // got SIGCONT from unexpected process: hopefully the terminal driver | ||||||
| 				log.Println("got SIGCONT from unexpected process") | 				msg.Verbose("got SIGCONT from unexpected process") | ||||||
| 
 | 
 | ||||||
| 			default: // unreachable | 			default: // unreachable | ||||||
| 				log.Fatalf("got invalid message %d from signal handler", buf[0]) | 				k.fatalf("got invalid message %d from signal handler", buf[0]) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}) | ||||||
| 
 | 
 | ||||||
| 	if stateParams.params.Ops == nil { | 	if stateParams.params.Ops == nil { | ||||||
| 		log.Fatal("invalid container params") | 		k.fatal("invalid container params") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// close setup socket | 	// close setup socket | ||||||
| 	if err := closeSetup(); err != nil { | 	if err := closeSetup(); err != nil { | ||||||
| 		log.Printf("cannot close setup pipe: %v", err) | 		msg.Verbosef("cannot close setup pipe: %v", err) | ||||||
| 		// not fatal | 		// not fatal | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) | 	ctx, stop := k.notifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) | ||||||
| 	cancelContainer.Store(&stop) | 	cancelContainer.Store(&stop) | ||||||
| 	z := container.New(ctx, msg) | 	z := container.New(ctx, msg) | ||||||
| 	z.Params = *stateParams.params | 	z.Params = *stateParams.params | ||||||
| @ -167,30 +206,30 @@ func ShimMain() { | |||||||
| 	// bounds and default enforced in finalise.go | 	// bounds and default enforced in finalise.go | ||||||
| 	z.WaitDelay = state.Shim.WaitDelay | 	z.WaitDelay = state.Shim.WaitDelay | ||||||
| 
 | 
 | ||||||
| 	if err := z.Start(); err != nil { | 	if err := k.containerStart(z); err != nil { | ||||||
| 		printMessageError("cannot start container:", err) | 		printMessageError("cannot start container:", err) | ||||||
| 		os.Exit(hst.ExitFailure) | 		k.exit(hst.ExitFailure) | ||||||
| 	} | 	} | ||||||
| 	if err := z.Serve(); err != nil { | 	if err := k.containerServe(z); err != nil { | ||||||
| 		printMessageError("cannot configure container:", err) | 		printMessageError("cannot configure container:", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := seccomp.Load( | 	if err := k.seccompLoad( | ||||||
| 		seccomp.Preset(comp.PresetStrict, seccomp.AllowMultiarch), | 		seccomp.Preset(comp.PresetStrict, seccomp.AllowMultiarch), | ||||||
| 		seccomp.AllowMultiarch, | 		seccomp.AllowMultiarch, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		log.Fatalf("cannot load syscall filter: %v", err) | 		k.fatalf("cannot load syscall filter: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := z.Wait(); err != nil { | 	if err := k.containerWait(z); err != nil { | ||||||
| 		var exitError *exec.ExitError | 		var exitError *exec.ExitError | ||||||
| 		if !errors.As(err, &exitError) { | 		if !errors.As(err, &exitError) { | ||||||
| 			if errors.Is(err, context.Canceled) { | 			if errors.Is(err, context.Canceled) { | ||||||
| 				os.Exit(hst.ExitCancel) | 				k.exit(hst.ExitCancel) | ||||||
| 			} | 			} | ||||||
| 			log.Printf("wait: %v", err) | 			msg.Verbosef("cannot wait: %v", err) | ||||||
| 			os.Exit(127) | 			k.exit(127) | ||||||
| 		} | 		} | ||||||
| 		os.Exit(exitError.ExitCode()) | 		k.exit(exitError.ExitCode()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user