理解Redis分布式锁-1(简单实现)

Posted by AceKei on January 15, 2020

1. 为什么要使用分布式锁

在单机情况下,当我们有多个线程一起操作某一个共享变量的时候,如果不使用锁(单机锁),可能发生非原子性的操作,导致最后的结果有问题,此时,一般会使用一个来锁 互斥,以保证共享变量的原子性和正确性。

但是在微服务的集群下,是没有单机锁这个概念的。如果有多个请求同时发起来修改MySQL的某一条数据,为了避免程序数据错乱,此时,我们可以使用 分布式锁 来解决这个问题。

对于接口的幂等,一般也是使用分布式锁来解决的。

2. 注意问题

2.1 为什么需要设置过期时间

如果key没有设置过期时间,那么服务端运行期间,突然宕机了,那么这个锁将永远不会过期,导致后续的请求,都获取不到分布式锁。

2.2 为什么使用lua脚本

redis 设置分布式锁的时候,一般分为2个步骤

  1. setnx(key, value) = 如果key不存在,则设置key=value,如果key已经存在,则不操作。
  2. expire(key, time) = 对key设置过期时间。 如果 redis 在执行命令的时候,第一步执行完成了,第二步还没执行就宕机了,那么也就相当于这个key没有设置过期时间。
2.3 为什么要有一个随机唯一的value

一般来说,我们会在方法的 finally 里执行 redis.remove 操作,如果没有设置 一个随机唯一的value,那么会存在当前的进程释放了其他进程的锁,导致分布式失效。 对于 redis的释放锁操作,也是配好lua脚本。

3. 怎么实现

搭配自定义的注解和AOP来实现简单的分布式锁。

3.1 声明注解 AnIdempotent

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import java.lang.annotation.*;

@Target(ElementType.METHOD)
@Documented
@Retention(RetentionPolicy.RUNTIME)
public @interface AnIdempotent {

    /**
     * 过期时间,秒
     * 默认 60 秒
     */
    int expireTimes() default (60);

    /**
     * 可用于自定义的类里的某一个字段
     * 通过反射获取对应的值,进行hash后作为redis的key
     * @return
     */
    String[] fields() default {};

    /**
     * redis 的 key 的hash策略
     * 1 = 方法的所有入参拼接后hash
     * 2 = 无参,直接竞争
     * 3 = 使用 fields
     *
     * @return
     */
    HashType hashType() default HashType.REQUEST;

    enum HashType {
        //方法的所有入参拼接后hash
        REQUEST,
        //无参,直接竞争
        NO,
        //使用 fields
        FIELDS,
        ;
    }

}

3.2 构造hash策略

先将hash声明成bean,交给spring管理。然后通过策略模式+工厂模式选择对应的策略解析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public interface LockHash {

    /**
     * hash策略类型
     * @return
     */
    AnIdempotent.HashType hashType();

    /**
     * hash策略
     * @param os
     * @param anIdempotent
     * @return
     */
    String hash(Object[] os, AnIdempotent anIdempotent);

}
fields策略
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import cn.hutool.crypto.digest.MD5;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.util.Optional;

@Slf4j
@Component
public class FieldsHash implements LockHash {

    MD5 md5 = new MD5();

    @Override
    public AnIdempotent.HashType hashType() {
        return AnIdempotent.HashType.FIELDS;
    }

    public String hash(Object[] o, AnIdempotent anIdempotent) {
        String[] spels = anIdempotent.fields();
        String sessionId = "";
        String hash;
        if (spels.length == 0) {
            //没有定义fields时,相当于直接竞争
            hash = "no fields";
        } else {
            StringBuilder sb = new StringBuilder();
            for (String s : spels) {
                Object o0 = o[0];
                Field field = FieldUtils.getDeclaredField(o0.getClass(), s, true);
                if (field == null) {
                    log.warn("field 不存在{}", s);
                    sb.append(s).append("=null");
                } else {
                    try {
                        Object temp = field.get(o0);
                        if (temp == null) {
                            log.warn("field 获取值为空{}", s);
                            sb.append(s).append("=null");
                        } else {
                            sb.append(temp);
                        }
                    } catch (Exception e) {
                        log.error("field 获取值失败{}", s, e);
                        sb.append(s).append("=null");
                    }
                }
            }
            hash = md5.digestHex(sb.toString());
        }
        return hash;
    }
}
无参策略
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import org.springframework.stereotype.Component;

@Component
public class NoParamsHash implements LockHash {

    @Override
    public AnIdempotent.HashType hashType() {
        return AnIdempotent.HashType.NO;
    }

    public String hash(Object[] os, AnIdempotent anIdempotent) {
        return "no params";
    }

}
所有入参hash策略
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Slf4j
@Component
public class RequestHash implements LockHash {

    @Override
    public AnIdempotent.HashType hashType() {
        return AnIdempotent.HashType.REQUEST;
    }

    public String hash(Object[] os, AnIdempotent anIdempotent) {
        StringBuilder sb = new StringBuilder();
        for (Object o1 : os) {
            sb.append(getString(o1));
        }
        return fnvHash(sb.toString());
    }

