How I Created a Rate Limiting Proxy Middleware With Go

A quick post on how I implemented an IP rate limiting micro-service middleware for my API using Go.

You have just released your shiny new API. With a cool data set that you intend to share with the world. Have you thought about rate limiting?

What is Rate Limiting?

Rate Limiting is where you restrict access to your service to a certain number of requests per timeframe, per client.

Why Rate Limit?

The most important reason is to prevent one user from making too many requests and hogging all the resource. Since each client is limited to a certain number of requests, per time frame. Rate limiting can also help prevent against DDOS attacks, where a malicious script hammers your API in an attempt to take it offline.

How is it done?

You will need to find a way of creating a unique identifier for each user of your api. This could be the IP address, a consumer token, or a composite key of both. A tally will then be kept of the number of requests per client and when this limit is reached they will receive a 429 status code.

The Code

Below is the main.go file for my rather rudimentary rate limiting service. Shout out to ulule for github.com/ulule/limiter.

Without further to do here is the code;

package main

import (
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"os"
	"strconv"
	"time"

	"github.com/ulule/limiter/v3"
	"github.com/ulule/limiter/v3/drivers/store/memory"
)

// Rate Limitr struct holds most of our logic,
// the limiter is ulule's package
// the callback will be called if the user has not hit their limit
type Rate Limitr struct {
	limiter  *limiter.Limiter
	callback func(http.ResponseWriter, *http.Request)
}

// ServeHTTP looks familiar? this will be ran on every request
// this is where the rate limiting logic will live
func (t Rate Limitr) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	// see below
	ip, err := resolveIP(r)

	if err != nil {
		log.Printf("error obtaining IP address: %s", err)
		w.WriteHeader(http.StatusInternalServerError)
		return
	}

	// fetch the users usage by their IP. If they have any
	limiterCtx, err := t.limiter.Get(r.Context(), ip)
	if err != nil {
		log.Printf("IPRateLimit - ipRateLimiter.Get - err: %v, %s on %s", err, ip, r.URL)
		w.WriteHeader(http.StatusInternalServerError)
		return
	}

	// set some headers to inform the user of their current usage
	h := w.Header()
	h.Set("X-RateLimit-Limit", strconv.FormatInt(limiterCtx.Limit, 10))
	h.Set("X-RateLimit-Remaining", strconv.FormatInt(limiterCtx.Remaining, 10))
	h.Set("X-RateLimit-Reset", strconv.FormatInt(limiterCtx.Reset, 10))

	// check whether the client has reached their limit,
	// if they have throw a 429 and a helpful message
	if limiterCtx.Reached {
		log.SetOutput(os.Stderr)
		log.Printf("Too Many Requests from %s on %s", ip, r.URL)
		w.WriteHeader(429)
		w.Write([]byte("{\"msg\":\"too many requests\"}"))
		return
	}

	// all should be good let us continue
	t.callback(w, r)
}

func main() {
	// set up a remote endpoint
	remote, err := url.Parse("http://my-app")
	if err != nil {
		panic(err)
	}
	// set up the proxy to connect to this endpoint
	proxy := httputil.NewSingleHostReverseProxy(remote)

	// init the throttler from above
	t := Rate Limitr{
		limiter: limiter.New(
			memory.NewStore(),
			limiter.Rate{
				Period: 1 * time.Hour,
				Limit:  100, // TODO add as env var
			}),
		callback: handler(proxy),
	}

	// begin to listen on port 8079
	err = http.ListenAndServe(":8079", t)
	if err != nil {
		panic(err)
	}
}


// handler is ran for users who have not reached their limit
func handler(p *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		log.SetOutput(os.Stderr)
		log.Println(r.URL)
		p.ServeHTTP(w, r)
	}
}

// resolveIP a helper function that finds the IP of the current user
// tries the IP, and then the for X-FORWARDED-FOR in case of proxy
// not fool proof, but good enough
func resolveIP(r *http.Request) (string, error) {
	ip, _, err := net.SplitHostPort(r.RemoteAddr)

	if err != nil {
		return "", err
	}

	if forIP := r.Header.Get("X-FORWARDED-FOR"); forIP != "" {
		IP = forIP
	}
	return ip, nil
}

alt text alt text alt text

Possible Improvements

If I were to improve this service further I would introduce consumer tokens. I would then rate limit based on a composite key of both their IP and their consumer token. Introducing a consumer token will also give me closer control over who has access to my API as well as allow multiple people from the same IP to access without sharing bandwidth.

You can checkout the github repo here.