Don't use a global logger (#423)

This commit is contained in:
Nathan Brown 2021-03-26 09:46:30 -05:00 committed by GitHub
parent 7a9f9dbded
commit 3ea7e1b75f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 590 additions and 470 deletions

View File

@ -26,7 +26,7 @@ func NewBits(bits uint64) *Bits {
} }
} }
func (b *Bits) Check(i uint64) bool { func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
// If i is the next number, return true. // If i is the next number, return true.
if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) { if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
return true return true
@ -47,7 +47,7 @@ func (b *Bits) Check(i uint64) bool {
return false return false
} }
func (b *Bits) Update(i uint64) bool { func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
// If i is the next number, return true and update current. // If i is the next number, return true and update current.
if i == b.current+1 { if i == b.current+1 {
// Report missed packets, we can only understand what was missed after the first window has been gone through // Report missed packets, we can only understand what was missed after the first window has been gone through

View File

@ -7,6 +7,7 @@ import (
) )
func TestBits(t *testing.T) { func TestBits(t *testing.T) {
l := NewTestLogger()
b := NewBits(10) b := NewBits(10)
// make sure it is the right size // make sure it is the right size
@ -14,46 +15,46 @@ func TestBits(t *testing.T) {
// This is initialized to zero - receive one. This should work. // This is initialized to zero - receive one. This should work.
assert.True(t, b.Check(1)) assert.True(t, b.Check(l, 1))
u := b.Update(1) u := b.Update(l, 1)
assert.True(t, u) assert.True(t, u)
assert.EqualValues(t, 1, b.current) assert.EqualValues(t, 1, b.current)
g := []bool{false, true, false, false, false, false, false, false, false, false} g := []bool{false, true, false, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two // Receive two
assert.True(t, b.Check(2)) assert.True(t, b.Check(l, 2))
u = b.Update(2) u = b.Update(l, 2)
assert.True(t, u) assert.True(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
g = []bool{false, true, true, false, false, false, false, false, false, false} g = []bool{false, true, true, false, false, false, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Receive two again - it will fail // Receive two again - it will fail
assert.False(t, b.Check(2)) assert.False(t, b.Check(l, 2))
u = b.Update(2) u = b.Update(l, 2)
assert.False(t, u) assert.False(t, u)
assert.EqualValues(t, 2, b.current) assert.EqualValues(t, 2, b.current)
// Jump ahead to 15, which should clear everything and set the 6th element // Jump ahead to 15, which should clear everything and set the 6th element
assert.True(t, b.Check(15)) assert.True(t, b.Check(l, 15))
u = b.Update(15) u = b.Update(l, 15)
assert.True(t, u) assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, false, true, false, false, false, false} g = []bool{false, false, false, false, false, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 14, which is allowed because it is in the window // Mark 14, which is allowed because it is in the window
assert.True(t, b.Check(14)) assert.True(t, b.Check(l, 14))
u = b.Update(14) u = b.Update(l, 14)
assert.True(t, u) assert.True(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
assert.Equal(t, g, b.bits) assert.Equal(t, g, b.bits)
// Mark 5, which is not allowed because it is not in the window // Mark 5, which is not allowed because it is not in the window
assert.False(t, b.Check(5)) assert.False(t, b.Check(l, 5))
u = b.Update(5) u = b.Update(l, 5)
assert.False(t, u) assert.False(t, u)
assert.EqualValues(t, 15, b.current) assert.EqualValues(t, 15, b.current)
g = []bool{false, false, false, false, true, true, false, false, false, false} g = []bool{false, false, false, false, true, true, false, false, false, false}
@ -61,63 +62,65 @@ func TestBits(t *testing.T) {
// make sure we handle wrapping around once to the current position // make sure we handle wrapping around once to the current position
b = NewBits(10) b = NewBits(10)
assert.True(t, b.Update(1)) assert.True(t, b.Update(l, 1))
assert.True(t, b.Update(11)) assert.True(t, b.Update(l, 11))
assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits)
// Walk through a few windows in order // Walk through a few windows in order
b = NewBits(10) b = NewBits(10)
for i := uint64(0); i <= 100; i++ { for i := uint64(0); i <= 100; i++ {
assert.True(t, b.Check(i), "Error while checking %v", i) assert.True(t, b.Check(l, i), "Error while checking %v", i)
assert.True(t, b.Update(i), "Error while updating %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i)
} }
} }
func TestBitsDupeCounter(t *testing.T) { func TestBitsDupeCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(1)) assert.True(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.False(t, b.Update(1)) assert.False(t, b.Update(l, 1))
assert.Equal(t, int64(1), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(2)) assert.True(t, b.Update(l, 2))
assert.Equal(t, int64(1), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.True(t, b.Update(3)) assert.True(t, b.Update(l, 3))
assert.Equal(t, int64(1), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.dupeCounter.Count())
assert.False(t, b.Update(1)) assert.False(t, b.Update(l, 1))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
assert.Equal(t, int64(2), b.dupeCounter.Count()) assert.Equal(t, int64(2), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
} }
func TestBitsOutOfWindowCounter(t *testing.T) { func TestBitsOutOfWindowCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(20)) assert.True(t, b.Update(l, 20))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.True(t, b.Update(21)) assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(22)) assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(23)) assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(24)) assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(25)) assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(26)) assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(27)) assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(28)) assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(29)) assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
assert.False(t, b.Update(0)) assert.False(t, b.Update(l, 0))
assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
//tODO: make sure lostcounter doesn't increase in orderly increment //tODO: make sure lostcounter doesn't increase in orderly increment
@ -127,23 +130,24 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
} }
func TestBitsLostCounter(t *testing.T) { func TestBitsLostCounter(t *testing.T) {
l := NewTestLogger()
b := NewBits(10) b := NewBits(10)
b.lostCounter.Clear() b.lostCounter.Clear()
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
//assert.True(t, b.Update(0)) //assert.True(t, b.Update(0))
assert.True(t, b.Update(0)) assert.True(t, b.Update(l, 0))
assert.True(t, b.Update(20)) assert.True(t, b.Update(l, 20))
assert.True(t, b.Update(21)) assert.True(t, b.Update(l, 21))
assert.True(t, b.Update(22)) assert.True(t, b.Update(l, 22))
assert.True(t, b.Update(23)) assert.True(t, b.Update(l, 23))
assert.True(t, b.Update(24)) assert.True(t, b.Update(l, 24))
assert.True(t, b.Update(25)) assert.True(t, b.Update(l, 25))
assert.True(t, b.Update(26)) assert.True(t, b.Update(l, 26))
assert.True(t, b.Update(27)) assert.True(t, b.Update(l, 27))
assert.True(t, b.Update(28)) assert.True(t, b.Update(l, 28))
assert.True(t, b.Update(29)) assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(20), b.lostCounter.Count()) assert.Equal(t, int64(20), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())
assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
@ -153,56 +157,56 @@ func TestBitsLostCounter(t *testing.T) {
b.dupeCounter.Clear() b.dupeCounter.Clear()
b.outOfWindowCounter.Clear() b.outOfWindowCounter.Clear()
assert.True(t, b.Update(0)) assert.True(t, b.Update(l, 0))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
assert.True(t, b.Update(9)) assert.True(t, b.Update(l, 9))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// 10 will set 0 index, 0 was already set, no lost packets // 10 will set 0 index, 0 was already set, no lost packets
assert.True(t, b.Update(10)) assert.True(t, b.Update(l, 10))
assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(0), b.lostCounter.Count())
// 11 will set 1 index, 1 was missed, we should see 1 packet lost // 11 will set 1 index, 1 was missed, we should see 1 packet lost
assert.True(t, b.Update(11)) assert.True(t, b.Update(l, 11))
assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(1), b.lostCounter.Count())
// Now let's fill in the window, should end up with 8 lost packets // Now let's fill in the window, should end up with 8 lost packets
assert.True(t, b.Update(12)) assert.True(t, b.Update(l, 12))
assert.True(t, b.Update(13)) assert.True(t, b.Update(l, 13))
assert.True(t, b.Update(14)) assert.True(t, b.Update(l, 14))
assert.True(t, b.Update(15)) assert.True(t, b.Update(l, 15))
assert.True(t, b.Update(16)) assert.True(t, b.Update(l, 16))
assert.True(t, b.Update(17)) assert.True(t, b.Update(l, 17))
assert.True(t, b.Update(18)) assert.True(t, b.Update(l, 18))
assert.True(t, b.Update(19)) assert.True(t, b.Update(l, 19))
assert.Equal(t, int64(8), b.lostCounter.Count()) assert.Equal(t, int64(8), b.lostCounter.Count())
// Jump ahead by a window size // Jump ahead by a window size
assert.True(t, b.Update(29)) assert.True(t, b.Update(l, 29))
assert.Equal(t, int64(8), b.lostCounter.Count()) assert.Equal(t, int64(8), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in // Now lets walk ahead normally through the window, the missed packets should fill in
assert.True(t, b.Update(30)) assert.True(t, b.Update(l, 30))
assert.True(t, b.Update(31)) assert.True(t, b.Update(l, 31))
assert.True(t, b.Update(32)) assert.True(t, b.Update(l, 32))
assert.True(t, b.Update(33)) assert.True(t, b.Update(l, 33))
assert.True(t, b.Update(34)) assert.True(t, b.Update(l, 34))
assert.True(t, b.Update(35)) assert.True(t, b.Update(l, 35))
assert.True(t, b.Update(36)) assert.True(t, b.Update(l, 36))
assert.True(t, b.Update(37)) assert.True(t, b.Update(l, 37))
assert.True(t, b.Update(38)) assert.True(t, b.Update(l, 38))
// 39 packets tracked, 22 seen, 17 lost // 39 packets tracked, 22 seen, 17 lost
assert.Equal(t, int64(17), b.lostCounter.Count()) assert.Equal(t, int64(17), b.lostCounter.Count())
// Jump ahead by 2 windows, should have recording 1 full window missing // Jump ahead by 2 windows, should have recording 1 full window missing
assert.True(t, b.Update(58)) assert.True(t, b.Update(l, 58))
assert.Equal(t, int64(27), b.lostCounter.Count()) assert.Equal(t, int64(27), b.lostCounter.Count())
// Now lets walk ahead normally through the window, the missed packets should fill in from this window // Now lets walk ahead normally through the window, the missed packets should fill in from this window
assert.True(t, b.Update(59)) assert.True(t, b.Update(l, 59))
assert.True(t, b.Update(60)) assert.True(t, b.Update(l, 60))
assert.True(t, b.Update(61)) assert.True(t, b.Update(l, 61))
assert.True(t, b.Update(62)) assert.True(t, b.Update(l, 62))
assert.True(t, b.Update(63)) assert.True(t, b.Update(l, 63))
assert.True(t, b.Update(64)) assert.True(t, b.Update(l, 64))
assert.True(t, b.Update(65)) assert.True(t, b.Update(l, 65))
assert.True(t, b.Update(66)) assert.True(t, b.Update(l, 66))
assert.True(t, b.Update(67)) assert.True(t, b.Update(l, 67))
// 68 packets tracked, 32 seen, 36 missed // 68 packets tracked, 32 seen, 36 missed
assert.Equal(t, int64(36), b.lostCounter.Count()) assert.Equal(t, int64(36), b.lostCounter.Count())
assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count())

View File

