From 41578ca971ab1c0e291bbe0706072c7b62d7d658 Mon Sep 17 00:00:00 2001 From: Nathan Brown Date: Tue, 30 Jun 2020 13:48:58 -0500 Subject: [PATCH] Be more like a library to support mobile (#247) --- bits_test.go | 4 +- cmd/nebula-service/main.go | 26 +++++- cmd/nebula-service/service.go | 13 ++- cmd/nebula/main.go | 23 ++++- config.go | 20 +++++ logger.go | 31 +++++++ logger_test.go | 66 +++++++++++++++ main.go | 154 ++++++++++++++++++++++------------ tun_darwin.go | 14 +++- tun_freebsd.go | 4 + tun_ios.go | 105 +++++++++++++++++++++++ tun_linux.go | 17 ++++ tun_windows.go | 6 +- udp_android.go | 36 ++++++++ udp_darwin.go | 9 ++ udp_freebsd.go | 4 + udp_generic.go | 2 +- udp_linux.go | 6 ++ udp_linux_32.go | 1 + udp_linux_64.go | 1 + udp_windows.go | 4 + 21 files changed, 477 insertions(+), 69 deletions(-) create mode 100644 logger.go create mode 100644 logger_test.go create mode 100644 tun_ios.go create mode 100644 udp_android.go diff --git a/bits_test.go b/bits_test.go index f918c82..880fcd9 100644 --- a/bits_test.go +++ b/bits_test.go @@ -212,10 +212,10 @@ func TestBitsLostCounter(t *testing.T) { func BenchmarkBits(b *testing.B) { z := NewBits(10) for n := 0; n < b.N; n++ { - for i, _ := range z.bits { + for i := range z.bits { z.bits[i] = true } - for i, _ := range z.bits { + for i := range z.bits { z.bits[i] = false } diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 7e0634e..b5ea062 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -3,9 +3,9 @@ package main import ( "flag" "fmt" - "os" - + "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "os" ) // A version string that can be set with @@ -45,5 +45,25 @@ func main() { os.Exit(1) } - nebula.Main(*configPath, *configTest, Build) + config := nebula.NewConfig() + err := config.Load(*configPath) + if err != nil { + fmt.Printf("failed to load config: %s", err) + os.Exit(1) + } + + l := logrus.New() + l.Out = os.Stdout + err = nebula.Main(config, *configTest, true, Build, l, nil, nil) + + switch v := err.(type) { + case nebula.ContextualError: + v.Log(l) + os.Exit(1) + case error: + l.WithError(err).Error("Failed to start") + os.Exit(1) + } + + os.Exit(0) } diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index 5d88160..e96bbf4 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -1,6 +1,8 @@ package main import ( + "fmt" + "github.com/sirupsen/logrus" "log" "os" "path/filepath" @@ -27,8 +29,15 @@ func (p *program) Start(s service.Service) error { } func (p *program) run() error { - nebula.Main(*p.configPath, *p.configTest, Build) - return nil + config := nebula.NewConfig() + err := config.Load(*p.configPath) + if err != nil { + return fmt.Errorf("failed to load config: %s", err) + } + + l := logrus.New() + l.Out = os.Stdout + return nebula.Main(config, *p.configTest, true, Build, l, nil, nil) } func (p *program) Stop(s service.Service) error { diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 62156a3..9e23aca 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "github.com/sirupsen/logrus" "os" "github.com/slackhq/nebula" @@ -39,5 +40,25 @@ func main() { os.Exit(1) } - nebula.Main(*configPath, *configTest, Build) + config := nebula.NewConfig() + err := config.Load(*configPath) + if err != nil { + fmt.Printf("failed to load config: %s", err) + os.Exit(1) + } + + l := logrus.New() + l.Out = os.Stdout + err = nebula.Main(config, *configTest, true, Build, l, nil, nil) + + switch v := err.(type) { + case nebula.ContextualError: + v.Log(l) + os.Exit(1) + case error: + l.WithError(err).Error("Failed to start") + os.Exit(1) + } + + os.Exit(0) } diff --git a/config.go b/config.go index 570cd85..558c4c5 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package nebula import ( + "errors" "fmt" "github.com/imdario/mergo" "github.com/sirupsen/logrus" @@ -56,6 +57,13 @@ func (c *Config) Load(path string) error { return nil } +func (c *Config) LoadString(raw string) error { + if raw == "" { + return errors.New("Empty configuration") + } + return c.parseRaw([]byte(raw)) +} + // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered // here should decide if they need to make a change to the current process before making the change. HasChanged can be // used to help decide if a change is necessary. @@ -407,6 +415,18 @@ func (c *Config) addFile(path string, direct bool) error { return nil } +func (c *Config) parseRaw(b []byte) error { + var m map[interface{}]interface{} + + err := yaml.Unmarshal(b, &m) + if err != nil { + return err + } + + c.Settings = m + return nil +} + func (c *Config) parse() error { var m map[interface{}]interface{} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..c8c7fdf --- /dev/null +++ b/logger.go @@ -0,0 +1,31 @@ +package nebula + +import ( + "github.com/sirupsen/logrus" +) + +type ContextualError struct { + RealError error + Fields map[string]interface{} + Context string +} + +func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError { + return ContextualError{Context: msg, Fields: fields, RealError: realError} +} + +func (ce ContextualError) Error() string { + return ce.RealError.Error() +} + +func (ce ContextualError) Unwrap() error { + return ce.RealError +} + +func (ce *ContextualError) Log(lr *logrus.Logger) { + if ce.RealError != nil { + lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) + } else { + lr.WithFields(ce.Fields).Error(ce.Context) + } +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..2cd82d4 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,66 @@ +package nebula + +import ( + "errors" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "testing" +) + +type TestLogWriter struct { + Logs []string +} + +func NewTestLogWriter() *TestLogWriter { + return &TestLogWriter{Logs: make([]string, 0)} +} + +func (tl *TestLogWriter) Write(p []byte) (n int, err error) { + tl.Logs = append(tl.Logs, string(p)) + return len(p), nil +} + +func (tl *TestLogWriter) Reset() { + tl.Logs = tl.Logs[:0] +} + +func TestContextualError_Log(t *testing.T) { + l := logrus.New() + l.Formatter = &logrus.TextFormatter{ + DisableTimestamp: true, + DisableColors: true, + } + + tl := NewTestLogWriter() + l.Out = tl + + // Test a full context line + tl.Reset() + e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) + e.Log(l) + assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) + + // Test a line with an error and msg but no fields + tl.Reset() + e = NewContextualError("test message", nil, errors.New("error")) + e.Log(l) + assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) + + // Test just a context and fields + tl.Reset() + e = NewContextualError("test message", m{"field": "1"}, nil) + e.Log(l) + assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) + + // Test just a context + tl.Reset() + e = NewContextualError("test message", nil, nil) + e.Log(l) + assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) + + // Test just an error + tl.Reset() + e = NewContextualError("", nil, errors.New("error")) + e.Log(l) + assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) +} diff --git a/main.go b/main.go index 9eb6584..470804e 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,9 @@ package nebula import ( "encoding/binary" "fmt" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/sshd" + "gopkg.in/yaml.v2" "net" "os" "os/signal" @@ -10,42 +13,38 @@ import ( "strings" "syscall" "time" - - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/sshd" - "gopkg.in/yaml.v2" ) +// The caller should provide a real logger, we have one just in case var l = logrus.New() type m map[string]interface{} -func Main(configPath string, configTest bool, buildVersion string) { - l.Out = os.Stdout +type CommandRequest struct { + Command string + Callback chan error +} + +func Main(config *Config, configTest bool, block bool, buildVersion string, logger *logrus.Logger, tunFd *int, commandChan <-chan CommandRequest) error { + l = logger l.Formatter = &logrus.TextFormatter{ FullTimestamp: true, } - config := NewConfig() - err := config.Load(configPath) - if err != nil { - l.WithError(err).Error("Failed to load config") - os.Exit(1) - } - // Print the config if in test, the exit comes later if configTest { b, err := yaml.Marshal(config.Settings) if err != nil { - l.Println(err) - os.Exit(1) + return err } + + // Print the final config l.Println(string(b)) } - err = configLogger(config) + err := configLogger(config) if err != nil { - l.WithError(err).Error("Failed to configure the logger") + return NewContextualError("Failed to configure the logger", nil, err) } config.RegisterReloadCallback(func(c *Config) { @@ -59,20 +58,20 @@ func Main(configPath string, configTest bool, buildVersion string) { trustedCAs, err = loadCAFromConfig(config) if err != nil { //The errors coming out of loadCA are already nicely formatted - l.WithError(err).Fatal("Failed to load ca from config") + return NewContextualError("Failed to load ca from config", nil, err) } l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints") cs, err := NewCertStateFromConfig(config) if err != nil { //The errors coming out of NewCertStateFromConfig are already nicely formatted - l.WithError(err).Fatal("Failed to load certificate from config") + return NewContextualError("Failed to load certificate from config", nil, err) } l.WithField("cert", cs.certificate).Debug("Client nebula certificate") fw, err := NewFirewallFromConfig(cs.certificate, config) if err != nil { - l.WithError(err).Fatal("Error while loading firewall rules") + return NewContextualError("Error while loading firewall rules", nil, err) } l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started") @@ -80,11 +79,11 @@ func Main(configPath string, configTest bool, buildVersion string) { tunCidr := cs.certificate.Details.Ips[0] routes, err := parseRoutes(config, tunCidr) if err != nil { - l.WithError(err).Fatal("Could not parse tun.routes") + return NewContextualError("Could not parse tun.routes", nil, err) } unsafeRoutes, err := parseUnsafeRoutes(config, tunCidr) if err != nil { - l.WithError(err).Fatal("Could not parse tun.unsafe_routes") + return NewContextualError("Could not parse tun.unsafe_routes", nil, err) } ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) @@ -92,7 +91,7 @@ func Main(configPath string, configTest bool, buildVersion string) { if config.GetBool("sshd.enabled", false) { err = configSSH(ssh, config) if err != nil { - l.WithError(err).Fatal("Error while configuring the sshd") + return NewContextualError("Error while configuring the sshd", nil, err) } } @@ -105,17 +104,28 @@ func Main(configPath string, configTest bool, buildVersion string) { if !configTest { config.CatchHUP() - // set up our tun dev - tun, err = newTun( - config.GetString("tun.dev", ""), - tunCidr, - config.GetInt("tun.mtu", DEFAULT_MTU), - routes, - unsafeRoutes, - config.GetInt("tun.tx_queue", 500), - ) + if tunFd != nil { + tun, err = newTunFromFd( + *tunFd, + tunCidr, + config.GetInt("tun.mtu", DEFAULT_MTU), + routes, + unsafeRoutes, + config.GetInt("tun.tx_queue", 500), + ) + } else { + tun, err = newTun( + config.GetString("tun.dev", ""), + tunCidr, + config.GetInt("tun.mtu", DEFAULT_MTU), + routes, + unsafeRoutes, + config.GetInt("tun.tx_queue", 500), + ) + } + if err != nil { - l.WithError(err).Fatal("Failed to get a tun/tap device") + return NewContextualError("Failed to get a tun/tap device", nil, err) } } @@ -126,11 +136,28 @@ func Main(configPath string, configTest bool, buildVersion string) { if !configTest { udpServer, err = NewListener(config.GetString("listen.host", "0.0.0.0"), config.GetInt("listen.port", 0), udpQueues > 1) if err != nil { - l.WithError(err).Fatal("Failed to open udp listener") + return NewContextualError("Failed to open udp listener", nil, err) } udpServer.reloadConfig(config) } + sigChan := make(chan os.Signal) + killChan := make(chan CommandRequest) + if commandChan != nil { + go func() { + cmd := CommandRequest{} + for { + cmd = <-commandChan + switch cmd.Command { + case "rebind": + udpServer.Rebind() + case "exit": + killChan <- cmd + } + } + }() + } + // Set up my internal host map var preferredRanges []*net.IPNet rawPreferredRanges := config.GetStringSlice("preferred_ranges", []string{}) @@ -139,7 +166,7 @@ func Main(configPath string, configTest bool, buildVersion string) { for _, rawPreferredRange := range rawPreferredRanges { _, preferredRange, err := net.ParseCIDR(rawPreferredRange) if err != nil { - l.WithError(err).Fatal("Failed to parse preferred ranges") + return NewContextualError("Failed to parse preferred ranges", nil, err) } preferredRanges = append(preferredRanges, preferredRange) } @@ -152,7 +179,7 @@ func Main(configPath string, configTest bool, buildVersion string) { if rawLocalRange != "" { _, localRange, err := net.ParseCIDR(rawLocalRange) if err != nil { - l.WithError(err).Fatal("Failed to parse local range") + return NewContextualError("Failed to parse local_range", nil, err) } // Check if the entry for local_range was already specified in @@ -192,7 +219,7 @@ func Main(configPath string, configTest bool, buildVersion string) { if port == 0 && !configTest { uPort, err := udpServer.LocalAddr() if err != nil { - l.WithError(err).Fatal("Failed to get listening port") + return NewContextualError("Failed to get listening port", nil, err) } port = int(uPort.Port) } @@ -209,10 +236,10 @@ func Main(configPath string, configTest bool, buildVersion string) { for i, host := range rawLighthouseHosts { ip := net.ParseIP(host) if ip == nil { - l.WithField("host", host).Fatalf("Unable to parse lighthouse host entry %v", i+1) + return NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) } if !tunCidr.Contains(ip) { - l.WithField("vpnIp", ip).WithField("network", tunCidr.String()).Fatalf("lighthouse host is not in our subnet, invalid") + return NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) } lighthouseHosts[i] = ip2int(ip) } @@ -232,13 +259,13 @@ func Main(configPath string, configTest bool, buildVersion string) { remoteAllowList, err := config.GetAllowList("lighthouse.remote_allow_list", false) if err != nil { - l.WithError(err).Fatal("Invalid lighthouse.remote_allow_list") + return NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } lightHouse.SetRemoteAllowList(remoteAllowList) localAllowList, err := config.GetAllowList("lighthouse.local_allow_list", true) if err != nil { - l.WithError(err).Fatal("Invalid lighthouse.local_allow_list") + return NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } lightHouse.SetLocalAllowList(localAllowList) @@ -246,7 +273,7 @@ func Main(configPath string, configTest bool, buildVersion string) { for k, v := range config.GetMap("static_host_map", map[interface{}]interface{}{}) { vpnIp := net.ParseIP(fmt.Sprintf("%v", k)) if !tunCidr.Contains(vpnIp) { - l.WithField("vpnIp", vpnIp).WithField("network", tunCidr.String()).Fatalf("static_host_map key is not in our subnet, invalid") + return NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": tunCidr.String()}, nil) } vals, ok := v.([]interface{}) if ok { @@ -257,7 +284,7 @@ func Main(configPath string, configTest bool, buildVersion string) { ip := addr.IP port, err := strconv.Atoi(parts[1]) if err != nil { - l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) + return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) } @@ -270,7 +297,7 @@ func Main(configPath string, configTest bool, buildVersion string) { ip := addr.IP port, err := strconv.Atoi(parts[1]) if err != nil { - l.Fatalf("Static host address for %s could not be parsed: %s", vpnIp, v) + return NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err) } lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip2int(ip), uint16(port)), true) } @@ -330,14 +357,14 @@ func Main(configPath string, configTest bool, buildVersion string) { case "chachapoly": noiseEndianness = binary.LittleEndian default: - l.Fatalf("Unknown cipher: %v", ifConfig.Cipher) + return fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) } var ifce *Interface if !configTest { ifce, err = NewInterface(ifConfig) if err != nil { - l.WithError(err).Fatal("Failed to initialize interface") + return fmt.Errorf("failed to initialize interface: %s", err) } ifce.RegisterConfigChangeCallbacks(config) @@ -348,11 +375,11 @@ func Main(configPath string, configTest bool, buildVersion string) { err = startStats(config, configTest) if err != nil { - l.WithError(err).Fatal("Failed to start stats emitter") + return NewContextualError("Failed to start stats emitter", nil, err) } if configTest { - os.Exit(0) + return nil } //TODO: check if we _should_ be emitting stats @@ -367,19 +394,33 @@ func Main(configPath string, configTest bool, buildVersion string) { go dnsMain(hostMap, config) } - // Just sit here and be friendly, main thread. - shutdownBlock(ifce) + if block { + // Just sit here and be friendly, main thread. + shutdownBlock(ifce, sigChan, killChan) + } else { + // Even though we aren't blocking we still want to shutdown gracefully + go shutdownBlock(ifce, sigChan, killChan) + } + return nil } -func shutdownBlock(ifce *Interface) { - var sigChan = make(chan os.Signal) +func shutdownBlock(ifce *Interface, sigChan chan os.Signal, killChan chan CommandRequest) { + var cmd CommandRequest + var sig string + signal.Notify(sigChan, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT) - sig := <-sigChan + select { + case rawSig := <-sigChan: + sig = rawSig.String() + case cmd = <-killChan: + sig = "controlling app" + } + l.WithField("signal", sig).Info("Caught signal, shutting down") - //TODO: stop tun and udp routines, the lock on hostMap does effectively does that though + //TODO: stop tun and udp routines, the lock on hostMap effectively does that though //TODO: this is probably better as a function in ConnectionManager or HostMap directly ifce.hostMap.Lock() for _, h := range ifce.hostMap.Hosts { @@ -392,5 +433,8 @@ func shutdownBlock(ifce *Interface) { ifce.hostMap.Unlock() l.WithField("signal", sig).Info("Goodbye") - os.Exit(0) + select { + case cmd.Callback <- nil: + default: + } } diff --git a/tun_darwin.go b/tun_darwin.go index 0562559..aff7afc 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -1,12 +1,13 @@ +// +build !ios + package nebula import ( "fmt" + "github.com/songgao/water" "net" "os/exec" "strconv" - - "github.com/songgao/water" ) type Tun struct { @@ -20,8 +21,9 @@ type Tun struct { func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { if len(routes) > 0 { - return nil, fmt.Errorf("Route MTU not supported in Darwin") + return nil, fmt.Errorf("route MTU not supported in Darwin") } + // NOTE: You cannot set the deviceName under Darwin, so you must check tun.Device after calling .Activate() return &Tun{ Cidr: cidr, @@ -30,13 +32,17 @@ func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, }, nil } +func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { + return nil, fmt.Errorf("newTunFromFd not supported in Darwin") +} + func (c *Tun) Activate() error { var err error c.Interface, err = water.New(water.Config{ DeviceType: water.TUN, }) if err != nil { - return fmt.Errorf("Activate failed: %v", err) + return fmt.Errorf("activate failed: %v", err) } c.Device = c.Interface.Name() diff --git a/tun_freebsd.go b/tun_freebsd.go index 9570443..b294aa3 100644 --- a/tun_freebsd.go +++ b/tun_freebsd.go @@ -22,6 +22,10 @@ type Tun struct { io.ReadWriteCloser } +func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { + return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") +} + func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { if len(routes) > 0 { return nil, fmt.Errorf("Route MTU not supported in FreeBSD") diff --git a/tun_ios.go b/tun_ios.go new file mode 100644 index 0000000..df1078d --- /dev/null +++ b/tun_ios.go @@ -0,0 +1,105 @@ +// +build ios + +package nebula + +import ( + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "syscall" +) + +type Tun struct { + io.ReadWriteCloser + Device string + Cidr *net.IPNet +} + +func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { + return nil, fmt.Errorf("newTun not supported in iOS") +} + +func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { + if len(routes) > 0 { + return nil, fmt.Errorf("route MTU not supported in Darwin") + } + + file := os.NewFile(uintptr(deviceFd), "/dev/tun") + ifce = &Tun{ + Cidr: cidr, + ReadWriteCloser: &tunReadCloser{f: file}, + } + return +} + +func (c *Tun) Activate() error { + c.Device = "iOS" + return nil +} + +func (c *Tun) WriteRaw(b []byte) error { + _, err := c.Write(b) + return err +} + +// The following is hoisted up from water, we do this so we can inject our own fd on iOS +type tunReadCloser struct { + f io.ReadWriteCloser + + rMu sync.Mutex + rBuf []byte + + wMu sync.Mutex + wBuf []byte +} + +func (t *tunReadCloser) Read(to []byte) (int, error) { + t.rMu.Lock() + defer t.rMu.Unlock() + + if cap(t.rBuf) < len(to)+4 { + t.rBuf = make([]byte, len(to)+4) + } + t.rBuf = t.rBuf[:len(to)+4] + + n, err := t.f.Read(t.rBuf) + copy(to, t.rBuf[4:]) + return n - 4, err +} + +func (t *tunReadCloser) Write(from []byte) (int, error) { + + if len(from) == 0 { + return 0, syscall.EIO + } + + t.wMu.Lock() + defer t.wMu.Unlock() + + if cap(t.wBuf) < len(from)+4 { + t.wBuf = make([]byte, len(from)+4) + } + t.wBuf = t.wBuf[:len(from)+4] + + // Determine the IP Family for the NULL L2 Header + ipVer := from[0] >> 4 + if ipVer == 4 { + t.wBuf[3] = syscall.AF_INET + } else if ipVer == 6 { + t.wBuf[3] = syscall.AF_INET6 + } else { + return 0, errors.New("unable to determine IP version from packet") + } + + copy(t.wBuf[4:], from) + + n, err := t.f.Write(t.wBuf) + return n - 4, err +} + +func (t *tunReadCloser) Close() error { + return t.f.Close() +} diff --git a/tun_linux.go b/tun_linux.go index 6a9cb09..1cce919 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -75,6 +75,23 @@ type ifreqQLEN struct { pad [8]byte } +func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { + + file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") + + ifce = &Tun{ + ReadWriteCloser: file, + fd: int(file.Fd()), + Device: "tun0", + Cidr: cidr, + DefaultMTU: defaultMTU, + TXQueueLen: txQueueLen, + Routes: routes, + UnsafeRoutes: unsafeRoutes, + } + return +} + func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { diff --git a/tun_windows.go b/tun_windows.go index b15c1fc..8795507 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -18,9 +18,13 @@ type Tun struct { *water.Interface } +func newTunFromFd(deviceFd int, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { + return nil, fmt.Errorf("newTunFromFd not supported in Windows") +} + func newTun(deviceName string, cidr *net.IPNet, defaultMTU int, routes []route, unsafeRoutes []route, txQueueLen int) (ifce *Tun, err error) { if len(routes) > 0 { - return nil, fmt.Errorf("Route MTU not supported in Windows") + return nil, fmt.Errorf("route MTU not supported in Windows") } // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() diff --git a/udp_android.go b/udp_android.go new file mode 100644 index 0000000..7b6fea5 --- /dev/null +++ b/udp_android.go @@ -0,0 +1,36 @@ +package nebula + +import ( + "fmt" + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +func NewListenConfig(multi bool) net.ListenConfig { + return net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + if multi { + var controlErr error + err := c.Control(func(fd uintptr) { + if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) + return + } + }) + if err != nil { + return err + } + if controlErr != nil { + return controlErr + } + } + return nil + }, + } +} + +func (u *udpConn) Rebind() { + return +} diff --git a/udp_darwin.go b/udp_darwin.go index 61fda7a..4f7c0b2 100644 --- a/udp_darwin.go +++ b/udp_darwin.go @@ -32,3 +32,12 @@ func NewListenConfig(multi bool) net.ListenConfig { }, } } + +func (u *udpConn) Rebind() error { + file, err := u.File() + if err != nil { + return err + } + + return syscall.SetsockoptInt(int(file.Fd()), unix.IPPROTO_IP, unix.IP_BOUND_IF, 0) +} diff --git a/udp_freebsd.go b/udp_freebsd.go index 88ff618..88730be 100644 --- a/udp_freebsd.go +++ b/udp_freebsd.go @@ -32,3 +32,7 @@ func NewListenConfig(multi bool) net.ListenConfig { }, } } + +func (u *udpConn) Rebind() { + return +} diff --git a/udp_generic.go b/udp_generic.go index 0c988ab..c59f8c1 100644 --- a/udp_generic.go +++ b/udp_generic.go @@ -1,4 +1,4 @@ -// +build !linux +// +build !linux android // udp_generic implements the nebula UDP interface in pure Go stdlib. This // means it can be used on platforms like Darwin and Windows. diff --git a/udp_linux.go b/udp_linux.go index c810cc2..e79dac1 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -1,3 +1,5 @@ +// +build !android + package nebula import ( @@ -85,6 +87,10 @@ func NewListener(ip string, port int, multi bool) (*udpConn, error) { return &udpConn{sysFd: fd}, err } +func (u *udpConn) Rebind() { + return +} + func (u *udpConn) SetRecvBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) } diff --git a/udp_linux_32.go b/udp_linux_32.go index cb8233f..752b529 100644 --- a/udp_linux_32.go +++ b/udp_linux_32.go @@ -1,5 +1,6 @@ // +build linux // +build 386 amd64p32 arm mips mipsle +// +build !android package nebula diff --git a/udp_linux_64.go b/udp_linux_64.go index bd5e3dc..4d65d5a 100644 --- a/udp_linux_64.go +++ b/udp_linux_64.go @@ -1,5 +1,6 @@ // +build linux // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x +// +build !android package nebula diff --git a/udp_windows.go b/udp_windows.go index 6376503..463f79d 100644 --- a/udp_windows.go +++ b/udp_windows.go @@ -20,3 +20,7 @@ func NewListenConfig(multi bool) net.ListenConfig { }, } } + +func (u *udpConn) Rebind() { + return +}