diff --git a/config.go b/config.go index 05f692c..2ba0b6e 100644 --- a/config.go +++ b/config.go @@ -65,7 +65,7 @@ func tryTemplate() { } else { fmt.Println(string(s)) } - os.Exit(0) + fmsg.Exit(0) } } diff --git a/error.go b/error.go index 112427a..18f5002 100644 --- a/error.go +++ b/error.go @@ -2,7 +2,6 @@ package main import ( "errors" - "fmt" "git.ophivana.moe/security/fortify/internal/app" "git.ophivana.moe/security/fortify/internal/fmsg" @@ -51,6 +50,6 @@ func logBaseError(err error, message string) { if fmsg.AsBaseError(err, &e) { fmsg.Print(e.Message()) } else { - fmt.Println(message, err) + fmsg.Println(message, err) } } diff --git a/helper/stub.go b/helper/stub.go index c0be010..5b45896 100644 --- a/helper/stub.go +++ b/helper/stub.go @@ -11,6 +11,7 @@ import ( "testing" "git.ophivana.moe/security/fortify/helper/bwrap" + "git.ophivana.moe/security/fortify/internal/fmsg" ) // InternalChildStub is an internal function but exported because it is cross-package; @@ -33,7 +34,7 @@ func InternalChildStub() { genericStub(argsFD, statFD) } - os.Exit(0) + fmsg.Exit(0) } // InternalReplaceExecCommand is an internal function but exported because it is cross-package; diff --git a/internal/fmsg/defer.go b/internal/fmsg/defer.go new file mode 100644 index 0000000..cee3c73 --- /dev/null +++ b/internal/fmsg/defer.go @@ -0,0 +1,70 @@ +package fmsg + +import ( + "os" + "sync" + "sync/atomic" +) + +var ( + wstate atomic.Bool + withhold = make(chan struct{}, 1) + msgbuf = make(chan dOp, 64) // these ops are tiny so a large buffer is allocated for withholding output + + dequeueOnce sync.Once + queueSync sync.WaitGroup +) + +func dequeue() { + go func() { + for { + select { + case op := <-msgbuf: + op.Do() + queueSync.Done() + case <-withhold: + <-withhold + } + } + }() +} + +type dOp interface{ Do() } + +func Exit(code int) { + queueSync.Wait() + os.Exit(code) +} + +func Withhold() { + if wstate.CompareAndSwap(false, true) { + withhold <- struct{}{} + } +} + +func Resume() { + if wstate.CompareAndSwap(true, false) { + withhold <- struct{}{} + } +} + +type dPrint []any + +func (v dPrint) Do() { + std.Print(v...) +} + +type dPrintf struct { + format string + v []any +} + +func (d *dPrintf) Do() { + std.Printf(d.format, d.v...) +} + +type dPrintln []any + +func (v dPrintln) Do() { + std.Println(v...) +} diff --git a/internal/fmsg/fmsg.go b/internal/fmsg/fmsg.go index c72913f..dbe590b 100644 --- a/internal/fmsg/fmsg.go +++ b/internal/fmsg/fmsg.go @@ -4,38 +4,40 @@ package fmsg import ( "log" "os" - "sync/atomic" ) -var ( - std = log.New(os.Stdout, "fortify: ", 0) - warn = log.New(os.Stderr, "fortify: ", 0) - - verbose = new(atomic.Bool) -) +var std = log.New(os.Stderr, "fortify: ", 0) func SetPrefix(prefix string) { prefix += ": " std.SetPrefix(prefix) - warn.SetPrefix(prefix) + std.SetPrefix(prefix) } func Print(v ...any) { - warn.Print(v...) + dequeueOnce.Do(dequeue) + queueSync.Add(1) + msgbuf <- dPrint(v) } func Printf(format string, v ...any) { - warn.Printf(format, v...) + dequeueOnce.Do(dequeue) + queueSync.Add(1) + msgbuf <- &dPrintf{format, v} } func Println(v ...any) { - warn.Println(v...) + dequeueOnce.Do(dequeue) + queueSync.Add(1) + msgbuf <- dPrintln(v) } func Fatal(v ...any) { - warn.Fatal(v...) + Print(v...) + Exit(1) } func Fatalf(format string, v ...any) { - warn.Fatalf(format, v...) + Printf(format, v...) + Exit(1) } diff --git a/internal/fmsg/verbose.go b/internal/fmsg/verbose.go index 36faade..72a92a6 100644 --- a/internal/fmsg/verbose.go +++ b/internal/fmsg/verbose.go @@ -1,5 +1,9 @@ package fmsg +import "sync/atomic" + +var verbose = new(atomic.Bool) + func Verbose() bool { return verbose.Load() } @@ -10,12 +14,12 @@ func SetVerbose(v bool) { func VPrintf(format string, v ...any) { if verbose.Load() { - std.Printf(format, v...) + Printf(format, v...) } } func VPrintln(v ...any) { if verbose.Load() { - std.Println(v...) + Println(v...) } } diff --git a/internal/init/main.go b/internal/init/main.go index 7a340a8..f292772 100644 --- a/internal/init/main.go +++ b/internal/init/main.go @@ -129,7 +129,7 @@ func doInit(fd uintptr) { select { case s := <-sig: fmsg.VPrintln("received", s.String()) - os.Exit(0) + fmsg.Exit(0) case w := <-info: if w.wpid == cmd.Process.Pid { switch { @@ -147,10 +147,10 @@ func doInit(fd uintptr) { }() } case <-done: - os.Exit(r) + fmsg.Exit(r) case <-timeout: fmsg.Println("timeout exceeded waiting for lingering processes") - os.Exit(r) + fmsg.Exit(r) } } } diff --git a/internal/shim/main.go b/internal/shim/main.go index f53e8fb..52a8fa9 100644 --- a/internal/shim/main.go +++ b/internal/shim/main.go @@ -134,9 +134,9 @@ func doShim(socket string) { fmsg.VPrintln("wait:", err) } if b.Unwrap().ProcessState != nil { - os.Exit(b.Unwrap().ProcessState.ExitCode()) + fmsg.Exit(b.Unwrap().ProcessState.ExitCode()) } else { - os.Exit(127) + fmsg.Exit(127) } } } diff --git a/internal/state/print.go b/internal/state/print.go index d8f08fa..8ca3529 100644 --- a/internal/state/print.go +++ b/internal/state/print.go @@ -21,8 +21,7 @@ func MustPrintLauncherStateSimpleGlobal(w **tabwriter.Writer, runDir string) { // read runtime directory to get all UIDs if dirs, err := os.ReadDir(path.Join(runDir, "state")); err != nil && !errors.Is(err, os.ErrNotExist) { - fmsg.Println("cannot read runtime directory:", err) - os.Exit(1) + fmsg.Fatal("cannot read runtime directory:", err) } else { for _, e := range dirs { // skip non-directories @@ -112,13 +111,11 @@ func (s *simpleStore) mustPrintLauncherState(w **tabwriter.Writer, now time.Time }); err != nil { fmsg.Printf("cannot perform action on store %q: %s", path.Join(s.path...), err) if !ok { - fmsg.Println("store faulted before printing") - os.Exit(1) + fmsg.Fatal("store faulted before printing") } } if innerErr != nil { - fmsg.Printf("cannot print launcher state for store %q: %s", path.Join(s.path...), innerErr) - os.Exit(1) + fmsg.Fatalf("cannot print launcher state for store %q: %s", path.Join(s.path...), innerErr) } } diff --git a/internal/system.go b/internal/system.go index a8d915c..e98bb67 100644 --- a/internal/system.go +++ b/internal/system.go @@ -109,7 +109,7 @@ func (s *Std) Open(name string) (fs.File, error) { return os.Open(name) } func (s *Std) Exit(code int) { - os.Exit(code) + fmsg.Exit(code) } const xdgRuntimeDir = "XDG_RUNTIME_DIR" diff --git a/state.go b/state.go index 40b550c..3a3f9a6 100644 --- a/state.go +++ b/state.go @@ -30,6 +30,6 @@ func tryState() { fmt.Println("No information available") } - os.Exit(0) + fmsg.Exit(0) } }