package main import ( "bytes" "html/template" "io" "net" "net/http" "strings" "github.com/pkg/errors" ) // Server that handles http responses type Server struct { headerNames []string logger LoggerHandler tmpl *template.Template } type responseFormat int const ( unknownResponse responseFormat = iota jsonResponse textResponse htmlResponse ) func (s *Server) ServeHTTP(resp http.ResponseWriter, req *http.Request) { s.handleHTTP(resp, req, s.getResponseType(req)) } func (s *Server) handleHTTP(resp http.ResponseWriter, req *http.Request, responseType responseFormat) { for _, headerName := range s.headerNames { possibleIP := req.Header.Get(headerName) if possibleIP == "" { continue } ip, err := NewIP(possibleIP) if err != nil { continue } s.sendResponse(resp, ip, responseType) return } addr, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { s.handleError(resp, http.StatusBadRequest, err, "Could not get IP address from request") return } ip, err := NewIP(addr) if err != nil { s.handleError(resp, http.StatusBadRequest, err, "Could not parse IP address"+addr) return } s.sendResponse(resp, ip, responseType) } func hasType(accepts, valid []string) bool { for _, accept := range accepts { value := strings.Trim(strings.Split(accept, ";")[0], " ") for _, check := range valid { if check == value { return true } } } return false } func (s *Server) getResponseType(req *http.Request) responseFormat { accepts := strings.Split(req.Header.Get("Accept"), ",") if hasType(accepts, []string{"text/html", "application/xhtml+xml"}) { return htmlResponse } if hasType(accepts, []string{"application/json", "text/json", "text/javascript"}) { return jsonResponse } if hasType(accepts, []string{"text/plain"}) { return textResponse } return unknownResponse } func (s *Server) handleHTTPAs(responseType responseFormat) func(http.ResponseWriter, *http.Request) { return func(resp http.ResponseWriter, req *http.Request) { s.handleHTTP(resp, req, responseType) } } func (s *Server) sendResponse(resp http.ResponseWriter, ip *IP, responseType responseFormat) { s.logger.Printf("Request from %s %s\n", ip.version, ip) var body []byte var contentType = "text/plain; charset=utf-8" var err error switch responseType { case jsonResponse: jsonBody, marshalErr := ip.MarshalJSON() if marshalErr != nil { err = errors.Wrap(marshalErr, "could not marshal json") } contentType = "application/json; charset=utf-8" body = jsonBody case htmlResponse: buffer := bytes.NewBuffer([]byte{}) exeErr := s.tmpl.Execute(buffer, map[string]string{"IP": ip.String()}) if exeErr != nil { err = errors.Wrap(exeErr, "could not get html") } contentType = "text/html; charset=utf-8" body = buffer.Bytes() default: body = []byte(ip.String()) } if err != nil { const errorMessage = "could not send response" s.handleError(resp, http.StatusInternalServerError, errors.Wrap(err, errorMessage), errorMessage) return } resp.Header().Set("Content-Type", contentType) resp.Write(body) } func (s *Server) handleError(resp http.ResponseWriter, status int, err error, message string) { resp.WriteHeader(status) io.WriteString(resp, message) s.logger.Printf("Error handling request: %s (%s)", message, err) }