diff --git a/main.go b/main.go index 6222b40..f78ea96 100644 --- a/main.go +++ b/main.go @@ -32,24 +32,38 @@ const appVersion = "1.0.1" var src = rand.NewSource(time.Now().UnixNano()) var pool = cache.New(240*time.Hour, 1*time.Hour) -var okTmpl = template.Must(template.ParseFiles("templates/ok.html")) -var indexTmpl = template.Must(template.ParseFiles("templates/index.html")) -var returnTmpl = template.Must(template.ParseFiles("templates/returnPage.html")) -var notFoundTmpl = template.Must(template.ParseFiles("templates/404.html")) -var badRequestTmpl = template.Must(template.ParseFiles("templates/400.html")) -var internalErrorTmpl = template.Must(template.ParseFiles("templates/500.html")) +var templates *template.Template +var allFiles []string + +func init() { + files, err := ioutil.ReadDir("templates") + if err != nil { + log.Fatalln(err) + } + for _, file := range files { + filename := file.Name() + if strings.HasSuffix(filename, ".html") { + allFiles = append(allFiles, "templates/"+filename) + } + } + templates, err = template.ParseFiles(allFiles...) + if err != nil { + log.Fatalln(err) + } +} func index(w http.ResponseWriter, r *http.Request) { - indexTmpl.Execute(w, indexTmpl) + t := templates.Lookup("index.html") + t.Execute(w, nil) } // get executes the GET command func get(key string) (string, bool) { value, status := pool.Get(key) - if status { - return value.(string), status + if !status { + return "", status } - return "", false + return value.(string), status } // set executes the redis SET command @@ -63,6 +77,7 @@ func set(key, suffix string) { func redirect(w http.ResponseWriter, r *http.Request, path string) { vals := mux.Vars(r) key := vals["key"] + t := templates.Lookup("404.html") if path != "" { key = strings.Replace(key, path, "", 1) } @@ -71,18 +86,14 @@ func redirect(w http.ResponseWriter, r *http.Request, path string) { key, status := get(key) if !status { w.WriteHeader(http.StatusNotFound) - notFoundTmpl.Execute(w, nil) + t.Execute(w, nil) return } - u, _ := url.Parse(key) - if u.Scheme == "" { - u.Scheme = "https" - } - http.Redirect(w, r, u.String(), http.StatusFound) - } else { - w.WriteHeader(http.StatusNotFound) - notFoundTmpl.Execute(w, notFoundTmpl) + u, _ := url.Parse(key) + if u.Scheme == "" { + u.Scheme = "https" } + http.Redirect(w, r, u.String(), http.StatusFound) } // shortner reads url from a POST request, validates the url, generate a @@ -91,24 +102,26 @@ func redirect(w http.ResponseWriter, r *http.Request, path string) { // then if writes the kv pair suffix, url to the database and return the // shortened url to the user func shortner(w http.ResponseWriter, r *http.Request, proto, domain, hostSuf, path string, urlSize int) { + ret := templates.Lookup("returnPage.html") + badR := templates.Lookup("400.html") if !govalidator.IsURL(r.FormValue("url")) { w.WriteHeader(http.StatusBadRequest) - badRequestTmpl.Execute(w, nil) + badR.Execute(w, nil) return } - u, _ := url.Parse(r.FormValue("url")) + u, _ := url.Parse(r.FormValue("url")) suffix := randStringBytesMaskImprSrc(urlSize) - for { - _, status := get(suffix) + for { + _, status := get(suffix) if !status { - break - } - suffix = randStringBytesMaskImprSrc(urlSize) + break } - set(u.String(), suffix) - shortend := proto + "://" + domain + hostSuf + path + suffix - returnTmpl.Execute(w, shortend) + suffix = randStringBytesMaskImprSrc(urlSize) + } + set(u.String(), suffix) + shortend := proto + "://" + domain + hostSuf + path + suffix + ret.Execute(w, shortend) } // randStringBytesMaskImprSrc Generate random string of n size @@ -132,9 +145,10 @@ func randStringBytesMaskImprSrc(n int) string { // internalError receives a http.ResponseWriter, msg and error and // return a internal error page with http code 500 to the user func internalError(w http.ResponseWriter, msg string, err error) { + t := templates.Lookup("500.html") log.Println(err) w.WriteHeader(http.StatusInternalServerError) - internalErrorTmpl.Execute(w, msg+err.Error()) + t.Execute(w, msg+err.Error()) } // itemsCount returns the number of kv pairs on the in meomry database @@ -163,6 +177,7 @@ func itemsDump(w http.ResponseWriter, r *http.Request) { // itemsFromFile loads kv pairs from the dumpFile json to the in memory database func itemsFromFile(w http.ResponseWriter, r *http.Request, dumpFile string) { + t := templates.Lookup("ok.html") jsonFile, err := ioutil.ReadFile(dumpFile) var dumpObj map[string]cache.Item json.Unmarshal([]byte(jsonFile), &dumpObj) @@ -170,12 +185,13 @@ func itemsFromFile(w http.ResponseWriter, r *http.Request, dumpFile string) { internalError(w, "Cannot open file "+dumpFile+": ", err) } else { pool = cache.NewFrom(240*time.Hour, 1*time.Hour, dumpObj) - okTmpl.Execute(w, "Imported "+strconv.Itoa(len(dumpObj))+" items to the DB") + t.Execute(w, "Imported "+strconv.Itoa(len(dumpObj))+" items to the DB") } } // itemsFromPost loads kv pairs from a json POST to the in memory database func itemsFromPost(w http.ResponseWriter, r *http.Request) { + t := templates.Lookup("ok.html") decoder := json.NewDecoder(r.Body) var dumpObj map[string]cache.Item err := decoder.Decode(&dumpObj) @@ -183,12 +199,13 @@ func itemsFromPost(w http.ResponseWriter, r *http.Request) { internalError(w, "Cannot parse JSON: ", err) } else { pool = cache.NewFrom(240*time.Hour, 1*time.Hour, dumpObj) - okTmpl.Execute(w, "Imported "+strconv.Itoa(len(dumpObj))+" items to the DB") + t.Execute(w, "Imported "+strconv.Itoa(len(dumpObj))+" items to the DB") } } // itemsDumpToFile dumps the kv pairs from the in memory database to the dumpFile func itemsDumpToFile(w http.ResponseWriter, r *http.Request, dumpFile string) { + t := templates.Lookup("ok.html") dumpObj, _ := json.Marshal( pool.Items(), ) @@ -196,7 +213,7 @@ func itemsDumpToFile(w http.ResponseWriter, r *http.Request, dumpFile string) { if err != nil { internalError(w, "Failed to open json file: ", err) } else { - okTmpl.Execute(w, "Dump writen to: "+dumpFile) + t.Execute(w, "Dump writen to: "+dumpFile) } } @@ -220,7 +237,7 @@ func main() { } if *port > 65535 || *port < 1 { - + log.Fatalln("Invalid port number") } if *path != "" && !strings.HasSuffix(*path, "/") { *path = *path + "/"