@ -7,6 +7,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@ -119,7 +120,7 @@ func NewCertStateFromConfig(c *Config) (*CertState, error) {
return NewCertState(nebulaCert, rawKey) return NewCertState(nebulaCert, rawKey)
} }
func loadCAFromConfig(c *Config) (*cert.NebulaCAPool, error) { func loadCAFromConfig(l *logrus.Logger, c *Config) (*cert.NebulaCAPool, error) {
var rawCA []byte var rawCA []byte
var err error var err error

View File

@ -46,15 +46,16 @@ func main() {
os.Exit(1) os.Exit(1)
} }
config := nebula.NewConfig() l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath) err := config.Load(*configPath)
if err != nil { if err != nil {
fmt.Printf("failed to load config: %s", err) fmt.Printf("failed to load config: %s", err)
os.Exit(1) os.Exit(1)
} }
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil) c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) { switch v := err.(type) {

View File

@ -24,14 +24,15 @@ func (p *program) Start(s service.Service) error {
// Start should not block. // Start should not block.
logger.Info("Nebula service starting.") logger.Info("Nebula service starting.")
config := nebula.NewConfig() l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*p.configPath) err := config.Load(*p.configPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to load config: %s", err) return fmt.Errorf("failed to load config: %s", err)
} }
l := logrus.New()
l.Out = os.Stdout
p.control, err = nebula.Main(config, *p.configTest, Build, l, nil) p.control, err = nebula.Main(config, *p.configTest, Build, l, nil)
if err != nil { if err != nil {
return err return err

View File

@ -40,15 +40,16 @@ func main() {
os.Exit(1) os.Exit(1)
} }
config := nebula.NewConfig() l := logrus.New()
l.Out = os.Stdout
config := nebula.NewConfig(l)
err := config.Load(*configPath) err := config.Load(*configPath)
if err != nil { if err != nil {
fmt.Printf("failed to load config: %s", err) fmt.Printf("failed to load config: %s", err)
os.Exit(1) os.Exit(1)
} }
l := logrus.New()
l.Out = os.Stdout
c, err := nebula.Main(config, *configTest, Build, l, nil) c, err := nebula.Main(config, *configTest, Build, l, nil)
switch v := err.(type) { switch v := err.(type) {

View File

@ -26,11 +26,13 @@ type Config struct {
Settings map[interface{}]interface{} Settings map[interface{}]interface{}
oldSettings map[interface{}]interface{} oldSettings map[interface{}]interface{}
callbacks []func(*Config) callbacks []func(*Config)
l *logrus.Logger
} }
func NewConfig() *Config { func NewConfig(l *logrus.Logger) *Config {
return &Config{ return &Config{
Settings: make(map[interface{}]interface{}), Settings: make(map[interface{}]interface{}),
l: l,
} }
} }
@ -99,12 +101,12 @@ func (c *Config) HasChanged(k string) bool {
newVals, err := yaml.Marshal(nv) newVals, err := yaml.Marshal(nv)
if err != nil { if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
} }
oldVals, err := yaml.Marshal(ov) oldVals, err := yaml.Marshal(ov)
if err != nil { if err != nil {
l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
} }
return string(newVals) != string(oldVals) return string(newVals) != string(oldVals)
@ -118,7 +120,7 @@ func (c *Config) CatchHUP() {
go func() { go func() {
for range ch { for range ch {
l.Info("Caught HUP, reloading config") c.l.Info("Caught HUP, reloading config")
c.ReloadConfig() c.ReloadConfig()
} }
}() }()
@ -132,7 +134,7 @@ func (c *Config) ReloadConfig() {
err := c.Load(c.path) err := c.Load(c.path)
if err != nil { if err != nil {
l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
return return
} }
@ -500,7 +502,7 @@ func configLogger(c *Config) error {
if err != nil { if err != nil {
return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
} }
l.SetLevel(logLevel) c.l.SetLevel(logLevel)
disableTimestamp := c.GetBool("logging.disable_timestamp", false) disableTimestamp := c.GetBool("logging.disable_timestamp", false)
timestampFormat := c.GetString("logging.timestamp_format", "") timestampFormat := c.GetString("logging.timestamp_format", "")
@ -512,13 +514,13 @@ func configLogger(c *Config) error {
logFormat := strings.ToLower(c.GetString("logging.format", "text")) logFormat := strings.ToLower(c.GetString("logging.format", "text"))
switch logFormat { switch logFormat {
case "text": case "text":
l.Formatter = &logrus.TextFormatter{ c.l.Formatter = &logrus.TextFormatter{
TimestampFormat: timestampFormat, TimestampFormat: timestampFormat,
FullTimestamp: fullTimestamp, FullTimestamp: fullTimestamp,
DisableTimestamp: disableTimestamp, DisableTimestamp: disableTimestamp,
} }
case "json": case "json":
l.Formatter = &logrus.JSONFormatter{ c.l.Formatter = &logrus.JSONFormatter{
TimestampFormat: timestampFormat, TimestampFormat: timestampFormat,
DisableTimestamp: disableTimestamp, DisableTimestamp: disableTimestamp,
} }

View File

@ -11,14 +11,15 @@ import (
) )
func TestConfig_Load(t *testing.T) { func TestConfig_Load(t *testing.T) {
l := NewTestLogger()
dir, err := ioutil.TempDir("", "config-test") dir, err := ioutil.TempDir("", "config-test")
// invalid yaml // invalid yaml
c := NewConfig() c := NewConfig(l)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge // simple multi config merge
c = NewConfig() c = NewConfig(l)
os.RemoveAll(dir) os.RemoveAll(dir)
os.Mkdir(dir, 0755) os.Mkdir(dir, 0755)
@ -40,8 +41,9 @@ func TestConfig_Load(t *testing.T) {
} }
func TestConfig_Get(t *testing.T) { func TestConfig_Get(t *testing.T) {
l := NewTestLogger()
// test simple type // test simple type
c := NewConfig() c := NewConfig(l)
c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"} c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
assert.Equal(t, "hi", c.Get("firewall.outbound")) assert.Equal(t, "hi", c.Get("firewall.outbound"))
@ -55,13 +57,15 @@ func TestConfig_Get(t *testing.T) {
} }
func TestConfig_GetStringSlice(t *testing.T) { func TestConfig_GetStringSlice(t *testing.T) {
c := NewConfig() l := NewTestLogger()
c := NewConfig(l)
c.Settings["slice"] = []interface{}{"one", "two"} c.Settings["slice"] = []interface{}{"one", "two"}
assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
} }
func TestConfig_GetBool(t *testing.T) { func TestConfig_GetBool(t *testing.T) {
c := NewConfig() l := NewTestLogger()
c := NewConfig(l)
c.Settings["bool"] = true c.Settings["bool"] = true
assert.Equal(t, true, c.GetBool("bool", false)) assert.Equal(t, true, c.GetBool("bool", false))
@ -88,7 +92,8 @@ func TestConfig_GetBool(t *testing.T) {
} }
func TestConfig_GetAllowList(t *testing.T) { func TestConfig_GetAllowList(t *testing.T) {
c := NewConfig() l := NewTestLogger()
c := NewConfig(l)
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0": true, "192.168.0.0": true,
} }
@ -181,20 +186,21 @@ func TestConfig_GetAllowList(t *testing.T) {
} }
func TestConfig_HasChanged(t *testing.T) { func TestConfig_HasChanged(t *testing.T) {
l := NewTestLogger()
// No reload has occurred, return false // No reload has occurred, return false
c := NewConfig() c := NewConfig(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
assert.False(t, c.HasChanged("")) assert.False(t, c.HasChanged(""))
// Test key change // Test key change
c = NewConfig() c = NewConfig(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "no"} c.oldSettings = map[interface{}]interface{}{"test": "no"}
assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("test"))
assert.True(t, c.HasChanged("")) assert.True(t, c.HasChanged(""))
// No key change // No key change
c = NewConfig() c = NewConfig(l)
c.Settings["test"] = "hi" c.Settings["test"] = "hi"
c.oldSettings = map[interface{}]interface{}{"test": "hi"} c.oldSettings = map[interface{}]interface{}{"test": "hi"}
assert.False(t, c.HasChanged("test")) assert.False(t, c.HasChanged("test"))
@ -202,12 +208,13 @@ func TestConfig_HasChanged(t *testing.T) {
} }
func TestConfig_ReloadConfig(t *testing.T) { func TestConfig_ReloadConfig(t *testing.T) {
l := NewTestLogger()
done := make(chan bool, 1) done := make(chan bool, 1)
dir, err := ioutil.TempDir("", "config-test") dir, err := ioutil.TempDir("", "config-test")
assert.Nil(t, err) assert.Nil(t, err)
ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewConfig() c := NewConfig(l)
assert.Nil(t, c.Load(dir)) assert.Nil(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer.inner"))

View File

@ -28,10 +28,11 @@ type connectionManager struct {
checkInterval int checkInterval int
pendingDeletionInterval int pendingDeletionInterval int
l *logrus.Logger
// I wanted to call one matLock // I wanted to call one matLock
} }
func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager { func newConnectionManager(l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
nc := &connectionManager{ nc := &connectionManager{
hostMap: intf.hostMap, hostMap: intf.hostMap,
in: make(map[uint32]struct{}), in: make(map[uint32]struct{}),
@ -47,6 +48,7 @@ func newConnectionManager(intf *Interface, checkInterval, pendingDeletionInterva
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60), pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
checkInterval: checkInterval, checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval, pendingDeletionInterval: pendingDeletionInterval,
l: l,
} }
nc.Start() nc.Start()
return nc return nc
@ -166,8 +168,8 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
// If we saw incoming packets from this ip, just return // If we saw incoming packets from this ip, just return
if traf { if traf {
if l.Level >= logrus.DebugLevel { if n.l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIP)). n.l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status") Debug("Tunnel status")
} }
@ -179,13 +181,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
// If we didn't we may need to probe or destroy the conn // If we didn't we may need to probe or destroy the conn
hostinfo, err := n.hostMap.QueryVpnIP(vpnIP) hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
if err != nil { if err != nil {
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
n.ClearIP(vpnIP) n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIP)
continue continue
} }
hostinfo.logger(). hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}). WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status") Debug("Tunnel status")
@ -194,7 +196,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out) n.intf.SendMessageToVpnIp(test, testRequest, vpnIP, p, nb, out)
} else { } else {
hostinfo.logger().Debugf("Hostinfo sadness: %s", IntIp(vpnIP)) hostinfo.logger(n.l).Debugf("Hostinfo sadness: %s", IntIp(vpnIP))
} }
n.AddPendingDeletion(vpnIP) n.AddPendingDeletion(vpnIP)
} }
@ -214,7 +216,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
// If we saw incoming packets from this ip, just return // If we saw incoming packets from this ip, just return
traf := n.CheckIn(vpnIP) traf := n.CheckIn(vpnIP)
if traf { if traf {
l.WithField("vpnIp", IntIp(vpnIP)). n.l.WithField("vpnIp", IntIp(vpnIP)).
WithField("tunnelCheck", m{"state": "alive", "method": "active"}). WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
Debug("Tunnel status") Debug("Tunnel status")
n.ClearIP(vpnIP) n.ClearIP(vpnIP)
@ -226,7 +228,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if err != nil { if err != nil {
n.ClearIP(vpnIP) n.ClearIP(vpnIP)
n.ClearPendingDeletion(vpnIP) n.ClearPendingDeletion(vpnIP)
l.Debugf("Not found in hostmap: %s", IntIp(vpnIP)) n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
continue continue
} }
@ -236,7 +238,7 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil { if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name cn = hostinfo.ConnectionState.peerCert.Details.Name
} }
hostinfo.logger(). hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}). WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
WithField("certName", cn). WithField("certName", cn).
Info("Tunnel status") Info("Tunnel status")

View File

@ -13,6 +13,7 @@ import (
var vpnIP uint32 var vpnIP uint32
func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest(t *testing.T) {
l := NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
@ -20,7 +21,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges) hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
cs := &CertState{ cs := &CertState{
rawCertificate: []byte{}, rawCertificate: []byte{},
privateKey: []byte{}, privateKey: []byte{},
@ -28,7 +29,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
@ -36,12 +37,13 @@ func Test_NewConnectionManagerTest(t *testing.T) {
certState: cs, certState: cs,
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
l: l,
} }
now := time.Now() now := time.Now()
// Create manager // Create manager
nc := newConnectionManager(ifce, 5, 10) nc := newConnectionManager(l, ifce, 5, 10)
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)
@ -79,13 +81,14 @@ func Test_NewConnectionManagerTest(t *testing.T) {
} }
func Test_NewConnectionManagerTest2(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) {
l := NewTestLogger()
//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
// Very incomplete mock objects // Very incomplete mock objects
hostMap := NewHostMap("test", vpncidr, preferredRanges) hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
cs := &CertState{ cs := &CertState{
rawCertificate: []byte{}, rawCertificate: []byte{},
privateKey: []byte{}, privateKey: []byte{},
@ -93,7 +96,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
rawCertificateNoKey: []byte{}, rawCertificateNoKey: []byte{},
} }
lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false) lh := NewLightHouse(l, false, 0, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
ifce := &Interface{ ifce := &Interface{
hostMap: hostMap, hostMap: hostMap,
inside: &Tun{}, inside: &Tun{},
@ -101,12 +104,13 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
certState: cs, certState: cs,
firewall: &Firewall{}, firewall: &Firewall{},
lightHouse: lh, lightHouse: lh,
handshakeManager: NewHandshakeManager(vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig), handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
l: l,
} }
now := time.Now() now := time.Now()
// Create manager // Create manager
nc := newConnectionManager(ifce, 5, 10) nc := newConnectionManager(l, ifce, 5, 10)
p := []byte("") p := []byte("")
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
out := make([]byte, mtu) out := make([]byte, mtu)

View File

@ -7,6 +7,7 @@ import (
"sync/atomic" "sync/atomic"
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@ -26,7 +27,7 @@ type ConnectionState struct {
ready bool ready bool
} }
func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256) cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" { if f.cipher == "chachapoly" {
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256) cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
@ -37,7 +38,7 @@ func (f *Interface) newConnectionState(initiator bool, pattern noise.HandshakePa
b := NewBits(ReplayWindow) b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss // Clear out bit 0, we never transmit it and we don't want it showing as packet loss
b.Update(0) b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{ hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: cs, CipherSuite: cs,

View File

@ -13,9 +13,10 @@ import (
) )
func TestControl_GetHostInfoByVpnIP(t *testing.T) { func TestControl_GetHostInfoByVpnIP(t *testing.T) {
l := NewTestLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller // To properly ensure we are not exposing core memory to the caller
hm := NewHostMap("test", &net.IPNet{}, make([]*net.IPNet, 0)) hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := NewUDPAddr(int2ip(100), 4444) remote1 := NewUDPAddr(int2ip(100), 4444)
remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) remote2 := NewUDPAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{ ipNet := net.IPNet{

View File

@ -7,6 +7,7 @@ import (
"sync" "sync"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/sirupsen/logrus"
) )
// This whole thing should be rewritten to use context // This whole thing should be rewritten to use context
@ -63,7 +64,7 @@ func (d *dnsRecords) Add(host, data string) {
d.Unlock() d.Unlock()
} }
func parseQuery(m *dns.Msg, w dns.ResponseWriter) { func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question { for _, q := range m.Question {
switch q.Qtype { switch q.Qtype {
case dns.TypeA: case dns.TypeA:
@ -95,34 +96,38 @@ func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
} }
} }
func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Compress = false m.Compress = false
switch r.Opcode { switch r.Opcode {
case dns.OpcodeQuery: case dns.OpcodeQuery:
parseQuery(m, w) parseQuery(l, m, w)
} }
w.WriteMsg(m) w.WriteMsg(m)
} }
func dnsMain(hostMap *HostMap, c *Config) { func dnsMain(l *logrus.Logger, hostMap *HostMap, c *Config) {
dnsR = newDnsRecords(hostMap) dnsR = newDnsRecords(hostMap)
// attach request handler func // attach request handler func
dns.HandleFunc(".", handleDnsRequest) dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDnsRequest(l, w, r)
})
c.RegisterReloadCallback(reloadDns) c.RegisterReloadCallback(func(c *Config) {
startDns(c) reloadDns(l, c)
})
startDns(l, c)
} }
func getDnsServerAddr(c *Config) string { func getDnsServerAddr(c *Config) string {
return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)) return c.GetString("lighthouse.dns.host", "") + ":" + strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))
} }
func startDns(c *Config) { func startDns(l *logrus.Logger, c *Config) {
dnsAddr = getDnsServerAddr(c) dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.Debugf("Starting DNS responder at %s\n", dnsAddr) l.Debugf("Starting DNS responder at %s\n", dnsAddr)
@ -133,7 +138,7 @@ func startDns(c *Config) {
} }
} }
func reloadDns(c *Config) { func reloadDns(l *logrus.Logger, c *Config) {
if dnsAddr == getDnsServerAddr(c) { if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected") l.Debug("No DNS server config change detected")
return return
@ -141,5 +146,5 @@ func reloadDns(c *Config) {
l.Debug("Restarting DNS server") l.Debug("Restarting DNS server")
dnsServer.Shutdown() dnsServer.Shutdown()
go startDns(c) go startDns(l, c)
} }

View File

@ -70,6 +70,7 @@ type Firewall struct {
trackTCPRTT bool trackTCPRTT bool
metricTCPRTT metrics.Histogram metricTCPRTT metrics.Histogram
l *logrus.Logger
} }
type FirewallConntrack struct { type FirewallConntrack struct {
@ -156,7 +157,7 @@ func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
} }
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
//TODO: error on 0 duration //TODO: error on 0 duration
var min, max time.Duration var min, max time.Duration
@ -195,11 +196,13 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
DefaultTimeout: defaultTimeout, DefaultTimeout: defaultTimeout,
localIps: localIps, localIps: localIps,
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)), metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
l: l,
} }
} }
func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) { func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
fw := NewFirewall( fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
@ -207,12 +210,12 @@ func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, er
//TODO: max_connections //TODO: max_connections
) )
err := AddFirewallRulesFromConfig(false, c, fw) err := AddFirewallRulesFromConfig(l, false, c, fw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = AddFirewallRulesFromConfig(true, c, fw) err = AddFirewallRulesFromConfig(l, true, c, fw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -240,7 +243,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming { if !incoming {
direction = "outgoing" direction = "outgoing"
} }
l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}). f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
Info("Firewall rule added") Info("Firewall rule added")
var ( var (
@ -276,7 +279,7 @@ func (f *Firewall) GetRuleHash() string {
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
} }
func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterface) error { func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
var table string var table string
if inbound { if inbound {
table = "firewall.inbound" table = "firewall.inbound"
@ -296,7 +299,7 @@ func AddFirewallRulesFromConfig(inbound bool, config *Config, fw FirewallInterfa
for i, t := range rs { for i, t := range rs {
var groups []string var groups []string
r, err := convertRule(t, table, i) r, err := convertRule(l, t, table, i)
if err != nil { if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err) return fmt.Errorf("%s rule #%v; %s", table, i, err)
} }
@ -459,8 +462,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
// We now know which firewall table to check against // We now know which firewall table to check against
if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
h.logger(). h.logger(f.l).
WithField("fwPacket", fp). WithField("fwPacket", fp).
WithField("incoming", c.incoming). WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion). WithField("rulesVersion", f.rulesVersion).
@ -472,8 +475,8 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
return false return false
} }
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
h.logger(). h.logger(f.l).
WithField("fwPacket", fp). WithField("fwPacket", fp).
WithField("incoming", c.incoming). WithField("incoming", c.incoming).
WithField("rulesVersion", f.rulesVersion). WithField("rulesVersion", f.rulesVersion).
@ -795,7 +798,7 @@ type rule struct {
CASha string CASha string
} }
func convertRule(p interface{}, table string, i int) (rule, error) { func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
r := rule{} r := rule{}
m, ok := p.(map[interface{}]interface{}) m, ok := p.(map[interface{}]interface{})
@ -968,14 +971,14 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
// Get checks if the cache ticker has moved to the next version before returning // Get checks if the cache ticker has moved to the next version before returning
// the map. If it has moved, we reset the map. // the map. If it has moved, we reset the map.
func (c *ConntrackCacheTicker) Get() ConntrackCache { func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
if c == nil { if c == nil {
return nil return nil
} }
if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV { if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
c.cacheV = tick c.cacheV = tick
if ll := len(c.cache); ll > 0 { if ll := len(c.cache); ll > 0 {
if l.GetLevel() == logrus.DebugLevel { if l.Level == logrus.DebugLevel {
l.WithField("len", ll).Debug("resetting conntrack cache") l.WithField("len", ll).Debug("resetting conntrack cache")
} }
c.cache = make(ConntrackCache, ll) c.cache = make(ConntrackCache, ll)

View File

@ -15,8 +15,9 @@ import (
) )
func TestNewFirewall(t *testing.T) { func TestNewFirewall(t *testing.T) {
l := NewTestLogger()
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
conntrack := fw.Conntrack conntrack := fw.Conntrack
assert.NotNil(t, conntrack) assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns) assert.NotNil(t, conntrack.Conns)
@ -31,35 +32,34 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Second, time.Hour, time.Minute, c) fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Second, time.Minute, c) fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Hour, time.Minute, time.Second, c) fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Hour, time.Second, c) fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(time.Minute, time.Second, time.Hour, c) fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen) assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
} }
func TestFirewall_AddRule(t *testing.T) { func TestFirewall_AddRule(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob) l.SetOutput(ob)
defer l.SetOutput(out)
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
fw := NewFirewall(time.Second, time.Minute, time.Hour, c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules) assert.NotNil(t, fw.OutRules)
@ -74,7 +74,7 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right) assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value) assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
assert.False(t, fw.InRules.UDP[1].Any.Any) assert.False(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
@ -83,7 +83,7 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right) assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value) assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
assert.False(t, fw.InRules.ICMP[1].Any.Any) assert.False(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
@ -92,23 +92,23 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right) assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value) assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", "")) assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
assert.False(t, fw.OutRules.AnyProto[1].Any.Any) assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts) assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP))) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", "")) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha")) assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
// Set any and clear fields // Set any and clear fields
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", "")) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0]) assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1") assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
@ -125,26 +125,25 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right) assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value) assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", "")) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
_, anyIp, _ := net.ParseCIDR("0.0.0.0/0") _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", "")) assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", "")) assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", "")) assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob) l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{ p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)), ip2int(net.IPv4(1, 2, 3, 4)),
@ -177,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) {
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
@ -196,27 +195,27 @@ func TestFirewall_Drop(t *testing.T) {
p.RemoteIP = oldRemote p.RemoteIP = oldRemote
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
@ -317,10 +316,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
} }
func TestFirewall_Drop2(t *testing.T) { func TestFirewall_Drop2(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob) l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{ p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)), ip2int(net.IPv4(1, 2, 3, 4)),
@ -365,7 +363,7 @@ func TestFirewall_Drop2(t *testing.T) {
} }
h1.CreateRemoteCIDR(&c1) h1.CreateRemoteCIDR(&c1)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
@ -377,10 +375,9 @@ func TestFirewall_Drop2(t *testing.T) {
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob) l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{ p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)), ip2int(net.IPv4(1, 2, 3, 4)),
@ -448,7 +445,7 @@ func TestFirewall_Drop3(t *testing.T) {
} }
h3.CreateRemoteCIDR(&c3) h3.CreateRemoteCIDR(&c3)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
@ -464,10 +461,9 @@ func TestFirewall_Drop3(t *testing.T) {
} }
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob) l.SetOutput(ob)
defer l.SetOutput(out)
p := FirewallPacket{ p := FirewallPacket{
ip2int(net.IPv4(1, 2, 3, 4)), ip2int(net.IPv4(1, 2, 3, 4)),
@ -500,7 +496,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
} }
h.CreateRemoteCIDR(&c) h.CreateRemoteCIDR(&c)
fw := NewFirewall(time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
@ -513,7 +509,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw := fw oldFw := fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@ -522,7 +518,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil)) assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
oldFw = fw oldFw = fw
fw = NewFirewall(time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", "")) assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@ -647,124 +643,126 @@ func Test_parsePort(t *testing.T) {
} }
func TestNewFirewallFromConfig(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) {
l := NewTestLogger()
// Test a bad rule definition // Test a bad rule definition
c := &cert.NebulaCertificate{} c := &cert.NebulaCertificate{}
conf := NewConfig() conf := NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err := NewFirewallFromConfig(c, conf) _, err := NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code // Test both port and code
conf = NewConfig() conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha // Test missing host, group, cidr, ca_name and ca_sha
conf = NewConfig() conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided") assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = NewConfig() conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error // Test proto error
conf = NewConfig() conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error // Test cidr parse error
conf = NewConfig() conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
// Test both group and groups // Test both group and groups
conf = NewConfig() conf = NewConfig(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(c, conf) _, err = NewFirewallFromConfig(l, c, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
} }
func TestAddFirewallRulesFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) {
l := NewTestLogger()
// Test adding tcp rule // Test adding tcp rule
conf := NewConfig() conf := NewConfig(l)
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test single groups // Test single groups
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf)) assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
// Test Add error // Test Add error
conf = NewConfig() conf = NewConfig(l)
mf = &mockFirewall{} mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error") mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`") assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
} }
func TestTCPRTTTracking(t *testing.T) { func TestTCPRTTTracking(t *testing.T) {
@ -859,17 +857,16 @@ func TestTCPRTTTracking(t *testing.T) {
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
l := NewTestLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}
out := l.Out
l.SetOutput(ob) l.SetOutput(ob)
defer l.SetOutput(out)
// Ensure group array of 1 is converted and a warning is printed // Ensure group array of 1 is converted and a warning is printed
c := map[interface{}]interface{}{ c := map[interface{}]interface{}{
"group": []interface{}{"group1"}, "group": []interface{}{"group1"},
} }
r, err := convertRule(c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
@ -880,7 +877,7 @@ func TestFirewall_convertRule(t *testing.T) {
"group": []interface{}{"group1", "group2"}, "group": []interface{}{"group1", "group2"},
} }
r, err = convertRule(c, "test", 1) r, err = convertRule(l, c, "test", 1)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
@ -890,7 +887,7 @@ func TestFirewall_convertRule(t *testing.T) {
"group": "group1", "group": "group1",
} }
r, err = convertRule(c, "test", 1) r, err = convertRule(l, c, "test", 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
} }

