Clean up the DatabaseCreate call.

This commit is contained in:
Sean Chittenden 2016-11-07 10:32:51 -08:00
parent db5d7b0438
commit e2448473cb
No known key found for this signature in database
GPG Key ID: 4EBC9DC16C2E5E16
3 changed files with 89 additions and 139 deletions

View File

@ -0,0 +1,24 @@
package postgresql
import (
"fmt"
"strings"
)
// pqQuoteLiteral returns a string literal safe for inclusion in a PostgreSQL
// query as a parameter. The resulting string still needs to be wrapped in
// single quotes in SQL (i.e. fmt.Sprintf(`'%s'`, pqQuoteLiteral("str"))). See
// quote_literal_internal() in postgresql/backend/utils/adt/quote.c:77.
func pqQuoteLiteral(in string) string {
in = strings.Replace(in, `\`, `\\`, -1)
in = strings.Replace(in, `'`, `''`, -1)
return in
}
func validateConnLimit(v interface{}, key string) (warnings []string, errors []error) {
value := v.(int)
if value < -1 {
errors = append(errors, fmt.Errorf("%d can not be less than -1", key))
}
return
}

View File

@ -1,6 +1,7 @@
package postgresql
import (
"bytes"
"database/sql"
"errors"
"fmt"
@ -112,134 +113,70 @@ func resourcePostgreSQLDatabaseCreate(d *schema.ResourceData, meta interface{})
}
defer conn.Close()
stringOpts := []struct {
hclKey string
sqlKey string
}{
{dbOwnerAttr, "OWNER"},
{dbTemplateAttr, "TEMPLATE"},
{dbEncodingAttr, "ENCODING"},
{dbCollationAttr, "LC_COLLATE"},
{dbCTypeAttr, "LC_CTYPE"},
{dbTablespaceAttr, "TABLESPACE"},
}
intOpts := []struct {
hclKey string
sqlKey string
}{
{dbConnLimitAttr, "CONNECTION LIMIT"},
}
boolOpts := []struct {
hclKey string
sqlKey string
}{
{dbAllowConnsAttr, "ALLOW_CONNECTIONS"},
{dbIsTemplateAttr, "IS_TEMPLATE"},
}
createOpts := make([]string, 0, len(stringOpts)+len(intOpts)+len(boolOpts))
for _, opt := range stringOpts {
v, ok := d.GetOk(opt.hclKey)
var val string
if !ok {
switch {
case opt.hclKey == dbOwnerAttr && v.(string) == "":
// No owner specified in the config, default to using
// the connecting username.
val = c.username
case strings.ToUpper(v.(string)) == "DEFAULT" &&
(opt.hclKey == dbTemplateAttr ||
opt.hclKey == dbEncodingAttr ||
opt.hclKey == dbCollationAttr ||
opt.hclKey == dbCTypeAttr):
// Use the defaults from the template database
// as opposed to best practices.
fallthrough
default:
continue
}
}
val = v.(string)
switch {
case opt.hclKey == dbOwnerAttr && (val == "" || strings.ToUpper(val) == "DEFAULT"):
// Owner was blank/DEFAULT, default to using the connecting username.
val = c.username
d.Set(dbOwnerAttr, val)
case opt.hclKey == dbTablespaceAttr && (val == "" || strings.ToUpper(val) == "DEFAULT"):
val = "pg_default"
d.Set(dbTablespaceAttr, val)
case opt.hclKey == dbTemplateAttr:
switch {
case val == "":
val = "template0"
d.Set(dbTemplateAttr, val)
case strings.ToUpper(val) == "DEFAULT":
val = ""
default:
d.Set(dbTemplateAttr, val)
}
case opt.hclKey == dbEncodingAttr:
switch {
case val == "":
val = "UTF8"
d.Set(dbEncodingAttr, val)
case strings.ToUpper(val) == "DEFAULT":
val = ""
default:
d.Set(dbEncodingAttr, val)
}
case opt.hclKey == dbCollationAttr:
switch {
case val == "":
val = "C"
d.Set(dbCollationAttr, val)
case strings.ToUpper(val) == "DEFAULT":
val = ""
default:
d.Set(dbCollationAttr, val)
}
case opt.hclKey == dbCTypeAttr:
switch {
case val == "":
val = "C"
d.Set(dbCTypeAttr, val)
case strings.ToUpper(val) == "DEFAULT":
val = ""
default:
d.Set(dbCTypeAttr, val)
}
}
if val != "" {
createOpts = append(createOpts, fmt.Sprintf("%s=%s", opt.sqlKey, pq.QuoteIdentifier(val)))
}
}
for _, opt := range intOpts {
val := d.Get(opt.hclKey).(int)
createOpts = append(createOpts, fmt.Sprintf("%s=%d", opt.sqlKey, val))
}
for _, opt := range boolOpts {
val := d.Get(opt.hclKey).(bool)
valStr := "FALSE"
if val {
valStr = "TRUE"
}
createOpts = append(createOpts, fmt.Sprintf("%s=%s", opt.sqlKey, valStr))
}
dbName := d.Get(dbNameAttr).(string)
createStr := strings.Join(createOpts, " ")
if len(createOpts) > 0 {
createStr = " WITH " + createStr
b := bytes.NewBufferString("CREATE DATABASE ")
fmt.Fprint(b, pq.QuoteIdentifier(dbName))
// Handle each option individually and stream results into the query
// buffer.
switch v, ok := d.GetOk(dbOwnerAttr); {
case ok:
fmt.Fprint(b, " OWNER ", pq.QuoteIdentifier(v.(string)))
default:
// No owner specified in the config, default to using
// the connecting username.
fmt.Fprint(b, " OWNER ", pq.QuoteIdentifier(c.username))
}
query := fmt.Sprintf("CREATE DATABASE %s%s", pq.QuoteIdentifier(dbName), createStr)
switch v, ok := d.GetOk(dbTemplateAttr); {
case ok:
fmt.Fprint(b, " TEMPLATE ", pq.QuoteIdentifier(v.(string)))
case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT":
fmt.Fprint(b, " TEMPLATE template0")
}
switch v, ok := d.GetOk(dbEncodingAttr); {
case ok:
fmt.Fprint(b, " ENCODING ", pq.QuoteIdentifier(v.(string)))
case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT":
fmt.Fprint(b, ` ENCODING "UTF8"`)
}
switch v, ok := d.GetOk(dbCollationAttr); {
case ok:
fmt.Fprint(b, " LC_COLLATE ", pq.QuoteIdentifier(v.(string)))
case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT":
fmt.Fprint(b, ` LC_COLLATE "C"`)
}
switch v, ok := d.GetOk(dbCTypeAttr); {
case ok:
fmt.Fprint(b, " LC_CTYPE ", pq.QuoteIdentifier(v.(string)))
case v.(string) == "", strings.ToUpper(v.(string)) != "DEFAULT":
fmt.Fprint(b, ` LC_CTYPE "C"`)
}
if v, ok := d.GetOk(dbTablespaceAttr); ok {
fmt.Fprint(b, " TABLESPACE ", pq.QuoteIdentifier(v.(string)))
}
{
val := d.Get(dbAllowConnsAttr).(bool)
fmt.Fprint(b, " ALLOW_CONNECTIONS ", val)
}
{
val := d.Get(dbConnLimitAttr).(int)
fmt.Fprint(b, " CONNECTION LIMIT ", val)
}
{
val := d.Get(dbIsTemplateAttr).(bool)
fmt.Fprint(b, " IS_TEMPLATE ", val)
}
query := b.String()
_, err = conn.Query(query)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("Error creating database %s: {{err}}", dbName), err)
@ -296,7 +233,7 @@ func resourcePostgreSQLDatabaseRead(d *schema.ResourceData, meta interface{}) er
err = conn.QueryRow("SELECT d.datname, pg_catalog.pg_get_userbyid(d.datdba) from pg_database d WHERE datname=$1", dbId).Scan(&dbName, &ownerName)
switch {
case err == sql.ErrNoRows:
log.Printf("[WARN] PostgreSQL database (%s) not found", d.Id())
log.Printf("[WARN] PostgreSQL database (%s) not found", dbId)
d.SetId("")
return nil
case err != nil:
@ -313,7 +250,7 @@ func resourcePostgreSQLDatabaseRead(d *schema.ResourceData, meta interface{}) er
)
switch {
case err == sql.ErrNoRows:
log.Printf("[WARN] PostgreSQL database (%s) not found", d.Id())
log.Printf("[WARN] PostgreSQL database (%s) not found", dbId)
d.SetId("")
return nil
case err != nil:

View File

@ -1,11 +0,0 @@
package postgresql
import "fmt"
func validateConnLimit(v interface{}, key string) (warnings []string, errors []error) {
value := v.(int)
if value < -1 {
errors = append(errors, fmt.Errorf("%d can not be less than -1", key))
}
return
}