Don't use a global logger (#423)

This commit is contained in:
Nathan Brown 2021-03-26 09:46:30 -05:00 committed by GitHub
parent 7a9f9dbded
commit 3ea7e1b75f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 590 additions and 470 deletions

View File

@ -26,7 +26,7 @@ func NewBits(bits uint64) *Bits {
}
}
func (b *Bits) Check(i uint64) bool {
func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true
@ -47,7 +47,7 @@ func (b *Bits) Check(i uint64) bool {
return false
}
func (b *Bits) Update(i uint64) bool {
func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current.
if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through

View File

@ -7,6 +7,7 @@ import (
)
func TestBits(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
// make sure it is the right size
@ -14,46 +15,46 @@ func TestBits(t *testing.T) {
// This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(1))
u := b.Update(1)
assert.True(t, b.Check(l, 1))
u := b.Update(l, 1)
assert.True(t, u)
assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two
assert.True(t, b.Check(2))
u = b.Update(2)
assert.True(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.True(t, u)
assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits)
// Receive two again - it will fail
assert.False(t, b.Check(2))
u = b.Update(2)
assert.False(t, b.Check(l, 2))
u = b.Update(l, 2)
assert.False(t, u)
assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(15))
u = b.Update(15)
assert.True(t, b.Check(l, 15))
u = b.Update(l, 15)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(14))
u = b.Update(14)
assert.True(t, b.Check(l, 14))
u = b.Update(l, 14)
assert.True(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(5))
u = b.Update(5)
assert.False(t, b.Check(l, 5))
u = b.Update(l, 5)
assert.False(t, u)
assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false}
@ -61,63 +62,65 @@ func TestBits(t *testing.T) {
// make sure we handle wrapping around once to the current position
b = NewBits(10)
assert.True(t, b.Update(1))
assert.True(t, b.Update(11))
assert.True(t, b.Update(l, 1))
assert.True(t, b.Update(l, 11))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
// Walk through a few windows in order
b = NewBits(10)
for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(i), "Error while checking %v", i)
assert.True(t, b.Update(i), "Error while updating %v", i)
assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(l, i), "Error while updating %v", i)
}
}
func TestBitsDupeCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(1))
assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.False(t, b.Update(1))
assert.False(t, b.Update(l, 1))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(2))
assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(3))
assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.False(t, b.Update(1))
assert.False(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.Equal(t, int64(2), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
}
func TestBitsOutOfWindowCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(20))
assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.True(t, b.Update(21))
assert.True(t, b.Update(22))
assert.True(t, b.Update(23))
assert.True(t, b.Update(24))
assert.True(t, b.Update(25))
assert.True(t, b.Update(26))
assert.True(t, b.Update(27))
assert.True(t, b.Update(28))
assert.True(t, b.Update(29))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.False(t, b.Update(0))
assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment
@ -127,23 +130,24 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
}
func TestBitsLostCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10)
b.lostCounter.Clear()
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0))
assert.True(t, b.Update(0))
assert.True(t, b.Update(20))
assert.True(t, b.Update(21))
assert.True(t, b.Update(22))
assert.True(t, b.Update(23))
assert.True(t, b.Update(24))
assert.True(t, b.Update(25))
assert.True(t, b.Update(26))
assert.True(t, b.Update(27))
assert.True(t, b.Update(28))
assert.True(t, b.Update(29))
assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@ -153,56 +157,56 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear()
b.outOfWindowCounter.Clear()
assert.True(t, b.Update(0))
assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(9))
assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets
assert.True(t, b.Update(10))
assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count())
// 11 will set 1 index, 1 was missed, we should see 1 packet lost
assert.True(t, b.Update(11))
assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(1), b.lostCounter.Count())
// Now let's fill in the window, should end up with 8 lost packets
assert.True(t, b.Update(12))
assert.True(t, b.Update(13))
assert.True(t, b.Update(14))
assert.True(t, b.Update(15))
assert.True(t, b.Update(16))
assert.True(t, b.Update(17))
assert.True(t, b.Update(18))
assert.True(t, b.Update(19))
assert.True(t, b.Update(l, 12))
assert.True(t, b.Update(l, 13))
assert.True(t, b.Update(l, 14))
assert.True(t, b.Update(l, 15))
assert.True(t, b.Update(l, 16))
assert.True(t, b.Update(l, 17))
assert.True(t, b.Update(l, 18))
assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Jump ahead by a window size
assert.True(t, b.Update(29))
assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(8), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in
assert.True(t, b.Update(30))
assert.True(t, b.Update(31))
assert.True(t, b.Update(32))
assert.True(t, b.Update(33))
assert.True(t, b.Update(34))
assert.True(t, b.Update(35))
assert.True(t, b.Update(36))
assert.True(t, b.Update(37))
assert.True(t, b.Update(38))
assert.True(t, b.Update(l, 30))
assert.True(t, b.Update(l, 31))
assert.True(t, b.Update(l, 32))
assert.True(t, b.Update(l, 33))
assert.True(t, b.Update(l, 34))
assert.True(t, b.Update(l, 35))
assert.True(t, b.Update(l, 36))
assert.True(t, b.Update(l, 37))
assert.True(t, b.Update(l, 38))
// 39 packets tracked, 22 seen, 17 lost
assert.Equal(t, int64(17), b.lostCounter.Count())
// Jump ahead by 2 windows, should have recording 1 full window missing
assert.True(t, b.Update(58))
assert.True(t, b.Update(l, 58))
assert.Equal(t, int64(27), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in from this window
assert.True(t, b.Update(59))
assert.True(t, b.Update(60))
assert.True(t, b.Update(61))
assert.True(t, b.Update(62))
assert.True(t, b.Update(63))
assert.True(t, b.Update(64))
assert.True(t, b.Update(65))
assert.True(t, b.Update(66))
assert.True(t, b.Update(67))
assert.True(t, b.Update(l, 59))
assert.True(t, b.Update(l, 60))
assert.True(t, b.Update(l, 61))
assert.True(t, b.Update(l, 62))
assert.True(t, b.Update(l, 63))
assert.True(t, b.Update(l, 64))
assert.True(t, b.Update(l, 65))
assert.True(t, b.Update(l, 66))
assert.True(t, b.Update(l, 67))
// 68 packets tracked, 32 seen, 36 missed
assert.Equal(t, int64(36), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count())

View File

@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
@ -119,7 +120,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
return NewCertState(nebulaCert, rawKey)
}
func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) {
func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
var rawCA []byte
var err error

View File

@ -46,15 +46,16 @@ func main() {
os.Exit(1)
}
config := nebula.NewConfig()
l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {

View File

@ -24,14 +24,15 @@ func (p *program) Start(s service.Service) error {
// Start should not block.
logger.Info("Nebula service starting.")
config := nebula.NewConfig()
l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*p.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %s", err)
}
l := logrus.New()
l.Out = os.Stdout
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
if err != nil {
return err

View File

@ -40,15 +40,16 @@ func main() {
os.Exit(1)
}
config := nebula.NewConfig()
l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath)
if err != nil {
fmt.Printf("failed to load config: %s", err)
os.Exit(1)
}
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) {

View File

@ -26,11 +26,13 @@ type Config struct {
Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{}
callbacks []func(*Config)
l *logrus.Logger
}
func NewConfig() *Config {
func NewConfig(l *logrus.Logger) *Config {
return &Config{
Settings: make(map[interface{}]interface{}),
l: l,
}
}
@ -99,12 +101,12 @@ func (c *Config) HasChanged(k string) bool {
newVals, err := yaml.Marshal(nv)
if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
}
oldVals, err := yaml.Marshal(ov)
if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
}
return string(newVals) != string(oldVals)
@ -118,7 +120,7 @@ func (c *Config) CatchHUP() {
go func() {
for range ch {
l.Info("Caught HUP, reloading config")
c.l.Info("Caught HUP, reloading config")
c.ReloadConfig()
}
}()
@ -132,7 +134,7 @@ func (c *Config) ReloadConfig() {
err := c.Load(c.path)
if err != nil {
l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return
}
@ -500,7 +502,7 @@ func configLogger(c *Config) error {
if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
}
l.SetLevel(logLevel)
c.l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "")
@ -512,13 +514,13 @@ func configLogger(c *Config) error {
logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat {
case "text":
l.Formatter = &logrus.TextFormatter{
c.l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp,
}
case "json":
l.Formatter = &logrus.JSONFormatter{
c.l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp,
}

View File

@ -11,14 +11,15 @@ import (
)
func TestConfig_Load(t *testing.T) {
l := NewTestLogger()
dir, err := ioutil.TempDir("", "config-test")
// invalid yaml
c := NewConfig()
c := NewConfig(l)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge
c = NewConfig()
c = NewConfig(l)
os.RemoveAll(dir)
os.Mkdir(dir, 0755)
@ -40,8 +41,9 @@ func TestConfig_Load(t *testing.T) {
}
func TestConfig_Get(t *testing.T) {
l := NewTestLogger()
// test simple type
c := NewConfig()
c := NewConfig(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound"))
@ -55,13 +57,15 @@ func TestConfig_Get(t *testing.T) {
}
func TestConfig_GetStringSlice(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
}
func TestConfig_GetBool(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false))
@ -88,7 +92,8 @@ func TestConfig_GetBool(t *testing.T) {
}
func TestConfig_GetAllowList(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true,
}
@ -181,20 +186,21 @@ func TestConfig_GetAllowList(t *testing.T) {
}
func TestConfig_HasChanged(t *testing.T) {
l := NewTestLogger()
// No reload has occurred, return false
c := NewConfig()
c := NewConfig(l)
c.Settings["test"] = "hi"
assert.False(t, c.HasChanged(""))
// Test key change
c = NewConfig()
c = NewConfig(l)
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "no"}
assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged(""))
// No key change
c = NewConfig()
c = NewConfig(l)
c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "hi"}
assert.False(t, c.HasChanged("test"))
@ -202,12 +208,13 @@ func TestConfig_HasChanged(t *testing.T) {
}
func TestConfig_ReloadConfig(t *testing.T) {
l := NewTestLogger()
done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewConfig()
c := NewConfig(l)
assert.Nil(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner"))

View File

@ -28,10 +28,11 @@ type connectionManager struct {
checkInterval int
pendingDeletionInterval int
l *logrus.Logger
// I wanted to call one matLock
}
func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
@ -47,6 +48,7 @@ func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterva
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
l: l,
}
nc.Start()
return nc
@ -166,8 +168,8 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
// If we saw incoming packets from this ip, just return
if traf {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIP)).
if n.l.Level >= logrus.DebugLevel {
n.l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
@ -179,13 +181,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
// If we didn't we may need to probe or destroy the conn
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
if err != nil {
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
continue
}
hostinfo.logger().
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
@ -194,7 +196,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
} else {
hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
}
n.AddPendingDeletion(vpnIP)
}
@ -214,7 +216,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
// If we saw incoming packets from this ip, just return
traf := n.CheckIn(vpnIP)
if traf {
l.WithField("vpnIp", IntIp(vpnIP)).
n.l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
Debug("Tunnel status")
n.ClearIP(vpnIP)
@ -226,7 +228,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if err != nil {
n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP)
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
continue
}
@ -236,7 +238,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name
}
hostinfo.logger().
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
WithField("certName", cn).
Info("Tunnel status")

View File

@ -13,6 +13,7 @@ import (
var vpnIP uint32
func Test_NewConnectionManagerTest(t *testing.T) {
l := NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -20,7 +21,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
@ -28,7 +29,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
@ -36,12 +37,13 @@ func Test_NewConnectionManagerTest(t *testing.T) {
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
l: l,
}
now := time.Now()
// Create manager
nc := newConnectionManager(ifce, 5, 10)
nc := newConnectionManager(l, ifce, 5, 10)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
@ -79,13 +81,14 @@ func Test_NewConnectionManagerTest(t *testing.T) {
}
func Test_NewConnectionManagerTest2(t *testing.T) {
l := NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
@ -93,7 +96,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{},
}
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{
hostMap: hostMap,
inside: &Tun{},
@ -101,12 +104,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
certState: cs,
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
l: l,
}
now := time.Now()
// Create manager
nc := newConnectionManager(ifce, 5, 10)
nc := newConnectionManager(l, ifce, 5, 10)
p := []byte("")
nb := make([]byte, 12, 12)
out := make([]byte, mtu)

View File

@ -7,6 +7,7 @@ import (
"sync/atomic"
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
@ -26,7 +27,7 @@ type ConnectionState struct {
ready bool
}
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
@ -37,7 +38,7 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
b.Update(0)
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs,

View File

@ -13,9 +13,10 @@ import (
)
func TestControl_GetHostInfoByVpnIP(t *testing.T) {
l := NewTestLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0))
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(int2ip(100), 4444)
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{

View File

@ -7,6 +7,7 @@ import (
"sync"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
// This whole thing should be rewritten to use context
@ -63,7 +64,7 @@ func (d *dnsRecords) Add(host, data string) {
d.Unlock()
}
func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
@ -95,34 +96,38 @@ func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
}
}
func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(m, w)
parseQuery(l, m, w)
}
w.WriteMsg(m)
}
func dnsMain(hostMap *HostMap, c *Config) {
func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
dnsR = newDnsRecords(hostMap)
// attach request handler func
dns.HandleFunc(".", handleDnsRequest)
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDnsRequest(l, w, r)
})
c.RegisterReloadCallback(reloadDns)
startDns(c)
c.RegisterReloadCallback(func(c *Config) {
reloadDns(l, c)
})
startDns(l, c)
}
func getDnsServerAddr(c *Config) string {
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
}
func startDns(c *Config) {
func startDns(l *logrus.Logger, c *Config) {
dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.Debugf("Starting DNS responder at %s\n", dnsAddr)
@ -133,7 +138,7 @@ func startDns(c *Config) {
}
}
func reloadDns(c *Config) {
func reloadDns(l *logrus.Logger, c *Config) {
if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected")
return
@ -141,5 +146,5 @@ func reloadDns(c *Config) {
l.Debug("Restarting DNS server")
dnsServer.Shutdown()
go startDns(c)
go startDns(l, c)
}

View File

@ -70,6 +70,7 @@ type Firewall struct {
trackTCPRTT bool
metricTCPRTT metrics.Histogram
l *logrus.Logger
}
type FirewallConntrack struct {
@ -156,7 +157,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration
var min, max time.Duration
@ -195,11 +196,13 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
DefaultTimeout: defaultTimeout,
localIps: localIps,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
l: l,
}
}
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
@ -207,12 +210,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
//TODO: max_connections
)
err := AddFirewallRulesFromConfig(false, c, fw)
err := AddFirewallRulesFromConfig(l, false, c, fw)
if err != nil {
return nil, err
}
err = AddFirewallRulesFromConfig(true, c, fw)
err = AddFirewallRulesFromConfig(l, true, c, fw)
if err != nil {
return nil, err
}
@ -240,7 +243,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
@ -276,7 +279,7 @@ func (f *Firewall) GetRuleHash() string {
return hex.EncodeToString(sum[:])
}
func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
var table string
if inbound {
table = "firewall.inbound"
@ -296,7 +299,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
for i, t := range rs {
var groups []string
r, err := convertRule(t, table, i)
r, err := convertRule(l, t, table, i)
if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err)
}
@ -459,8 +462,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
// We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if l.Level >= logrus.DebugLevel {
h.logger().
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
@ -472,8 +475,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
return false
}
if l.Level >= logrus.DebugLevel {
h.logger().
if f.l.Level >= logrus.DebugLevel {
h.logger(f.l).
WithField("fwPacket", fp).
WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion).
@ -795,7 +798,7 @@ type rule struct {
CASha string
}
func convertRule(p interface{}, table string, i int) (rule, error) {
func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
r := rule{}
m, ok := p.(map[interface{}]interface{})
@ -968,14 +971,14 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
// Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get() ConntrackCache {
func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil {
return nil
}
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick
if ll := len(c.cache); ll > 0 {
if l.GetLevel() == logrus.DebugLevel {
if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache")
}
c.cache = make(ConntrackCache, ll)

View File

@ -15,8 +15,9 @@ import (
)
func TestNewFirewall(t *testing.T) {
l := NewTestLogger()
c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
@ -31,35 +32,34 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
}
func TestFirewall_AddRule(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
@ -74,7 +74,7 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
@ -83,7 +83,7 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
@ -92,23 +92,23 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
@ -125,26 +125,25 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
// Test error conditions
fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
}
func TestFirewall_Drop(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@ -177,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
@ -196,27 +195,27 @@ func TestFirewall_Drop(t *testing.T) {
p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
@ -317,10 +316,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
func TestFirewall_Drop2(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@ -365,7 +363,7 @@ func TestFirewall_Drop2(t *testing.T) {
}
h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
cp := cert.NewCAPool()
@ -377,10 +375,9 @@ func TestFirewall_Drop2(t *testing.T) {
}
func TestFirewall_Drop3(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@ -448,7 +445,7 @@ func TestFirewall_Drop3(t *testing.T) {
}
h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
cp := cert.NewCAPool()
@ -464,10 +461,9 @@ func TestFirewall_Drop3(t *testing.T) {
}
func TestFirewall_DropConntrackReload(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)),
@ -500,7 +496,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}
h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool()
@ -513,7 +509,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -522,7 +518,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -647,124 +643,126 @@ func Test_parsePort(t *testing.T) {
}
func TestNewFirewallFromConfig(t *testing.T) {
l := NewTestLogger()
// Test a bad rule definition
c := &cert.NebulaCertificate{}
conf := NewConfig()
conf := NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(c, conf)
_, err := NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test both group and groups
conf = NewConfig()
conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(c, conf)
_, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
func TestAddFirewallRulesFromConfig(t *testing.T) {
l := NewTestLogger()
// Test adding tcp rule
conf := NewConfig()
conf := NewConfig(l)
mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding udp rule
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding icmp rule
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding any rule
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding rule with ca_sha
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
// Test single group
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test single groups
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test multiple AND groups
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
// Test Add error
conf = NewConfig()
conf = NewConfig(l)
mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
}
func TestTCPRTTTracking(t *testing.T) {
@ -859,17 +857,16 @@ func TestTCPRTTTracking(t *testing.T) {
}
func TestFirewall_convertRule(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob)
defer l.SetOutput(out)
// Ensure group array of 1 is converted and a warning is printed
c := map[interface{}]interface{}{
"group": []interface{}{"group1"},
}
r, err := convertRule(c, "test", 1)
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Nil(t, err)
assert.Equal(t, "group1", r.Group)
@ -880,7 +877,7 @@ func TestFirewall_convertRule(t *testing.T) {
"group": []interface{}{"group1", "group2"},
}
r, err = convertRule(c, "test", 1)
r, err = convertRule(l, c, "test", 1)
assert.Equal(t, "", ob.String())
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
@ -890,7 +887,7 @@ func TestFirewall_convertRule(t *testing.T) {
"group": "group1",
}
r, err = convertRule(c, "test", 1)
r, err = convertRule(l, c, "test", 1)
assert.Nil(t, err)
assert.Equal(t, "group1", r.Group)
}

View File

@ -7,7 +7,7 @@ const (
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}

View File

@ -27,7 +27,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
err := f.handshakeManager.AddIndexHostInfo(hostinfo)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return
}
@ -48,7 +48,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hsBytes, err = proto.Marshal(hs)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
}
@ -58,14 +58,14 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
}
// We are sending handshake packet 1, so we don't expect to receive
// handshake packet 1 from the responder
ci.window.Update(1)
ci.window.Update(f.l, 1)
hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
@ -74,13 +74,13 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
}
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0)
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(1)
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil {
l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return
}
@ -91,14 +91,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/
if err != nil || hs.Details == nil {
l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil {
l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host")
return
@ -108,16 +108,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
fingerprint, _ := remoteCert.Sha256Sum()
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return
}
myIndex, err := generateIndex()
myIndex, err := generateIndex(f.l)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
@ -133,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
HandshakePacket: make(map[uint8][]byte, 0),
}
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -145,7 +145,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hsBytes, err := proto.Marshal(hs)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
@ -155,13 +155,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
@ -178,7 +178,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(2)
ci.window.Update(f.l, 2)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
@ -203,11 +203,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
@ -215,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
case ErrExistingHostInfo:
// This means there was an existing tunnel and we didn't win
// handshake avoidance
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -227,7 +227,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -238,7 +238,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -252,14 +252,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err = f.outside.WriteTo(msg, addr)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
} else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -267,7 +267,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
Info("Handshake message sent")
}
hostinfo.handshakeComplete()
hostinfo.handshakeComplete(f.l)
return
}
@ -280,7 +280,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
defer hostinfo.Unlock()
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Info("Already seen this handshake packet")
return false
@ -288,14 +288,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ci := hostinfo.ConnectionState
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(2)
ci.window.Update(f.l, 2)
hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
@ -304,7 +304,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// near future
return false
} else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
return true
@ -313,14 +313,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs)
if err != nil || hs.Details == nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return true
}
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Invalid certificate from host")
return true
@ -330,7 +330,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
fingerprint, _ := remoteCert.Sha256Sum()
duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -362,7 +362,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.CreateRemoteCIDR(remoteCert)
f.handshakeManager.Complete(hostinfo, f)
hostinfo.handshakeComplete()
hostinfo.handshakeComplete(f.l)
f.metricHandshakes.Update(duration)
return false

