From 557140110343a55545185f3e41e0072e79a17357 Mon Sep 17 00:00:00 2001 From: Christopher Ramey Date: Mon, 1 Mar 2021 07:26:27 -0900 Subject: [PATCH] added hmac validation to http based api --- config/config.go | 6 ++-- config/parser.go | 2 +- server/apicommand.go | 72 ++++++++++++++++++++++++++++++++++++++++++++ server/http.go | 40 +++++++++++++++++++----- server/server.go | 10 +++--- 5 files changed, 114 insertions(+), 16 deletions(-) create mode 100644 server/apicommand.go diff --git a/config/config.go b/config/config.go index 71b47e7..1ab5d4b 100644 --- a/config/config.go +++ b/config/config.go @@ -14,7 +14,7 @@ type Config struct { DebugLevel int Listen string Path string - APIKey string + APIKey []byte APIKeyFile string } @@ -77,12 +77,12 @@ func ReadConfig(fn string, debuglvl int) (*Config, error) { return nil, err } - if cfg.APIKey == "" { + if len(cfg.APIKey) == 0 { b, err := os.ReadFile(cfg.APIKeyFile) if err != nil { return nil, err } - cfg.APIKey = string(b) + cfg.APIKey = b } return cfg, nil diff --git a/config/parser.go b/config/parser.go index 133348b..8040024 100644 --- a/config/parser.go +++ b/config/parser.go @@ -71,7 +71,7 @@ func (p *parser) parse() error { case "listen": p.config.Listen = value case "api.key": - p.config.APIKey = value + p.config.APIKey = []byte(value) case "api.keyfile": p.config.APIKeyFile = value default: diff --git a/server/apicommand.go b/server/apicommand.go new file mode 100644 index 0000000..009d796 --- /dev/null +++ b/server/apicommand.go @@ -0,0 +1,72 @@ +package server + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/json" + "fmt" + "time" +) + +type APICommand struct { + Expires time.Time `json:"exp"` + Command string `json:"cmd"` + Scheme string `json:"sch"` + Signature []byte `json:"sig,omitempty"` +} + +func ParseAPICommand(jsn []byte) (*APICommand, error) { + api := &APICommand{} + err := json.Unmarshal(jsn, api) + if err != nil { + return nil, err + } + return api, nil +} + +func (ac *APICommand) JSON() ([]byte, error) { + return json.Marshal(ac) +} + +func (ac *APICommand) Sign(key []byte) error { + switch ac.Scheme { + case "hmac-sha256": + j, err := ac.JSON() + if err != nil { + return fmt.Errorf("json encoding error") + } + + mac := hmac.New(sha256.New, key) + mac.Write(j) + ac.Signature = mac.Sum(nil) + + case "": + return fmt.Errorf("scheme may not be empty") + + default: + return fmt.Errorf("unsupported scheme: %s", ac.Scheme) + } + return nil +} + +func (ac *APICommand) Validate(key []byte) error { + cpy := &APICommand{ + Expires: ac.Expires, + Command: ac.Command, + Scheme: ac.Scheme, + } + err := cpy.Sign(key) + if err != nil { + return err + } + + if !hmac.Equal(cpy.Signature, ac.Signature) { + return fmt.Errorf("invalid signature") + } + + if time.Now().After(ac.Expires) { + return fmt.Errorf("command expired") + } + + return nil +} diff --git a/server/http.go b/server/http.go index 2fcb110..b2af9f0 100644 --- a/server/http.go +++ b/server/http.go @@ -7,13 +7,39 @@ import ( func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { - case "/shutdown": - fmt.Fprintf(w, "shutting down .. ") - s.shutdownc <- false - case "/restart": - fmt.Fprintf(w, "restarting .. ") - s.shutdownc <- true + case "/api": + g := r.URL.Query() + c := g.Get("cmd") + if c == "" { + http.Error(w, "no command given", http.StatusBadRequest) + return + } + cmd, err := ParseAPICommand([]byte(c)) + if err != nil { + http.Error(w, fmt.Sprintf("error parsing command: %s", err.Error()), + http.StatusBadRequest) + return + } + + err = cmd.Validate(s.config.APIKey) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + switch cmd.Command { + case "shutdown": + s.shutdownc <- false + + case "restart": + s.shutdownc <- true + + default: + http.Error(w, fmt.Sprintf("unknown command: %s", cmd.Command), + http.StatusBadRequest) + } + default: - fmt.Fprintf(w, "Hello, world!") + http.Error(w, "File not found", http.StatusNotFound) } } diff --git a/server/server.go b/server/server.go index 8bde9d3..cf33688 100644 --- a/server/server.go +++ b/server/server.go @@ -11,13 +11,13 @@ import ( type Server struct { workers []*worker - cfg *config.Config + config *config.Config shutdownc chan bool http http.Server } func (srv *Server) Start() (bool, error) { - listen, err := net.Listen("tcp", srv.cfg.Listen) + listen, err := net.Listen("tcp", srv.config.Listen) if err != nil { return false, err } @@ -29,12 +29,12 @@ func (srv *Server) Start() (bool, error) { srv.http = http.Server{Handler: srv} go srv.http.Serve(listen) - t := time.NewTicker(srv.cfg.Interval) + t := time.NewTicker(srv.config.Interval) defer t.Stop() for { select { case r := <-t.C: - if srv.cfg.DebugLevel > 0 { + if srv.config.DebugLevel > 0 { fmt.Printf("interval check at %s\n", r) } for _, w := range srv.workers { @@ -52,7 +52,7 @@ func (srv *Server) Start() (bool, error) { func NewServer(cfg *config.Config) *Server { srv := &Server{ - cfg: cfg, + config: cfg, shutdownc: make(chan bool, 1), } for _, g := range cfg.Groups {