diff --git a/addrs/provider.go b/addrs/provider.go index 0514ed6c6..01284300f 100644 --- a/addrs/provider.go +++ b/addrs/provider.go @@ -126,6 +126,25 @@ func (pt Provider) LessThan(other Provider) bool { } } +// IsLegacy returns true if the provider is a legacy-style provider +func (pt Provider) IsLegacy() bool { + if pt.IsZero() { + panic("called IsLegacy() on zero-value addrs.Provider") + } + + return pt.Hostname == DefaultRegistryHost && pt.Namespace == LegacyProviderNamespace + +} + +// IsDefault returns true if the provider is a default hashicorp provider +func (pt Provider) IsDefault() bool { + if pt.IsZero() { + panic("called IsDefault() on zero-value addrs.Provider") + } + + return pt.Hostname == DefaultRegistryHost && pt.Namespace == "hashicorp" +} + // Equals returns true if the receiver and other provider have the same attributes. func (pt Provider) Equals(other Provider) bool { return pt == other diff --git a/addrs/provider_test.go b/addrs/provider_test.go index cf97bb0c3..fe5344a18 100644 --- a/addrs/provider_test.go +++ b/addrs/provider_test.go @@ -7,6 +7,84 @@ import ( svchost "github.com/hashicorp/terraform-svchost" ) +func TestProviderIsDefault(t *testing.T) { + tests := []struct { + Input Provider + Want bool + }{ + { + Provider{ + Type: "test", + Hostname: DefaultRegistryHost, + Namespace: "hashicorp", + }, + true, + }, + { + Provider{ + Type: "test", + Hostname: "registry.terraform.com", + Namespace: "hashicorp", + }, + false, + }, + { + Provider{ + Type: "test", + Hostname: DefaultRegistryHost, + Namespace: "othercorp", + }, + false, + }, + } + + for _, test := range tests { + got := test.Input.IsDefault() + if got != test.Want { + t.Errorf("wrong result for %s\n", test.Input.String()) + } + } +} + +func TestProviderIsLegacy(t *testing.T) { + tests := []struct { + Input Provider + Want bool + }{ + { + Provider{ + Type: "test", + Hostname: DefaultRegistryHost, + Namespace: LegacyProviderNamespace, + }, + true, + }, + { + Provider{ + Type: "test", + Hostname: "registry.terraform.com", + Namespace: LegacyProviderNamespace, + }, + false, + }, + { + Provider{ + Type: "test", + Hostname: DefaultRegistryHost, + Namespace: "hashicorp", + }, + false, + }, + } + + for _, test := range tests { + got := test.Input.IsLegacy() + if got != test.Want { + t.Errorf("wrong result for %s\n", test.Input.String()) + } + } +} + func TestParseProviderSourceStr(t *testing.T) { tests := map[string]struct { Want Provider diff --git a/configs/config_test.go b/configs/config_test.go index ec85208cd..61149920f 100644 --- a/configs/config_test.go +++ b/configs/config_test.go @@ -9,12 +9,18 @@ import ( ) func TestConfigProviderTypes(t *testing.T) { + // nil cfg should return an empty map + got := NewEmptyConfig().ProviderTypes() + if len(got) != 0 { + t.Fatal("expected empty result from empty config") + } + cfg, diags := testModuleConfigFromFile("testdata/valid-files/providers-explicit-implied.tf") if diags.HasErrors() { t.Fatal(diags.Error()) } - got := cfg.ProviderTypes() + got = cfg.ProviderTypes() want := []addrs.Provider{ addrs.NewLegacyProvider("aws"), addrs.NewLegacyProvider("null"), diff --git a/configs/module_merge_test.go b/configs/module_merge_test.go index d4cb3f5a7..991bb9fad 100644 --- a/configs/module_merge_test.go +++ b/configs/module_merge_test.go @@ -297,6 +297,27 @@ func TestMergeProviderVersionConstraints(t *testing.T) { }, }, }, + "merge with source constraint": { + map[string]ProviderRequirements{ + "random": ProviderRequirements{ + Type: addrs.Provider{Type: "random"}, + VersionConstraints: []VersionConstraint{vc1}, + }, + }, + []*RequiredProvider{ + &RequiredProvider{ + Name: "random", + Source: Source{SourceStr: "hashicorp/random"}, + Requirement: vc2, + }, + }, + map[string]ProviderRequirements{ + "random": ProviderRequirements{ + Type: addrs.NewDefaultProvider("random"), + VersionConstraints: []VersionConstraint{vc2}, + }, + }, + }, } for name, test := range tests { diff --git a/configs/module_test.go b/configs/module_test.go index aab08a4a2..81d05ea35 100644 --- a/configs/module_test.go +++ b/configs/module_test.go @@ -27,6 +27,12 @@ func TestNewModule_provider_local_name(t *testing.T) { if localName != "foo-test" { t.Fatal("provider local name not found") } + + // if there is not a local name for a provider, it should return the type name + localName = mod.LocalNameForProvider(addrs.NewLegacyProvider("nonexist")) + if localName != "nonexist" { + t.Error("wrong local name returned for a non-local provider") + } } // This test validates the provider FQNs set in each Resource