View File

@ -53,11 +53,12 @@ type HandshakeManager struct {
InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics
l *logrus.Logger
}
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges),
pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
@ -70,6 +71,7 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics,
l: l,
}
}
@ -78,7 +80,7 @@ func (c *HandshakeManager) Run(f EncWriter) {
for {
select {
case vpnIP := <-c.trigger:
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true)
case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f)
@ -149,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil {
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@ -157,7 +159,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
} else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes
hostinfo.logger().WithField("udpAddr", hostinfo.remote).
hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@ -245,7 +247,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger().
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
@ -280,7 +282,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger().
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex")
}
@ -298,7 +300,7 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
defer c.mainHostMap.RUnlock()
for i := 0; i < 32; i++ {
index, err := generateIndex()
index, err := generateIndex(c.l)
if err != nil {
return err
}
@ -336,7 +338,7 @@ func (c *HandshakeManager) EmitStats() {
// Utility functions below
func generateIndex() (uint32, error) {
func generateIndex(l *logrus.Logger) (uint32, error) {
b := make([]byte, 4)
// Let zero mean we don't know the ID, so don't generate zero

View File

@ -12,15 +12,15 @@ import (
var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)
@ -63,15 +63,16 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
}
func Test_NewHandshakeManagerVpnIP(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
@ -112,16 +113,17 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
}
func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{}
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
@ -162,15 +164,16 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
}
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)
@ -216,13 +219,14 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
}
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges)
mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now()
blah.NextInboundHandshakeTimerTick(now)

