From a0735dd7d5c9c8871357384a27e9e7327760fb68 Mon Sep 17 00:00:00 2001 From: Nathan Brown Date: Fri, 23 Apr 2021 14:43:16 -0500 Subject: [PATCH] Add locking around ssh conns to avoid concurrent map access on reload (#447) --- sshd/server.go | 43 +++++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/sshd/server.go b/sshd/server.go index 1ff32eb..4a78fdf 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -1,8 +1,10 @@ package sshd import ( + "errors" "fmt" "net" + "sync" "github.com/armon/go-radix" "github.com/sirupsen/logrus" @@ -20,8 +22,11 @@ type SSHServer struct { helpCommand *Command commands *radix.Tree listener net.Listener - conns map[int]*session - counter int + + // Locks the conns/counter to avoid concurrent map access + connsLock sync.Mutex + conns map[int]*session + counter int } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen @@ -97,11 +102,24 @@ func (s *SSHServer) Run(addr string) error { } s.l.WithField("sshListener", addr).Info("SSH server is listening") + + // Run loops until there is an error + s.run() + s.closeSessions() + + s.l.Info("SSH server stopped listening") + // We don't return an error because run logs for us + return nil +} + +func (s *SSHServer) run() { for { c, err := s.listener.Accept() if err != nil { - s.l.WithError(err).Warn("Error in listener, shutting down") - return nil + if !errors.Is(err, net.ErrClosed) { + s.l.WithError(err).Warn("Error in listener, shutting down") + } + return } conn, chans, reqs, err := ssh.NewServerConn(c, s.config) @@ -127,37 +145,38 @@ func (s *SSHServer) Run(addr string) error { l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) + s.connsLock.Lock() s.counter++ counter := s.counter s.conns[counter] = session + s.connsLock.Unlock() go ssh.DiscardRequests(reqs) go func() { <-session.exitChan s.l.WithField("id", counter).Debug("closing conn") + s.connsLock.Lock() delete(s.conns, counter) + s.connsLock.Unlock() }() } } func (s *SSHServer) Stop() { - // Close the listener first, to prevent any new connections being accepted. + // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { s.l.WithError(err).Warn("Failed to close the sshd listener") - } else { - s.l.Info("SSH server stopped listening") } } +} - // Force close all existing connections. - // TODO I believe this has a slight race if the listener has just accepted - // a connection. Can fix by moving this to the goroutine that's accepting. +func (s *SSHServer) closeSessions() { + s.connsLock.Lock() for _, c := range s.conns { c.Close() } - - return + s.connsLock.Unlock() } func (s *SSHServer) matchPubKey(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {