diff --git a/cmd/peakHandling.go b/cmd/peakHandling.go new file mode 100644 index 0000000..c7c0977 --- /dev/null +++ b/cmd/peakHandling.go @@ -0,0 +1,79 @@ +package main + +import ( + "log" + "net/http" + "sync" + "time" + + "github.com/robfig/cron" +) + +var peakRequest30 sync.Map +var peakRequest60 sync.Map + +func initPeakHandling() { + c := cron.New() + // cronTime := fmt.Sprintf("%d,%d * * * *", 30-prefetchInterval/60, 60-prefetchInterval/60) + c.AddFunc("24 * * * *", prefetchPeakRequests30) + c.AddFunc("54 * * * *", prefetchPeakRequests60) + c.Start() +} + +func savePeakRequest(cacheDigest string, r *http.Request) { + _, min, _ := time.Now().Clock() + if min == 30 { + peakRequest30.Store(cacheDigest, *r) + } else if min == 0 { + peakRequest60.Store(cacheDigest, *r) + } +} + +func prefetchRequest(r *http.Request) { + processRequest(r) +} + +func syncMapLen(sm *sync.Map) int { + count := 0 + + f := func(key, value interface{}) bool { + + // Not really certain about this part, don't know for sure + // if this is a good check for an entry's existence + if key == "" { + return false + } + count++ + + return true + } + + sm.Range(f) + + return count +} + +func prefetchPeakRequests(peakRequestMap *sync.Map) { + peakRequestLen := syncMapLen(peakRequestMap) + log.Printf("PREFETCH: Prefetching %d requests\n", peakRequestLen) + if peakRequestLen == 0 { + return + } + sleepBetweenRequests := time.Duration(prefetchInterval*1000/peakRequestLen) * time.Millisecond + peakRequestMap.Range(func(key interface{}, value interface{}) bool { + r := value.(http.Request) + log.Printf("Prefetching %s\n", key) + prefetchRequest(&r) + peakRequestMap.Delete(key) + time.Sleep(sleepBetweenRequests) + return true + }) +} + +func prefetchPeakRequests30() { + prefetchPeakRequests(&peakRequest30) +} + +func prefetchPeakRequests60() { + prefetchPeakRequests(&peakRequest60) +} diff --git a/cmd/processRequest.go b/cmd/processRequest.go new file mode 100644 index 0000000..25209b6 --- /dev/null +++ b/cmd/processRequest.go @@ -0,0 +1,130 @@ +package main + +import ( + "fmt" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "time" +) + +func processRequest(r *http.Request) responseWithHeader { + var response responseWithHeader + + foundInCache := false + cacheDigest := getCacheDigest(r) + + savePeakRequest(cacheDigest, r) + + cacheBody, ok := lruCache.Get(cacheDigest) + if ok { + cacheEntry := cacheBody.(responseWithHeader) + + // if after all attempts we still have no answer, + // we try to make the query on our own + for attempts := 0; attempts < 300; attempts++ { + if !ok || !cacheEntry.InProgress { + break + } + time.Sleep(30 * time.Millisecond) + cacheBody, ok = lruCache.Get(cacheDigest) + cacheEntry = cacheBody.(responseWithHeader) + } + if cacheEntry.InProgress { + log.Printf("TIMEOUT: %s\n", cacheDigest) + } + if ok && !cacheEntry.InProgress && cacheEntry.Expires.After(time.Now()) { + response = cacheEntry + foundInCache = true + } + } + + if !foundInCache { + lruCache.Add(cacheDigest, responseWithHeader{InProgress: true}) + response = get(r) + if response.StatusCode == 200 || response.StatusCode == 304 { + lruCache.Add(cacheDigest, response) + } else { + log.Printf("REMOVE: %d response for %s from cache\n", response.StatusCode, cacheDigest) + lruCache.Remove(cacheDigest) + } + } + return response +} + +func get(req *http.Request) responseWithHeader { + + client := &http.Client{} + + queryURL := fmt.Sprintf("http://%s%s", req.Host, req.RequestURI) + + proxyReq, err := http.NewRequest(req.Method, queryURL, req.Body) + if err != nil { + log.Printf("Request: %s\n", err) + } + + // proxyReq.Header.Set("Host", req.Host) + // proxyReq.Header.Set("X-Forwarded-For", req.RemoteAddr) + + for header, values := range req.Header { + for _, value := range values { + proxyReq.Header.Add(header, value) + } + } + + res, err := client.Do(proxyReq) + + if err != nil { + panic(err) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + log.Println(err) + } + + return responseWithHeader{ + InProgress: false, + Expires: time.Now().Add(time.Duration(randInt(1000, 1500)) * time.Second), + Body: body, + Header: res.Header, + StatusCode: res.StatusCode, + } +} + +// implementation of the cache.get_signature of original wttr.in +func getCacheDigest(req *http.Request) string { + + userAgent := req.Header.Get("User-Agent") + + queryHost := req.Host + queryString := req.RequestURI + + clientIPAddress := readUserIP(req) + + lang := req.Header.Get("Accept-Language") + + return fmt.Sprintf("%s:%s%s:%s:%s", userAgent, queryHost, queryString, clientIPAddress, lang) +} + +func readUserIP(r *http.Request) string { + IPAddress := r.Header.Get("X-Real-Ip") + if IPAddress == "" { + IPAddress = r.Header.Get("X-Forwarded-For") + } + if IPAddress == "" { + IPAddress = r.RemoteAddr + var err error + IPAddress, _, err = net.SplitHostPort(IPAddress) + if err != nil { + log.Printf("ERROR: userip: %q is not IP:port\n", IPAddress) + } + } + return IPAddress +} + +func randInt(min int, max int) int { + return min + rand.Intn(max-min) +} diff --git a/cmd/srv.go b/cmd/srv.go index 371cf21..603391b 100644 --- a/cmd/srv.go +++ b/cmd/srv.go @@ -2,113 +2,48 @@ package main import ( "context" - "fmt" - "io/ioutil" "log" "net" "net/http" "time" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" ) +const uplinkSrvAddr = "127.0.0.1:9002" +const uplinkTimeout = 30 +const prefetchInterval = 300 +const lruCacheSize = 12800 + var lruCache *lru.Cache -type ResponseWithHeader struct { +type responseWithHeader struct { + InProgress bool // true if the request is being processed + Expires time.Time // expiration time of the cache entry + Body []byte Header http.Header StatusCode int // e.g. 200 - } func init() { var err error - lruCache, err = lru.New(12800) + lruCache, err = lru.New(lruCacheSize) if err != nil { panic(err) } dialer := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, + Timeout: uplinkTimeout * time.Second, + KeepAlive: uplinkTimeout * time.Second, DualStack: true, } - http.DefaultTransport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - addr = "127.0.0.1:8002" - return dialer.DialContext(ctx, network, addr) - } - -} - -func readUserIP(r *http.Request) string { - IPAddress := r.Header.Get("X-Real-Ip") - if IPAddress == "" { - IPAddress = r.Header.Get("X-Forwarded-For") - } - if IPAddress == "" { - IPAddress = r.RemoteAddr - var err error - IPAddress, _, err = net.SplitHostPort(IPAddress) - if err != nil { - fmt.Printf("userip: %q is not IP:port\n", IPAddress) - } - } - return IPAddress -} - -// implementation of the cache.get_signature of original wttr.in -func findCacheDigest(req *http.Request) string { - - userAgent := req.Header.Get("User-Agent") - - queryHost := req.Host - queryString := req.RequestURI - - clientIpAddress := readUserIP(req) - - lang := req.Header.Get("Accept-Language") - - now := time.Now() - secs := now.Unix() - timestamp := secs / 1000 - - return fmt.Sprintf("%s:%s%s:%s:%s:%d", userAgent, queryHost, queryString, clientIpAddress, lang, timestamp) -} - -func get(req *http.Request) ResponseWithHeader { - - client := &http.Client{} - - queryURL := fmt.Sprintf("http://%s%s", req.Host, req.RequestURI) - - proxyReq, err := http.NewRequest(req.Method, queryURL, req.Body) - if err != nil { - // handle error - } - - // proxyReq.Header.Set("Host", req.Host) - // proxyReq.Header.Set("X-Forwarded-For", req.RemoteAddr) - - for header, values := range req.Header { - for _, value := range values { - proxyReq.Header.Add(header, value) - } + http.DefaultTransport.(*http.Transport).DialContext = func(ctx context.Context, network, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, network, uplinkSrvAddr) } - res, err := client.Do(proxyReq) - - if err != nil { - panic(err) - } - - body, err := ioutil.ReadAll(res.Body) - - return ResponseWithHeader{ - Body: body, - Header: res.Header, - StatusCode: res.StatusCode, - } + initPeakHandling() } func copyHeader(dst, src http.Header) { @@ -120,26 +55,14 @@ func copyHeader(dst, src http.Header) { } func main() { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - var response ResponseWithHeader + // printStat() + response := processRequest(r) - cacheDigest := findCacheDigest(r) - cacheBody, ok := lruCache.Get(cacheDigest) - if ok { - response = cacheBody.(ResponseWithHeader) - } else { - fmt.Println(cacheDigest) - response = get(r) - if response.StatusCode == 200 { - lruCache.Add(cacheDigest, response) - } - } copyHeader(w.Header(), response.Header) w.WriteHeader(response.StatusCode) w.Write(response.Body) }) - log.Fatal(http.ListenAndServe(":8081", nil)) - + log.Fatal(http.ListenAndServe(":8082", nil)) } diff --git a/cmd/stat.go b/cmd/stat.go new file mode 100644 index 0000000..1fc135e --- /dev/null +++ b/cmd/stat.go @@ -0,0 +1,40 @@ +package main + +import ( + "log" + "sync" + "time" +) + +type safeCounter struct { + v map[int]int + mux sync.Mutex +} + +func (c *safeCounter) inc(key int) { + c.mux.Lock() + c.v[key]++ + c.mux.Unlock() +} + +// func (c *safeCounter) val(key int) int { +// c.mux.Lock() +// defer c.mux.Unlock() +// return c.v[key] +// } +// +// func (c *safeCounter) reset(key int) int { +// c.mux.Lock() +// defer c.mux.Unlock() +// result := c.v[key] +// c.v[key] = 0 +// return result +// } + +var queriesPerMinute safeCounter + +func printStat() { + _, min, _ := time.Now().Clock() + queriesPerMinute.inc(min) + log.Printf("Processed %d requests\n", min) +}