View File

@ -33,6 +33,7 @@ type HostMap struct {
defaultRoute uint32
unsafeRoutes *CIDRTree
metricsEnabled bool
l *logrus.Logger
}
type HostInfo struct {
@ -83,7 +84,7 @@ type Probe struct {
Counter int
}
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
@ -96,6 +97,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
vpnCIDR: vpnCIDR,
defaultRoute: 0,
unsafeRoutes: NewCIDRTree(),
l: l,
}
return &m
}
@ -160,8 +162,8 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted")
}
}
@ -173,8 +175,8 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
hm.RemoteIndexes[index] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap remoteIndex added")
}
@ -188,8 +190,8 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
hm.RemoteIndexes[h.remoteIndexId] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
if hm.l.Level > logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap vpnIp added")
}
@ -212,8 +214,8 @@ func (hm *HostMap) DeleteIndex(index uint32) {
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap index deleted")
}
}
@ -236,8 +238,8 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap remote index deleted")
}
}
@ -269,8 +271,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted")
}
@ -313,8 +315,10 @@ func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
}
i.remote = i.Remotes[0].addr
hm.Hosts[vpnIp] = i
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
}
}
i.ForcePromoteBest(hm.preferredRanges)
hm.Unlock()
@ -377,8 +381,8 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
if hm.l.Level >= logrus.DebugLevel {
hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
Debug("Hostmap vpnIp added")
}
@ -436,7 +440,7 @@ func (hm *HostMap) Punchy(conn *udpConn) {
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes {
l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
}
}
@ -566,7 +570,7 @@ func (i *HostInfo) rotateRemote() {
i.remote = i.Remotes[0].addr
}
func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
//TODO: return the error so we can log with more context
if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
@ -574,14 +578,14 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
if l.Level >= logrus.DebugLevel {
i.logger().
i.logger(l).
WithField("length", len(i.packetStore)).
WithField("stored", true).
Debugf("Packet store")
}
} else if l.Level >= logrus.DebugLevel {
i.logger().
i.logger(l).
WithField("length", len(i.packetStore)).
WithField("stored", false).
Debugf("Packet store")
@ -589,7 +593,7 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
}
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
func (i *HostInfo) handshakeComplete() {
func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
@ -601,7 +605,7 @@ func (i *HostInfo) handshakeComplete() {
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
if l.Level >= logrus.DebugLevel {
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
}
if len(i.packetStore) > 0 {
@ -689,7 +693,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
i.remoteCidr = remoteCidr
}
func (i *HostInfo) logger() *logrus.Entry {
func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
if i == nil {
return logrus.NewEntry(l)
}
@ -804,7 +808,7 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
// Utility functions
func localIps(allowList *AllowList) *[]net.IP {
func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
//FIXME: This function is pretty garbage
var ips []net.IP
ifaces, _ := net.Interfaces()

View File

@ -64,12 +64,13 @@ func TestHostInfoDestProbe(t *testing.T) {
*/
func TestHostmap(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
myNets := []*net.IPNet{myNet}
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
@ -103,10 +104,11 @@ func TestHostmap(t *testing.T) {
}
func TestHostmapdebug(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222")
@ -151,11 +153,12 @@ func TestHostMap_rotateRemote(t *testing.T) {
}
func BenchmarkHostmappromote2(b *testing.B) {
l := NewTestLogger()
for n := 0; n < b.N; n++ {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges)
m := NewHostMap(l, "test", myNet, preferredRanges)
y := NewUDPAddrFromString("10.128.0.3:11111")
a := NewUDPAddrFromString("10.127.0.3:11111")
g := NewUDPAddrFromString("1.0.0.1:22222")

View File

@ -10,7 +10,7 @@ import (
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
err := newPacket(packet, false, fwPacket)
if err != nil {
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
return
}
@ -31,8 +31,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
}
@ -45,7 +45,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
// the packet queue.
ci.queueLock.Lock()
if !ci.ready {
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow)
hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow)
ci.queueLock.Unlock()
return
}
@ -59,8 +59,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
f.lightHouse.Query(fwPacket.RemoteIP, f)
}
} else if l.Level >= logrus.DebugLevel {
hostinfo.logger().
} else if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).
WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping outbound packet")
@ -104,7 +104,7 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
if ci == nil {
// if we don't have a connection state, then send a handshake initiation
ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0)
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
hostinfo.ConnectionState = ci
@ -135,15 +135,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
fp := &FirewallPacket{}
err := newPacket(p, false, fp)
if err != nil {
l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
return
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
l.WithField("fwPacket", fp).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).
WithField("reason", dropReason).
Debugln("dropping cached packet")
}
@ -160,8 +160,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
}
return
@ -172,7 +172,7 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp)
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
@ -191,8 +191,8 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil {
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
}
return
@ -203,7 +203,7 @@ func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
// the packet queue.
hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToAll)
hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
hostInfo.ConnectionState.queueLock.Unlock()
return
}
@ -247,8 +247,8 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.Query(hostinfo.hostId, f)
hostinfo.lastRebindCount = f.rebindCount
if l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
}
}
@ -256,7 +256,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
//TODO: see above note on lock
//ci.writeLock.Unlock()
if err != nil {
hostinfo.logger().WithError(err).
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet")
@ -265,7 +265,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
err = f.writers[q].WriteTo(out, remote)
if err != nil {
hostinfo.logger().WithError(err).
hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet")
}
return c

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
const mtu = 9001
@ -42,6 +43,7 @@ type InterfaceConfig struct {
version string
ConntrackCacheTimeout time.Duration
l *logrus.Logger
}
type Interface struct {
@ -73,6 +75,7 @@ type Interface struct {
metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics
l *logrus.Logger
}
func NewInterface(c *InterfaceConfig) (*Interface, error) {
@ -113,9 +116,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics,
l: c.l,
}
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval)
ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
return ifce, nil
}
@ -125,10 +129,10 @@ func (f *Interface) run() {
addr, err := f.outside.LocalAddr()
if err != nil {
l.WithError(err).Error("Failed to get udp listen address")
f.l.WithError(err).Error("Failed to get udp listen address")
}
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active")
@ -140,14 +144,14 @@ func (f *Interface) run() {
if i > 0 {
reader, err = f.inside.NewMultiQueueReader()
if err != nil {
l.Fatal(err)
f.l.Fatal(err)
}
}
f.readers[i] = reader
}
if err := f.inside.Activate(); err != nil {
l.Fatal(err)
f.l.Fatal(err)
}
// Launch n queues to read packets from udp
@ -187,12 +191,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
for {
n, err := reader.Read(packet)
if err != nil {
l.WithError(err).Error("Error while reading outbound packet")
f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2)
}
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
}
}
@ -208,21 +212,21 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
func (f *Interface) reloadCA(c *Config) {
// reload and check regardless
// todo: need mutex?
newCAs, err := loadCAFromConfig(c)
newCAs, err := loadCAFromConfig(f.l, c)
if err != nil {
l.WithError(err).Error("Could not refresh trusted CA certificates")
f.l.WithError(err).Error("Could not refresh trusted CA certificates")
return
}
trustedCAs = newCAs
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
}
func (f *Interface) reloadCertKey(c *Config) {
// reload and check in all cases
cs, err := NewCertStateFromConfig(c)
if err != nil {
l.WithError(err).Error("Could not refresh client cert")
f.l.WithError(err).Error("Could not refresh client cert")
return
}
@ -230,24 +234,24 @@ func (f *Interface) reloadCertKey(c *Config) {
oldIPs := f.certState.certificate.Details.Ips
newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return
}
f.certState = cs
l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
}
func (f *Interface) reloadFirewall(c *Config) {
//TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false {
l.Debug("No firewall config change detected")
f.l.Debug("No firewall config change detected")
return
}
fw, err := NewFirewallFromConfig(f.certState.certificate, c)
fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
if err != nil {
l.WithError(err).Error("Error while creating firewall during reload")
f.l.WithError(err).Error("Error while creating firewall during reload")
return
}
@ -260,7 +264,7 @@ func (f *Interface) reloadFirewall(c *Config) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case.
if fw.rulesVersion == 0 {
l.WithField("firewallHash", fw.GetRuleHash()).
f.l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack")
@ -271,7 +275,7 @@ func (f *Interface) reloadFirewall(c *Config) {
f.firewall = fw
oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()).
f.l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed")

View File

@ -48,6 +48,7 @@ type LightHouse struct {
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
}
type EncWriter interface {
@ -55,7 +56,7 @@ type EncWriter interface {
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
}
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
func NewLightHouse(l *logrus.Logger, amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
h := LightHouse{
amLighthouse: amLighthouse,
myIp: myIp,
@ -67,6 +68,7 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
punchConn: pc,
punchBack: punchBack,
punchDelay: punchDelay,
l: l,
}
if metricsEnabled {
@ -126,7 +128,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
// Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
return
}
@ -159,7 +161,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
lh.Lock()
//l.Debugln(lh.addrMap)
delete(lh.addrMap, vpnIP)
l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
lh.Unlock()
}
@ -181,7 +183,7 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
}
allow := lh.remoteAllowList.Allow(toIp.IP)
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
lh.l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
if !allow {
return
}
@ -270,7 +272,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
var v4 []*IpAndPort
var v6 []*Ip6AndPort
for _, e := range *localIps(lh.localAllowList) {
for _, e := range *localIps(lh.l, lh.localAllowList) {
// Only add IPs that aren't my VPN/tun IP
if ip2int(e) != lh.myIp {
ipp := NewIpAndPort(e, lh.nebulaPort)
@ -297,7 +299,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
for vpnIp := range lh.lighthouses {
mm, err := proto.Marshal(m)
if err != nil {
l.Debugf("Invalid marshal to update")
lh.l.Debugf("Invalid marshal to update")
}
//l.Error("LIGHTHOUSE PACKET SEND", mm)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
@ -368,14 +370,14 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
n := lhh.resetMeta()
err := proto.UnmarshalMerge(p, n)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet")
//TODO: send recv_error?
return
}
if n.Details == nil {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update")
//TODO: send recv_error?
return
@ -387,7 +389,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
case NebulaMeta_HostQuery:
// Exit if we don't answer queries
if !lh.amLighthouse {
l.Debugln("I don't answer queries, but received from: ", rAddr)
lh.l.Debugln("I don't answer queries, but received from: ", rAddr)
return
}
@ -422,7 +424,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
reply, err := proto.Marshal(n)
if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return
}
lh.metricTx(NebulaMeta_HostQueryReply, 1)
@ -431,7 +433,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
// This signals the other side to punch some zero byte udp packets
ips, err = lh.Query(vpnIp, f)
if err != nil {
l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
lh.l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
return
} else {
//l.Debugln("Notify host to punch", iap)
@ -492,7 +494,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
case NebulaMeta_HostUpdateNotification:
//Simple check that the host sent this not someone else
if n.Details.VpnIp != vpnIp {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
return
}
@ -530,9 +532,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
}()
if l.Level >= logrus.DebugLevel {
if lh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
}
}
@ -549,9 +551,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
}()
if l.Level >= logrus.DebugLevel {
if lh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
}
}
@ -561,7 +563,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
if lh.punchBack {
go func() {
time.Sleep(time.Second * 5)
l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
lh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
// TODO we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel.

View File

@ -65,12 +65,13 @@ func TestSetipandportsfromudpaddrs(t *testing.T) {
}
func Test_lhStaticMapping(t *testing.T) {
l := NewTestLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener("0.0.0.0", 0, true)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
err := meh.ValidateLHStaticEntries()
assert.Nil(t, err)
@ -78,19 +79,20 @@ func Test_lhStaticMapping(t *testing.T) {
lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh = NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
}
func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := NewTestLogger()
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener("0.0.0.0", 0, true)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
hAddr := NewUDPAddrFromString("4.5.6.7:12345")
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
@ -136,7 +138,8 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}
func Test_lhRemoteAllowList(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
c.Settings["remoteallowlist"] = map[interface{}]interface{}{
"10.20.0.0/12": false,
}
@ -146,9 +149,9 @@ func Test_lhRemoteAllowList(t *testing.T) {
lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener("0.0.0.0", 0, true)
udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
lh.SetRemoteAllowList(allowList)
remote1 := "10.20.0.3"

29
main.go
View File

@ -11,13 +11,10 @@ import (
"gopkg.in/yaml.v2"
)
// The caller should provide a real logger, we have one just in case
var l = logrus.New()
type m map[string]interface{}
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
l = logger
l := logger
l.Formatter = &logrus.TextFormatter{
FullTimestamp: true,
}
@ -46,7 +43,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
})
// trustedCAs is currently a global, so loadCA operates on that global directly
trustedCAs, err = loadCAFromConfig(config)
trustedCAs, err = loadCAFromConfig(l, config)
if err != nil {
//The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err)
@ -60,7 +57,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config)
fw, err := NewFirewallFromConfig(l, cs.certificate, config)
if err != nil {
return nil, NewContextualError("Error while loading firewall rules", nil, err)
}
@ -78,9 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(ssh, config)
wireSSHReload(l, ssh, config)
if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config)
err = configSSH(l, ssh, config)
if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err)
}
@ -136,6 +133,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
case tunFd != nil:
tun, err = newTunFromFd(
l,
*tunFd,
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
@ -145,6 +143,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
)
default:
tun, err = newTun(
l,
config.GetString("tun.dev", ""),
tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU),
@ -166,7 +165,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
if !configTest {
for i := 0; i < routines; i++ {
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
if err != nil {
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
}
@ -222,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
}
hostMap := NewHostMap("main", tunCidr, preferredRanges)
hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
@ -266,6 +265,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}
lightHouse := NewLightHouse(
l,
amLighthouse,
ip2int(tunCidr.IP),
lighthouseHosts,
@ -337,7 +337,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
messageMetrics: messageMetrics,
}
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk
@ -367,6 +367,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
version: buildVersion,
ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
}
switch ifConfig.Cipher {
@ -395,7 +396,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
go lightHouse.LhUpdateWorker(ifce)
}
err = startStats(config, configTest)
err = startStats(l, config, configTest)
if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err)
}
@ -407,12 +408,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
//TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
if amLighthouse && serveDns {
l.Debugln("Starting dns server")
go dnsMain(hostMap, config)
go dnsMain(l, hostMap, config)
}
return &Control{ifce, l}, nil