    private String getString(Object o) {
        StringBuilder sb = new StringBuilder();
        if (o == null) {
            sb.append("null");
        } else if (o instanceof String || o instanceof Number) {
            sb.append(o);
        } else {
            try {
                sb.append(JSONUtil.toJsonStr(o));
            } catch (Exception e) {
                log.error("参数JSON失败", e);
                sb.append(o);
            }
        }
        return sb.toString();
    }

    /**
     * 改进的32位FNV算法1
     *
     * @param data 字符串
     * @return hash结果
     */
    private String fnvHash(String data) {
        final int p = 16777619;
        int hash = (int) 2166136261L;
        for (int i = 0; i < data.length(); i++) {
            hash = (hash ^ data.charAt(i)) * p;
        }
        hash += hash << 13;
        hash ^= hash >> 7;
        hash += hash << 3;
        hash ^= hash >> 17;
        hash += hash << 5;
        int r = Math.abs(hash);
        return String.valueOf(r);
    }

}
工厂
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.util.HashMap;
import java.util.Map;

@Component
public class LockHashFactory {

    @Autowired
    private ApplicationContext applicationContext;
    private final Map<AnIdempotent.HashType, LockHash> map = new HashMap<>();

    @PostConstruct
    public void initLock() {
        Map<String, LockHash> oo = applicationContext.getBeansOfType(LockHash.class);
        oo.forEach((k, v) -> {
            AnIdempotent.HashType hashType = v.hashType();
            map.putIfAbsent(hashType, v);
        });
    }

    public LockHash getLockHash(AnIdempotent.HashType hashType) {
        return map.get(hashType);
    }

}

3.3 编写Redis工具类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import lombok.extern.slf4j.Slf4j;
import redis.clients.jedis.Jedis;

import java.util.Arrays;
import java.util.Collections;
import java.util.Objects;

/**
 * @author sk
 * @date 2022/4/24
 * @description 描述
 */
@Slf4j
public class MyRedisUtil {

    /**
     * 1 = 设置成功
     * 2 = 设置失败
     * @param key
     * @param value
     * @param times
     * @return
     */
    public static boolean acquire(String key, String value, int times) {
        log.info("MyRedisUtil-acquire:key={},value={},times={}", key, value, times);
        String defaultTime = String.valueOf(times);
        String lua = "if redis.call('setnx',KEYS[1],ARGV[1]) == 1 then return redis.call('expire',KEYS[1],ARGV[2]) else return 0 end";
        Jedis jedis = new Jedis("host", 17379);
        try {
            Object result = jedis.evalsha(jedis.scriptLoad(lua), Collections.singletonList(key), Arrays.asList(value, defaultTime));
            return Objects.equals(result, 1L);
        } finally {
        }
    }

    /**
     * 1 = 设置成功
     * 2 = 设置失败
     * @param key
     * @param value
     * @return
     */
    public static boolean release(String key, String value) {
        log.info("MyRedisUtil-release:key={},value={}", key, value);
        String lua = "if redis.call('get',KEYS[1]) == ARGV[1] then return redis.call('del',KEYS[1]) else return 0 end";
        Jedis jedis = new Jedis("host", 17379);
        try {
            Object result = jedis.evalsha(jedis.scriptLoad(lua), Collections.singletonList(key), Collections.singletonList(value));
            return Objects.equals(result, 1L);
        } finally {
        }
    }

    
}

3.4 编写AOP,拦截方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.RandomStringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

@Aspect
@Component
@Slf4j
public class IdempotentAspect {

    @Autowired
    private LockHashFactory lockHashFactory;

    /**
     * 注解AOP切面
     */
    @Pointcut("execution(* com.example.demo.main.AspectTest.*(..)) && @annotation(AnIdempotent)")
    public void filter() {

    }
    @Around(value = "filter()")
    public Object around(ProceedingJoinPoint pj) {
        String key = null;
        //随机唯一的value
        String value = RandomStringUtils.randomAlphanumeric(32);
        try {
            AnIdempotent anIdempotent = ((MethodSignature) pj.getSignature()).getMethod().getAnnotation(AnIdempotent.class);
            if (anIdempotent != null) {
                Object[] o = pj.getArgs();
                String hash;
                String methodName = pj.getSignature().getName();
                if (o != null && o.length > 0) {
                    LockHash lockHash = lockHashFactory.getLockHash(anIdempotent.hashType());
                    hash = lockHash.hash(o, anIdempotent);
                } else {
                    hash = lockHashFactory.getLockHash(AnIdempotent.HashType.NO).hash(o, anIdempotent);
                }
                //防止key重复,需要拼接一下前缀
                key = "AnIdempotent:" + methodName + ":" + hash;
                if (!MyRedisUtil.acquire(key, value, anIdempotent.expireTimes())) {
                    return ResponseVo.error("操作已提交,请勿频繁操作");
                }
            }
            return pj.proceed();
        } catch (Exception e) {
            log.error("Exception:", e);
            return ResponseVo.error("系统异常");
        } catch (Throwable throwable) {
            log.error("环绕通知切面处理失败:", throwable);
            return ResponseVo.error("系统异常");
        } finally {
            if (key != null) {
                try {
                    //释放锁,附带value,防止释放别人的锁
                    MyRedisUtil.release(key, value);
                } catch (Exception e) {
                    log.error("删除redis失败:key={}", key, e);
                }
            }
        }
    }

}

至此,一个简单的分布式锁就实现了。