View File

@ -7,7 +7,7 @@ const (
func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) { func HandleIncomingHandshake(f *Interface, addr *udpAddr, packet []byte, h *Header, hostinfo *HostInfo) {
if !f.lightHouse.remoteAllowList.Allow(addr.IP) { if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return return
} }

View File

@ -27,7 +27,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
err := f.handshakeManager.AddIndexHostInfo(hostinfo) err := f.handshakeManager.AddIndexHostInfo(hostinfo)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return return
} }
@ -48,7 +48,7 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
hsBytes, err = proto.Marshal(hs) hsBytes, err = proto.Marshal(hs)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return return
} }
@ -58,14 +58,14 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
msg, _, _, err := ci.H.WriteMessage(header, hsBytes) msg, _, _, err := ci.H.WriteMessage(header, hsBytes)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)). f.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} }
// We are sending handshake packet 1, so we don't expect to receive // We are sending handshake packet 1, so we don't expect to receive
// handshake packet 1 from the responder // handshake packet 1 from the responder
ci.window.Update(1) ci.window.Update(f.l, 1)
hostinfo.HandshakePacket[0] = msg hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true hostinfo.HandshakeReady = true
@ -74,13 +74,13 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
} }
func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) { func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
ci := f.newConnectionState(false, noise.HandshakeIX, []byte{}, 0) ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed // Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(1) ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) msg, _, _, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil { if err != nil {
l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
return return
} }
@ -91,14 +91,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
*/ */
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return return
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil { if err != nil {
l.WithError(err).WithField("udpAddr", addr). f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Info("Invalid certificate from host") Info("Invalid certificate from host")
return return
@ -108,16 +108,16 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) { if vpnIP == ip2int(f.certState.certificate.Details.Ips[0].IP) {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
return return
} }
myIndex, err := generateIndex() myIndex, err := generateIndex(f.l)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
@ -133,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),
} }
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -145,7 +145,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
hsBytes, err := proto.Marshal(hs) hsBytes, err := proto.Marshal(hs)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
@ -155,13 +155,13 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2) header := HeaderEncode(make([]byte, HeaderLen), Version, uint8(handshake), handshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes) msg, dKey, eKey, err := ci.H.WriteMessage(header, hsBytes)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return return
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
@ -178,7 +178,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
// We are sending handshake packet 2, so we don't expect to receive // We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator. // handshake packet 2 from the initiator.
ci.window.Update(2) ci.window.Update(f.l, 2)
ci.peerCert = remoteCert ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey) ci.dKey = NewNebulaCipherState(dKey)
@ -203,11 +203,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err := f.outside.WriteTo(msg, addr) err := f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message") WithError(err).Error("Failed to send handshake message")
} else { } else {
l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(existing.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent") Info("Handshake message sent")
} }
@ -215,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
case ErrExistingHostInfo: case ErrExistingHostInfo:
// This means there was an existing tunnel and we didn't win // This means there was an existing tunnel and we didn't win
// handshake avoidance // handshake avoidance
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -227,7 +227,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
return return
case ErrLocalIndexCollision: case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -238,7 +238,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
default: default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here // And we forget to update it here
l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -252,14 +252,14 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1) f.messageMetrics.Tx(handshake, NebulaMessageSubType(msg[1]), 1)
err = f.outside.WriteTo(msg, addr) err = f.outside.WriteTo(msg, addr)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake") WithError(err).Error("Failed to send handshake")
} else { } else {
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -267,7 +267,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
Info("Handshake message sent") Info("Handshake message sent")
} }
hostinfo.handshakeComplete() hostinfo.handshakeComplete(f.l)
return return
} }
@ -280,7 +280,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
defer hostinfo.Unlock() defer hostinfo.Unlock()
if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) { if bytes.Equal(hostinfo.HandshakePacket[2], packet[HeaderLen:]) {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Info("Already seen this handshake packet") Info("Already seen this handshake packet")
return false return false
@ -288,14 +288,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
// Mark packet 2 as seen so it doesn't show up as missed // Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(2) ci.window.Update(f.l, 2)
hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:])) hostinfo.HandshakePacket[2] = make([]byte, len(packet[HeaderLen:]))
copy(hostinfo.HandshakePacket[2], packet[HeaderLen:]) copy(hostinfo.HandshakePacket[2], packet[HeaderLen:])
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:]) msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[HeaderLen:])
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage") Error("Failed to call noise.ReadMessage")
@ -304,7 +304,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
// near future // near future
return false return false
} else if dKey == nil || eKey == nil { } else if dKey == nil || eKey == nil {
l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key") Error("Noise did not arrive at a key")
return true return true
@ -313,14 +313,14 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hs := &NebulaHandshake{} hs := &NebulaHandshake{}
err = proto.Unmarshal(msg, hs) err = proto.Unmarshal(msg, hs)
if err != nil || hs.Details == nil { if err != nil || hs.Details == nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
return true return true
} }
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert) remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr). f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Invalid certificate from host") Error("Invalid certificate from host")
return true return true
@ -330,7 +330,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
fingerprint, _ := remoteCert.Sha256Sum() fingerprint, _ := remoteCert.Sha256Sum()
duration := time.Since(hostinfo.handshakeStart).Nanoseconds() duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr). f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
WithField("certName", certName). WithField("certName", certName).
WithField("fingerprint", fingerprint). WithField("fingerprint", fingerprint).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@ -362,7 +362,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
hostinfo.CreateRemoteCIDR(remoteCert) hostinfo.CreateRemoteCIDR(remoteCert)
f.handshakeManager.Complete(hostinfo, f) f.handshakeManager.Complete(hostinfo, f)
hostinfo.handshakeComplete() hostinfo.handshakeComplete(f.l)
f.metricHandshakes.Update(duration) f.metricHandshakes.Update(duration)
return false return false

