diff --git a/main.go b/main.go index ae52755..29b7251 100644 --- a/main.go +++ b/main.go @@ -13,10 +13,7 @@ import ( "github.com/gorilla/mux" ) -var ( - l int - C counter -) +var C counter type counter struct { mu sync.Mutex @@ -42,38 +39,48 @@ func (c *counter) Reset() { c.mu.Unlock() } -func getenv(k string, d int) int { +func getenv[D ~string | int](k string, d D) D { v := os.Getenv(k) if len(v) == 0 { return d } - i, err := strconv.Atoi(v) - if err != nil { - log.Fatalf("Invalid Value, %s not a valid integer: %v", k, err) + var r any + switch any(d).(type) { + case string: + r = v + case int: + i, err := strconv.Atoi(v) + if err != nil { + } + r = i + default: + log.Fatalf("Invalid Value, %s not a valid", k) } - return i + return r.(D) } -func serve() bool { +func serve(l int) bool { return C.Get() < l } -func handler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - jm, err := json.Marshal(r.PostForm) - if err != nil || !serve() { - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintln(w, "why did you do that?") +func handler(l int) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + jm, err := json.Marshal(r.PostForm) + if err != nil || !serve(l) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, "why did you do that?") + C.Add() + return + } + fmt.Fprintf(w, "%v", string(jm)) C.Add() - return } - fmt.Fprintf(w, "%v", string(jm)) - C.Add() } -func httpHealth() http.HandlerFunc { +func httpHealth(l int) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if !serve() { + if !serve(l) { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, `{"status":"FAIL","requests":"%v"}`, C.Get()) return @@ -82,13 +89,30 @@ func httpHealth() http.HandlerFunc { } } +func reset(rt string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + t := r.FormValue("TOKEN") + if len(t) == 0 || t != rt { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "Bad request, invalid token") + return + } + C.Reset() + fmt.Fprintf(w, `{"status":"ok","requests":"%v"}`, C.Get()) + } +} + func main() { - l = getenv("MAX_REQUESTS", 500) + l := getenv("MAX_REQUESTS", 500) + t := getenv("TOKEN", "token") r := mux.NewRouter() - r.HandleFunc("/", handler). + r.HandleFunc("/", handler(l)). Methods("POST") - r.HandleFunc("/healthz", httpHealth()). + r.HandleFunc("/healthz", httpHealth(l)). Methods("GET") + r.HandleFunc("/reset", reset(t)). + Methods("PUT") logger := handlers.LoggingHandler(os.Stdout, r) log.Fatal(http.ListenAndServe(":8080", logger)) }