MrShi
2025-08-19 ff087240b3dee29ce4e14ad0836e76b9fdf312cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package com.doumee.config.mybatis;
 
import com.doumee.core.model.LoginUserInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.shiro.SecurityUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.ReflectionUtils;
 
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
 
/**
 * MyBatis 拦截器
 * - INSERT语句默认填充创建人和创建时间字段
 * - UPDATE语句默认填充更新人和更新时间字段
 * @author  dm
 * @since 2025/03/31 16:44
 */
@Slf4j
@Component
@Intercepts({
    @Signature(type= Executor.class, method = "update", args={MappedStatement.class, Object.class})
})
public class MyBatisInterceptor implements Interceptor {
 
    private static final String CREATE_TIME = "createTime";
 
    private static final String CREATE_USER = "createUser";
 
    private static final String UPDATE_TIME = "updateTime";
 
    private static final String UPDATE_USER = "updateUser";
 
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
        Object target = invocation.getArgs()[1];
 
        if(target instanceof MapperMethod.ParamMap) {
            try {
                target = ((MapperMethod.ParamMap) target).get("param1");
            } catch (Exception e) {
            }
        }
        if (target == null)
            return invocation.proceed();
        // 创建语句
        if (SqlCommandType.INSERT == sqlCommandType) {
            this.handleOperaStatement(target, CREATE_TIME, CREATE_USER);
        }
        // 更新语句
        else if (SqlCommandType.UPDATE == sqlCommandType) {
            this.handleOperaStatement(target, UPDATE_TIME, UPDATE_USER);
        }
        return invocation.proceed();
    }
 
    @Override
    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }
 
    /**
     * 处理新增和编辑语句
     */
    private void handleOperaStatement(Object target, String... fieldNames) throws Exception{
        // 操作时间
        Field operaTimeField = ReflectionUtils.findField(target.getClass(), fieldNames[0]);
        if (operaTimeField != null) {
            Object operaTime = this.getFieldValue(operaTimeField, target);
            if (operaTime == null) {
                this.setFieldValue(operaTimeField, target, new Date());
            }
        }
        // 操作人
        Field operaUserField = ReflectionUtils.findField(target.getClass(), fieldNames[1]);
        if (operaUserField != null) {
            Object operaUser = this.getFieldValue(operaUserField, target);
            LoginUserInfo user = this.getLoginUser();
            if (operaUser == null && user!=null) {
                this.setFieldValue(operaUserField, target, user.getId());
            }
        }
    }
 
    /**
     * 给属性赋值
     */
    private void setFieldValue(Field field, Object target, Object value)   {
        try {
        field.setAccessible(true);
        field.set(target, value);
        field.setAccessible(false);
        }catch (Exception e){
            e.printStackTrace();
        }
    }
 
    /**
     * 获取属性值
     */
    private Object getFieldValue(Field field, Object target)  {
        try {
            field.setAccessible(true);
            Object value = field.get(target);
            field.setAccessible(false);
            return value;
        }catch (Exception e){
            e.printStackTrace();
        }
      return  null;
    }
 
    /**
     * 获取登录用户信息
     */
    private LoginUserInfo getLoginUser () {
        try {
              return (LoginUserInfo) SecurityUtils.getSubject().getPrincipal();
        }catch (Exception e){
 
        }
        return null;
    }
}