diff --git a/communicator/ssh/communicator_test.go b/communicator/ssh/communicator_test.go index b2344fa55..5064a7e82 100644 --- a/communicator/ssh/communicator_test.go +++ b/communicator/ssh/communicator_test.go @@ -5,6 +5,7 @@ package ssh import ( "bufio" "bytes" + "encoding/base64" "fmt" "io" "io/ioutil" @@ -13,6 +14,7 @@ import ( "os" "path/filepath" "regexp" + "strconv" "strings" "testing" @@ -171,6 +173,74 @@ func TestStart(t *testing.T) { } } +func TestHostKey(t *testing.T) { + // get the server's public key + signer, err := ssh.ParsePrivateKey([]byte(testServerPrivateKey)) + if err != nil { + panic("unable to parse private key: " + err.Error()) + } + pubKey := fmt.Sprintf("ssh-rsa %s", base64.StdEncoding.EncodeToString(signer.PublicKey().Marshal())) + + address := newMockLineServer(t) + host, p, _ := net.SplitHostPort(address) + port, _ := strconv.Atoi(p) + + connInfo := &connectionInfo{ + User: "user", + Password: "pass", + Host: host, + HostKey: pubKey, + Port: port, + Timeout: "30s", + } + + cfg, err := prepareSSHConfig(connInfo) + if err != nil { + t.Fatal(err) + } + + c := &Communicator{ + connInfo: connInfo, + config: cfg, + } + + var cmd remote.Cmd + stdout := new(bytes.Buffer) + cmd.Command = "echo foo" + cmd.Stdout = stdout + + if err := c.Start(&cmd); err != nil { + t.Fatal(err) + } + if err := c.Disconnect(); err != nil { + t.Fatal(err) + } + + // now check with the wrong HostKey + address = newMockLineServer(t) + _, p, _ = net.SplitHostPort(address) + port, _ = strconv.Atoi(p) + + connInfo.HostKey = testClientPublicKey + connInfo.Port = port + + cfg, err = prepareSSHConfig(connInfo) + if err != nil { + t.Fatal(err) + } + + c = &Communicator{ + connInfo: connInfo, + config: cfg, + } + + err = c.Start(&cmd) + if err == nil || !strings.Contains(err.Error(), "mismatch") { + t.Fatalf("expected host key mismatch, got error:%v", err) + } + +} + func TestAccUploadFile(t *testing.T) { // use the local ssh server and scp binary to check uploads if ok := os.Getenv("SSH_UPLOAD_TEST"); ok == "" {