RedisとGolangを使ったレート制限アルゴリズム

Cover image for Algorithms for Rate limiting with Redis and Golang

はじめに

今日のインターネットで結ばれた世界では、何百万というユーザーが同時にインターネットにアクセスしています。そのため、ネットワークを通るトラフィック量を調整することが不可欠です。過負荷やDoS攻撃(Denial of Service)を防ぐため、トラフィックの流れを制御するプロセスをレート制限と呼びます。この記事では、Redisをデータベースとして使い、Golangのコードを用いていくつかの人気のあるレート制限アルゴリズムについて説明します。

  • 固定ウィンドウアルゴリズム
  • スライディングログアルゴリズム
  • リーキーバケットアルゴリズム
  • スライディングウィンドウアルゴリズム
  • トークンバケットアルゴリズム

それでは、一つずつ見ていきましょう。

固定ウィンドウアルゴリズム

固定ウィンドウアルゴリズムは最もシンプルなレート制限アルゴリズムです。このアルゴリズムでは、特定の時間ウィンドウ内にネットワークを通過できるトラフィックがあらかじめ決められています。例えば、レート制限が1秒あたり10リクエストに設定されている場合、1秒間のウィンドウ内で10リクエストまでがネットワークを通過できます。設定された制限を超えるリクエストがある場合、残りのリクエストは破棄されるか後で処理するためにキューに入れられます。

func fixedWindowAlgorithm(client *redis.Client, key string, limit int64, window time.Duration) bool {
    currentTime := time.Now()
    keyWindow := fmt.Sprintf("%s_%d", key, currentTime.Unix()/int64(window.Seconds()))

    client.Incr(key)
    count, err := client.Get(keyWindow).Int64()
    if err != nil && err != redis.Nil {
        panic(err)
    }
    if count >= limit {
        return false
    }

    pipe := client.TxPipeline()
    pipe.Incr(keyWindow)
    pipe.Expire(keyWindow, window)
    _, err = pipe.Exec()
    if err != nil {
        panic(err)
    }

    return true
}

スライディングログアルゴリズム

スライディングログアルゴリズムは固定ウィンドウアルゴリズムを改良したものです。固定ウィンドウを使う代わりに、このアルゴリズムはレート制限を計算するためにスライディングウィンドウを使用します。スライディングウィンドウは特定の時間内に行なわれたリクエストを追跡し、レート制限を超過した場合はリクエストを破棄します。このアルゴリズムはより長い時間を考慮に入れるため、レート制限のより細かい制御を提供します。

func slidingLogsAlgorithm(client *redis.Client, key string, limit int64, window time.Duration) bool {
    currentTime := time.Now().UnixNano()
    keyLogs := fmt.Sprintf("%s_logs", key)
    keyTimestamps := fmt.Sprintf("%s_timestamps", key)

    pipe := client.TxPipeline()
    pipe.ZRemRangeByScore(keyLogs, "0", fmt.Sprintf("%d", currentTime-int64(window)))
    pipe.ZAdd(keyLogs, &redis.Z{
        Score:  float64(currentTime),
        Member: currentTime,
    })
    pipe.ZCard(keyLogs)
    _, err := pipe.Exec()
    if err != nil {
        panic(err)
    }

    count, err := client.Get(keyTimestamps).Int64()
    if err != nil && err != redis.Nil {
        panic(err)
    }
    if count >= limit {
        return false
    }

    pipe = client.TxPipeline()
    pipe.Incr(keyTimestamps)
    pipe.Expire(keyTimestamps, window)
    _, err = pipe.Exec()
    if err != nil {
        panic(err)
    }

    return true
}

リーキーバケットアルゴリズム

リーキーバケットアルゴリズムは人気のあるレート制限アルゴリズムで、ネットワークを通過できるトラフィックの突発的な増加を許可しつつ、全体的なトラフィック率を制限します。このアルゴリズムではバケットが使われており、各リクエストにトークンが割り当てられます。バケットには固定の容量があり、リクエストが到着するとトークンがバケットに追加されます。バケットがいっぱいの場合、追加されるトークンは破棄されます。ネットワークを通過する各リクエストはバケットから引き去られ、バケットが空の場合、リクエストは破棄されます。

func leakyBucketAlgorithm(client *redis.Client, key string, burst int64, rate time.Duration) bool {
    keyBucket := fmt.Sprintf("%s_bucket", key)

    currentTime := time.Now().UnixNano()
    pipe := client.TxPipeline()
    pipe.ZRemRangeByScore(keyBucket, "0", fmt.Sprintf("%d", currentTime-int64(rate)))
    pipe.ZAdd(keyBucket, &redis.Z{
        Score:  float64(currentTime),
        Member: currentTime,
    })
    pipe.ZCard(keyBucket)
    _, err := pipe.Exec()
    if err != nil {
        panic(err)
    }

    count, err := client.Get(keyBucket).Int64()
    if err != nil && err != redis.Nil {
        panic(err)
    }
    if count > burst {
        return false
    }

    return true
}