View File

@ -1 +1,30 @@
package nebula
import (
"io/ioutil"
"os"
"github.com/sirupsen/logrus"
)
func NewTestLogger() *logrus.Logger {
l := logrus.New()
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(ioutil.Discard)
return l
}
switch v {
case "1":
// This is the default level but we are being explicit
l.SetLevel(logrus.InfoLevel)
case "2":
l.SetLevel(logrus.DebugLevel)
case "3":
l.SetLevel(logrus.TraceLevel)
}
return l
}

View File

@ -24,7 +24,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
}
return
}
@ -57,7 +57,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
@ -78,7 +78,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr).
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet).
Error("Failed to decrypt test packet")
@ -115,7 +115,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return
}
hostinfo.logger().WithField("udpAddr", addr).
hostinfo.logger(f.l).WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
@ -123,7 +123,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
default:
f.messageMetrics.Rx(header.Type, header.Subtype, 1)
hostinfo.logger().Debugf("Unexpected packet received from %s", addr)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
return
}
@ -143,18 +143,18 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
remoteCopy := *hostinfo.remote
@ -170,7 +170,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
// If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(header.MessageCounter) {
if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
f.sendRecvError(addr, header.RemoteIndex)
return false
}
@ -247,8 +247,8 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return nil, err
}
if !hostinfo.ConnectionState.window.Update(mc) {
hostinfo.logger().WithField("header", header).
if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger(f.l).WithField("header", header).
Debugln("dropping out of window packet")
return nil, errors.New("out of window packet")
}
@ -261,7 +261,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
if err != nil {
hostinfo.logger().WithError(err).Error("Failed to decrypt packet")
hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return
@ -269,21 +269,21 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
err = newPacket(out, true, fwPacket)
if err != nil {
hostinfo.logger().WithError(err).WithField("packet", out).
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet")
return
}
if !hostinfo.ConnectionState.window.Update(messageCounter) {
hostinfo.logger().WithField("fwPacket", fwPacket).
if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet")
return
}
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
if dropReason != nil {
if l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket).
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason).
Debugln("dropping inbound packet")
}
@ -293,7 +293,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
f.connectionManager.In(hostinfo.hostId)
_, err = f.readers[q].Write(out)
if err != nil {
l.WithError(err).Error("Failed to write to tun")
f.l.WithError(err).Error("Failed to write to tun")
}
}
@ -303,16 +303,16 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
//TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
f.outside.WriteTo(b, endpoint)
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", index).
WithField("udpAddr", endpoint).
Debug("Recv error sent")
}
}
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
if l.Level >= logrus.DebugLevel {
l.WithField("index", h.RemoteIndex).
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr).
Debug("Recv error received")
}
@ -322,7 +322,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if err != nil {
l.Debugln(err, ": ", h.RemoteIndex)
f.l.Debugln(err, ": ", h.RemoteIndex)
return
}
@ -333,7 +333,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
return
}
if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return
}