View File

@ -53,11 +53,12 @@ type HandshakeManager struct {
InboundHandshakeTimer *SystemTimerWheel InboundHandshakeTimer *SystemTimerWheel
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
l *logrus.Logger
} }
func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager { func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{ return &HandshakeManager{
pendingHostMap: NewHostMap("pending", tunCidr, preferredRanges), pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
mainHostMap: mainHostMap, mainHostMap: mainHostMap,
lightHouse: lightHouse, lightHouse: lightHouse,
outside: outside, outside: outside,
@ -70,6 +71,7 @@ func NewHandshakeManager(tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainH
InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)), InboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
messageMetrics: config.messageMetrics, messageMetrics: config.messageMetrics,
l: l,
} }
} }
@ -78,7 +80,7 @@ func (c *HandshakeManager) Run(f EncWriter) {
for { for {
select { select {
case vpnIP := <-c.trigger: case vpnIP := <-c.trigger:
l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered") c.l.WithField("vpnIp", IntIp(vpnIP)).Debug("HandshakeManager: triggered")
c.handleOutbound(vpnIP, f, true) c.handleOutbound(vpnIP, f, true)
case now := <-clockSource: case now := <-clockSource:
c.NextOutboundHandshakeTimerTick(now, f) c.NextOutboundHandshakeTimerTick(now, f)
@ -149,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1) c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote) err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
if err != nil { if err != nil {
hostinfo.logger().WithField("udpAddr", hostinfo.remote). hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId). WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@ -157,7 +159,7 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
} else { } else {
//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should //TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
// keep the real packet struct around for logging purposes // keep the real packet struct around for logging purposes
hostinfo.logger().WithField("udpAddr", hostinfo.remote). hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
WithField("initiatorIndex", hostinfo.localIndexId). WithField("initiatorIndex", hostinfo.localIndexId).
WithField("remoteIndex", hostinfo.remoteIndexId). WithField("remoteIndex", hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@ -245,7 +247,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId { if found && existingRemoteIndex != nil && existingRemoteIndex.hostId != hostinfo.hostId {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(). hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
@ -280,7 +282,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
if found && existingRemoteIndex != nil { if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control // We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note. // the remote ID. Just log about the situation as a note.
hostinfo.logger(). hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", IntIp(existingRemoteIndex.hostId)).
Info("New host shadows existing host remoteIndex") Info("New host shadows existing host remoteIndex")
} }
@ -298,7 +300,7 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
defer c.mainHostMap.RUnlock() defer c.mainHostMap.RUnlock()
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
index, err := generateIndex() index, err := generateIndex(c.l)
if err != nil { if err != nil {
return err return err
} }
@ -336,7 +338,7 @@ func (c *HandshakeManager) EmitStats() {
// Utility functions below // Utility functions below
func generateIndex() (uint32, error) { func generateIndex(l *logrus.Logger) (uint32, error) {
b := make([]byte, 4) b := make([]byte, 4)
// Let zero mean we don't know the ID, so don't generate zero // Let zero mean we don't know the ID, so don't generate zero

View File

@ -12,15 +12,15 @@ import (
var ips []uint32 var ips []uint32
func Test_NewHandshakeManagerIndex(t *testing.T) { func Test_NewHandshakeManagerIndex(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)
@ -63,15 +63,16 @@ func Test_NewHandshakeManagerIndex(t *testing.T) {
} }
func Test_NewHandshakeManagerVpnIP(t *testing.T) { func Test_NewHandshakeManagerVpnIP(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))} ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
@ -112,16 +113,17 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
} }
func Test_NewHandshakeManagerTrigger(t *testing.T) { func Test_NewHandshakeManagerTrigger(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := ip2int(net.ParseIP("172.1.1.2")) ip := ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
lh := &LightHouse{} lh := &LightHouse{}
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
@ -162,15 +164,16 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
} }
func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) { func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
vpnIP = ip2int(net.ParseIP("172.1.1.2")) vpnIP = ip2int(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mw := &mockEncWriter{} mw := &mockEncWriter{}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw) blah.NextOutboundHandshakeTimerTick(now, mw)
@ -216,13 +219,14 @@ func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
} }
func Test_NewHandshakeManagerIndexcleanup(t *testing.T) { func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
l := NewTestLogger()
_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24") _, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange} preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap("test", vpncidr, preferredRanges) mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
blah := NewHandshakeManager(tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
now := time.Now() now := time.Now()
blah.NextInboundHandshakeTimerTick(now) blah.NextInboundHandshakeTimerTick(now)

View File

@ -33,6 +33,7 @@ type HostMap struct {
defaultRoute uint32 defaultRoute uint32
unsafeRoutes *CIDRTree unsafeRoutes *CIDRTree
metricsEnabled bool metricsEnabled bool
l *logrus.Logger
} }
type HostInfo struct { type HostInfo struct {
@ -83,7 +84,7 @@ type Probe struct {
Counter int Counter int
} }
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{} h := map[uint32]*HostInfo{}
i := map[uint32]*HostInfo{} i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{} r := map[uint32]*HostInfo{}
@ -96,6 +97,7 @@ func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *
vpnCIDR: vpnCIDR, vpnCIDR: vpnCIDR,
defaultRoute: 0, defaultRoute: 0,
unsafeRoutes: NewCIDRTree(), unsafeRoutes: NewCIDRTree(),
l: l,
} }
return &m return &m
} }
@ -160,8 +162,8 @@ func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
} }
hm.Unlock() hm.Unlock()
if l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}). hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted") Debug("Hostmap vpnIp deleted")
} }
} }
@ -173,8 +175,8 @@ func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
hm.RemoteIndexes[index] = h hm.RemoteIndexes[index] = h
hm.Unlock() hm.Unlock()
if l.Level > logrus.DebugLevel { if hm.l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap remoteIndex added") Debug("Hostmap remoteIndex added")
} }
@ -188,8 +190,8 @@ func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
hm.RemoteIndexes[h.remoteIndexId] = h hm.RemoteIndexes[h.remoteIndexId] = h
hm.Unlock() hm.Unlock()
if l.Level > logrus.DebugLevel { if hm.l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}). "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap vpnIp added") Debug("Hostmap vpnIp added")
} }
@ -212,8 +214,8 @@ func (hm *HostMap) DeleteIndex(index uint32) {
} }
hm.Unlock() hm.Unlock()
if l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap index deleted") Debug("Hostmap index deleted")
} }
} }
@ -236,8 +238,8 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
} }
hm.Unlock() hm.Unlock()
if l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap remote index deleted") Debug("Hostmap remote index deleted")
} }
} }
@ -269,8 +271,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
} }
hm.Unlock() hm.Unlock()
if l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
"vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). "vpnIp": IntIp(hostinfo.hostId), "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Hostmap hostInfo deleted") Debug("Hostmap hostInfo deleted")
} }
@ -313,8 +315,10 @@ func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
} }
i.remote = i.Remotes[0].addr i.remote = i.Remotes[0].addr
hm.Hosts[vpnIp] = i hm.Hosts[vpnIp] = i
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}). if hm.l.Level >= logrus.DebugLevel {
Debug("Hostmap remote ip added") hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
}
} }
i.ForcePromoteBest(hm.preferredRanges) i.ForcePromoteBest(hm.preferredRanges)
hm.Unlock() hm.Unlock()
@ -377,8 +381,8 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
hm.Indexes[hostinfo.localIndexId] = hostinfo hm.Indexes[hostinfo.localIndexId] = hostinfo
hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
if l.Level >= logrus.DebugLevel { if hm.l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts), hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(hostinfo.hostId), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}). "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": IntIp(hostinfo.hostId)}}).
Debug("Hostmap vpnIp added") Debug("Hostmap vpnIp added")
} }
@ -436,7 +440,7 @@ func (hm *HostMap) Punchy(conn *udpConn) {
func (hm *HostMap) addUnsafeRoutes(routes *[]route) { func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes { for _, r := range *routes {
l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route") hm.l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via)) hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
} }
} }
@ -566,7 +570,7 @@ func (i *HostInfo) rotateRemote() {
i.remote = i.Remotes[0].addr i.remote = i.Remotes[0].addr
} }
func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) { func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
//TODO: return the error so we can log with more context //TODO: return the error so we can log with more context
if len(i.packetStore) < 100 { if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet)) tempPacket := make([]byte, len(packet))
@ -574,14 +578,14 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket) //l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket}) i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
i.logger(). i.logger(l).
WithField("length", len(i.packetStore)). WithField("length", len(i.packetStore)).
WithField("stored", true). WithField("stored", true).
Debugf("Packet store") Debugf("Packet store")
} }
} else if l.Level >= logrus.DebugLevel { } else if l.Level >= logrus.DebugLevel {
i.logger(). i.logger(l).
WithField("length", len(i.packetStore)). WithField("length", len(i.packetStore)).
WithField("stored", false). WithField("stored", false).
Debugf("Packet store") Debugf("Packet store")
@ -589,7 +593,7 @@ func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, pac
} }
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets // handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
func (i *HostInfo) handshakeComplete() { func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because: //TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send //TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical //TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
@ -601,7 +605,7 @@ func (i *HostInfo) handshakeComplete() {
atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2) atomic.StoreUint64(&i.ConnectionState.atomicMessageCounter, 2)
if l.Level >= logrus.DebugLevel { if l.Level >= logrus.DebugLevel {
i.logger().Debugf("Sending %d stored packets", len(i.packetStore)) i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
} }
if len(i.packetStore) > 0 { if len(i.packetStore) > 0 {
@ -689,7 +693,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
i.remoteCidr = remoteCidr i.remoteCidr = remoteCidr
} }
func (i *HostInfo) logger() *logrus.Entry { func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
if i == nil { if i == nil {
return logrus.NewEntry(l) return logrus.NewEntry(l)
} }
@ -804,7 +808,7 @@ func (d *HostInfoDest) ProbeReceived(probeCount int) {
// Utility functions // Utility functions
func localIps(allowList *AllowList) *[]net.IP { func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
//FIXME: This function is pretty garbage //FIXME: This function is pretty garbage
var ips []net.IP var ips []net.IP
ifaces, _ := net.Interfaces() ifaces, _ := net.Interfaces()

View File

@ -64,12 +64,13 @@ func TestHostInfoDestProbe(t *testing.T) {
*/ */
func TestHostmap(t *testing.T) { func TestHostmap(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16") _, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24") _, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
myNets := []*net.IPNet{myNet} myNets := []*net.IPNet{myNet}
preferredRanges := []*net.IPNet{localToMe} preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges) m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111") a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222") b := NewUDPAddrFromString("1.0.0.1:22222")
@ -103,10 +104,11 @@ func TestHostmap(t *testing.T) {
} }
func TestHostmapdebug(t *testing.T) { func TestHostmapdebug(t *testing.T) {
l := NewTestLogger()
_, myNet, _ := net.ParseCIDR("10.128.0.0/16") _, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24") _, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe} preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges) m := NewHostMap(l, "test", myNet, preferredRanges)
a := NewUDPAddrFromString("10.127.0.3:11111") a := NewUDPAddrFromString("10.127.0.3:11111")
b := NewUDPAddrFromString("1.0.0.1:22222") b := NewUDPAddrFromString("1.0.0.1:22222")
@ -151,11 +153,12 @@ func TestHostMap_rotateRemote(t *testing.T) {
} }
func BenchmarkHostmappromote2(b *testing.B) { func BenchmarkHostmappromote2(b *testing.B) {
l := NewTestLogger()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_, myNet, _ := net.ParseCIDR("10.128.0.0/16") _, myNet, _ := net.ParseCIDR("10.128.0.0/16")
_, localToMe, _ := net.ParseCIDR("192.168.1.0/24") _, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
preferredRanges := []*net.IPNet{localToMe} preferredRanges := []*net.IPNet{localToMe}
m := NewHostMap("test", myNet, preferredRanges) m := NewHostMap(l, "test", myNet, preferredRanges)
y := NewUDPAddrFromString("10.128.0.3:11111") y := NewUDPAddrFromString("10.128.0.3:11111")
a := NewUDPAddrFromString("10.127.0.3:11111") a := NewUDPAddrFromString("10.127.0.3:11111")
g := NewUDPAddrFromString("1.0.0.1:22222") g := NewUDPAddrFromString("1.0.0.1:22222")

View File

@ -10,7 +10,7 @@ import (
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) { func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
return return
} }
@ -31,8 +31,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
hostinfo := f.getOrHandshake(fwPacket.RemoteIP) hostinfo := f.getOrHandshake(fwPacket.RemoteIP)
if hostinfo == nil { if hostinfo == nil {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)). f.l.WithField("vpnIp", IntIp(fwPacket.RemoteIP)).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
} }
@ -45,7 +45,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
// the packet queue. // the packet queue.
ci.queueLock.Lock() ci.queueLock.Lock()
if !ci.ready { if !ci.ready {
hostinfo.cachePacket(message, 0, packet, f.sendMessageNow) hostinfo.cachePacket(f.l, message, 0, packet, f.sendMessageNow)
ci.queueLock.Unlock() ci.queueLock.Unlock()
return return
} }
@ -59,8 +59,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
f.lightHouse.Query(fwPacket.RemoteIP, f) f.lightHouse.Query(fwPacket.RemoteIP, f)
} }
} else if l.Level >= logrus.DebugLevel { } else if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(). hostinfo.logger(f.l).
WithField("fwPacket", fwPacket). WithField("fwPacket", fwPacket).
WithField("reason", dropReason). WithField("reason", dropReason).
Debugln("dropping outbound packet") Debugln("dropping outbound packet")
@ -104,7 +104,7 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
if ci == nil { if ci == nil {
// if we don't have a connection state, then send a handshake initiation // if we don't have a connection state, then send a handshake initiation
ci = f.newConnectionState(true, noise.HandshakeIX, []byte{}, 0) ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us. // FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0) //ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
hostinfo.ConnectionState = ci hostinfo.ConnectionState = ci
@ -135,15 +135,15 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
fp := &FirewallPacket{} fp := &FirewallPacket{}
err := newPacket(p, false, fp) err := newPacket(p, false, fp)
if err != nil { if err != nil {
l.Warnf("error while parsing outgoing packet for firewall check; %v", err) f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
return return
} }
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil) dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
if dropReason != nil { if dropReason != nil {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
l.WithField("fwPacket", fp). f.l.WithField("fwPacket", fp).
WithField("reason", dropReason). WithField("reason", dropReason).
Debugln("dropping cached packet") Debugln("dropping cached packet")
} }
@ -160,8 +160,8 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp) hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil { if hostInfo == nil {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)). f.l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
} }
return return
@ -172,7 +172,7 @@ func (f *Interface) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
// the packet queue. // the packet queue.
hostInfo.ConnectionState.queueLock.Lock() hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready { if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToVpnIp) hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp)
hostInfo.ConnectionState.queueLock.Unlock() hostInfo.ConnectionState.queueLock.Unlock()
return return
} }
@ -191,8 +191,8 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) { func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
hostInfo := f.getOrHandshake(vpnIp) hostInfo := f.getOrHandshake(vpnIp)
if hostInfo == nil { if hostInfo == nil {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", IntIp(vpnIp)). f.l.WithField("vpnIp", IntIp(vpnIp)).
Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes") Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
} }
return return
@ -203,7 +203,7 @@ func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubTyp
// the packet queue. // the packet queue.
hostInfo.ConnectionState.queueLock.Lock() hostInfo.ConnectionState.queueLock.Lock()
if !hostInfo.ConnectionState.ready { if !hostInfo.ConnectionState.ready {
hostInfo.cachePacket(t, st, p, f.sendMessageToAll) hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
hostInfo.ConnectionState.queueLock.Unlock() hostInfo.ConnectionState.queueLock.Unlock()
return return
} }
@ -247,8 +247,8 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.Query(hostinfo.hostId, f) f.lightHouse.Query(hostinfo.hostId, f)
hostinfo.lastRebindCount = f.rebindCount hostinfo.lastRebindCount = f.rebindCount
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter") f.l.WithField("vpnIp", hostinfo.hostId).Debug("Lighthouse update triggered for punch due to rebind counter")
} }
} }
@ -256,7 +256,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
//TODO: see above note on lock //TODO: see above note on lock
//ci.writeLock.Unlock() //ci.writeLock.Unlock()
if err != nil { if err != nil {
hostinfo.logger().WithError(err). hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).WithField("counter", c). WithField("udpAddr", remote).WithField("counter", c).
WithField("attemptedCounter", c). WithField("attemptedCounter", c).
Error("Failed to encrypt outgoing packet") Error("Failed to encrypt outgoing packet")
@ -265,7 +265,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
err = f.writers[q].WriteTo(out, remote) err = f.writers[q].WriteTo(out, remote)
if err != nil { if err != nil {
hostinfo.logger().WithError(err). hostinfo.logger(f.l).WithError(err).
WithField("udpAddr", remote).Error("Failed to write outgoing packet") WithField("udpAddr", remote).Error("Failed to write outgoing packet")
} }
return c return c

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
) )
const mtu = 9001 const mtu = 9001
@ -42,6 +43,7 @@ type InterfaceConfig struct {
version string version string
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
l *logrus.Logger
} }
type Interface struct { type Interface struct {
@ -73,6 +75,7 @@ type Interface struct {
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
l *logrus.Logger
} }
func NewInterface(c *InterfaceConfig) (*Interface, error) { func NewInterface(c *InterfaceConfig) (*Interface, error) {
@ -113,9 +116,10 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics, messageMetrics: c.MessageMetrics,
l: c.l,
} }
ifce.connectionManager = newConnectionManager(ifce, c.checkInterval, c.pendingDeletionInterval) ifce.connectionManager = newConnectionManager(c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
return ifce, nil return ifce, nil
} }
@ -125,10 +129,10 @@ func (f *Interface) run() {
addr, err := f.outside.LocalAddr() addr, err := f.outside.LocalAddr()
if err != nil { if err != nil {
l.WithError(err).Error("Failed to get udp listen address") f.l.WithError(err).Error("Failed to get udp listen address")
} }
l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()). f.l.WithField("interface", f.inside.DeviceName()).WithField("network", f.inside.CidrNet().String()).
WithField("build", f.version).WithField("udpAddr", addr). WithField("build", f.version).WithField("udpAddr", addr).
Info("Nebula interface is active") Info("Nebula interface is active")
@ -140,14 +144,14 @@ func (f *Interface) run() {
if i > 0 { if i > 0 {
reader, err = f.inside.NewMultiQueueReader() reader, err = f.inside.NewMultiQueueReader()
if err != nil { if err != nil {
l.Fatal(err) f.l.Fatal(err)
} }
} }
f.readers[i] = reader f.readers[i] = reader
} }
if err := f.inside.Activate(); err != nil { if err := f.inside.Activate(); err != nil {
l.Fatal(err) f.l.Fatal(err)
} }
// Launch n queues to read packets from udp // Launch n queues to read packets from udp
@ -187,12 +191,12 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
for { for {
n, err := reader.Read(packet) n, err := reader.Read(packet)
if err != nil { if err != nil {
l.WithError(err).Error("Error while reading outbound packet") f.l.WithError(err).Error("Error while reading outbound packet")
// This only seems to happen when something fatal happens to the fd, so exit. // This only seems to happen when something fatal happens to the fd, so exit.
os.Exit(2) os.Exit(2)
} }
f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get()) f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
} }
} }
@ -208,21 +212,21 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *Config) {
func (f *Interface) reloadCA(c *Config) { func (f *Interface) reloadCA(c *Config) {
// reload and check regardless // reload and check regardless
// todo: need mutex? // todo: need mutex?
newCAs, err := loadCAFromConfig(c) newCAs, err := loadCAFromConfig(f.l, c)
if err != nil { if err != nil {
l.WithError(err).Error("Could not refresh trusted CA certificates") f.l.WithError(err).Error("Could not refresh trusted CA certificates")
return return
} }
trustedCAs = newCAs trustedCAs = newCAs
l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed") f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
} }
func (f *Interface) reloadCertKey(c *Config) { func (f *Interface) reloadCertKey(c *Config) {
// reload and check in all cases // reload and check in all cases
cs, err := NewCertStateFromConfig(c) cs, err := NewCertStateFromConfig(c)
if err != nil { if err != nil {
l.WithError(err).Error("Could not refresh client cert") f.l.WithError(err).Error("Could not refresh client cert")
return return
} }
@ -230,24 +234,24 @@ func (f *Interface) reloadCertKey(c *Config) {
oldIPs := f.certState.certificate.Details.Ips oldIPs := f.certState.certificate.Details.Ips
newIPs := cs.certificate.Details.Ips newIPs := cs.certificate.Details.Ips
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old") f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
return return
} }
f.certState = cs f.certState = cs
l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk") f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
} }
func (f *Interface) reloadFirewall(c *Config) { func (f *Interface) reloadFirewall(c *Config) {
//TODO: need to trigger/detect if the certificate changed too //TODO: need to trigger/detect if the certificate changed too
if c.HasChanged("firewall") == false { if c.HasChanged("firewall") == false {
l.Debug("No firewall config change detected") f.l.Debug("No firewall config change detected")
return return
} }
fw, err := NewFirewallFromConfig(f.certState.certificate, c) fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
if err != nil { if err != nil {
l.WithError(err).Error("Error while creating firewall during reload") f.l.WithError(err).Error("Error while creating firewall during reload")
return return
} }
@ -260,7 +264,7 @@ func (f *Interface) reloadFirewall(c *Config) {
// If rulesVersion is back to zero, we have wrapped all the way around. Be // If rulesVersion is back to zero, we have wrapped all the way around. Be
// safe and just reset conntrack in this case. // safe and just reset conntrack in this case.
if fw.rulesVersion == 0 { if fw.rulesVersion == 0 {
l.WithField("firewallHash", fw.GetRuleHash()). f.l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion). WithField("rulesVersion", fw.rulesVersion).
Warn("firewall rulesVersion has overflowed, resetting conntrack") Warn("firewall rulesVersion has overflowed, resetting conntrack")
@ -271,7 +275,7 @@ func (f *Interface) reloadFirewall(c *Config) {
f.firewall = fw f.firewall = fw
oldFw.Destroy() oldFw.Destroy()
l.WithField("firewallHash", fw.GetRuleHash()). f.l.WithField("firewallHash", fw.GetRuleHash()).
WithField("oldFirewallHash", oldFw.GetRuleHash()). WithField("oldFirewallHash", oldFw.GetRuleHash()).
WithField("rulesVersion", fw.rulesVersion). WithField("rulesVersion", fw.rulesVersion).
Info("New firewall has been installed") Info("New firewall has been installed")

View File

@ -48,6 +48,7 @@ type LightHouse struct {
metrics *MessageMetrics metrics *MessageMetrics
metricHolepunchTx metrics.Counter metricHolepunchTx metrics.Counter
l *logrus.Logger
} }
type EncWriter interface { type EncWriter interface {
@ -55,7 +56,7 @@ type EncWriter interface {
SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte)
} }
func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse { func NewLightHouse(l *logrus.Logger, amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort uint32, pc *udpConn, punchBack bool, punchDelay time.Duration, metricsEnabled bool) *LightHouse {
h := LightHouse{ h := LightHouse{
amLighthouse: amLighthouse, amLighthouse: amLighthouse,
myIp: myIp, myIp: myIp,
@ -67,6 +68,7 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, n
punchConn: pc, punchConn: pc,
punchBack: punchBack, punchBack: punchBack,
punchDelay: punchDelay, punchDelay: punchDelay,
l: l,
} }
if metricsEnabled { if metricsEnabled {
@ -126,7 +128,7 @@ func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
// Send a query to the lighthouses and hope for the best next time // Send a query to the lighthouses and hope for the best next time
query, err := proto.Marshal(NewLhQueryByInt(ip)) query, err := proto.Marshal(NewLhQueryByInt(ip))
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload") lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
return return
} }
@ -159,7 +161,7 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
lh.Lock() lh.Lock()
//l.Debugln(lh.addrMap) //l.Debugln(lh.addrMap)
delete(lh.addrMap, vpnIP) delete(lh.addrMap, vpnIP)
l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP)) lh.l.Debugf("deleting %s from lighthouse.", IntIp(vpnIP))
lh.Unlock() lh.Unlock()
} }
@ -181,7 +183,7 @@ func (lh *LightHouse) AddRemote(vpnIP uint32, toIp *udpAddr, static bool) {
} }
allow := lh.remoteAllowList.Allow(toIp.IP) allow := lh.remoteAllowList.Allow(toIp.IP)
l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow") lh.l.WithField("remoteIp", toIp).WithField("allow", allow).Debug("remoteAllowList.Allow")
if !allow { if !allow {
return return
} }
@ -270,7 +272,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
var v4 []*IpAndPort var v4 []*IpAndPort
var v6 []*Ip6AndPort var v6 []*Ip6AndPort
for _, e := range *localIps(lh.localAllowList) { for _, e := range *localIps(lh.l, lh.localAllowList) {
// Only add IPs that aren't my VPN/tun IP // Only add IPs that aren't my VPN/tun IP
if ip2int(e) != lh.myIp { if ip2int(e) != lh.myIp {
ipp := NewIpAndPort(e, lh.nebulaPort) ipp := NewIpAndPort(e, lh.nebulaPort)
@ -297,7 +299,7 @@ func (lh *LightHouse) SendUpdate(f EncWriter) {
for vpnIp := range lh.lighthouses { for vpnIp := range lh.lighthouses {
mm, err := proto.Marshal(m) mm, err := proto.Marshal(m)
if err != nil { if err != nil {
l.Debugf("Invalid marshal to update") lh.l.Debugf("Invalid marshal to update")
} }
//l.Error("LIGHTHOUSE PACKET SEND", mm) //l.Error("LIGHTHOUSE PACKET SEND", mm)
f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out) f.SendMessageToVpnIp(lightHouse, 0, vpnIp, mm, nb, out)
@ -368,14 +370,14 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
n := lhh.resetMeta() n := lhh.resetMeta()
err := proto.UnmarshalMerge(p, n) err := proto.UnmarshalMerge(p, n)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Failed to unmarshal lighthouse packet") Error("Failed to unmarshal lighthouse packet")
//TODO: send recv_error? //TODO: send recv_error?
return return
} }
if n.Details == nil { if n.Details == nil {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr). lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("udpAddr", rAddr).
Error("Invalid lighthouse update") Error("Invalid lighthouse update")
//TODO: send recv_error? //TODO: send recv_error?
return return
@ -387,7 +389,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
case NebulaMeta_HostQuery: case NebulaMeta_HostQuery:
// Exit if we don't answer queries // Exit if we don't answer queries
if !lh.amLighthouse { if !lh.amLighthouse {
l.Debugln("I don't answer queries, but received from: ", rAddr) lh.l.Debugln("I don't answer queries, but received from: ", rAddr)
return return
} }
@ -422,7 +424,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
reply, err := proto.Marshal(n) reply, err := proto.Marshal(n)
if err != nil { if err != nil {
l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply") lh.l.WithError(err).WithField("vpnIp", IntIp(vpnIp)).Error("Failed to marshal lighthouse host query reply")
return return
} }
lh.metricTx(NebulaMeta_HostQueryReply, 1) lh.metricTx(NebulaMeta_HostQueryReply, 1)
@ -431,7 +433,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
// This signals the other side to punch some zero byte udp packets // This signals the other side to punch some zero byte udp packets
ips, err = lh.Query(vpnIp, f) ips, err = lh.Query(vpnIp, f)
if err != nil { if err != nil {
l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch") lh.l.WithField("vpnIp", IntIp(vpnIp)).Debugln("Can't notify host to punch")
return return
} else { } else {
//l.Debugln("Notify host to punch", iap) //l.Debugln("Notify host to punch", iap)
@ -492,7 +494,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
case NebulaMeta_HostUpdateNotification: case NebulaMeta_HostUpdateNotification:
//Simple check that the host sent this not someone else //Simple check that the host sent this not someone else
if n.Details.VpnIp != vpnIp { if n.Details.VpnIp != vpnIp {
l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update") lh.l.WithField("vpnIp", IntIp(vpnIp)).WithField("answer", IntIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
return return
} }
@ -530,9 +532,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
}() }()
if l.Level >= logrus.DebugLevel { if lh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp)) lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
} }
} }
@ -549,9 +551,9 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
}() }()
if l.Level >= logrus.DebugLevel { if lh.l.Level >= logrus.DebugLevel {
//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp)) lh.l.Debugf("Punching on %d for %s", a.Port, IntIp(n.Details.VpnIp))
} }
} }
@ -561,7 +563,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []by
if lh.punchBack { if lh.punchBack {
go func() { go func() {
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp)) lh.l.Debugf("Sending a nebula test packet to vpn ip %s", IntIp(n.Details.VpnIp))
// TODO we have to allocate a new output buffer here since we are spawning a new goroutine // TODO we have to allocate a new output buffer here since we are spawning a new goroutine
// for each punchBack packet. We should move this into a timerwheel or a single goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine
// managed by a channel. // managed by a channel.

View File

@ -65,12 +65,13 @@ func TestSetipandportsfromudpaddrs(t *testing.T) {
} }
func Test_lhStaticMapping(t *testing.T) { func Test_lhStaticMapping(t *testing.T) {
l := NewTestLogger()
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1) lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener("0.0.0.0", 0, true) udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) meh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
err := meh.ValidateLHStaticEntries() err := meh.ValidateLHStaticEntries()
assert.Nil(t, err) assert.Nil(t, err)
@ -78,19 +79,20 @@ func Test_lhStaticMapping(t *testing.T) {
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
lh2IP := net.ParseIP(lh2) lh2IP := net.ParseIP(lh2)
meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false) meh = NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
err = meh.ValidateLHStaticEntries() err = meh.ValidateLHStaticEntries()
assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
func BenchmarkLighthouseHandleRequest(b *testing.B) { func BenchmarkLighthouseHandleRequest(b *testing.B) {
l := NewTestLogger()
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1) lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener("0.0.0.0", 0, true) udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
hAddr := NewUDPAddrFromString("4.5.6.7:12345") hAddr := NewUDPAddrFromString("4.5.6.7:12345")
hAddr2 := NewUDPAddrFromString("4.5.6.7:12346") hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
@ -136,7 +138,8 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
} }
func Test_lhRemoteAllowList(t *testing.T) { func Test_lhRemoteAllowList(t *testing.T) {
c := NewConfig() l := NewTestLogger()
c := NewConfig(l)
c.Settings["remoteallowlist"] = map[interface{}]interface{}{ c.Settings["remoteallowlist"] = map[interface{}]interface{}{
"10.20.0.0/12": false, "10.20.0.0/12": false,
} }
@ -146,9 +149,9 @@ func Test_lhRemoteAllowList(t *testing.T) {
lh1 := "10.128.0.2" lh1 := "10.128.0.2"
lh1IP := net.ParseIP(lh1) lh1IP := net.ParseIP(lh1)
udpServer, _ := NewListener("0.0.0.0", 0, true) udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
lh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false) lh := NewLightHouse(l, true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
lh.SetRemoteAllowList(allowList) lh.SetRemoteAllowList(allowList)
remote1 := "10.20.0.3" remote1 := "10.20.0.3"

29
main.go
View File

@ -11,13 +11,10 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// The caller should provide a real logger, we have one just in case
var l = logrus.New()
type m map[string]interface{} type m map[string]interface{}
func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) { func Main(config *Config, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (*Control, error) {
l = logger l := logger
l.Formatter = &logrus.TextFormatter{ l.Formatter = &logrus.TextFormatter{
FullTimestamp: true, FullTimestamp: true,
} }
@ -46,7 +43,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
}) })
// trustedCAs is currently a global, so loadCA operates on that global directly // trustedCAs is currently a global, so loadCA operates on that global directly
trustedCAs, err = loadCAFromConfig(config) trustedCAs, err = loadCAFromConfig(l, config)
if err != nil { if err != nil {
//The errors coming out of loadCA are already nicely formatted //The errors coming out of loadCA are already nicely formatted
return nil, NewContextualError("Failed to load ca from config", nil, err) return nil, NewContextualError("Failed to load ca from config", nil, err)
@ -60,7 +57,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
l.WithField("cert", cs.certificate).Debug("Client nebula certificate") l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
fw, err := NewFirewallFromConfig(cs.certificate, config) fw, err := NewFirewallFromConfig(l, cs.certificate, config)
if err != nil { if err != nil {
return nil, NewContextualError("Error while loading firewall rules", nil, err) return nil, NewContextualError("Error while loading firewall rules", nil, err)
} }
@ -78,9 +75,9 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
wireSSHReload(ssh, config) wireSSHReload(l, ssh, config)
if config.GetBool("sshd.enabled", false) { if config.GetBool("sshd.enabled", false) {
err = configSSH(ssh, config) err = configSSH(l, ssh, config)
if err != nil { if err != nil {
return nil, NewContextualError("Error while configuring the sshd", nil, err) return nil, NewContextualError("Error while configuring the sshd", nil, err)
} }
@ -136,6 +133,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l) tun = newDisabledTun(tunCidr, config.GetInt("tun.tx_queue", 500), config.GetBool("stats.message_metrics", false), l)
case tunFd != nil: case tunFd != nil:
tun, err = newTunFromFd( tun, err = newTunFromFd(
l,
*tunFd, *tunFd,
tunCidr, tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU), config.GetInt("tun.mtu", DEFAULT_MTU),
@ -145,6 +143,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
) )
default: default:
tun, err = newTun( tun, err = newTun(
l,
config.GetString("tun.dev", ""), config.GetString("tun.dev", ""),
tunCidr, tunCidr,
config.GetInt("tun.mtu", DEFAULT_MTU), config.GetInt("tun.mtu", DEFAULT_MTU),
@ -166,7 +165,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
if !configTest { if !configTest {
for i := 0; i < routines; i++ { for i := 0; i < routines; i++ {
udpServer, err := NewListener(config.GetString("listen.host", "0.0.0.0"), port, routines > 1) udpServer, err := NewListener(l, config.GetString("listen.host", "0.0.0.0"), port, routines > 1)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err) return nil, NewContextualError("Failed to open udp listener", m{"queue": i}, err)
} }
@ -222,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
} }
hostMap := NewHostMap("main", tunCidr, preferredRanges) hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0")))) hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
hostMap.addUnsafeRoutes(&unsafeRoutes) hostMap.addUnsafeRoutes(&unsafeRoutes)
hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false) hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
@ -266,6 +265,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
} }
lightHouse := NewLightHouse( lightHouse := NewLightHouse(
l,
amLighthouse, amLighthouse,
ip2int(tunCidr.IP), ip2int(tunCidr.IP),
lighthouseHosts, lighthouseHosts,
@ -337,7 +337,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
messageMetrics: messageMetrics, messageMetrics: messageMetrics,
} }
handshakeManager := NewHandshakeManager(tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig) handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger lightHouse.handshakeTrigger = handshakeManager.trigger
//TODO: These will be reused for psk //TODO: These will be reused for psk
@ -367,6 +367,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
version: buildVersion, version: buildVersion,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
l: l,
} }
switch ifConfig.Cipher { switch ifConfig.Cipher {
@ -395,7 +396,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
go lightHouse.LhUpdateWorker(ifce) go lightHouse.LhUpdateWorker(ifce)
} }
err = startStats(config, configTest) err = startStats(l, config, configTest)
if err != nil { if err != nil {
return nil, NewContextualError("Failed to start stats emitter", nil, err) return nil, NewContextualError("Failed to start stats emitter", nil, err)
} }
@ -407,12 +408,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
//TODO: check if we _should_ be emitting stats //TODO: check if we _should_ be emitting stats
go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10)) go ifce.emitStats(config.GetDuration("stats.interval", time.Second*10))
attachCommands(ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) attachCommands(l, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host // Start DNS server last to allow using the nebula IP as lighthouse.dns.host
if amLighthouse && serveDns { if amLighthouse && serveDns {
l.Debugln("Starting dns server") l.Debugln("Starting dns server")
go dnsMain(hostMap, config) go dnsMain(l, hostMap, config)
} }
return &Control{ifce, l}, nil return &Control{ifce, l}, nil

View File

@ -1 +1,30 @@
package nebula package nebula
import (
"io/ioutil"
"os"
"github.com/sirupsen/logrus"
)
func NewTestLogger() *logrus.Logger {
l := logrus.New()
v := os.Getenv("TEST_LOGS")
if v == "" {
l.SetOutput(ioutil.Discard)
return l
}
switch v {
case "1":
// This is the default level but we are being explicit
l.SetLevel(logrus.InfoLevel)
case "2":
l.SetLevel(logrus.DebugLevel)
case "3":
l.SetLevel(logrus.TraceLevel)
}
return l
}

View File

@ -24,7 +24,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 { if len(packet) > 1 {
l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
} }
return return
} }
@ -57,7 +57,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil { if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt lighthouse packet") Error("Failed to decrypt lighthouse packet")
@ -78,7 +78,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb) d, err := f.decrypt(hostinfo, header.MessageCounter, out, packet, header, nb)
if err != nil { if err != nil {
hostinfo.logger().WithError(err).WithField("udpAddr", addr). hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
WithField("packet", packet). WithField("packet", packet).
Error("Failed to decrypt test packet") Error("Failed to decrypt test packet")
@ -115,7 +115,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
return return
} }
hostinfo.logger().WithField("udpAddr", addr). hostinfo.logger(f.l).WithField("udpAddr", addr).
Info("Close tunnel received, tearing down.") Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo) f.closeTunnel(hostinfo)
@ -123,7 +123,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
default: default:
f.messageMetrics.Rx(header.Type, header.Subtype, 1) f.messageMetrics.Rx(header.Type, header.Subtype, 1)
hostinfo.logger().Debugf("Unexpected packet received from %s", addr) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
return return
} }
@ -143,18 +143,18 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
if hostDidRoam(hostinfo.remote, addr) { if hostDidRoam(hostinfo.remote, addr) {
if !f.lightHouse.remoteAllowList.Allow(addr.IP) { if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
hostinfo.logger().WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
return return
} }
if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
} }
return return
} }
hostinfo.logger().WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
Info("Host roamed to new udp ip/port.") Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now() hostinfo.lastRoam = time.Now()
remoteCopy := *hostinfo.remote remoteCopy := *hostinfo.remote
@ -170,7 +170,7 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool { func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udpAddr, header *Header) bool {
// If connectionstate exists and the replay protector allows, process packet // If connectionstate exists and the replay protector allows, process packet
// Else, send recv errors for 300 seconds after a restart to allow fast reconnection. // Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
if ci == nil || !ci.window.Check(header.MessageCounter) { if ci == nil || !ci.window.Check(f.l, header.MessageCounter) {
f.sendRecvError(addr, header.RemoteIndex) f.sendRecvError(addr, header.RemoteIndex)
return false return false
} }
@ -247,8 +247,8 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return nil, err return nil, err
} }
if !hostinfo.ConnectionState.window.Update(mc) { if !hostinfo.ConnectionState.window.Update(f.l, mc) {
hostinfo.logger().WithField("header", header). hostinfo.logger(f.l).WithField("header", header).
Debugln("dropping out of window packet") Debugln("dropping out of window packet")
return nil, errors.New("out of window packet") return nil, errors.New("out of window packet")
} }
@ -261,7 +261,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb) out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
if err != nil { if err != nil {
hostinfo.logger().WithError(err).Error("Failed to decrypt packet") hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
//TODO: maybe after build 64 is out? 06/14/2018 - NB //TODO: maybe after build 64 is out? 06/14/2018 - NB
//f.sendRecvError(hostinfo.remote, header.RemoteIndex) //f.sendRecvError(hostinfo.remote, header.RemoteIndex)
return return
@ -269,21 +269,21 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
err = newPacket(out, true, fwPacket) err = newPacket(out, true, fwPacket)
if err != nil { if err != nil {
hostinfo.logger().WithError(err).WithField("packet", out). hostinfo.logger(f.l).WithError(err).WithField("packet", out).
Warnf("Error while validating inbound packet") Warnf("Error while validating inbound packet")
return return
} }
if !hostinfo.ConnectionState.window.Update(messageCounter) { if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
hostinfo.logger().WithField("fwPacket", fwPacket). hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
Debugln("dropping out of window packet") Debugln("dropping out of window packet")
return return
} }
dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache) dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
if dropReason != nil { if dropReason != nil {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
hostinfo.logger().WithField("fwPacket", fwPacket). hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
WithField("reason", dropReason). WithField("reason", dropReason).
Debugln("dropping inbound packet") Debugln("dropping inbound packet")
} }
@ -293,7 +293,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
f.connectionManager.In(hostinfo.hostId) f.connectionManager.In(hostinfo.hostId)
_, err = f.readers[q].Write(out) _, err = f.readers[q].Write(out)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to write to tun") f.l.WithError(err).Error("Failed to write to tun")
} }
} }
@ -303,16 +303,16 @@ func (f *Interface) sendRecvError(endpoint *udpAddr, index uint32) {
//TODO: this should be a signed message so we can trust that we should drop the index //TODO: this should be a signed message so we can trust that we should drop the index
b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0) b := HeaderEncode(make([]byte, HeaderLen), Version, uint8(recvError), 0, index, 0)
f.outside.WriteTo(b, endpoint) f.outside.WriteTo(b, endpoint)
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
l.WithField("index", index). f.l.WithField("index", index).
WithField("udpAddr", endpoint). WithField("udpAddr", endpoint).
Debug("Recv error sent") Debug("Recv error sent")
} }
} }
func (f *Interface) handleRecvError(addr *udpAddr, h *Header) { func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
if l.Level >= logrus.DebugLevel { if f.l.Level >= logrus.DebugLevel {
l.WithField("index", h.RemoteIndex). f.l.WithField("index", h.RemoteIndex).
WithField("udpAddr", addr). WithField("udpAddr", addr).
Debug("Recv error received") Debug("Recv error received")
} }
@ -322,7 +322,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex) hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
if err != nil { if err != nil {
l.Debugln(err, ": ", h.RemoteIndex) f.l.Debugln(err, ": ", h.RemoteIndex)
return return
} }
@ -333,7 +333,7 @@ func (f *Interface) handleRecvError(addr *udpAddr, h *Header) {
return return
} }
if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() { if hostinfo.remote != nil && hostinfo.remote.String() != addr.String() {
l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
return return
} }

