开发者

Go+Redis实现常见限流算法的示例代码

开发者 https://www.devze.com 2023-04-03 08:57 出处:网络 作者: jxwu
目录固定窗口滑动窗口hash实现list实现漏桶算法令牌桶滑动日志总结限流是项目中经常需要使用到的一种工具,一般用于限制用户的请求的频率,也可以避免瞬间流量过大导致系统崩溃,或者稳定消息处理速率。并且有时候我
目录
  • 固定窗口
  • 滑动窗口
    • hash实现
    • list实现
  • 漏桶算法
    • 令牌桶
      • 滑动日志
        • 总结

          限流是项目中经常需要使用到的一种工具,一般用于限制用户的请求的频率,也可以避免瞬间流量过大导致系统崩溃,或者稳定消息处理速率。并且有时候我们还需要使用到分布式限流,常见的实现方式是使用Redis作为中心存储。

          这个文章主要是使用Go+Redis实现常见的限流算法,如果需要了解每种限流算法的原理可以阅读文章 Go实现常见的限流算法

          下面的代码使用到了go-redis客户端

          固定窗口

          使用Redis实现固定窗口比较简单,主要是由于固定窗口同时只会存在一个窗口,所以我们可以在第一次进入窗口时使用pexpire命令设置过期时间为窗口时间大小,这样窗口会随过期时间而失效,同时我们使用incr命令增加窗口计数。

          因为我们需要在counter==1的时候设置窗口的过期时间,为了保证原子性,我们使用简单的Lua脚本实现。

          const fixedwindowLimiterTryAcquireRedisScript = `
          -- ARGV[1]: 窗口时间大小
          -- ARGV[2]: 窗口请求上限
          
          local window = tonumber(ARGV[1])
          local limit = tonumber(ARGV[2])
          
          -- 获取原始值
          local counter = tonumber(redis.call("get", KEYS[1]))
          if counter == nil then 
             counter = 0
          end
          -- 若到达窗口请求上限,请求失败
          if counter >= limit then
             return 0
          end
          -- 窗口值+1
          redjavascriptis.call("incr", KEYS[1])
          if counter == 0 then
              redis.call("pexpire", KEYS[1], window)
          end
          return 1
          `
          package redis
          
          import (
             "context"
             "errors"
             "github.com/go-redis/redis/v8"
             "time"
          )
          
          // FixedWindowLimiter 固定窗口限流器
          type FixedWindowLimiter struct {
             limit  int           // 窗口请求上限
             window int           // 窗口时间大小
             client *redis.Client // Redis客户端
             script *redis.Script // TryAcquire脚本
          }
          
          func NewFixedWindowLimiter(client *redis.Client, limit int, window time.Duration) (*FixedWindowLimiter, error) {
             // redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
             if window%time.Millisecond != 0 {
                return nil, errors.New("the window uint must not be less than millisecond")
             }
          
             return开发者_数据库 &FixedWindowLimiter{
                limit:  limit,
                window: int(window / time.Millisecond),
                client: client,
                script: redis.NewScript(fixedWindowLimiterTryAcquireRedisScript),
             }, nil
          }
          
          func (l *FixedWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
             success, err := l.script.Run(ctx, l.client, []string{resource}, l.window, l.limit).Bool()
             if err != nil {
                return err
             }
             // 若到达窗口请求上限,请求失败
             if !success {
                return ErrAcquireFailed
             }
             return nil
          }

          滑动窗口

          hash实现

          我们使用Redis的hash存储每个小窗口的计数,每次请求会把所有有效窗口的计数累加到count,使用hdel删除失效窗口,最后判断窗口的总计数是否大于上限。

          我们基本上把所有的逻辑都放到Lua脚本里面,其中大头是对hash的遍历,时间复杂度是O(N),N是小窗口数量,所以小窗口数量最好不要太多。

          const slidingWindowLimiterTryAcquireRedisScriptHashImpl = `
          -- ARGV[1]: 窗口时间大小
          -- ARGV[2]: 窗口请求上限
          -- ARGV[3]: 当前小窗口值
          -- ARGV[4]: 起始小窗口值
          
          local window = tonumber(ARGV[1])
          local limit = tonumber(ARGV[2])
          local currentSmallWindow = tonumber(ARGV[3])
          local startSmallWindow = tonumber(ARGV[4])
          
          -- 计算当前窗口的请求总数
          local counters = redis.call("hgetall", KEYS[1])
          local count = 0
          for i = 1, #(counters) / 2 do 
             local smallWindow = tonumber(counters[i * 2 - 1])
             local counter = tonumber(counters[i * 2])
             if smallWindow < startSmallWindow then
                redis.call("hdel", KEYS[1], smallWindow)
             else 
                count = count + counter
             end
          end
          
          -- 若到达窗口请求上限,请求失败
          if count >= limit then
             return 0
          end
          
          -- 若没到窗口请求上限,当前小窗口计数器+1,请求成功
          redis.call("hincrby", KEYS[1], currentSmallWindow, 1)
          redis.call("pexpire", KEYS[1], window)
          return 1
          `
          package redis
          
          import (
             "context"
             "errors"
             "github.com/go-redis/redis/v8"
             "time"
          )
          
          // SlidingWindowLimiter 滑动窗口限流器
          type SlidingWindowLimiter struct {
             limit        int           // 窗口请求上限
             window       int64         // 窗口时间大小
             smallWindow  int64         // 小窗口时间大小
             smallWindows int64         // 小窗口数量
             client       *redis.Client // Redis客户端
             script       *redis.Script // TryAcquire脚本
          }
          
          func NewSlidingWindowLimiter(client *redis.Client, limit int, window, smallWindow time.Duration) (
             *SlidingWindowLimiter, error) {
             // redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
             if window%time.Millisecond != 0 || smallWindow%time.Millisecond != 0 {
                return nil, errors.New("the window uint must not be less than millisecond")
             }
          
             // 窗口时间必须能够被小窗口时间整除
             if window%smapythonllWindow != 0 {
                return nil, errors.New("window cannot be split by integers")
             }
          
             return &SlidingWindowLimiter{
                limit:        limit,
                window:       int64(window / time.Millisecond),
                smallWindow:  int64(smallWindow / time.Millisecond),
                smallWindows: int64(window / smallWindow),
                client:       client,
                script:       redis.NewScript(slidingWindowLimiterTryAcquireRedisScriptHashImpl),
             }, nil
          }
          
          func (l *SlidingWindowLimiter) TryAcquire(ctx context.Context, resource string) error {
             // 获取当前小窗口值
             currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
             // 获取起始小窗口值
             startSmallWindow := currentSmallWindow - l.smallWindow*(l.smallWindows-1)
          
             success, err := l.script.Run(
                ctx, l.client, []string{resource}, l.window, l.limit, currentSmallWindow, startSmallWindow).Bool()
             if err != nil {
                return err
             }
             // 若到达窗口请求上限,请求失败
             if !success {
                return ErrAcquireFailed
             }
             return nil
          }

          list实现

          如果小窗口数量特别多,可以使用list优化时间复杂度,list的结构是:

          [counter, smallWindow1, count1, smallWindow2, count2, smallWindow3, count3...]

          也就是我们使用list的第一个元素存储计数器,每个窗口用两个元素表示,第一个元素表示小窗口值,第二个元素表示这个小窗口的计数。不直接把小窗口值和计数放到一个元素里是因为Redis Lua脚本里没有分割字符串的函数。

          具体操作流程:

          1.获取list长度

          2.如果长度是0,设置counter,长度+1

          3.如果长度大于1,获取第二第三个元素

          如果该值小于起始小窗口值,counter-第三个元素的值,删除第二第三个元素,长度-2

          4.如果counter大于等于limit,请求失败

          5.如果长度大于1,获取倒数第二第一个元素

          • 如果倒数第二个元素小窗口值大于等于当前小窗口值,表示当前请求因为网络延迟的问题,到达服务器的时候,窗口已经过时了,把倒数第二个元素当成当前小窗口(因为它更新),倒数第一个元素值+1
          • 否则,添加新的窗口值,添加新的计数(1),更新过期时间

          6.否则,添加新的窗口值,添加新的计数(1),更新过期时间

          7.counter + 1

          8.返回成功

          const slidingWindowLimiterTryAcquireRedisScriptListImpl = `
          -- ARGV[1]: 窗口时间大小
          -- ARGV[2]: 窗口请求上限
          -- ARGV[3]: 当前小窗口值
          -- ARGV[4]: 起始小窗口值
          
          local window = tonumber(ARGV[1])
          local limit = tonumber(ARGV[2])
          local currentSmallWindow = tonumber(ARGV[3])
          local startSmallWindow = tonumber(ARGV[4])
          
          -- 获取list长度
          local len = redis.call("llen", KEYS[1])
          -- 如果长度是0,设置counter,长度+1
          local counter = 0
          if len == 0 then 
             redis.call("rpush", KEYS[1], 0)
             redis.call("pexpire", KEYS[1], window)
             len = len + 1
          else
             -- 如果长度大于1,获取第二第个元素
             local smallWindow1 = tonumber(redis.call("lindex", KEYS[1], 1))
             counter = tonumber(redis.call("lindex", KEYS[1], 0))
             -- 如果该值小于起始小窗口值
             if smallWindow1 < startSmallWindow then 
                local count1 = redis.call("lindex", KEYS[1], 2)
                -- counter-第三个元素的值
                counter = counter - count1
                -- 长度-2
                len = len - 2
                -- 删除第二第三个元素
                redis.call("lrem", KEYS[1], 1, smallWindow1)
                redis.call("lrem", KEYS[1], 1, count1)
             end
          end
          
          -- 若到达窗口请求上限,请求失败
          if counter >= limit then 
             return 0
          end 
          
          -- 如果长度大于1,获取倒数第二第一个元素
          if len > 1 then
             local smallWindown = tonumber(redis.call("lindex", KEYS[1], -2))
             -- 如果倒数第二个元素小窗口值大于等于当前小窗口值
             if smallWindown >= currentSmallWindow then
                -- 把倒数第二个元素当成当前小窗口(因为它更新),倒数第一个元素值+1
                local countn = redis.call("lindex", KEYS[1], -1)
                redis.call("lset", KEYS[1], -1, countn + 1)
             else 
                -- 否则,添加新的窗口值,添加新的计数(1),更新过期时间
                redis.call("rpush", KEYS[1], currentSmallWindow, 1)
                redis.call("pexpire", KEYS[1], window)
             end
          else 
             -- 否则,添加新的窗口值,添加新的计数(1),更新过期时间
             redis.call("rpush", KEYS[1], currentSmallWindow, 1)
             redis.call("pexpire", KEYS[1], window)
          end 
          
          -- counter + 1并更新
          redis.call("lset", KEYS[1], 0, counter + 1)
          return 1
          `

          算法都是操作list头部或者尾部,所以时间复杂度接近O(1)

          漏桶算法

          漏桶需要保存当前水位和上次放水时间,因此我们使用hash来保存这两个值。

          const leakyBucketLimiterTryAcquireRedisScript = `
          -- ARGV[1]: 最高水位
          -- ARGV[2]: 水流速度/秒
          -- ARGV[3]: 当前时间(秒)
          
          local peakLevel = tonumber(ARGV[1])
          local currentVelocity = tonumber(ARGV[2])
          local now = tonumber(ARGV[3])
          
          local lastTime = tonumber(redis.call("hget", KEYS[1], "lastTime"))
          local currentLevel = tonumber(redis.call("hget", KEYS[1], "currentLevel"))
          -- 初始化
          if lastTime == nil then 
             lastTime = now
             currentLevel = 0
             redis.call("hmset", KEYS[1], "currentLevel", currentLevel, "lastTime", lastTime)
          end 
          
          -- 尝试放水
          -- 距离上次放水的时间
          local interval = now - lastTime
          if interval > 0 then
             -- 当前水位-距离上次放水的时间(秒)*水流速度
             local newLevel = currentLevel - interval * currentVwww.devze.comelocity
             if newLevel < 0 then 
                newLevel = 0
             end 
             currentLevel = newLevel
             redis.call("hmset", KEYS[1], "currentLevel", newLevel, "lastTime", now)
          end
          
          -- 若到达最高水位,请求失败
          if currentLevel >= peakLevel then
             return 0
          end
          -- 若没有到达最高水位,当前水位+1,请求成功
          redis.call("hincrby", KEYS[1], "currentLevel", 1)
          redis.call("expire", KEYS[1], peakLevel / currentVelocity)
          return 1
          `
          package redis
          
          import (
             "context"
             "github.com/go-redis/redis/v8"
             "time"
          )
          
          // LeakyBucketLimiter 漏桶限流器
          type LeakyBucketLimiter struct {
             peakLevel       int           // 最高水位
             currentVelocity int           // 水流速度/秒
             client          *redis.Client // Redis客户端
             script          *redis.Script // TryAcquire脚本
          }
          
          func NewLeakyBucketLimiter(client *redis.Client, peakLevel, currentVelocity int) *LeakyBucketLimiter {
             return &LeakyBucketLimiter{
                peakLevel:       peakLevel,
                currentVelocity: currentVelocity,
                client:          client,
                script:          redis.NewScript(leakyBucketLimiterTryAcquireRedisScript),
             }
          }
          
          func (l *LeakyBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
             // 当前时间
             now := time.Now().Unix()
             success, err := l.script.Run(ctx, l.client, []string{resource}, l.peakLevel, l.currentVelocity, now).Bool()
             if err != nil {
                return err
             }
             // 若到达窗口请求上限,请求失败
             if !success {
                return ErrAcquireFailed
             }
             return nil
          }

          令牌桶

          令牌桶可以看作是漏桶的相反算法,它们一个是把水倒进桶里,一个是从桶里获取令牌。

          const tokenBucketLimiterTryAcquireRedisScript = `
          -- ARGV[1]: 容量
          -- ARGV[2]: 发放令牌速率/秒
          -- ARGV[3]: 当前时间(秒)
          
          local capacity = tonumber(ARGV[1])
          local rate = tonumber(ARGV[2])
          local now = tonumber(ARGV[3])
          
          local lastTime = tonumber(redis.call("hget", KEYS[1], "lastTime"))
          local currentTokens = tonumber(redis.call("hget", KEYS[1], "currentTokens"))
          -- 初始化
          if lastTime == nil then 
             lastTime = now
             currentTokens = capacity
             redis.call("hmset", KEYS[1], "currentTokens", currentTokens, "lastTime", lastTime)
          end 
          
          -- 尝试发放令牌
          -- 距离上次发放令牌的时间
          local interval = now - lastTime
          if interval > 0 then
             -- 当前令牌数量+距离上次发放令牌的时间(秒)*发放令牌速率
             local newTokens = currentTokens + interval * rate
             if newTokens > capacity then 
                newTokens = capacity
             end 
             currentTokens = newTokens
             redis.call("hmset", KEYS[1], "currentTokens", newTokens, "lastTime", now)
          end
          
          -- 如果没有令牌,请求失败
          if currentTokens == 0 then
             return 0
          end
          -- 果有令牌,当前令牌-1,请求成功www.devze.com
          redis.call("hincrby", KEYS[1], "currentTokens", -1)
          redis.call("expire", KEYS[1], capacity / rate)
          return 1
          `
          package redis
          
          import (
             "context"
             "github.com/go-redis/redis/v8"
             "time"
          )
          
          // TokenBucketLimiter 令牌桶限流器
          type TokenBucketLimiter struct {
             capacity int           // 容量
             rate     int           // 发放令牌速率/秒
             client   *redis.Client // Redis客户端
             script   *redis.Script // TryAcquire脚本
          }
          
          func NewTokenBucketLimiter(client *redis.Client, capacity, rate int) *TokenBucketLimiter {
             return &TokenBucketLimiter{
                capacity: capacity,
                rate:     rate,
                client:   client,
                script:   redis.NewScript(tokenBucketLimiterTryAcquireRedisScript),
             }
          }
          
          func (l *TokenBucketLimiter) TryAcquire(ctx context.Context, resource string) error {
             // 当前时间
             now := time.Now().Unix()
             success, err := l.script.Run(ctx, l.client, []string{resource}, l.capacity, l.rate, now).Bool()
             if err != nil {
                return err
             }
             // 若到达窗口请求上限,请求失败
             if !success {
                return ErrAcquireFailed
             }
             return nil
          }

          滑动日志

          算法流程与滑动窗口相同,只是它可以指定多个策略,同时在请求失败的时候,需要通知调用方是被哪个策略所拦截。

          const slidingLogLimiterTryAcquireRedisScriptHashImpl = `
          -- ARGV[1]: 当前小窗口值
          -- ARGV[2]: 第一个策略的窗口时间大小
          -- ARGV[i * 2 + 1]: 每个策略的起始小窗口值
          -- ARGV[i * 2 + 2]: 每个策略的窗口请求上限
          
          local currentSmallWindow = tonumber(ARGV[1])
          -- 第一个策略的窗口时间大小
          local window = tonumber(ARGV[2])
          -- 第一个策略的起始小窗口值
          local startSmallWindow = tonumber(ARGV[3])
          local strategiesLen = #(ARGV) / 2 - 1
          
          -- 计算每个策略当前窗口的请求总数
          local counters = redis.call("hgetall", KEYS[1])
          local counts = {}
          -- 初始化counts
          for j = 1, strategiesLen do
             counts[j] = 0
          end
          
          for i = 1, #(counters) / 2 do 
             local smallWindow = tonumber(counters[i * 2 - 1])
             local counter = tonumber(counters[i * 2])
             if smallWindow < startSmallWindow then
                redis.call("hdel", KEYS[1], smallWindow)
             else 
                for j = 1, strategiesLen do
                   if smallWindow >= tonumber(ARGV[j * 2 + 1]) then
                      counts[j] = counts[j] + counter
                   end
                end
             end
          end
          
          -- 若到达对应策略窗口请求上限,请求失败,返回违背的策略下标
          for i = 1, strategiesLen do
             if counts[i] >= tonumber(ARGV[i * 2 + 2]) then
                return i - 1
             end
          end
          
          -- 若没到窗口请求上限,当前小窗口计数器+1,请求成功
          redis.call("hincrby", KEYS[1], currentSmallWindow, 1)
          redis.call("pexpire", KEYS[1], window)
          return -1
          `
          package redis
          
          import (
             "context"
             "errors"
             "fmt"
             "github.com/go-redis/redis/v8"
             "sort"
             "time"
          )
          
          // ViolationStrategyError 违背策略错误
          type ViolationStrategyError struct {
             Limit  int           // 窗口请求上限
             Window time.Duration // 窗口时间大小
          }
          
          func (e *ViolationStrategyError) Error() string {
             return fmt.Sprintf("violation strategy that limit = %d and window = %d", e.Limit, e.Window)
          }
          
          // SlidingLogLimiterStrategy 滑动日志限流器的策略
          type SlidingLogLimiterStrategy struct {
             limit        int   // 窗口请求上限
             window       int64 // 窗口时间大小
             smallWindows int64 // 小窗口数量
          }
          
          func NewSlidingLogLimiterStrategy(limit int, window time.Duration) *SlidingLogLimiterStrategy {
             return &SlidingLogLimiterStrategy{
                limit:  limit,
                window: int64(window),
             }
          }
          
          // SlidingLogLimiter 滑动日志限流器
          type SlidingLogLimiter struct {
             strategies  []*SlidingLogLimiterStrategy // 滑动日志限流器策略列表
             smallWindow int64                        // 小窗口时间大小
             client      *redis.Client                // Redis客户端
             script      *redis.Script                // TryAcquire脚本
          }
          
          func NewSlidingLogLimiter(client *redis.Client, smallWindow time.Duration, strategies ...*SlidingLogLimiterStrategy) (
             *SlidingLogLimiter, error) {
             // 复制策略避免被修改
             strategies = append(make([]*SlidingLogLimiterStrategy, 0, len(strategies)), strategies...)
          
             // 不能不设置策略
             if len(strategies) == 0 {
                return nil, errors.New("must be set strategies")
             }
          
             // redis过期时间精度最大到毫秒,因此窗口必须能被毫秒整除
             if smallWindow%time.Millisecond != 0 {
                return nil, errors.New("the window uint must not be less than millisecond")
             }
             smallWindow = smallWindow / time.Millisecond
             for _, strategy := range strategies {
                if strategy.window%int64(time.Millisecond) != 0 {
                   return nil, errors.New("the window uint must not be less than millisecond")
                }
                strategy.window = strategy.window / int64(time.Millisecond)
             }
          
             // 排序策略,窗口时间大的排前面,相同窗口上限大的排前面
             sort.Slice(strategies, func(i, j int) bool {
                a, b := strategies[i], strategies[j]
                if a.window == b.window {
                   return a.limit > b.limit
                }
                return a.window > b.window
             })
          
             for i, strategy := range strategies {
                // 随着窗口时间变小,窗口上限也应该变小
                if i > 0 {
                   if strategy.limit >= strategies[i-1].limit {
                      return nil, errors.New("the smaller window should be the smaller limit")
                   }
                }
                // 窗口时间必须能够被小窗口时间整除
                if strategy.window%int64(smallWindow) != 0 {
                   return nil, errors.New("window cannot be split by integers")
                }
           编程客栈     strategy.smallWindows = strategy.window / int64(smallWindow)
             }
          
             return &SlidingLogLimiter{
                strategies:  strategies,
                smallWindow: int64(smallWindow),
                client:      client,
                script:      redis.NewScript(slidingLogLimiterTryAcquireRedisScriptHashImpl),
             }, nil
          }
          
          func (l *SlidingLogLimiter) TryAcquire(ctx context.Context, resource string) error {
             // 获取当前小窗口值
             currentSmallWindow := time.Now().UnixMilli() / l.smallWindow * l.smallWindow
             args := make([]interface{}, len(l.strategies)*2+2)
             args[0] = currentSmallWindow
             args[1] = l.strategies[0].window
             // 获取每个策略的起始小窗口值
             for i, strategy := range l.strategies {
                args[i*2+2] = currentSmallWindow - l.smallWindow*(strategy.smallWindows-1)
                args[i*2+3] = strategy.limit
             }
          
             index, err := l.script.Run(
                ctx, l.client, []string{resource}, args...).Int()
             if err != nil {
                return err
             }
             // 若到达窗口请求上限,请求失败
             if index != -1 {
                return &ViolationStrategyError{
                   Limit:  l.strategies[index].limit,
                   Window: time.Duration(l.strategies[index].window),
                }
             }
             return nil
          }

          总结

          由于Redis拥有丰富而且高性能的数据类型,因此使用Redis实现限流算法并不困难,但是每个算法都需要编写Lua脚本,所以如果不熟悉Lua可能会踩一些坑。

          需要完整代码和测试代码可以查看:github.com/jiaxwu/limiter/tree/main/redis

          以上就是Go+Redis实现常见限流算法的示例代码的详细内容,更多关于Go Redis限流算法的资料请关注我们其它相关文章!

          0

          精彩评论

          暂无评论...
          验证码 换一张
          取 消

          关注公众号