# 分布式限流解决方案

# 依赖

  • 增加aop依赖
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
    <groupId>org.aspectj</groupId>
    <artifactId>aspectjrt</artifactId>
</dependency>
  • 修改pom文件,将.lua后缀的文件打包进项目
    <build>
        <resources>
            <resource>
                <directory>src/main/resources</directory>
                <filtering>true</filtering>
                    ...此处省略...
                    <include>**/*.lua</include>
                    ...此处省略...
                </includes>
            </resource>
        </resources>
    </build>

# 代码清单

  • 新增lua脚本
--获取KEY
local key1 = KEYS[1]

local val = redis.call('incr', key1)
local ttl = redis.call('ttl', key1)

--获取ARGV内的参数并打印
local expire = ARGV[1]
local times = ARGV[2]

redis.log(redis.LOG_DEBUG,tostring(times))
redis.log(redis.LOG_DEBUG,tostring(expire))

redis.log(redis.LOG_NOTICE, "incr "..key1.." "..val);
if val == 1 then redis.call('expire', key1, tonumber(expire))
else if ttl == -1 then redis.call('expire', key1, tonumber(expire)) end
end

if val > tonumber(times) then return 0
end

return 1

# 限流工具类

  • RateLimiter
/**
* The type IntelliJ IDEA.
* <p>
*
* @author liuxiaolu
* @date 2021/4/25
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {

    /**
    * 限流key keyPattern
    *
    * @return keyPattern
    */
    String keyPattern() default "rate:limiter:%s";

    /**
    * 业务id
    *
    * @return key
    */
    String bizId();

    /**
    * 单位时间限制通过请求数
    *
    * @return limit
    */
    String limit();

    /**
    * 限流的单位时间 单位:s
    *
    * @return expire
    */
    String expire();
}
  • RateLimiterParam
@Target(ElementType.PARAMETER)
@Inherited
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiterParam {
    String value();
}
  • RateLimiterHandler
@Slf4j
@Aspect
@Component
public class RateLimiterHandler {

    /**
     * 默认无限制
     */
    private static final Integer DEFAULT_LIMIT_TIME = -1;

    /**
     * 默认单位时间 s
     */
    private static final Integer DEFAULT_LIMIT_TIME_EXPIRE = 60;

    @Autowired
    private RedisTemplate redisTemplate;

    private DefaultRedisScript<Long> redisScript;

    @PostConstruct
    public void init() {
        redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Long.class);
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("RateLimiter.lua")));
        log.info("RateLimiterHandler[分布式限流处理器]脚本加载完成");
    }

    @Pointcut("@annotation(cn.com.xxx.xxx.xxx.annotation.RateLimiter)")
    public void pointCut() {
    }

    @Around(value = "pointCut()")
    public Object around(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
        if (log.isDebugEnabled()) {
            log.debug("RateLimiterHandler[分布式限流处理器]开始执行限流操作");
        }
        Method targetMethod = findTargetMethod(proceedingJoinPoint);
        RateLimiter rateLimiter = targetMethod.getDeclaredAnnotation(RateLimiter.class);

        String bizIdStr = rateLimiter.bizId();
        Long bizId = getArg(targetMethod, proceedingJoinPoint.getArgs(), bizIdStr);
        if (Objects.isNull(bizId)) {
            log.warn("未获取到业务ID, 请求放行。");
            return proceedingJoinPoint.proceed();
        }

        // 限流模块key
        String keyPattern = rateLimiter.keyPattern();

        String limitStr = rateLimiter.limit();
        String expireStr = rateLimiter.expire();
        Integer limit = getArg(targetMethod, proceedingJoinPoint.getArgs(), limitStr);
        Integer expire = getArg(targetMethod, proceedingJoinPoint.getArgs(), expireStr);

        String key = String.format(keyPattern, bizId);

        // 限流阈值
        Integer limitTimes = Optional.ofNullable(limit).orElse(DEFAULT_LIMIT_TIME);
        if (DEFAULT_LIMIT_TIME.equals(limitTimes)) {
            if (log.isDebugEnabled()) {
                log.debug("业务[{}]调用频率无限制, 请求放行", bizId);
            }

            return proceedingJoinPoint.proceed();
        }

        // 限流超时时间
        Integer expireTime = Optional.ofNullable(expire).orElse(DEFAULT_LIMIT_TIME_EXPIRE);

        if (log.isDebugEnabled()) {
            log.debug("RateLimiterHandler[分布式限流处理器]参数值为-limitTimes={},limitTimeout={}", limitTimes, expireTime);
        }

        List<String> keyList = new ArrayList<>();

        // 设置key值为注解中的值
        keyList.add(key);
        Long result;
        do {
            result = (Long) redisTemplate.execute(redisScript, keyList, expireTime, limitTimes);
            if (result == null || result == 0) {
                if (log.isDebugEnabled()) {
                    log.debug("由于业务[{}]的调用频率, 在单位时间[{}]s内, 超出允许的最大请求次数[{}], 触发[限流机制], 该任务将等待下一个时间窗口运行。", bizId,
                              expireTime, limitTimes);
                }

                // 等待一个时间窗口
                TimeUnit.SECONDS.sleep(expireTime);
            }
        } while (result == null || result == 0L);

        if (log.isDebugEnabled()) {
            log.debug("业务[{}]的请求, [正常]响应", bizId);
        }
        return proceedingJoinPoint.proceed();
    }

    private Method findTargetMethod(ProceedingJoinPoint pjp) throws NoSuchMethodException {
        MethodSignature signature = (MethodSignature) pjp.getSignature();
        Class[] parameterTypes = signature.getParameterTypes();
        return pjp.getTarget().getClass().getMethod(pjp.getSignature().getName(), parameterTypes);
    }

    private <T> T getArg(Method targetMethod, Object[] args, String expression) {
        if (StringUtils.isEmpty(expression)) {
            return null;
        }
        Parameter[] parameters = targetMethod.getParameters();
        StandardEvaluationContext context = new StandardEvaluationContext();
        for (int index = 0; index < parameters.length; index++) {
            if (parameters[index].isAnnotationPresent(RateLimiterParam.class)) {
                context.setVariable(parameters[index].getAnnotation(RateLimiterParam.class).value(), args[index]);
            }
        }
        ExpressionParser parser = new SpelExpressionParser();
        Expression exp = parser.parseExpression(expression);
        return (T) exp.getValue(context);
    }
}
Last Updated: 12/2/2021, 9:29:16 PM