diff --git a/main.go b/main.go index 6ab370753..7b88a5269 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "log" "os" "runtime" + "strings" "sync" "github.com/hashicorp/go-plugin" @@ -138,39 +139,27 @@ func wrappedMain() int { // Get the command line args. args := os.Args[1:] + // Build the CLI so far, we do this so we can query the subcommand. + cliRunner := &cli.CLI{ + Args: args, + Commands: Commands, + HelpFunc: helpFunc, + HelpWriter: os.Stdout, + } + // Prefix the args with any args from the EnvCLI - if v := os.Getenv(EnvCLI); v != "" { - log.Printf("[INFO] %s value: %q", EnvCLI, v) - extra, err := shellwords.Parse(v) - if err != nil { - Ui.Error(fmt.Sprintf( - "Error parsing extra CLI args from %s: %s", - EnvCLI, err)) - return 1 - } - - // Find the index to place the flags. We put them exactly - // after the first non-flag arg. - idx := -1 - for i, v := range args { - if len(v) > 0 && v[0] != '-' { - idx = i - break - } - } - - // idx points to the exact arg that isn't a flag. We increment - // by one so that all the copying below expects idx to be the - // insertion point. - idx++ - - // Copy the args - newArgs := make([]string, len(args)+len(extra)) - copy(newArgs, args[:idx]) - copy(newArgs[idx:], extra) - copy(newArgs[len(extra)+idx:], args[idx:]) - args = newArgs + args, err = mergeEnvArgs(EnvCLI, args) + if err != nil { + Ui.Error(err.Error()) + return 1 + } + // Prefix the args with any args from the EnvCLI targeting this command + suffix := strings.Replace(cliRunner.Subcommand(), "-", "_", -1) + args, err = mergeEnvArgs(fmt.Sprintf("%s_%s", EnvCLI, suffix), args) + if err != nil { + Ui.Error(err.Error()) + return 1 } // We shortcut "--version" and "-v" to just show the version @@ -184,8 +173,9 @@ func wrappedMain() int { } } + // Rebuild the CLI with any modified args. log.Printf("[INFO] CLI command args: %#v", args) - cli := &cli.CLI{ + cliRunner = &cli.CLI{ Args: args, Commands: Commands, HelpFunc: helpFunc, @@ -196,7 +186,7 @@ func wrappedMain() int { ContextOpts.Providers = config.ProviderFactories() ContextOpts.Provisioners = config.ProvisionerFactories() - exitCode, err := cli.Run() + exitCode, err := cliRunner.Run() if err != nil { Ui.Error(fmt.Sprintf("Error executing CLI: %s", err.Error())) return 1 @@ -284,3 +274,40 @@ func copyOutput(r io.Reader, doneCh chan<- struct{}) { wg.Wait() } + +func mergeEnvArgs(envName string, args []string) ([]string, error) { + v := os.Getenv(envName) + if v == "" { + return args, nil + } + + log.Printf("[INFO] %s value: %q", envName, v) + extra, err := shellwords.Parse(v) + if err != nil { + return nil, fmt.Errorf( + "Error parsing extra CLI args from %s: %s", + envName, err) + } + + // Find the index to place the flags. We put them exactly + // after the first non-flag arg. + idx := -1 + for i, v := range args { + if len(v) > 0 && v[0] != '-' { + idx = i + break + } + } + + // idx points to the exact arg that isn't a flag. We increment + // by one so that all the copying below expects idx to be the + // insertion point. + idx++ + + // Copy the args + newArgs := make([]string, len(args)+len(extra)) + copy(newArgs, args[:idx]) + copy(newArgs[idx:], extra) + copy(newArgs[len(extra)+idx:], args[idx:]) + return newArgs, nil +} diff --git a/main_test.go b/main_test.go index 154d27bf5..161c59715 100644 --- a/main_test.go +++ b/main_test.go @@ -110,6 +110,7 @@ func TestMain_cliArgsFromEnv(t *testing.T) { for i, tc := range cases { t.Run(fmt.Sprintf("%d-%s", i, tc.Name), func(t *testing.T) { os.Unsetenv(EnvCLI) + defer os.Unsetenv(EnvCLI) // Set the env var value if tc.Value != "" { @@ -142,6 +143,87 @@ func TestMain_cliArgsFromEnv(t *testing.T) { } } +// This test just has more options than the test above. Use this for +// more control over behavior at the expense of more complex test structures. +func TestMain_cliArgsFromEnvAdvanced(t *testing.T) { + // Restore original CLI args + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + cases := []struct { + Name string + Command string + EnvVar string + Args []string + Value string + Expected []string + Err bool + }{ + { + "targeted to another command", + "command", + EnvCLI + "_foo", + []string{"command", "foo", "bar"}, + "-flag", + []string{"foo", "bar"}, + false, + }, + + { + "targeted to this command", + "command", + EnvCLI + "_command", + []string{"command", "foo", "bar"}, + "-flag", + []string{"-flag", "foo", "bar"}, + false, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%d-%s", i, tc.Name), func(t *testing.T) { + // Setup test command and restore that + testCommandName := tc.Command + testCommand := &testCommandCLI{} + defer func() { delete(Commands, testCommandName) }() + Commands[testCommandName] = func() (cli.Command, error) { + return testCommand, nil + } + + os.Unsetenv(tc.EnvVar) + defer os.Unsetenv(tc.EnvVar) + + // Set the env var value + if tc.Value != "" { + if err := os.Setenv(tc.EnvVar, tc.Value); err != nil { + t.Fatalf("err: %s", err) + } + } + + // Setup the args + args := make([]string, len(tc.Args)+1) + args[0] = oldArgs[0] // process name + copy(args[1:], tc.Args) + + // Run it! + os.Args = args + testCommand.Args = nil + exit := wrappedMain() + if (exit != 0) != tc.Err { + t.Fatalf("bad: %d", exit) + } + if tc.Err { + return + } + + // Verify + if !reflect.DeepEqual(testCommand.Args, tc.Expected) { + t.Fatalf("bad: %#v", testCommand.Args) + } + }) + } +} + type testCommandCLI struct { Args []string }