diff --git a/helper/wrappedstreams/streams.go b/helper/wrappedstreams/streams.go index d3e5cf1c0..b661ed732 100644 --- a/helper/wrappedstreams/streams.go +++ b/helper/wrappedstreams/streams.go @@ -10,34 +10,31 @@ import ( // Stdin returns the true stdin of the process. func Stdin() *os.File { - stdin := os.Stdin - if panicwrap.Wrapped(nil) { - stdin = wrappedStdin - } - + stdin, _, _ := fds() return stdin } // Stdout returns the true stdout of the process. func Stdout() *os.File { - stdout := os.Stdout - if panicwrap.Wrapped(nil) { - stdout = wrappedStdout - } - + _, stdout, _ := fds() return stdout } // Stderr returns the true stderr of the process. func Stderr() *os.File { - stderr := os.Stderr - if panicwrap.Wrapped(nil) { - stderr = wrappedStderr - } - + _, _, stderr := fds() return stderr } +func fds() (stdin, stdout, stderr *os.File) { + stdin, stdout, stderr = os.Stdin, os.Stdout, os.Stderr + if panicwrap.Wrapped(nil) { + initPlatform() + stdin, stdout, stderr = wrappedStdin, wrappedStdout, wrappedStderr + } + return +} + // These are the wrapped standard streams. These are setup by the // platform specific code in initPlatform. var ( @@ -45,8 +42,3 @@ var ( wrappedStdout *os.File wrappedStderr *os.File ) - -func init() { - // Initialize the platform-specific code - initPlatform() -} diff --git a/helper/wrappedstreams/streams_other.go b/helper/wrappedstreams/streams_other.go index 5ffa413bc..82f1e150c 100644 --- a/helper/wrappedstreams/streams_other.go +++ b/helper/wrappedstreams/streams_other.go @@ -4,11 +4,18 @@ package wrappedstreams import ( "os" + "sync" ) +var initOnce sync.Once + func initPlatform() { - // The standard streams are passed in via extra file descriptors. - wrappedStdin = os.NewFile(uintptr(3), "stdin") - wrappedStdout = os.NewFile(uintptr(4), "stdout") - wrappedStderr = os.NewFile(uintptr(5), "stderr") + // These must be initialized lazily, once it's been determined that this is + // a wrapped process. + initOnce.Do(func() { + // The standard streams are passed in via extra file descriptors. + wrappedStdin = os.NewFile(uintptr(3), "stdin") + wrappedStdout = os.NewFile(uintptr(4), "stdout") + wrappedStderr = os.NewFile(uintptr(5), "stderr") + }) }