diff --git a/connection_manager_test.go b/connection_manager_test.go index 789b8ed..81bb049 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -6,8 +6,8 @@ import ( "time" "github.com/flynn/noise" - "github.com/stretchr/testify/assert" "github.com/slackhq/nebula/cert" + "github.com/stretchr/testify/assert" ) var vpnIP uint32 = uint32(12341234) @@ -27,7 +27,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(false, 0, []string{}, 1000, 0, &udpConn{}, false) + lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false) ifce := &Interface{ hostMap: hostMap, inside: &Tun{}, @@ -90,7 +90,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { rawCertificateNoKey: []byte{}, } - lh := NewLightHouse(false, 0, []string{}, 1000, 0, &udpConn{}, false) + lh := NewLightHouse(false, 0, []uint32{}, 1000, 0, &udpConn{}, false) ifce := &Interface{ hostMap: hostMap, inside: &Tun{}, diff --git a/lighthouse.go b/lighthouse.go index c19cbb7..7974be4 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -33,7 +33,7 @@ type EncWriter interface { SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) } -func NewLightHouse(amLighthouse bool, myIp uint32, ips []string, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse { +func NewLightHouse(amLighthouse bool, myIp uint32, ips []uint32, interval int, nebulaPort int, pc *udpConn, punchBack bool) *LightHouse { h := LightHouse{ amLighthouse: amLighthouse, myIp: myIp, @@ -46,8 +46,8 @@ func NewLightHouse(amLighthouse bool, myIp uint32, ips []string, interval int, n punchBack: punchBack, } - for _, rIp := range ips { - h.lighthouses[ip2int(net.ParseIP(rIp))] = struct{}{} + for _, ip := range ips { + h.lighthouses[ip] = struct{}{} } return &h diff --git a/lighthouse_test.go b/lighthouse_test.go index 08a9857..96d77bd 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -48,18 +48,19 @@ func TestNewipandportsfromudpaddrs(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { lh1 := "10.128.0.2" - lh1IP := net.ParseIP(lh1) udpServer, _ := NewListener("0.0.0.0", 0, true) - meh := NewLightHouse(true, 1, []string{lh1}, 10, 10003, udpServer, false) + meh := NewLightHouse(true, 1, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true) err := meh.ValidateLHStaticEntries() assert.Nil(t, err) lh2 := "10.128.0.3" - meh = NewLightHouse(true, 1, []string{lh1, lh2}, 10, 10003, udpServer, false) + lh2IP := net.ParseIP(lh2) + + meh = NewLightHouse(true, 1, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false) meh.AddRemote(ip2int(lh1IP), NewUDPAddr(ip2int(lh1IP), uint16(4242)), true) err = meh.ValidateLHStaticEntries() assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry") diff --git a/main.go b/main.go index f835aea..2fd6f83 100644 --- a/main.go +++ b/main.go @@ -190,16 +190,25 @@ func Main(configPath string, configTest bool, buildVersion string) { amLighthouse := config.GetBool("lighthouse.am_lighthouse", false) // warn if am_lighthouse is enabled but upstream lighthouses exists - lighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{}) - if amLighthouse && len(lighthouseHosts) != 0 { + rawLighthouseHosts := config.GetStringSlice("lighthouse.hosts", []string{}) + if amLighthouse && len(rawLighthouseHosts) != 0 { l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } + lighthouseHosts := make([]uint32, len(rawLighthouseHosts)) + for i, host := range rawLighthouseHosts { + ip := net.ParseIP(host) + if ip == nil { + l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1) + } + lighthouseHosts[i] = ip2int(ip) + } + serveDns := config.GetBool("lighthouse.serve_dns", false) lightHouse := NewLightHouse( amLighthouse, ip2int(tunCidr.IP), - config.GetStringSlice("lighthouse.hosts", []string{}), + lighthouseHosts, //TODO: change to a duration config.GetInt("lighthouse.interval", 10), port,