diff --git a/config/module/registry.go b/config/module/registry.go index 10209c4bf..da67c5ab9 100644 --- a/config/module/registry.go +++ b/config/module/registry.go @@ -44,8 +44,8 @@ func (e errModuleNotFound) Error() string { return `module "` + string(e) + `" not found` } -func (s *Storage) discoverRegURL(module *regsrc.Module) *url.URL { - regURL := s.Services.DiscoverServiceURL(svchost.Hostname(module.RawHost.Normalized()), serviceID) +func (s *Storage) discoverRegURL(host svchost.Hostname) *url.URL { + regURL := s.Services.DiscoverServiceURL(host, serviceID) if regURL == nil { return nil } @@ -75,13 +75,14 @@ func (s *Storage) addRequestCreds(host svchost.Hostname, req *http.Request) { // Lookup module versions in the registry. func (s *Storage) lookupModuleVersions(module *regsrc.Module) (*response.ModuleVersions, error) { - if module.RawHost == nil { - module.RawHost = regsrc.NewFriendlyHost(defaultRegistry) + host, err := module.SvcHost() + if err != nil { + return nil, err } - service := s.discoverRegURL(module) + service := s.discoverRegURL(host) if service == nil { - return nil, fmt.Errorf("host %s does not provide Terraform modules", module.RawHost.Display()) + return nil, fmt.Errorf("host %s does not provide Terraform modules", host) } p, err := url.Parse(path.Join(module.Module(), "versions")) @@ -98,7 +99,7 @@ func (s *Storage) lookupModuleVersions(module *regsrc.Module) (*response.ModuleV return nil, err } - s.addRequestCreds(svchost.Hostname(module.RawHost.Normalized()), req) + s.addRequestCreds(host, req) req.Header.Set(xTerraformVersion, tfVersion) resp, err := httpClient.Do(req) @@ -134,17 +135,17 @@ func (s *Storage) lookupModuleVersions(module *regsrc.Module) (*response.ModuleV // lookup the location of a specific module version in the registry func (s *Storage) lookupModuleLocation(module *regsrc.Module, version string) (string, error) { - if module.RawHost == nil { - module.RawHost = regsrc.NewFriendlyHost(defaultRegistry) + host, err := module.SvcHost() + if err != nil { + return "", err } - service := s.discoverRegURL(module) + service := s.discoverRegURL(host) if service == nil { - return "", fmt.Errorf("host %s does not provide Terraform modules", module.RawHost.Display()) + return "", fmt.Errorf("host %s does not provide Terraform modules", host.ForDisplay()) } var p *url.URL - var err error if version == "" { p, err = url.Parse(path.Join(module.Module(), "download")) } else { @@ -162,7 +163,7 @@ func (s *Storage) lookupModuleLocation(module *regsrc.Module, version string) (s return "", err } - s.addRequestCreds(svchost.Hostname(module.RawHost.Normalized()), req) + s.addRequestCreds(host, req) req.Header.Set(xTerraformVersion, tfVersion) resp, err := httpClient.Do(req) diff --git a/config/module/registry_test.go b/config/module/registry_test.go index 54c8f818e..dab7444c2 100644 --- a/config/module/registry_test.go +++ b/config/module/registry_test.go @@ -2,6 +2,7 @@ package module import ( "os" + "strings" "testing" version "github.com/hashicorp/go-version" @@ -93,6 +94,7 @@ func TestRegistryAuth(t *testing.T) { } } + func TestLookupModuleLocationRelative(t *testing.T) { server := mockRegistry() defer server.Close() @@ -117,6 +119,7 @@ func TestLookupModuleLocationRelative(t *testing.T) { } } + func TestAccLookupModuleVersions(t *testing.T) { if os.Getenv("TF_ACC") == "" { t.Skip() @@ -163,3 +166,29 @@ func TestAccLookupModuleVersions(t *testing.T) { } } } + +// the error should reference the config source exatly, not the discovered path. +func TestLookupLookupModuleError(t *testing.T) { + server := mockRegistry() + defer server.Close() + + regDisco := testDisco(server) + storage := testStorage(t, regDisco) + + // this should not be found in teh registry + src := "bad/local/path" + mod, err := regsrc.ParseModuleSource(src) + if err != nil { + t.Fatal(err) + } + + _, err = storage.lookupModuleLocation(mod, "0.2.0") + if err == nil { + t.Fatal("expected error") + } + + // check for the exact quoted string to ensure we didn't prepend a hostname. + if !strings.Contains(err.Error(), `"bad/local/path"`) { + t.Fatal("error should not include the hostname. got:", err) + } +} diff --git a/config/module/storage.go b/config/module/storage.go index 05065b3c6..121719765 100644 --- a/config/module/storage.go +++ b/config/module/storage.go @@ -343,7 +343,9 @@ func (s Storage) findRegistryModule(mSource, constraint string) (moduleRecord, e return rec, err } - s.output(fmt.Sprintf(" Found version %s of %s on %s", rec.Version, mod.Module(), mod.RawHost.Display())) + // we've already validated this by now + host, _ := mod.SvcHost() + s.output(fmt.Sprintf(" Found version %s of %s on %s", rec.Version, mod.Module(), host.ForDisplay())) } return rec, nil diff --git a/registry/regsrc/friendly_host.go b/registry/regsrc/friendly_host.go index 648e2a193..14b4dce9c 100644 --- a/registry/regsrc/friendly_host.go +++ b/registry/regsrc/friendly_host.go @@ -101,20 +101,16 @@ func (h *FriendlyHost) Valid() bool { // Display returns the host formatted for display to the user in CLI or web // output. func (h *FriendlyHost) Display() string { - hostname, err := svchost.ForComparison(h.Raw) - if err != nil { - return InvalidHostString - } - return hostname.ForDisplay() + return svchost.ForDisplay(h.Raw) } // Normalized returns the host formatted for internal reference or comparison. func (h *FriendlyHost) Normalized() string { - hostname, err := svchost.ForComparison(h.Raw) + host, err := svchost.ForComparison(h.Raw) if err != nil { return InvalidHostString } - return hostname.String() + return string(host) } // String returns the host formatted as the user originally typed it assuming it @@ -124,19 +120,21 @@ func (h *FriendlyHost) String() string { } // Equal compares the FriendlyHost against another instance taking normalization -// into account. +// into account. Invalid hosts cannot be compared and will always return false. func (h *FriendlyHost) Equal(other *FriendlyHost) bool { if other == nil { return false } - return h.Normalized() == other.Normalized() -} -func containsPuny(host string) bool { - for _, lbl := range strings.Split(host, ".") { - if strings.HasPrefix(strings.ToLower(lbl), "xn--") { - return true - } + otherHost, err := svchost.ForComparison(other.Raw) + if err != nil { + return false } - return false + + host, err := svchost.ForComparison(h.Raw) + if err != nil { + return false + } + + return otherHost == host } diff --git a/registry/regsrc/friendly_host_test.go b/registry/regsrc/friendly_host_test.go index 740395bf6..37589685d 100644 --- a/registry/regsrc/friendly_host_test.go +++ b/registry/regsrc/friendly_host_test.go @@ -59,7 +59,7 @@ func TestFriendlyHost(t *testing.T) { source: "xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io", wantHost: "xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io", wantDisplay: "ʎɹʇsıƃǝɹ.ɯɹoɟɐɹɹǝʇ.io", - wantNorm: "xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io", + wantNorm: InvalidHostString, wantValid: false, }, { @@ -95,38 +95,47 @@ func TestFriendlyHost(t *testing.T) { if v := gotHost.String(); v != tt.wantHost { t.Fatalf("String() = %v, want %v", v, tt.wantHost) } - if v := gotHost.Valid(); v != tt.wantValid { - t.Fatalf("Valid() = %v, want %v", v, tt.wantValid) - } - - // FIXME: should we allow punycode as input - if !tt.wantValid { - return - } - if v := gotHost.Display(); v != tt.wantDisplay { t.Fatalf("Display() = %v, want %v", v, tt.wantDisplay) } if v := gotHost.Normalized(); v != tt.wantNorm { t.Fatalf("Normalized() = %v, want %v", v, tt.wantNorm) } + if v := gotHost.Valid(); v != tt.wantValid { + t.Fatalf("Valid() = %v, want %v", v, tt.wantValid) + } if gotRest != strings.TrimLeft(sfx, "/") { t.Fatalf("ParseFriendlyHost() rest = %v, want %v", gotRest, strings.TrimLeft(sfx, "/")) } // Also verify that host compares equal with all the variants. - if !gotHost.Equal(&FriendlyHost{Raw: tt.wantDisplay}) { - t.Fatalf("Equal() should be true for %s and %t", tt.wantHost, tt.wantValid) + if gotHost.Valid() && !gotHost.Equal(&FriendlyHost{Raw: tt.wantDisplay}) { + t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantDisplay) } - - // FIXME: Do we need to accept normalized input? - //if !gotHost.Equal(&FriendlyHost{Raw: tt.wantNorm}) { - // fmt.Println(gotHost.Normalized(), tt.wantNorm) - // fmt.Println(" ", (&FriendlyHost{Raw: tt.wantNorm}).Normalized()) - // t.Fatalf("Equal() should be true for %s and %s", tt.wantHost, tt.wantNorm) - //} - }) } } } + +func TestInvalidHostEquals(t *testing.T) { + invalid := NewFriendlyHost("NOT_A_HOST_NAME") + valid := PublicRegistryHost + + // invalid hosts are not comparable + if invalid.Equal(invalid) { + t.Fatal("invalid host names are not comparable") + } + + if valid.Equal(invalid) { + t.Fatalf("%q is not equal to %q", valid, invalid) + } + + puny := NewFriendlyHost("xn--s-fka0wmm0zea7g8b.xn--o-8ta85a3b1dwcda1k.io") + display := NewFriendlyHost("ʎɹʇsıƃǝɹ.ɯɹoɟɐɹɹǝʇ.io") + + // The pre-normalized host is not a valid source, and therefore not + // comparable to the display version. + if display.Equal(puny) { + t.Fatalf("invalid host %q should not be comparable", puny) + } +} diff --git a/registry/regsrc/module.go b/registry/regsrc/module.go index b6671c8a4..325706ec2 100644 --- a/registry/regsrc/module.go +++ b/registry/regsrc/module.go @@ -5,6 +5,8 @@ import ( "fmt" "regexp" "strings" + + "github.com/hashicorp/terraform/svchost" ) var ( @@ -33,8 +35,16 @@ var ( fmt.Sprintf("^(%s)\\/(%s)\\/(%s)(?:\\/\\/(.*))?$", nameSubRe, nameSubRe, providerSubRe)) - // disallowed is a set of hostnames that have special usage in modules and - // can't be registry hosts + // NameRe is a regular expression defining the format allowed for namespace + // or name fields in module registry implementations. + NameRe = regexp.MustCompile("^" + nameSubRe + "$") + + // ProviderRe is a regular expression defining the format allowed for + // provider fields in module registry implementations. + ProviderRe = regexp.MustCompile("^" + providerSubRe + "$") + + // these hostnames are not allowed as registry sources, because they are + // already special case module sources in terraform. disallowed = map[string]bool{ "github.com": true, "bitbucket.org": true, @@ -59,7 +69,7 @@ type Module struct { // NewModule construct a new module source from separate parts. Pass empty // string if host or submodule are not needed. -func NewModule(host, namespace, name, provider, submodule string) *Module { +func NewModule(host, namespace, name, provider, submodule string) (*Module, error) { m := &Module{ RawNamespace: namespace, RawName: name, @@ -67,9 +77,16 @@ func NewModule(host, namespace, name, provider, submodule string) *Module { RawSubmodule: submodule, } if host != "" { - m.RawHost = NewFriendlyHost(host) + h := NewFriendlyHost(host) + if h != nil { + fmt.Println("HOST:", h) + if !h.Valid() || disallowed[h.Display()] { + return nil, ErrInvalidModuleSource + } + } + m.RawHost = h } - return m + return m, nil } // ParseModuleSource attempts to parse source as a Terraform registry module @@ -132,12 +149,6 @@ func (m *Module) String() string { return m.formatWithPrefix(hostPrefix, true) } -// Module returns just the registry ID of the module, without a hostname or -// suffix. -func (m *Module) Module() string { - return fmt.Sprintf("%s/%s/%s", m.RawNamespace, m.RawName, m.RawProvider) -} - // Equal compares the module source against another instance taking // normalization into account. func (m *Module) Equal(other *Module) bool { @@ -175,3 +186,20 @@ func (m *Module) formatWithPrefix(hostPrefix string, preserveCase bool) string { } return str } + +// Module returns just the registry ID of the module, without a hostname or +// suffix. +func (m *Module) Module() string { + return fmt.Sprintf("%s/%s/%s", m.RawNamespace, m.RawName, m.RawProvider) +} + +// SvcHost returns the svchost.Hostname for this module. Since FriendlyHost may +// contain an invalid hostname, this also returns an error indicating if it +// could be converted to a svchost.Hostname. If no host is specified, the +// default PublicRegistryHost is returned. +func (m *Module) SvcHost() (svchost.Hostname, error) { + if m.RawHost == nil { + return svchost.ForComparison(PublicRegistryHost.Raw) + } + return svchost.ForComparison(m.RawHost.Raw) +}