View File

@ -8,7 +8,8 @@ import (
) )
func TestNewPunchyFromConfig(t *testing.T) { func TestNewPunchyFromConfig(t *testing.T) {
c := NewConfig() l := NewTestLogger()
c := NewConfig(l)
// Test defaults // Test defaults
p := NewPunchyFromConfig(c) p := NewPunchyFromConfig(c)

20
ssh.go
View File

@ -44,10 +44,10 @@ type sshCreateTunnelFlags struct {
Address string Address string
} }
func wireSSHReload(ssh *sshd.SSHServer, c *Config) { func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) {
c.RegisterReloadCallback(func(c *Config) { c.RegisterReloadCallback(func(c *Config) {
if c.GetBool("sshd.enabled", false) { if c.GetBool("sshd.enabled", false) {
err := configSSH(ssh, c) err := configSSH(l, ssh, c)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to reconfigure the sshd") l.WithError(err).Error("Failed to reconfigure the sshd")
ssh.Stop() ssh.Stop()
@ -58,7 +58,7 @@ func wireSSHReload(ssh *sshd.SSHServer, c *Config) {
}) })
} }
func configSSH(ssh *sshd.SSHServer, c *Config) error { func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *Config) error {
//TODO conntrack list //TODO conntrack list
//TODO print firewall rules or hash? //TODO print firewall rules or hash?
@ -149,7 +149,7 @@ func configSSH(ssh *sshd.SSHServer, c *Config) error {
return nil return nil
} }
func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { func attachCommands(l *logrus.Logger, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "list-hostmap", Name: "list-hostmap",
ShortDescription: "List all known previously connected hosts", ShortDescription: "List all known previously connected hosts",
@ -225,13 +225,17 @@ func attachCommands(ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostM
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "log-level", Name: "log-level",
ShortDescription: "Gets or sets the current log level", ShortDescription: "Gets or sets the current log level",
Callback: sshLogLevel, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogLevel(l, fs, a, w)
},
}) })
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
Name: "log-format", Name: "log-format",
ShortDescription: "Gets or sets the current log format", ShortDescription: "Gets or sets the current log format",
Callback: sshLogFormat, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
return sshLogFormat(l, fs, a, w)
},
}) })
ssh.RegisterCommand(&sshd.Command{ ssh.RegisterCommand(&sshd.Command{
@ -629,7 +633,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
return err return err
} }
func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error { func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
} }
@ -643,7 +647,7 @@ func sshLogLevel(fs interface{}, a []string, w sshd.StringWriter) error {
return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
} }
func sshLogFormat(fs interface{}, a []string, w sshd.StringWriter) error { func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
if len(a) == 0 { if len(a) == 0 {
return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
} }

View File

@ -13,9 +13,10 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
) )
func startStats(c *Config, configTest bool) error { func startStats(l *logrus.Logger, c *Config, configTest bool) error {
mType := c.GetString("stats.type", "") mType := c.GetString("stats.type", "")
if mType == "" || mType == "none" { if mType == "" || mType == "none" {
return nil return nil
@ -28,9 +29,9 @@ func startStats(c *Config, configTest bool) error {
switch mType { switch mType {
case "graphite": case "graphite":
startGraphiteStats(interval, c, configTest) startGraphiteStats(l, interval, c, configTest)
case "prometheus": case "prometheus":
startPrometheusStats(interval, c, configTest) startPrometheusStats(l, interval, c, configTest)
default: default:
return fmt.Errorf("stats.type was not understood: %s", mType) return fmt.Errorf("stats.type was not understood: %s", mType)
} }
@ -44,7 +45,7 @@ func startStats(c *Config, configTest bool) error {
return nil return nil
} }
func startGraphiteStats(i time.Duration, c *Config, configTest bool) error { func startGraphiteStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
proto := c.GetString("stats.protocol", "tcp") proto := c.GetString("stats.protocol", "tcp")
host := c.GetString("stats.host", "") host := c.GetString("stats.host", "")
if host == "" { if host == "" {
@ -64,7 +65,7 @@ func startGraphiteStats(i time.Duration, c *Config, configTest bool) error {
return nil return nil
} }
func startPrometheusStats(i time.Duration, c *Config, configTest bool) error { func startPrometheusStats(l *logrus.Logger, i time.Duration, c *Config, configTest bool) error {
namespace := c.GetString("stats.namespace", "") namespace := c.GetString("stats.namespace", "")
subsystem := c.GetString("stats.subsystem", "") subsystem := c.GetString("stats.subsystem", "")

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"os" "os"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -19,9 +20,10 @@ type Tun struct {
TXQueueLen int TXQueueLen int
Routes []route Routes []route
UnsafeRoutes []route UnsafeRoutes []route
l *logrus.Logger
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
ifce = &Tun{ ifce = &Tun{
@ -33,6 +35,7 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen, TXQueueLen: txQueueLen,
Routes: routes, Routes: routes,
UnsafeRoutes: unsafeRoutes, UnsafeRoutes: unsafeRoutes,
l: l,
} }
return return
} }

View File

@ -9,6 +9,7 @@ import (
"os/exec" "os/exec"
"strconv" "strconv"
"github.com/sirupsen/logrus"
"github.com/songgao/water" "github.com/songgao/water"
) )
@ -17,11 +18,11 @@ type Tun struct {
Cidr *net.IPNet Cidr *net.IPNet
MTU int MTU int
UnsafeRoutes []route UnsafeRoutes []route
l *logrus.Logger
*water.Interface *water.Interface
} }
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Darwin") return nil, fmt.Errorf("route MTU not supported in Darwin")
} }
@ -31,10 +32,11 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr, Cidr: cidr,
MTU: defaultMTU, MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes, UnsafeRoutes: unsafeRoutes,
l: l,
}, nil }, nil
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin") return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
} }

