diff --git a/command/builder.go b/command/builder.go new file mode 100644 index 0000000..ae49c56 --- /dev/null +++ b/command/builder.go @@ -0,0 +1,58 @@ +package command + +import ( + "flag" + "fmt" + "io" +) + +// New initialises a root Node. +func New(output io.Writer, logf LogFunc, name string) Command { + return rootNode{newNode(output, logf, name, "")} +} + +func newNode(output io.Writer, logf LogFunc, name, usage string) *node { + n := &node{ + name: name, usage: usage, + out: output, logf: logf, + set: flag.NewFlagSet(name, flag.ContinueOnError), + } + n.set.SetOutput(output) + n.set.Usage = func() { + _ = n.writeHelp() + if n.suffix.Len() > 0 { + _, _ = fmt.Fprintln(output, "Flags:") + n.set.PrintDefaults() + _, _ = fmt.Fprintln(output) + } + } + + return n +} + +func (n *node) Command(name, usage string, f HandlerFunc) Node { + if f == nil { + panic("invalid handler") + } + if name == "" || usage == "" { + panic("invalid subcommand") + } + + s := newNode(n.out, n.logf, name, usage) + s.f = f + if !n.adopt(s) { + panic("attempted to initialise subcommand with non-unique name") + } + return n +} + +func (n *node) New(name, usage string) Node { + if name == "" || usage == "" { + panic("invalid subcommand tree") + } + s := newNode(n.out, n.logf, name, usage) + if !n.adopt(s) { + panic("attempted to initialise subcommand tree with non-unique name") + } + return s +} diff --git a/command/builder_test.go b/command/builder_test.go new file mode 100644 index 0000000..1cd4ce3 --- /dev/null +++ b/command/builder_test.go @@ -0,0 +1,56 @@ +package command_test + +import ( + "testing" + + "git.gensokyo.uk/security/fortify/command" +) + +func TestBuild(t *testing.T) { + c := command.New(nil, nil, "test") + stubHandler := func([]string) error { panic("unreachable") } + + t.Run("nil direct handler", func(t *testing.T) { + defer checkRecover(t, "Command", "invalid handler") + c.Command("name", "usage", nil) + }) + + t.Run("direct zero length", func(t *testing.T) { + wantPanic := "invalid subcommand" + t.Run("zero length name", func(t *testing.T) { defer checkRecover(t, "Command", wantPanic); c.Command("", "usage", stubHandler) }) + t.Run("zero length usage", func(t *testing.T) { defer checkRecover(t, "Command", wantPanic); c.Command("name", "", stubHandler) }) + }) + + t.Run("direct adopt unique names", func(t *testing.T) { + c.Command("d0", "usage", stubHandler) + c.Command("d1", "usage", stubHandler) + }) + + t.Run("direct adopt non-unique name", func(t *testing.T) { + defer checkRecover(t, "Command", "attempted to initialise subcommand with non-unique name") + c.Command("d0", "usage", stubHandler) + }) + + t.Run("zero length", func(t *testing.T) { + wantPanic := "invalid subcommand tree" + t.Run("zero length name", func(t *testing.T) { defer checkRecover(t, "New", wantPanic); c.New("", "usage") }) + t.Run("zero length usage", func(t *testing.T) { defer checkRecover(t, "New", wantPanic); c.New("name", "") }) + }) + + t.Run("direct adopt unique names", func(t *testing.T) { + c.New("t0", "usage") + c.New("t1", "usage") + }) + + t.Run("direct adopt non-unique name", func(t *testing.T) { + defer checkRecover(t, "Command", "attempted to initialise subcommand tree with non-unique name") + c.New("t0", "usage") + }) +} + +func checkRecover(t *testing.T, name, wantPanic string) { + if r := recover(); r != wantPanic { + t.Errorf("%s: panic = %v; wantPanic %v", + name, r, wantPanic) + } +} diff --git a/command/flag.go b/command/flag.go new file mode 100644 index 0000000..5765dd0 --- /dev/null +++ b/command/flag.go @@ -0,0 +1,49 @@ +package command + +import ( + "errors" + "flag" + "strings" +) + +// FlagError wraps errors returned by [flag]. +type FlagError struct{ error } + +func (e FlagError) Success() bool { return errors.Is(e.error, flag.ErrHelp) } +func (e FlagError) Is(target error) bool { + return (e.error == nil && target == nil) || + ((e.error != nil && target != nil) && e.error.Error() == target.Error()) +} + +func (n *node) Flag(p any, name string, value FlagDefiner, usage string) Node { + value.Define(&n.suffix, n.set, p, name, usage) + return n +} + +// StringFlag is the default value of a string flag. +type StringFlag string + +func (v StringFlag) Define(b *strings.Builder, set *flag.FlagSet, p any, name, usage string) { + set.StringVar(p.(*string), name, string(v), usage) + b.WriteString(" [" + prettyFlag(name) + " ]") +} + +// BoolFlag is the default value of a bool flag. +type BoolFlag bool + +func (v BoolFlag) Define(b *strings.Builder, set *flag.FlagSet, p any, name, usage string) { + set.BoolVar(p.(*bool), name, bool(v), usage) + b.WriteString(" [" + prettyFlag(name) + "]") +} + +// this has no effect on parse outcome +func prettyFlag(name string) string { + switch len(name) { + case 0: + panic("zero length flag name") + case 1: + return "-" + name + default: + return "--" + name + } +} diff --git a/command/parse.go b/command/parse.go new file mode 100644 index 0000000..6081ca6 --- /dev/null +++ b/command/parse.go @@ -0,0 +1,72 @@ +package command + +import ( + "errors" + "log" +) + +var ( + ErrEmptyTree = errors.New("subcommand tree has no nodes") + ErrNoMatch = errors.New("did not match any subcommand") +) + +func (n *node) Parse(arguments []string) error { + if n.usage == "" { // root node has zero length usage + if n.next != nil { + panic("invalid toplevel state") + } + goto match + } + + if len(arguments) == 0 { + // unreachable: zero length args cause upper level to return with a help message + panic("attempted to parse with zero length args") + } + if arguments[0] != n.name { + if n.next == nil { + n.printf("%q is not a valid command", arguments[0]) + return ErrNoMatch + } + n.next.prefix = n.prefix + return n.next.Parse(arguments) + } + arguments = arguments[1:] + +match: + if n.child != nil { + if n.f != nil { + panic("invalid subcommand tree state") + } + // propagate help prefix early: flag set usage dereferences help + n.child.prefix = append(n.prefix, n.name) + } + + if n.set.Parsed() { + panic("invalid set state") + } + if err := n.set.Parse(arguments); err != nil { + return FlagError{err} + } + args := n.set.Args() + + if n.child != nil { + if len(args) == 0 { + return n.writeHelp() + } + return n.child.Parse(args) + } + + if n.f == nil { + n.printf("%q has no subcommands", n.name) + return ErrEmptyTree + } + return n.f(args) +} + +func (n *node) printf(format string, a ...any) { + if n.logf == nil { + log.Printf(format, a...) + } else { + n.logf(format, a...) + } +} diff --git a/command/parse_test.go b/command/parse_test.go new file mode 100644 index 0000000..f424e97 --- /dev/null +++ b/command/parse_test.go @@ -0,0 +1,292 @@ +package command_test + +import ( + "bytes" + "errors" + "flag" + "fmt" + "io" + "log" + "strings" + "testing" + + "git.gensokyo.uk/security/fortify/command" +) + +func TestParse(t *testing.T) { + testCases := []struct { + name string + buildTree func(wout, wlog io.Writer) command.Command + args []string + want string + wantLog string + wantErr error + }{ + { + "d=0 empty sub", + func(wout, wlog io.Writer) command.Command { return command.New(wout, newLogFunc(wlog), "root") }, + []string{""}, + "", "test: \"root\" has no subcommands\n", command.ErrEmptyTree, + }, + { + "d=0 empty sub garbage", + func(wout, wlog io.Writer) command.Command { return command.New(wout, newLogFunc(wlog), "root") }, + []string{"a", "b", "c", "d"}, + "", "test: \"root\" has no subcommands\n", command.ErrEmptyTree, + }, + { + "d=0 no match", + buildTestCommand, + []string{"nonexistent"}, + "", "test: \"nonexistent\" is not a valid command\n", command.ErrNoMatch, + }, + { + "d=0 direct error", + buildTestCommand, + []string{"error"}, + "", "", errSuccess, + }, + { + "d=0 direct error garbage", + buildTestCommand, + []string{"error", "0", "1", "2"}, + "", "", errSuccess, + }, + { + "d=0 direct success out of order", + buildTestCommand, + []string{"succeed"}, + "", "", nil, + }, + { + "d=0 direct success output", + buildTestCommand, + []string{"print", "0", "1", "2"}, + "012", "", nil, + }, + { + "d=0 string flag", + buildTestCommand, + []string{"--val", "64d3b4b7b21788585845060e2199a78f", "flag"}, + "64d3b4b7b21788585845060e2199a78f", "", nil, + }, + { + "d=0 out of order string flag", + buildTestCommand, + []string{"flag", "--val", "64d3b4b7b21788585845060e2199a78f"}, + "flag provided but not defined: -val\n\nUsage:\ttest flag [-h | --help] COMMAND [OPTIONS]\n\n", "", + errors.New("flag provided but not defined: -val"), + }, + + { + "d=1 empty sub", + buildTestCommand, + []string{"empty"}, + "", "test: \"empty\" has no subcommands\n", command.ErrEmptyTree, + }, + { + "d=1 empty sub garbage", + buildTestCommand, + []string{"empty", "a", "b", "c", "d"}, + "", "test: \"empty\" has no subcommands\n", command.ErrEmptyTree, + }, + { + "d=1 empty sub help", + buildTestCommand, + []string{"empty", "-h"}, + "\nUsage:\ttest empty [-h | --help] COMMAND [OPTIONS]\n\n", "", flag.ErrHelp, + }, + { + "d=1 no match", + buildTestCommand, + []string{"join", "23aa3bb0", "34986782", "d8859355", "cd9ac317", ", "}, + "", "test: \"23aa3bb0\" is not a valid command\n", command.ErrNoMatch, + }, + { + "d=1 direct success out", + buildTestCommand, + []string{"join", "out", "23aa3bb0", "34986782", "d8859355", "cd9ac317", ", "}, + "23aa3bb0, 34986782, d8859355, cd9ac317", "", nil, + }, + { + "d=1 direct success log", + buildTestCommand, + []string{"join", "log", "23aa3bb0", "34986782", "d8859355", "cd9ac317", ", "}, + "", "test: 23aa3bb0, 34986782, d8859355, cd9ac317\n", nil, + }, + + { + "d=4 empty sub", + buildTestCommand, + []string{"deep", "d=2", "d=3", "d=4"}, + "", "test: \"d=4\" has no subcommands\n", command.ErrEmptyTree}, + + { + "d=0 help", + buildTestCommand, + []string{}, + ` +Usage: test [-h | --help] [-v] [--val ] COMMAND [OPTIONS] + +Commands: + error return an error + print wraps Fprint + flag print value passed by flag + empty empty subcommand + join wraps strings.Join + succeed this command succeeds + deep top level of command tree with various levels + +`, "", command.ErrHelp, + }, + { + "d=0 help flag", + buildTestCommand, + []string{"-h"}, + ` +Usage: test [-h | --help] [-v] [--val ] COMMAND [OPTIONS] + +Commands: + error return an error + print wraps Fprint + flag print value passed by flag + empty empty subcommand + join wraps strings.Join + succeed this command succeeds + deep top level of command tree with various levels + +Flags: + -v verbosity + -val string + store val for the "flag" command (default "default") + +`, "", flag.ErrHelp, + }, + + { + "d=1 help", + buildTestCommand, + []string{"join"}, + ` +Usage: test join [-h | --help] COMMAND [OPTIONS] + +Commands: + out write result to wout + log log result to wlog + +`, "", command.ErrHelp, + }, + { + "d=1 help flag", + buildTestCommand, + []string{"join", "-h"}, + ` +Usage: test join [-h | --help] COMMAND [OPTIONS] + +Commands: + out write result to wout + log log result to wlog + +`, "", flag.ErrHelp, + }, + + { + "d=2 help", + buildTestCommand, + []string{"deep", "d=2"}, + ` +Usage: test deep d=2 [-h | --help] COMMAND [OPTIONS] + +Commands: + d=3 relative third level + +`, "", command.ErrHelp, + }, + { + "d=2 help flag", + buildTestCommand, + []string{"deep", "d=2", "-h"}, + ` +Usage: test deep d=2 [-h | --help] COMMAND [OPTIONS] + +Commands: + d=3 relative third level + +`, "", flag.ErrHelp, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + wout, wlog := new(bytes.Buffer), new(bytes.Buffer) + c := tc.buildTree(wout, wlog) + + if err := c.Parse(tc.args); !errors.Is(err, tc.wantErr) { + t.Errorf("Parse: error = %v; wantErr %v", err, tc.wantErr) + } + if got := wout.String(); got != tc.want { + t.Errorf("Parse: %s want %s", got, tc.want) + } + if gotLog := wlog.String(); gotLog != tc.wantLog { + t.Errorf("Parse: log = %s wantLog %s", gotLog, tc.wantLog) + } + }) + } +} + +var ( + errJoinLen = errors.New("not enough arguments to join") + errSuccess = errors.New("success") +) + +func buildTestCommand(wout, wlog io.Writer) (c command.Command) { + var val string + + logf := newLogFunc(wlog) + c = command.New(wout, logf, "test"). + Flag(new(bool), "v", command.BoolFlag(false), "verbosity"). + Command("error", "return an error", func([]string) error { + return errSuccess + }). + Command("print", "wraps Fprint", func(args []string) error { + a := make([]any, len(args)) + for i, v := range args { + a[i] = v + } + _, err := fmt.Fprint(wout, a...) + return err + }). + Flag(&val, "val", command.StringFlag("default"), "store val for the \"flag\" command"). + Command("flag", "print value passed by flag", func(args []string) error { + _, err := fmt.Fprint(wout, val) + return err + }) + + c.New("empty", "empty subcommand") + + c.New("join", "wraps strings.Join"). + Command("out", "write result to wout", func(args []string) error { + if len(args) == 0 { + return errJoinLen + } + _, err := fmt.Fprint(wout, strings.Join(args[:len(args)-1], args[len(args)-1])) + return err + }). + Command("log", "log result to wlog", func(args []string) error { + if len(args) == 0 { + return errJoinLen + } + logf("%s", strings.Join(args[:len(args)-1], args[len(args)-1])) + return nil + }) + + c.Command("succeed", "this command succeeds", func([]string) error { return nil }) + + c.New("deep", "top level of command tree with various levels"). + New("d=2", "relative second level"). + New("d=3", "relative third level"). + New("d=4", "relative fourth level") + + return +} + +func newLogFunc(w io.Writer) command.LogFunc { return log.New(w, "test: ", 0).Printf } diff --git a/command/unreachable_test.go b/command/unreachable_test.go new file mode 100644 index 0000000..9ace479 --- /dev/null +++ b/command/unreachable_test.go @@ -0,0 +1,54 @@ +package command + +import ( + "flag" + "testing" +) + +func TestParseUnreachable(t *testing.T) { + // top level bypasses name matching and recursive calls to Parse + // returns when encountering zero-length args + t.Run("zero-length args", func(t *testing.T) { + defer checkRecover(t, "Parse", "attempted to parse with zero length args") + _ = newNode(panicWriter{}, nil, " ", " ").Parse(nil) + }) + + // top level must not have siblings + t.Run("toplevel siblings", func(t *testing.T) { + defer checkRecover(t, "Parse", "invalid toplevel state") + n := newNode(panicWriter{}, nil, " ", "") + n.append(newNode(panicWriter{}, nil, " ", " ")) + _ = n.Parse(nil) + }) + + // a node with descendents must not have a direct handler + t.Run("sub handle conflict", func(t *testing.T) { + defer checkRecover(t, "Parse", "invalid subcommand tree state") + n := newNode(panicWriter{}, nil, " ", "") + n.adopt(newNode(panicWriter{}, nil, " ", " ")) + n.f = func([]string) error { panic("unreachable") } + _ = n.Parse(nil) + }) + + // this would only happen if a node was matched twice + t.Run("parsed flag set", func(t *testing.T) { + defer checkRecover(t, "Parse", "invalid set state") + n := newNode(panicWriter{}, nil, " ", "") + set := flag.NewFlagSet("parsed", flag.ContinueOnError) + set.SetOutput(panicWriter{}) + _ = set.Parse(nil) + n.set = set + _ = n.Parse(nil) + }) +} + +type panicWriter struct{} + +func (p panicWriter) Write([]byte) (int, error) { panic("unreachable") } + +func checkRecover(t *testing.T, name, wantPanic string) { + if r := recover(); r != wantPanic { + t.Errorf("%s: panic = %v; wantPanic %v", + name, r, wantPanic) + } +} diff --git a/command/wrap.go b/command/wrap.go new file mode 100644 index 0000000..ee08a61 --- /dev/null +++ b/command/wrap.go @@ -0,0 +1,14 @@ +package command + +// the top level node wants [Command] returned for its builder methods +type rootNode struct{ *node } + +func (r rootNode) Command(name, usage string, f HandlerFunc) Command { + r.node.Command(name, usage, f) + return r +} + +func (r rootNode) Flag(p any, name string, value FlagDefiner, usage string) Command { + r.node.Flag(p, name, value, usage) + return r +}