diff --git a/dns/dns.go b/dns/dns.go new file mode 100644 index 0000000..f97cb24 --- /dev/null +++ b/dns/dns.go @@ -0,0 +1,115 @@ +package dns + +import ( + "strconv" + + "git.xbudex.com/buddy/update-dns/utils" + "github.com/dnsimple/dnsimple-go/dnsimple" + "github.com/pkg/errors" +) + +// Client holds metadata for the dns api client +type Client struct { + client *dnsimple.Client + accountID string +} + +// Record is the type used for a Zone record +type Record dnsimple.ZoneRecord + +// New returns an implementation of a DNS interface +func New(oauthToken string) *Client { + credentials := dnsimple.NewOauthTokenCredentials(oauthToken) + client := dnsimple.NewClient(credentials) + return &Client{ + client: client, + } +} + +// GetRecords returns all record for a host +func (d *Client) GetRecords(host, kind string) ([]Record, error) { + accountID, err := d.getAccountID() + if err != nil { + return nil, errors.Wrap(err, "unable to get account id") + } + + subdomain, domain, err := utils.SplitHost(host) + if err != nil { + return nil, errors.Wrap(err, "unable to parse host") + } + + records, err := d.client.Zones.ListRecords(accountID, domain, &dnsimple.ZoneRecordListOptions{}) + if err != nil { + return nil, errors.Wrap(err, "could not get list of records") + } + + ret := []Record{} + for _, record := range records.Data { + if record.ZoneID == domain && record.Name == subdomain && record.Type == kind { + ret = append(ret, Record(record)) + } + } + return ret, nil +} + +// CreateRecord adds a record to DNS +func (d *Client) CreateRecord(host string, record Record) error { + accountID, err := d.getAccountID() + if err != nil { + return errors.Wrap(err, "unable to get account id") + } + + subdomain, domain, err := utils.SplitHost(host) + if err != nil { + return errors.Wrap(err, "unable to parse host") + } + + newRecord := dnsimple.ZoneRecord(record) + newRecord.Name = subdomain + if _, err := d.client.Zones.CreateRecord(accountID, domain, newRecord); err != nil { + return errors.Wrap(err, "could not create record") + } + return nil +} + +// UpdateRecord returns all record for a host +func (d *Client) UpdateRecord(record Record) error { + accountID, err := d.getAccountID() + if err != nil { + return errors.Wrap(err, "unable to get account id") + } + + if _, err := d.client.Zones.UpdateRecord(accountID, record.ZoneID, record.ID, dnsimple.ZoneRecord(record)); err != nil { + return errors.Wrap(err, "could not update record") + } + return nil +} + +// DeleteRecord returns all record for a host +func (d *Client) DeleteRecord(record Record) error { + accountID, err := d.getAccountID() + if err != nil { + return errors.Wrap(err, "unable to get account id") + } + + if _, err := d.client.Zones.DeleteRecord(accountID, record.ZoneID, record.ID); err != nil { + return errors.Wrap(err, "could not delete record") + } + return nil +} + +func (d *Client) getAccountID() (string, error) { + if d.accountID != "" { + return d.accountID, nil + } + + whoamiResponse, err := d.client.Identity.Whoami() + if err != nil { + return "", errors.Wrap(err, "could not get account information from dnsimple") + } + if whoamiResponse.Data.Account == nil { + return "", errors.New("could not get account information") + } + d.accountID = strconv.Itoa(whoamiResponse.Data.Account.ID) + return d.accountID, nil +} diff --git a/main.go b/main.go index 740c87c..cdeee23 100644 --- a/main.go +++ b/main.go @@ -5,11 +5,10 @@ import ( "fmt" "net/http" "os" - "strconv" - "strings" "sync" - "github.com/dnsimple/dnsimple-go/dnsimple" + "git.xbudex.com/buddy/update-dns/dns" + "github.com/pkg/errors" "github.com/urfave/cli" ) @@ -18,6 +17,7 @@ func main() { app := cli.NewApp() app.Name = "update-dns" app.Usage = "update dnsimple with public ip address" + app.Version = "0.2.0-dev" app.Action = action app.Flags = []cli.Flag{ @@ -64,22 +64,19 @@ func action(context *cli.Context) error { ip = result.IP }() - var dnSimpleAccountID string - var getRecordError error - var record *dnsimple.ZoneRecord - credentials := dnsimple.NewOauthTokenCredentials(context.String("dnsimple-token")) - dnSimpleClient := dnsimple.NewClient(credentials) + client := dns.New(context.String("dnsimple-token")) + records := []dns.Record{} + var getRecordsError error + wg.Add(1) go func() { defer wg.Done() - - accountID, result, err := getRecord(dnSimpleClient, host, "AAAA") - dnSimpleAccountID = accountID - if err != nil { - getRecordError = errors.Wrap(err, "could not get ip for host") - return + fmt.Println("Getting records") + if results, err := client.GetRecords(host, "AAAA"); err != nil { + getRecordsError = errors.Wrap(err, "could not get records for host") + } else { + records = results } - record = result }() wg.Wait() @@ -87,93 +84,39 @@ func action(context *cli.Context) error { if getIPError != nil { return errors.Wrap(getIPError, "could no get ip address") } - if getRecordError != nil { - switch errors.Cause(getRecordError).(type) { - case notFoundError: - break - default: - return errors.Wrap(getRecordError, "could no get record") - } + + if getRecordsError != nil { + return errors.Wrap(getRecordsError, "could no get record") } - subdomain, domain, _ := splitHost(host) - if record == nil { - _, err := dnSimpleClient.Zones.CreateRecord(dnSimpleAccountID, domain, dnsimple.ZoneRecord{ - Name: subdomain, - Type: "AAAA", - Content: ip, - }) - if err != nil { - return errors.Wrap(err, "could not create record") - } - fmt.Printf("%s record created for %s as %s\n", "AAAA", host, ip) - } else if ip != record.Content { - record.Content = ip - _, err := dnSimpleClient.Zones.UpdateRecord(dnSimpleAccountID, domain, record.ID, *record) - if err != nil { - return errors.Wrap(err, "could not update record") - } - fmt.Printf("%s record for %s updated to %s\n", "AAAA", host, ip) - } else { + if len(records) == 1 && records[0].Content == ip { fmt.Printf("%s record for %s already set as %s\n", "AAAA", host, ip) + return nil } - return nil -} - -type whatIsMyIPResult struct { - IP string `json:ip` - Version string `json:version` -} -type notFoundError error - -func getRecord(client *dnsimple.Client, host string, kind string) (string, *dnsimple.ZoneRecord, error) { - subdomain, domain, err := splitHost(host) - if err != nil { - return "", nil, errors.Wrap(err, "unable to parse host") - } - accountID, err := getAccountID(client) - if err != nil { - return "", nil, errors.Wrap(err, "could not get account id") + if err := client.CreateRecord(host, dns.Record{Type: "AAAA", Content: ip}); err != nil { + return errors.Wrap(err, "could not create record") } - records, err := client.Zones.ListRecords(accountID, domain, &dnsimple.ZoneRecordListOptions{}) - if err != nil { - return "", nil, errors.Wrap(err, "could not get list of records") - } + fmt.Printf("%s record created for %s as %s\n", "AAAA", host, ip) - for _, record := range records.Data { - if record.ZoneID == domain && record.Name == subdomain && record.Type == kind { - return accountID, &record, nil - } + var deleteWG sync.WaitGroup + for _, record := range records { + deleteWG.Add(1) + go func(r dns.Record) { + defer deleteWG.Done() + fmt.Printf("%s record for %s being deleting %s\n", r.Type, host, r.Content) + client.DeleteRecord(r) + }(record) } + deleteWG.Wait() - return accountID, nil, notFoundError(errors.New("could not find record")) -} - -func getAccountID(client *dnsimple.Client) (string, error) { - whoamiResponse, err := client.Identity.Whoami() - if err != nil { - return "", errors.Wrap(err, "could not get account information from dnsimple") - } - if whoamiResponse.Data.Account == nil { - return "", errors.New("could not get account information") - } - return strconv.Itoa(whoamiResponse.Data.Account.ID), nil + return nil } -func splitHost(host string) (subdomain string, domain string, err error) { - subdomains := strings.Split(host, ".") - - if len(subdomains) >= 3 { - domain = strings.Join(subdomains[len(subdomains)-2:], ".") - subdomain = strings.Join(subdomains[:len(subdomains)-2], ".") - } else if len(subdomains) == 2 { - domain = strings.Join(subdomains, ".") - } else { - err = errors.New("invalid domain") - } - return subdomain, domain, err +type whatIsMyIPResult struct { + IP string `json:ip` + Version string `json:version` } func getIP(url string) (*whatIsMyIPResult, error) { diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..ca3903e --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,25 @@ +package utils + +import ( + "strings" + + "github.com/pkg/errors" +) + +// SplitHost returns hostname into domain and subdomain strings. +// For example: +// "some.subdomain.example.com" → ("some.subdomain", "example.com", nil) +// "example.com" → ("", "example.com", nil) +func SplitHost(host string) (subdomain string, domain string, err error) { + subdomains := strings.Split(host, ".") + + if len(subdomains) >= 3 { + domain = strings.Join(subdomains[len(subdomains)-2:], ".") + subdomain = strings.Join(subdomains[:len(subdomains)-2], ".") + } else if len(subdomains) == 2 { + domain = strings.Join(subdomains, ".") + } else { + err = errors.New("invalid domain") + } + return subdomain, domain, err +}