added hmac validation to http based api

This commit is contained in:
Christopher Ramey 2021-03-01 07:26:27 -09:00
parent 5c50f4ac95
commit 5571401103
5 changed files with 114 additions and 16 deletions

View File

@ -14,7 +14,7 @@ type Config struct {
DebugLevel int DebugLevel int
Listen string Listen string
Path string Path string
APIKey string APIKey []byte
APIKeyFile string APIKeyFile string
} }
@ -77,12 +77,12 @@ func ReadConfig(fn string, debuglvl int) (*Config, error) {
return nil, err return nil, err
} }
if cfg.APIKey == "" { if len(cfg.APIKey) == 0 {
b, err := os.ReadFile(cfg.APIKeyFile) b, err := os.ReadFile(cfg.APIKeyFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cfg.APIKey = string(b) cfg.APIKey = b
} }
return cfg, nil return cfg, nil

View File

@ -71,7 +71,7 @@ func (p *parser) parse() error {
case "listen": case "listen":
p.config.Listen = value p.config.Listen = value
case "api.key": case "api.key":
p.config.APIKey = value p.config.APIKey = []byte(value)
case "api.keyfile": case "api.keyfile":
p.config.APIKeyFile = value p.config.APIKeyFile = value
default: default:

72
server/apicommand.go Normal file
View File

@ -0,0 +1,72 @@
package server
import (
"crypto/hmac"
"crypto/sha256"
"encoding/json"
"fmt"
"time"
)
type APICommand struct {
Expires time.Time `json:"exp"`
Command string `json:"cmd"`
Scheme string `json:"sch"`
Signature []byte `json:"sig,omitempty"`
}
func ParseAPICommand(jsn []byte) (*APICommand, error) {
api := &APICommand{}
err := json.Unmarshal(jsn, api)
if err != nil {
return nil, err
}
return api, nil
}
func (ac *APICommand) JSON() ([]byte, error) {
return json.Marshal(ac)
}
func (ac *APICommand) Sign(key []byte) error {
switch ac.Scheme {
case "hmac-sha256":
j, err := ac.JSON()
if err != nil {
return fmt.Errorf("json encoding error")
}
mac := hmac.New(sha256.New, key)
mac.Write(j)
ac.Signature = mac.Sum(nil)
case "":
return fmt.Errorf("scheme may not be empty")
default:
return fmt.Errorf("unsupported scheme: %s", ac.Scheme)
}
return nil
}
func (ac *APICommand) Validate(key []byte) error {
cpy := &APICommand{
Expires: ac.Expires,
Command: ac.Command,
Scheme: ac.Scheme,
}
err := cpy.Sign(key)
if err != nil {
return err
}
if !hmac.Equal(cpy.Signature, ac.Signature) {
return fmt.Errorf("invalid signature")
}
if time.Now().After(ac.Expires) {
return fmt.Errorf("command expired")
}
return nil
}

View File

@ -7,13 +7,39 @@ import (
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/shutdown": case "/api":
fmt.Fprintf(w, "shutting down .. ") g := r.URL.Query()
c := g.Get("cmd")
if c == "" {
http.Error(w, "no command given", http.StatusBadRequest)
return
}
cmd, err := ParseAPICommand([]byte(c))
if err != nil {
http.Error(w, fmt.Sprintf("error parsing command: %s", err.Error()),
http.StatusBadRequest)
return
}
err = cmd.Validate(s.config.APIKey)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch cmd.Command {
case "shutdown":
s.shutdownc <- false s.shutdownc <- false
case "/restart":
fmt.Fprintf(w, "restarting .. ") case "restart":
s.shutdownc <- true s.shutdownc <- true
default: default:
fmt.Fprintf(w, "Hello, world!") http.Error(w, fmt.Sprintf("unknown command: %s", cmd.Command),
http.StatusBadRequest)
}
default:
http.Error(w, "File not found", http.StatusNotFound)
} }
} }

View File

@ -11,13 +11,13 @@ import (
type Server struct { type Server struct {
workers []*worker workers []*worker
cfg *config.Config config *config.Config
shutdownc chan bool shutdownc chan bool
http http.Server http http.Server
} }
func (srv *Server) Start() (bool, error) { func (srv *Server) Start() (bool, error) {
listen, err := net.Listen("tcp", srv.cfg.Listen) listen, err := net.Listen("tcp", srv.config.Listen)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -29,12 +29,12 @@ func (srv *Server) Start() (bool, error) {
srv.http = http.Server{Handler: srv} srv.http = http.Server{Handler: srv}
go srv.http.Serve(listen) go srv.http.Serve(listen)
t := time.NewTicker(srv.cfg.Interval) t := time.NewTicker(srv.config.Interval)
defer t.Stop() defer t.Stop()
for { for {
select { select {
case r := <-t.C: case r := <-t.C:
if srv.cfg.DebugLevel > 0 { if srv.config.DebugLevel > 0 {
fmt.Printf("interval check at %s\n", r) fmt.Printf("interval check at %s\n", r)
} }
for _, w := range srv.workers { for _, w := range srv.workers {
@ -52,7 +52,7 @@ func (srv *Server) Start() (bool, error) {
func NewServer(cfg *config.Config) *Server { func NewServer(cfg *config.Config) *Server {
srv := &Server{ srv := &Server{
cfg: cfg, config: cfg,
shutdownc: make(chan bool, 1), shutdownc: make(chan bool, 1),
} }
for _, g := range cfg.Groups { for _, g := range cfg.Groups {