diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 2afd10700..f55216366 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -3,6 +3,7 @@ package ssh import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -26,6 +27,9 @@ import ( const ( // DefaultShebang is added at the top of a SSH script file DefaultShebang = "#!/bin/sh\n" + + // enable ssh keeplive probes by default + keepAliveInterval = 2 * time.Second ) // randShared is a global random generator object that is shared. @@ -37,11 +41,12 @@ var randShared *rand.Rand // Communicator represents the SSH communicator type Communicator struct { - connInfo *connectionInfo - client *ssh.Client - config *sshConfig - conn net.Conn - address string + connInfo *connectionInfo + client *ssh.Client + config *sshConfig + conn net.Conn + address string + cancelKeepAlive context.CancelFunc lock sync.Mutex } @@ -203,11 +208,39 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { } } + if err != nil { + return err + } + if o != nil { o.Output("Connected!") } - return err + ctx, cancelKeepAlive := context.WithCancel(context.TODO()) + c.cancelKeepAlive = cancelKeepAlive + + // Start a keepalive goroutine to help maintain the connection for + // long-running commands. + log.Printf("[DEBUG] starting ssh KeepAlives") + go func() { + t := time.NewTicker(keepAliveInterval) + defer t.Stop() + for { + select { + case <-t.C: + // there's no useful response to these, just abort when there's + // an error. + _, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil) + if err != nil { + return + } + case <-ctx.Done(): + return + } + } + }() + + return nil } // Disconnect implementation of communicator.Communicator interface @@ -215,6 +248,10 @@ func (c *Communicator) Disconnect() error { c.lock.Lock() defer c.lock.Unlock() + if c.cancelKeepAlive != nil { + c.cancelKeepAlive() + } + if c.config.sshAgent != nil { if err := c.config.sshAgent.Close(); err != nil { return err diff --git a/communicator/ssh/communicator_test.go b/communicator/ssh/communicator_test.go index 546d6f88c..4d9fd848f 100644 --- a/communicator/ssh/communicator_test.go +++ b/communicator/ssh/communicator_test.go @@ -179,6 +179,48 @@ func TestStart(t *testing.T) { } } +// TestKeepAlives verifies that the keepalive messages don't interfere with +// normal operation of the client. +func TestKeepAlives(t *testing.T) { + address := newMockLineServer(t, nil) + parts := strings.Split(address, ":") + + r := &terraform.InstanceState{ + Ephemeral: terraform.EphemeralState{ + ConnInfo: map[string]string{ + "type": "ssh", + "user": "user", + "password": "pass", + "host": parts[0], + "port": parts[1], + "timeout": "30s", + }, + }, + } + + c, err := New(r) + if err != nil { + t.Fatalf("error creating communicator: %s", err) + } + + if err := c.Connect(nil); err != nil { + t.Fatal(err) + } + + var cmd remote.Cmd + stdout := new(bytes.Buffer) + cmd.Command = "echo foo" + cmd.Stdout = stdout + + // wait a bit before executing the command, so that at least 1 keepalive is sent + time.Sleep(3 * time.Second) + + err = c.Start(&cmd) + if err != nil { + t.Fatalf("error executing remote command: %s", err) + } +} + func TestLostConnection(t *testing.T) { address := newMockLineServer(t, nil) parts := strings.Split(address, ":")