View File

@ -8,7 +8,8 @@ import (
)
func TestNewPunchyFromConfig(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
// Test defaults
p := NewPunchyFromConfig(c)

20
ssh.go
View File

@ -44,10 +44,10 @@ type sshCreateTunnelFlags struct {
Address string
}
func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
c.RegisterReloadCallback(func(c *Config) {
if c.GetBool("sshd.enabled", false) {
err := configSSH(ssh, c)
err := configSSH(l, ssh, c)
if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd")
ssh.Stop()
@ -58,7 +58,7 @@ func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
})
}
func configSSH(ssh *sshd.SSHServer, c *Config) error {
func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
//TODO conntrack list
//TODO print firewall rules or hash?
@ -149,7 +149,7 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
return nil
}
func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts",
@ -225,13 +225,17 @@ func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostM
ssh.RegisterCommand(&sshd.Command{
Name: "log-level",
ShortDescription: "Gets or sets the current log level",
Callback: sshLogLevel,
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogLevel(l, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
Name: "log-format",
ShortDescription: "Gets or sets the current log format",
Callback: sshLogFormat,
Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogFormat(l, fs, a, w)
},
})
ssh.RegisterCommand(&sshd.Command{
@ -629,7 +633,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
return err
}
func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
@ -643,7 +647,7 @@ func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
}
func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error {
func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
}

View File

@ -13,9 +13,10 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
)
func startStats(c *Config, configTest bool) error {
func startStats(l *logrus.Logger, c *Config, configTest bool) error {
mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" {
return nil
@ -28,9 +29,9 @@ func startStats(c *Config, configTest bool) error {
switch mType {
case "graphite":
startGraphiteStats(interval, c, configTest)
startGraphiteStats(l, interval, c, configTest)
case "prometheus":
startPrometheusStats(interval, c, configTest)
startPrometheusStats(l, interval, c, configTest)
default:
return fmt.Errorf("stats.type was not understood: %s", mType)
}
@ -44,7 +45,7 @@ func startStats(c *Config, configTest bool) error {
return nil
}
func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "")
if host == "" {
@ -64,7 +65,7 @@ func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
return nil
}
func startPrometheusStats(i time.Duration, c *Config, configTest bool) error {
func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "")

View File

@ -6,6 +6,7 @@ import (
"net"
"os"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
@ -19,9 +20,10 @@ type Tun struct {
TXQueueLen int
Routes []route
UnsafeRoutes []route
l *logrus.Logger
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{
@ -33,6 +35,7 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
l: l,
}
return
}

View File

@ -9,6 +9,7 @@ import (
"os/exec"
"strconv"
"github.com/sirupsen/logrus"
"github.com/songgao/water"
)
@ -17,11 +18,11 @@ type Tun struct {
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
*water.Interface
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin")
}
@ -31,10 +32,11 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
}, nil
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}

View File

@ -9,24 +9,23 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
type disabledTun struct {
read chan []byte
cidr *net.IPNet
logger *log.Logger
read chan []byte
cidr *net.IPNet
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
rx metrics.Counter
l *logrus.Logger
}
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun {
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{
cidr: cidr,
read: make(chan []byte, queueLen),
logger: l,
cidr: cidr,
read: make(chan []byte, queueLen),
l: l,
}
if metricsEnabled {
@ -63,8 +62,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
}
t.tx.Inc(1)
if l.Level >= logrus.DebugLevel {
t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload")
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
}
return copy(b, r), nil
@ -103,7 +102,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
select {
case t.read <- buf:
default:
t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response")
t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
}
return true
@ -114,11 +113,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
// Check for ICMP Echo Request before spending time doing the full parsing
if t.handleICMPEchoRequest(b) {
if l.Level >= logrus.DebugLevel {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
}
} else if l.Level >= logrus.DebugLevel {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
} else if t.l.Level >= logrus.DebugLevel {
t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
}
return len(b), nil
}

View File

@ -9,6 +9,8 @@ import (
"regexp"
"strconv"
"strings"
"github.com/sirupsen/logrus"
)
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@ -18,15 +20,16 @@ type Tun struct {
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
io.ReadWriteCloser
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
}
@ -41,6 +44,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
}, nil
}
@ -52,21 +56,21 @@ func (c *Tun) Activate() error {
}
// TODO use syscalls instead of exec.Command
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err)
}
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err)
}
// Unsafe path routes
for _, r := range c.UnsafeRoutes {
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
}