スライディングウィンドウアルゴリズム

スライディングウィンドウアルゴリズムは、スライディングウィンドウを使いトラフィック率を追跡する別の人気のあるレート制限アルゴリズムです。このアルゴリズムでは、特定の時間内に行なわれたリクエストの数を追跡するスライディングウィンドウが使用されます。スライディングウィンドウは時とともに動き、レート制限を超過した場合にはリクエストが破棄されます。固定ウィンドウアルゴリズムよりも柔軟性があり、スライディングタイムウィンドウと動的なレート制限を許容します。

func slidingWindowAlgorithm(client *redis.Client, key string, limit int64, window time.Duration) bool {
    keyWindow := fmt.Sprintf("%s_window", key)
    currentTime := time.Now().UnixNano()

    // Remove old entries from the window
    pipe := client.TxPipeline()
    pipe.ZRemRangeByScore(keyWindow, "0", fmt.Sprintf("%d", currentTime-int64(window)))
    // Add the current request to the window
    pipe.ZAdd(keyWindow, &redis.Z{
        Score:  float64(currentTime),
        Member: currentTime,
    })
    // Count the number of requests in the current window
    pipe.ZCard(keyWindow)
    // Execute the pipeline
    _, err := pipe.Exec()
    if err != nil {
        panic(err)
    }

    // Check if the number of requests is within the limit
    count, err := client.Get(keyWindow).Int64()
    if err != nil && err != redis.Nil {
        panic(err)
    }
    if count > limit {
        return false
    }

    return true
}

トークンバケットアルゴリズム

トークンバケットアルゴリズムはリーキーバケットアルゴリズムと似ていますが、トークンを用いてトラフィック率を制御します。このアルゴリズムでは、バケットがトークンを格納するために使われ、各リクエストはネットワークを通過するために特定のトークン数を必要とします。トークンは定率でバケットに加えられ、バケットが空の場合はリクエストが破棄されます。このアルゴリズムはリーキーバケットアルゴリズムよりも柔軟性があり、動的なレート制限を許容します。

func tokenBucketAlgorithm(client *redis.Client, key string, capacity int64, rate time.Duration) bool {
    keyBucket := fmt.Sprintf("%s_bucket", key)
    currentTime := time.Now().UnixNano()

    pipe := client.TxPipeline()
    pipe.ZRemRangeByScore(keyBucket, "0", fmt.Sprintf("%d", currentTime-int64(rate)))
    pipe.ZAdd(keyBucket, &redis.Z{
        Score:  float64(currentTime),
        Member: currentTime,
    })
    pipe.ZCard(keyBucket)
    _, err := pipe.Exec()
    if err != nil {
        panic(err)
    }

    count, err := client.Get(keyBucket).Int64()
    if err != nil && err != redis.Nil {
        panic(err)
    }
    if count > capacity {
        return false
    }

    return true
}

コード内ですべてのアルゴリズム

一か所ですべてのアルゴリズムをテストできます。Redis DBに接続した状態でコードを実行してください。

Redis DB用のDocker-composeファイル。
ファイル名:docker-compose.yml
Redisを実行するコマンド:docker-compose up

version: '3'
services:
  redis:
    image: redis
    ports:
      - 6379:6379
    volumes:
      - redis-data:/data
volumes:
  redis-data:

package main

import (
    "context"
    "fmt"
    "time"

    "github.com/go-redis/redis/v8"
)

var (
    ctx = context.Background()
)

// Fixed Window Algorithm using Redis
func fixedWindowAlgorithm(client *redis.Client, key string, limit int64, window time.Duration) bool {
    currentTime := time.Now().UnixNano()
    keyWindow := fmt.Sprintf("%s:%d", key, currentTime/window.Nanoseconds())

    count, err := client.Incr(ctx, keyWindow).Result()
    if err != nil {
        panic(err)
    }
    if count > limit {
        return false
    }

    // Expire the key after the fixed window duration
    if err := client.Expire(ctx, keyWindow, window).Err(); err != nil {
        panic(err)
    }

    return true
}

// Sliding Logs Algorithm using Redis
func slidingLogsAlgorithm(client *redis.Client, key string, limit int64, window time.Duration) bool {
    keyWindow := fmt.Sprintf("%s_window", key)
    currentTime := time.Now().UnixNano()

    // Trim the old entries from the window
    if _, err := client.ZRemRangeByScore(ctx, keyWindow, "0", fmt.Sprintf("%d", currentTime-int64(window))).Result(); err != nil {
        panic(err)
    }

    // Add the current request to the window
    if _, err := client.ZAdd(ctx, keyWindow, &redis.Z{
        Score:  float64(currentTime),
        Member: currentTime,
    }).Result(); err != nil {
        panic(err)
    }

    // Get the count of requests within the window
    count, err := client.ZCard(ctx, keyWindow).Result()
    if err != nil && err != redis.Nil {
        panic(err)
    }
    if count > limit {
        return false
    }

    return true
}