View File

@ -9,24 +9,23 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
) )
type disabledTun struct { type disabledTun struct {
read chan []byte read chan []byte
cidr *net.IPNet cidr *net.IPNet
logger *log.Logger
// Track these metrics since we don't have the tun device to do it for us // Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter tx metrics.Counter
rx metrics.Counter rx metrics.Counter
l *logrus.Logger
} }
func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *log.Logger) *disabledTun { func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
tun := &disabledTun{ tun := &disabledTun{
cidr: cidr, cidr: cidr,
read: make(chan []byte, queueLen), read: make(chan []byte, queueLen),
logger: l, l: l,
} }
if metricsEnabled { if metricsEnabled {
@ -63,8 +62,8 @@ func (t *disabledTun) Read(b []byte) (int, error) {
} }
t.tx.Inc(1) t.tx.Inc(1)
if l.Level >= logrus.DebugLevel { if t.l.Level >= logrus.DebugLevel {
t.logger.WithField("raw", prettyPacket(r)).Debugf("Write payload") t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload")
} }
return copy(b, r), nil return copy(b, r), nil
@ -103,7 +102,7 @@ func (t *disabledTun) handleICMPEchoRequest(b []byte) bool {
select { select {
case t.read <- buf: case t.read <- buf:
default: default:
t.logger.Debugf("tun_disabled: dropped ICMP Echo Reply response") t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response")
} }
return true return true
@ -114,11 +113,11 @@ func (t *disabledTun) Write(b []byte) (int, error) {
// Check for ICMP Echo Request before spending time doing the full parsing // Check for ICMP Echo Request before spending time doing the full parsing
if t.handleICMPEchoRequest(b) { if t.handleICMPEchoRequest(b) {
if l.Level >= logrus.DebugLevel { if t.l.Level >= logrus.DebugLevel {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request")
} }
} else if l.Level >= logrus.DebugLevel { } else if t.l.Level >= logrus.DebugLevel {
t.logger.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload")
} }
return len(b), nil return len(b), nil
} }

View File

@ -9,6 +9,8 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/sirupsen/logrus"
) )
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
@ -18,15 +20,16 @@ type Tun struct {
Cidr *net.IPNet Cidr *net.IPNet
MTU int MTU int
UnsafeRoutes []route UnsafeRoutes []route
l *logrus.Logger
io.ReadWriteCloser io.ReadWriteCloser
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
} }
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("Route MTU not supported in FreeBSD") return nil, fmt.Errorf("Route MTU not supported in FreeBSD")
} }
@ -41,6 +44,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr, Cidr: cidr,
MTU: defaultMTU, MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes, UnsafeRoutes: unsafeRoutes,
l: l,
}, nil }, nil
} }
@ -52,21 +56,21 @@ func (c *Tun) Activate() error {
} }
// TODO use syscalls instead of exec.Command // TODO use syscalls instead of exec.Command
l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()) c.l.Debug("command: ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String())
if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil { if err = exec.Command("/sbin/ifconfig", c.Device, c.Cidr.String(), c.Cidr.IP.String()).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device) c.l.Debug("command: route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil { if err = exec.Command("/sbin/route", "-n", "add", "-net", c.Cidr.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add': %s", err) return fmt.Errorf("failed to run 'route add': %s", err)
} }
l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)) c.l.Debug("command: ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU))
if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil { if err = exec.Command("/sbin/ifconfig", c.Device, "mtu", strconv.Itoa(c.MTU)).Run(); err != nil {
return fmt.Errorf("failed to run 'ifconfig': %s", err) return fmt.Errorf("failed to run 'ifconfig': %s", err)
} }
// Unsafe path routes // Unsafe path routes
for _, r := range c.UnsafeRoutes { for _, r := range c.UnsafeRoutes {
l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device) c.l.Debug("command: route", "-n", "add", "-net", r.route.String(), "-interface", c.Device)
if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil { if err = exec.Command("/sbin/route", "-n", "add", "-net", r.route.String(), "-interface", c.Device).Run(); err != nil {
return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err) return fmt.Errorf("failed to run 'route add' for unsafe_route %s: %s", r.route.String(), err)
} }

View File

@ -10,6 +10,7 @@ import (
"strings" "strings"
"unsafe" "unsafe"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -24,6 +25,7 @@ type Tun struct {
TXQueueLen int TXQueueLen int
Routes []route Routes []route
UnsafeRoutes []route UnsafeRoutes []route
l *logrus.Logger
} }
type ifReq struct { type ifReq struct {
@ -78,7 +80,7 @@ type ifreqQLEN struct {
pad [8]byte pad [8]byte
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -91,11 +93,12 @@ func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen, TXQueueLen: txQueueLen,
Routes: routes, Routes: routes,
UnsafeRoutes: unsafeRoutes, UnsafeRoutes: unsafeRoutes,
l: l,
} }
return return
} }
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@ -131,6 +134,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
TXQueueLen: txQueueLen, TXQueueLen: txQueueLen,
Routes: routes, Routes: routes,
UnsafeRoutes: unsafeRoutes, UnsafeRoutes: unsafeRoutes,
l: l,
} }
return return
} }
@ -233,14 +237,14 @@ func (c Tun) Activate() error {
ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)} ifm := ifreqMTU{Name: devName, MTU: int32(c.MaxMTU)}
if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
// This is currently a non fatal condition because the route table must have the MTU set appropriately as well // This is currently a non fatal condition because the route table must have the MTU set appropriately as well
l.WithError(err).Error("Failed to set tun mtu") c.l.WithError(err).Error("Failed to set tun mtu")
} }
// Set the transmit queue length // Set the transmit queue length
ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)} ifrq := ifreqQLEN{Name: devName, Value: int32(c.TXQueueLen)}
if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
// If we can't set the queue length nebula will still work but it may lead to packet loss // If we can't set the queue length nebula will still work but it may lead to packet loss
l.WithError(err).Error("Failed to set tun tx queue length") c.l.WithError(err).Error("Failed to set tun tx queue length")
} }
// Bring up the interface // Bring up the interface