View File

@ -10,6 +10,7 @@ import (
"strings"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
@ -24,6 +25,7 @@ type Tun struct {
TXQueueLen int
Routes []route
UnsafeRoutes []route
l *logrus.Logger
}
type ifReq struct {
@ -78,7 +80,7 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -91,11 +93,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
l: l,
}
return
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
@ -131,6 +134,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen,
Routes: routes,
UnsafeRoutes: unsafeRoutes,
l: l,
}
return
}
@ -233,14 +237,14 @@ func (c Tun) Activate() error {
ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
l.WithError(err).Error("Failed to set tun mtu")
c.l.WithError(err).Error("Failed to set tun mtu")
}
// Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss
l.WithError(err).Error("Failed to set tun tx queue length")
c.l.WithError(err).Error("Failed to set tun tx queue length")
}
// Bring up the interface

View File

@ -9,7 +9,8 @@ import (
)
func Test_parseRoutes(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config
@ -104,7 +105,8 @@ func Test_parseRoutes(t *testing.T) {
}
func Test_parseUnsafeRoutes(t *testing.T) {
c := NewConfig()
l := NewTestLogger()
c := NewConfig(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config

View File

@ -7,6 +7,7 @@ import (
"os/exec"
"strconv"
"github.com/sirupsen/logrus"
"github.com/songgao/water"
)
@ -15,15 +16,16 @@ type Tun struct {
Cidr *net.IPNet
MTU int
UnsafeRoutes []route
l *logrus.Logger
*water.Interface
}
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Windows")
}
@ -33,6 +35,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr,
MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes,
l: l,
}, nil
}

