diff --git a/command/builder.go b/command/builder.go index 5c9b4de..c0782c6 100644 --- a/command/builder.go +++ b/command/builder.go @@ -7,8 +7,10 @@ import ( ) // New initialises a root Node. -func New(output io.Writer, logf LogFunc, name string) Command { - return rootNode{newNode(output, logf, name, "")} +func New(output io.Writer, logf LogFunc, name string, early HandlerFunc) Command { + c := rootNode{newNode(output, logf, name, "")} + c.f = early + return c } func newNode(output io.Writer, logf LogFunc, name, usage string) *node { diff --git a/command/builder_test.go b/command/builder_test.go index 1cd4ce3..22ce7ae 100644 --- a/command/builder_test.go +++ b/command/builder_test.go @@ -7,7 +7,7 @@ import ( ) func TestBuild(t *testing.T) { - c := command.New(nil, nil, "test") + c := command.New(nil, nil, "test", nil) stubHandler := func([]string) error { panic("unreachable") } t.Run("nil direct handler", func(t *testing.T) { diff --git a/command/parse.go b/command/parse.go index 6081ca6..0212046 100644 --- a/command/parse.go +++ b/command/parse.go @@ -34,9 +34,6 @@ func (n *node) Parse(arguments []string) error { match: if n.child != nil { - if n.f != nil { - panic("invalid subcommand tree state") - } // propagate help prefix early: flag set usage dereferences help n.child.prefix = append(n.prefix, n.name) } @@ -50,6 +47,17 @@ match: args := n.set.Args() if n.child != nil { + if n.f != nil { + if n.usage != "" { // root node early special case + panic("invalid subcommand tree state") + } + + // special case: root node calls HandlerFunc for initialisation + if err := n.f(nil); err != nil { + return err + } + } + if len(args) == 0 { return n.writeHelp() } diff --git a/command/parse_test.go b/command/parse_test.go index f424e97..3174dcc 100644 --- a/command/parse_test.go +++ b/command/parse_test.go @@ -24,13 +24,13 @@ func TestParse(t *testing.T) { }{ { "d=0 empty sub", - func(wout, wlog io.Writer) command.Command { return command.New(wout, newLogFunc(wlog), "root") }, + func(wout, wlog io.Writer) command.Command { return command.New(wout, newLogFunc(wlog), "root", nil) }, []string{""}, "", "test: \"root\" has no subcommands\n", command.ErrEmptyTree, }, { "d=0 empty sub garbage", - func(wout, wlog io.Writer) command.Command { return command.New(wout, newLogFunc(wlog), "root") }, + func(wout, wlog io.Writer) command.Command { return command.New(wout, newLogFunc(wlog), "root", nil) }, []string{"a", "b", "c", "d"}, "", "test: \"root\" has no subcommands\n", command.ErrEmptyTree, }, @@ -77,6 +77,18 @@ func TestParse(t *testing.T) { "flag provided but not defined: -val\n\nUsage:\ttest flag [-h | --help] COMMAND [OPTIONS]\n\n", "", errors.New("flag provided but not defined: -val"), }, + { + "d=0 bool flag", + buildTestCommand, + []string{"-v", "succeed"}, + "", "test: verbose\n", nil, + }, + { + "d=0 bool flag early error", + buildTestCommand, + []string{"--fail", "succeed"}, + "", "", errSuccess, + }, { "d=1 empty sub", @@ -126,7 +138,7 @@ func TestParse(t *testing.T) { buildTestCommand, []string{}, ` -Usage: test [-h | --help] [-v] [--val ] COMMAND [OPTIONS] +Usage: test [-h | --help] [-v] [--fail] [--val ] COMMAND [OPTIONS] Commands: error return an error @@ -144,7 +156,7 @@ Commands: buildTestCommand, []string{"-h"}, ` -Usage: test [-h | --help] [-v] [--val ] COMMAND [OPTIONS] +Usage: test [-h | --help] [-v] [--fail] [--val ] COMMAND [OPTIONS] Commands: error return an error @@ -156,7 +168,9 @@ Commands: deep top level of command tree with various levels Flags: - -v verbosity + -fail + fail early + -v verbose output -val string store val for the "flag" command (default "default") @@ -239,11 +253,24 @@ var ( ) func buildTestCommand(wout, wlog io.Writer) (c command.Command) { - var val string + var ( + flagVerbose bool + flagFail bool + flagVal string + ) logf := newLogFunc(wlog) - c = command.New(wout, logf, "test"). - Flag(new(bool), "v", command.BoolFlag(false), "verbosity"). + c = command.New(wout, logf, "test", func([]string) error { + if flagVerbose { + logf("verbose") + } + if flagFail { + return errSuccess + } + return nil + }). + Flag(&flagVerbose, "v", command.BoolFlag(false), "verbose output"). + Flag(&flagFail, "fail", command.BoolFlag(false), "fail early"). Command("error", "return an error", func([]string) error { return errSuccess }). @@ -255,9 +282,9 @@ func buildTestCommand(wout, wlog io.Writer) (c command.Command) { _, err := fmt.Fprint(wout, a...) return err }). - Flag(&val, "val", command.StringFlag("default"), "store val for the \"flag\" command"). + Flag(&flagVal, "val", command.StringFlag("default"), "store val for the \"flag\" command"). Command("flag", "print value passed by flag", func(args []string) error { - _, err := fmt.Fprint(wout, val) + _, err := fmt.Fprint(wout, flagVal) return err }) diff --git a/command/unreachable_test.go b/command/unreachable_test.go index 9ace479..7140818 100644 --- a/command/unreachable_test.go +++ b/command/unreachable_test.go @@ -24,10 +24,10 @@ func TestParseUnreachable(t *testing.T) { // a node with descendents must not have a direct handler t.Run("sub handle conflict", func(t *testing.T) { defer checkRecover(t, "Parse", "invalid subcommand tree state") - n := newNode(panicWriter{}, nil, " ", "") + n := newNode(panicWriter{}, nil, " ", " ") n.adopt(newNode(panicWriter{}, nil, " ", " ")) n.f = func([]string) error { panic("unreachable") } - _ = n.Parse(nil) + _ = n.Parse([]string{" "}) }) // this would only happen if a node was matched twice