diff --git a/cluster.go b/cluster.go index 4b9c48c..e66a0f8 100644 --- a/cluster.go +++ b/cluster.go @@ -184,6 +184,7 @@ func (c *cluster) members() (<-chan []node, <-chan error) { continue } nodes = append(nodes, node{ + Name: n.Name, Addr: n.Addr, nodeMeta: meta, }) @@ -200,6 +201,7 @@ func (c *cluster) members() (<-chan []node, <-chan error) { } type node struct { + Name string Addr net.IP nodeMeta } diff --git a/config.go b/config.go index 728b61b..5d0ade6 100644 --- a/config.go +++ b/config.go @@ -17,6 +17,7 @@ type config struct { WireguardPort int `id:"wireguard-port" desc:"port used for wireguard traffic (UDP); must be the same across cluster" default:"51820"` OverlayNet *network `id:"overlay-net" desc:"the network in which to allocate addresses for the overlay mesh network (CIDR format); smaller networks increase the chance of IP collision" default:"10.0.0.0/8"` Interface string `desc:"name of the wireguard interface to create and manage" default:"wgoverlay"` + NoEtcHosts bool `id:"no-etc-hosts" desc:"disable writing of entries to /etc/hosts"` LogLevel string `id:"log-level" desc:"set the verbosity (debug/info/warn/error)" default:"warn"` // for easier local testing diff --git a/etchosts/etchosts.go b/etchosts/etchosts.go new file mode 100644 index 0000000..1c33d42 --- /dev/null +++ b/etchosts/etchosts.go @@ -0,0 +1,164 @@ +package etchosts + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "os" + "path" + "strings" + + log "github.com/sirupsen/logrus" +) + +// DefaultBanner is the default magic comment used to identify entries managed by etchosts +const DefaultBanner = "# ! MANAGED AUTOMATICALLY !" + +// DefaultPath is the default path used to write hosts entries +const DefaultPath = "/etc/hosts" + +// EtcHosts contains the options used to write hosts entries. +// The zero value can be used to write to DefaultPath using DefaultBanner as a marker. +type EtcHosts struct { + // Banner is the magic comment used to identify entries managed by etchosts; if not set, will use DefaultBanner. + // It must start with "#" to mark it as a comment. + Banner string + // Path is the path to the /etc/hosts file; if not set, will use DefaultPath. + Path string + // Logger is an optional logrus.StdLogger interface, used for debugging. + Logger log.StdLogger +} + +// WriteEntries is used to write the hosts entries to EtcHosts.Path +// Each IP address with their (potentially multiple) hostnames are written to a line marked with EtcHosts.Banner, to +// avoid overwriting preexisting entries. +func (eh *EtcHosts) WriteEntries(ipsToNames map[string][]string) error { + hostsPath := eh.Path + if hostsPath == "" { + hostsPath = DefaultPath + } + + // We do not want to create the hosts file; if it's not there, we probably have the wrong path. + etcHosts, err := os.OpenFile(hostsPath, os.O_RDWR, 0644) + if err != nil { + return fmt.Errorf("could not open %s for reading: %s", hostsPath, err) + } + defer etcHosts.Close() + + // create tmpfile in same folder as + tmp, err := ioutil.TempFile(path.Dir(hostsPath), "etchosts") + if err != nil { + return fmt.Errorf("could not create tempfile") + } + + // remove tempfile; this might fail if we managed to move it, which is ok + defer func(file *os.File) { + file.Close() + if err := os.Remove(file.Name()); err != nil && !os.IsNotExist(err) { + if eh.Logger != nil { + eh.Logger.Printf("unexpected error trying to remove temp file %s: %s", file.Name(), err) + } + } + }(tmp) + + if err := eh.writeEntries(etcHosts, tmp, ipsToNames); err != nil { + return err + } + + if err = eh.movePreservePerms(tmp, etcHosts); err != nil { + return err + } + + return nil +} + +func (eh *EtcHosts) writeEntries(orig io.Reader, dest io.Writer, ipsToNames map[string][]string) error { + banner := eh.Banner + if banner == "" { + banner = DefaultBanner + } + + // go through file and update existing entries/prune nonexistent entries + scanner := bufio.NewScanner(orig) + for scanner.Scan() { + line := scanner.Text() + if strings.HasSuffix(strings.TrimSpace(line), strings.TrimSpace(banner)) { + tokens := strings.Fields(line) + if len(tokens) < 1 { + continue // remove empty managed line + } + ip := tokens[0] + if names, ok := ipsToNames[ip]; ok { + err := eh.writeEntryWithBanner(dest, banner, ip, names) + if err != nil { + return err + } + delete(ipsToNames, ip) // otherwise we'll append it again below + } + } else { + // keep original unmanaged line + fmt.Fprintf(dest, "%s\n", line) + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading hosts file: %s", err) + } + + // append remaining entries to file + for ip, names := range ipsToNames { + if err := eh.writeEntryWithBanner(dest, banner, ip, names); err != nil { + return err + } + } + + return nil +} + +func (eh *EtcHosts) writeEntryWithBanner(tmp io.Writer, banner, ip string, names []string) error { + if ip != "" && len(names) > 0 { + if eh.Logger != nil { + eh.Logger.Printf("writing entry for %s (%s)", ip, names) + } + if _, err := fmt.Fprintf(tmp, "%s\t%s\t%s\n", ip, strings.Join(names, " "), banner); err != nil { + return fmt.Errorf("error writing entry for %s: %s", ip, err) + } + } + return nil +} + +func (eh *EtcHosts) movePreservePerms(src, dst *os.File) error { + if err := src.Sync(); err != nil { + return fmt.Errorf("could not sync changes to %s: %s", src.Name(), err) + } + + etcHostsInfo, err := dst.Stat() + if err != nil { + return fmt.Errorf("could not stat %s: %s", dst.Name(), err) + } + + if err = os.Rename(src.Name(), dst.Name()); err != nil { + log.Infof("could not rename to %s; falling back to copy (%s)", dst.Name(), err) + + if _, err := src.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := dst.Seek(0, io.SeekStart); err != nil { + return err + } + if err := dst.Truncate(0); err != nil { + return err + } + _, err = io.Copy(dst, src) + return err + } + + // ensure we're not running with some umask that might break things + + if err := src.Chmod(etcHostsInfo.Mode()); err != nil { + return fmt.Errorf("could not chmod %s: %s", src.Name(), err) + } + // TODO: also keep user? + + return nil +} diff --git a/etchosts/etchosts_test.go b/etchosts/etchosts_test.go new file mode 100644 index 0000000..7bc438f --- /dev/null +++ b/etchosts/etchosts_test.go @@ -0,0 +1,120 @@ +package etchosts + +import ( + "bytes" + "fmt" + "io" + "strings" + "testing" + + log "github.com/sirupsen/logrus" +) + +func TestEtcHosts_writeEntryWithBanner(t *testing.T) { + type args struct { + banner string + ip string + names []string + } + + eh := &EtcHosts{} + + tests := []struct { + name string + args args + wantTmp string + wantErr bool + }{ + {"do not write empty ip", args{DefaultBanner, "", []string{"somename", "someothername"}}, "", false}, + {"do not write empty names", args{DefaultBanner, "1.2.3.4", []string{}}, "", false}, + {"complete entry", args{DefaultBanner, "1.2.3.4", []string{"somename", "someothername"}}, fmt.Sprintf("1.2.3.4\tsomename someothername\t%s\n", DefaultBanner), false}, + {"custom banner", args{"# somebanner", "1.2.3.4", []string{"somename", "someothername"}}, fmt.Sprintf("1.2.3.4\tsomename someothername\t%s\n", "# somebanner"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmp := &bytes.Buffer{} + if err := eh.writeEntryWithBanner(tmp, tt.args.banner, tt.args.ip, tt.args.names); (err != nil) != tt.wantErr { + t.Errorf("writeEntryWithBanner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotTmp := tmp.String(); gotTmp != tt.wantTmp { + t.Errorf("writeEntryWithBanner() got:\n%#v, want\n%#v", gotTmp, tt.wantTmp) + } + }) + } +} + +func TestEtcHosts_writeEntries(t *testing.T) { + type fields struct { + Banner string + Path string + Logger log.StdLogger + } + type args struct { + orig io.Reader + ipsToNames map[string][]string + } + tests := []struct { + name string + fields fields + args args + wantDest string + wantErr bool + }{ + { + "simple empty write", + fields{}, + args{strings.NewReader(""), map[string][]string{"1.2.3.4": []string{"foo", "bar"}}}, + "1.2.3.4\tfoo bar\t# ! MANAGED AUTOMATICALLY !\n", + false, + }, + { + "do not touch comments", + fields{}, + args{strings.NewReader("# some comment\n"), map[string][]string{"1.2.3.4": []string{"foo", "bar"}}}, + "# some comment\n1.2.3.4\tfoo bar\t# ! MANAGED AUTOMATICALLY !\n", + false, + }, + { + "do not touch existing entries", + fields{}, + args{strings.NewReader("4.3.2.1 hostname1 hostname2\n"), map[string][]string{"1.2.3.4": []string{"foo", "bar"}}}, + "4.3.2.1 hostname1 hostname2\n1.2.3.4\tfoo bar\t# ! MANAGED AUTOMATICALLY !\n", + false, + }, + { + "remove managed entry not in map", + fields{}, + args{strings.NewReader("4.3.2.1 fooz baarz # ! MANAGED AUTOMATICALLY !\n"), map[string][]string{"1.2.3.4": []string{"foo", "bar"}}}, + "1.2.3.4\tfoo bar\t# ! MANAGED AUTOMATICALLY !\n", + false, + }, + { + "custom banner", + fields{Banner: "# somebanner"}, + args{strings.NewReader(""), map[string][]string{"1.2.3.4": []string{"foo", "bar"}}}, + "1.2.3.4\tfoo bar\t# somebanner\n", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eh := &EtcHosts{ + Banner: tt.fields.Banner, + Path: tt.fields.Path, + Logger: tt.fields.Logger, + } + dest := &bytes.Buffer{} + if err := eh.writeEntries(tt.args.orig, dest, tt.args.ipsToNames); (err != nil) != tt.wantErr { + t.Errorf("EtcHosts.writeEntries() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotDest := dest.String(); gotDest != tt.wantDest { + t.Errorf("EtcHosts.writeEntries() = '%#v', want '%#v'", gotDest, tt.wantDest) + } + }) + } +} + +// 1.2.3.4 foo bar # ! MANAGED AUTOMATICALLY ! +// 1.2.3.4 foo bar # ! MANAGED AUTOMATICALLY ! diff --git a/main.go b/main.go index 30e38d9..318e009 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,8 @@ import ( "syscall" "github.com/sirupsen/logrus" + + "github.com/costela/wesher/etchosts" ) func main() { @@ -52,12 +54,31 @@ func main() { if err := wg.upInterface(); err != nil { logrus.Errorf("could not up interface: %s", err) } + if !config.NoEtcHosts { + if err := writeToEtcHosts(nodes); err != nil { + logrus.Errorf("could not write hosts entries: %s", err) + } + } case errs := <-errc: logrus.Errorf("could not receive node info: %s", errs) case <-incomingSigs: logrus.Info("terminating...") cluster.leave() + if err := writeToEtcHosts(nil); err != nil { + logrus.Errorf("could not remove stale hosts entries: %s", err) + } os.Exit(0) } } } + +func writeToEtcHosts(nodes []node) error { + hosts := make(map[string][]string, len(nodes)) + for _, n := range nodes { + hosts[n.OverlayAddr.String()] = []string{n.Name} + } + hostsFile := &etchosts.EtcHosts{ + Logger: logrus.StandardLogger(), + } + return hostsFile.WriteEntries(hosts) +}