View File

@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
import (

View File

@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig

View File

@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig

View File

@ -1,4 +1,5 @@
// +build !linux android
// +build !e2e_testing
// udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows.
@ -9,20 +10,23 @@ import (
"context"
"fmt"
"net"
"github.com/sirupsen/logrus"
)
type udpConn struct {
*net.UDPConn
l *logrus.Logger
}
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
if err != nil {
return nil, err
}
if uc, ok := pc.(*net.UDPConn); ok {
return &udpConn{UDPConn: uc}, nil
return &udpConn{UDPConn: uc, l: l}, nil
}
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
}
@ -76,13 +80,13 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
// Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer)
if err != nil {
l.WithError(err).Error("Failed to read packets")
f.l.WithError(err).Error("Failed to read packets")
continue
}
udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port)
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get())
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
}
}

View File

@ -1,4 +1,5 @@
// +build !android
// +build !e2e_testing
package nebula
@ -10,6 +11,7 @@ import (
"unsafe"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
@ -17,6 +19,7 @@ import (
type udpConn struct {
sysFd int
l *logrus.Logger
}
var x int
@ -38,7 +41,7 @@ const (
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func NewListener(ip string, port int, multi bool) (*udpConn, error) {
func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil {
@ -70,7 +73,7 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)
return &udpConn{sysFd: fd}, err
return &udpConn{sysFd: fd, l: l}, err
}
func (u *udpConn) Rebind() error {
@ -153,7 +156,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
for {
n, err := read(msgs)
if err != nil {
l.WithError(err).Error("Failed to read packets")
u.l.WithError(err).Error("Failed to read packets")
continue
}
@ -161,7 +164,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get())
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
}
}
}
@ -244,12 +247,12 @@ func (u *udpConn) reloadConfig(c *Config) {
if err == nil {
s, err := u.GetRecvBuffer()
if err == nil {
l.WithField("size", s).Info("listen.read_buffer was set")
u.l.WithField("size", s).Info("listen.read_buffer was set")
} else {
l.WithError(err).Warn("Failed to get listen.read_buffer")
u.l.WithError(err).Warn("Failed to get listen.read_buffer")
}
} else {
l.WithError(err).Error("Failed to set listen.read_buffer")
u.l.WithError(err).Error("Failed to set listen.read_buffer")
}
}
@ -259,12 +262,12 @@ func (u *udpConn) reloadConfig(c *Config) {
if err == nil {
s, err := u.GetSendBuffer()
if err == nil {
l.WithField("size", s).Info("listen.write_buffer was set")
u.l.WithField("size", s).Info("listen.write_buffer was set")
} else {
l.WithError(err).Warn("Failed to get listen.write_buffer")
u.l.WithError(err).Warn("Failed to get listen.write_buffer")
}
} else {
l.WithError(err).Error("Failed to set listen.write_buffer")
u.l.WithError(err).Error("Failed to set listen.write_buffer")
}
}
}

View File

@ -1,6 +1,7 @@
// +build linux
// +build 386 amd64p32 arm mips mipsle
// +build !android
// +build !e2e_testing
package nebula

View File

@ -1,6 +1,7 @@
// +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
// +build !android
// +build !e2e_testing
package nebula