// Leaky Bucket Algorithm using Redis
func leakyBucketAlgorithm(client *redis.Client, key string, limit int64, window time.Duration) bool {
    keyWindow := fmt.Sprintf("%s:%d", key, window.Nanoseconds())

    // Get the current value of the bucket
    value, err := client.Get(ctx, keyWindow).Int64()
    if err != nil && err != redis.Nil {
        panic(err)
    }

    // Increment the value by 1 and set the new value
    value++
    if err := client.Set(ctx, keyWindow, value, window).Err(); err != nil {
        panic(err)
    }

    // Check if the value is within the limit
    if value > limit {
        return false
    }

    return true
}

// Sliding Window Algorithm using Redis
func slidingWindowAlgorithm(client *redis.Client, key string, limit int64, window time.Duration) bool {
    keyWindow := fmt.Sprintf("%s_window", key)
    currentTime := time.Now().UnixNano()

    // Remove old entries from the window
    if _, err := client.ZRemRangeByScore(ctx, keyWindow, "0", fmt.Sprintf("%d", currentTime-int64(window))).Result(); err != nil {
        panic(err)
    }

    // Add the current request to the window
    if _, err := client.ZAdd(ctx, keyWindow, &redis.Z{
        Score:  float64(currentTime),
        Member: currentTime,
    }).Result(); err != nil {
        panic(err)
    }

    // Count the number of requests in the current window
    count, err := client.ZCard(ctx, keyWindow).Result()
    if err != nil && err != redis.Nil {
        panic(err)
    }
    if count > limit {
        return false
    }

    return true
}

// Token Bucket Algorithm using Redis
func tokenBucketAlgorithm(client *redis.Client, key string, limit int64, refillTime time.Duration, refillAmount int64) bool {
    currentTime := time.Now().UnixNano()
    keyWindow := fmt.Sprintf("%s:%d", key, refillTime.Nanoseconds())
    // Calculate the available tokens in the bucket
    availableTokens, err := client.Get(ctx, keyWindow).Int64()
    if err != nil && err != redis.Nil {
        panic(err)
    }

    // Calculate the number of tokens that should be added to the bucket
    additionalTokens := int64(float64(currentTime) / float64(refillTime.Nanoseconds()) * float64(refillAmount))

    // Add the additional tokens to the bucket
    if _, err := client.SetNX(ctx, keyWindow, availableTokens+additionalTokens, refillTime).Result(); err != nil {
        panic(err)
    }

    // Check if there are enough tokens for the current request
    if availableTokens+1 > limit {
        return false
    }

    // Consume a token from the bucket
    if _, err := client.Decr(ctx, keyWindow).Result(); err != nil {
        panic(err)
    }

    return true
}

func main() {
    // Connect to Redis
    client := redis.NewClient(&redis.Options{
        Addr: "localhost:6379",
        DB:   0,
    })
    // Test Fixed Window Algorithm
    for i := 0; i < 10; i++ {
        fmt.Printf("Fixed Window Algorithm request %d: %t\n", i+1, fixedWindowAlgorithm(client, "fixed_window", 5, time.Second))
        time.Sleep(time.Millisecond * 100)
    }

    // Test Sliding Logs Algorithm
    for i := 0; i < 10; i++ {
        fmt.Printf("Sliding Logs Algorithm request %d: %t\n", i+1, slidingLogsAlgorithm(client, "sliding_logs", 5, time.Second))
        time.Sleep(time.Millisecond * 100)
    }

    // Test Leaky Bucket Algorithm
    for i := 0; i < 10; i++ {
        fmt.Printf("Leaky Bucket Algorithm request %d: %t\n", i+1, leakyBucketAlgorithm(client, "leaky_bucket", 5, time.Second))
        time.Sleep(time.Millisecond * 100)
    }

    // Test Sliding Window Algorithm
    for i := 0; i < 10; i++ {
        fmt.Printf("Sliding Window Algorithm request %d: %t\n", i+1, slidingWindowAlgorithm(client, "sliding_window", 5, time.Second))
        time.Sleep(time.Millisecond * 100)
    }

    // Test Token Bucket Algorithm
    for i := 0; i < 10<br><br>こちらの記事はdev.toの良い記事を日本人向けに翻訳しています。<br>[https://dev.to/ankitmalikg/algorithms-for-rate-limiting-with-redis-and-golang-93d](https://dev.to/ankitmalikg/algorithms-for-rate-limiting-with-redis-and-golang-93d)