View File

@ -9,7 +9,8 @@ import (
) )
func Test_parseRoutes(t *testing.T) { func Test_parseRoutes(t *testing.T) {
c := NewConfig() l := NewTestLogger()
c := NewConfig(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") _, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config // test no routes config
@ -104,7 +105,8 @@ func Test_parseRoutes(t *testing.T) {
} }
func Test_parseUnsafeRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) {
c := NewConfig() l := NewTestLogger()
c := NewConfig(l)
_, n, _ := net.ParseCIDR("10.0.0.0/24") _, n, _ := net.ParseCIDR("10.0.0.0/24")
// test no routes config // test no routes config

View File

@ -7,6 +7,7 @@ import (
"os/exec" "os/exec"
"strconv" "strconv"
"github.com/sirupsen/logrus"
"github.com/songgao/water" "github.com/songgao/water"
) )
@ -15,15 +16,16 @@ type Tun struct {
Cidr *net.IPNet Cidr *net.IPNet
MTU int MTU int
UnsafeRoutes []route UnsafeRoutes []route
l *logrus.Logger
*water.Interface *water.Interface
} }
func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { func newTunFromFd(l *logrus.Logger, deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows") return nil, fmt.Errorf("newTunFromFd not supported in Windows")
} }
func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) { func newTun(l *logrus.Logger, deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int, multiqueue bool) (ifce *Tun, err error) {
if len(routes) > 0 { if len(routes) > 0 {
return nil, fmt.Errorf("route MTU not supported in Windows") return nil, fmt.Errorf("route MTU not supported in Windows")
} }
@ -33,6 +35,7 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route,
Cidr: cidr, Cidr: cidr,
MTU: defaultMTU, MTU: defaultMTU,
UnsafeRoutes: unsafeRoutes, UnsafeRoutes: unsafeRoutes,
l: l,
}, nil }, nil
} }

View File

@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula package nebula
import ( import (

View File

@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula package nebula
// Darwin support is primarily implemented in udp_generic, besides NewListenConfig // Darwin support is primarily implemented in udp_generic, besides NewListenConfig

View File

@ -1,3 +1,5 @@
// +build !e2e_testing
package nebula package nebula
// FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig

View File

@ -1,4 +1,5 @@
// +build !linux android // +build !linux android
// +build !e2e_testing
// udp_generic implements the nebula UDP interface in pure Go stdlib. This // udp_generic implements the nebula UDP interface in pure Go stdlib. This
// means it can be used on platforms like Darwin and Windows. // means it can be used on platforms like Darwin and Windows.
@ -9,20 +10,23 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"github.com/sirupsen/logrus"
) )
type udpConn struct { type udpConn struct {
*net.UDPConn *net.UDPConn
l *logrus.Logger
} }
func NewListener(ip string, port int, multi bool) (*udpConn, error) { func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
lc := NewListenConfig(multi) lc := NewListenConfig(multi)
pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port)) pc, err := lc.ListenPacket(context.TODO(), "udp", fmt.Sprintf("%s:%d", ip, port))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if uc, ok := pc.(*net.UDPConn); ok { if uc, ok := pc.(*net.UDPConn); ok {
return &udpConn{UDPConn: uc}, nil return &udpConn{UDPConn: uc, l: l}, nil
} }
return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
} }
@ -76,13 +80,13 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
// Just read one packet at a time // Just read one packet at a time
n, rua, err := u.ReadFromUDP(buffer) n, rua, err := u.ReadFromUDP(buffer)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to read packets") f.l.WithError(err).Error("Failed to read packets")
continue continue
} }
udpAddr.IP = rua.IP udpAddr.IP = rua.IP
udpAddr.Port = uint16(rua.Port) udpAddr.Port = uint16(rua.Port)
f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get()) f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get(f.l))
} }
} }

View File

@ -1,4 +1,5 @@
// +build !android // +build !android
// +build !e2e_testing
package nebula package nebula
@ -10,6 +11,7 @@ import (
"unsafe" "unsafe"
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -17,6 +19,7 @@ import (
type udpConn struct { type udpConn struct {
sysFd int sysFd int
l *logrus.Logger
} }
var x int var x int
@ -38,7 +41,7 @@ const (
type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
func NewListener(ip string, port int, multi bool) (*udpConn, error) { func NewListener(l *logrus.Logger, ip string, port int, multi bool) (*udpConn, error) {
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil { if err == nil {
@ -70,7 +73,7 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) {
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err) //l.Println(v, err)
return &udpConn{sysFd: fd}, err return &udpConn{sysFd: fd, l: l}, err
} }
func (u *udpConn) Rebind() error { func (u *udpConn) Rebind() error {
@ -153,7 +156,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
for { for {
n, err := read(msgs) n, err := read(msgs)
if err != nil { if err != nil {
l.WithError(err).Error("Failed to read packets") u.l.WithError(err).Error("Failed to read packets")
continue continue
} }
@ -161,7 +164,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24] udpAddr.IP = names[i][8:24]
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get()) f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get(u.l))
} }
} }
} }
@ -244,12 +247,12 @@ func (u *udpConn) reloadConfig(c *Config) {
if err == nil { if err == nil {
s, err := u.GetRecvBuffer() s, err := u.GetRecvBuffer()
if err == nil { if err == nil {
l.WithField("size", s).Info("listen.read_buffer was set") u.l.WithField("size", s).Info("listen.read_buffer was set")
} else { } else {
l.WithError(err).Warn("Failed to get listen.read_buffer") u.l.WithError(err).Warn("Failed to get listen.read_buffer")
} }
} else { } else {
l.WithError(err).Error("Failed to set listen.read_buffer") u.l.WithError(err).Error("Failed to set listen.read_buffer")
} }
} }
@ -259,12 +262,12 @@ func (u *udpConn) reloadConfig(c *Config) {
if err == nil { if err == nil {
s, err := u.GetSendBuffer() s, err := u.GetSendBuffer()
if err == nil { if err == nil {
l.WithField("size", s).Info("listen.write_buffer was set") u.l.WithField("size", s).Info("listen.write_buffer was set")
} else { } else {
l.WithError(err).Warn("Failed to get listen.write_buffer") u.l.WithError(err).Warn("Failed to get listen.write_buffer")
} }
} else { } else {
l.WithError(err).Error("Failed to set listen.write_buffer") u.l.WithError(err).Error("Failed to set listen.write_buffer")
} }
} }
} }

View File

@ -1,6 +1,7 @@
// +build linux // +build linux
// +build 386 amd64p32 arm mips mipsle // +build 386 amd64p32 arm mips mipsle
// +build !android // +build !android
// +build !e2e_testing
package nebula package nebula

View File

@ -1,6 +1,7 @@
// +build linux // +build linux
// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x
// +build !android // +build !android
// +build !e2e_testing
package nebula package nebula