上图阐释了如何基于mq实现2pc的分布式事务
- 一阶段红线部分。
- 二阶段为蓝线部分。
图中展示了较为复杂的调用方式,S1调用S2、S3,S3又调用了S4。
感谢seata开源社区大佬的帮助。虽然2pc本身存在很多问题,但是自己手动实现一遍还是学习到很多。
本文仅做参考,不具备生产意义。
seata社区陈建斌大佬指正的问题列表如下:
问题
第一:tm需要有事务记录表,来恢复事务,而且要考虑到rm没任何异常,只是因为tm宕机导致tm的二阶段提交没有入库,但是由于这样,rm本身应该提交的事务变成了回滚。
第二:需要把connection换为xaconnection,使用xa协议来保证rm宕机后事务数据可恢复。
第三:要保证消息队列中间件的高可用。
第四:要防止资源悬挂问题,因为没有了分支事务注册,很可能因为网络或者其它因素,先发后置了,导致了tm没感知到这个rm的存在,这个rm就可能因为用了xa协议导致死锁。
show your code
根据上图我们可以很好的实现代码如下:此处基于rocketmq方式实现。
引入以下包
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.apache.dubbo</groupId>
<artifactId>dubbo</artifactId>
<version>2.7.2</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.rocketmq</groupId>
<artifactId>rocketmq-spring-boot-starter</artifactId>
<version>2.1.1</version>
</dependency>
全局事务注解此注解开启全局事务,真正事务还是交给Transactional注解去执行
package com.xxx.mq.trx.config;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @Description TODO
* @Author 姚仲杰
* @Date 2021/1/2 21:36
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface GlobalTransaction {
}
全局事务切面
package com.xxx.mq.trx.aspect;
import com.xxx.mq.trx.config.TransactionConst;
import com.xxx.mq.trx.core.TrxContextHolder;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.apache.rocketmq.spring.core.RocketMQTemplate;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.StringUtils;
/**
* @Description TODO
* @Author 姚仲杰
* @Date 2021/1/2 21:38
*/
public class GlobalTrxAspect {
@Autowired
RocketMQTemplate rocketMQTemplate;
@Pointcut("@annotation(com.xxx.mq.trx.config.GlobalTransaction)")
public void pointcut(){}
@Around("pointcut()")
public void around(ProceedingJoinPoint joinPoint) throws Throwable {
//方法执行前需生成trx_id
//判断是否事务发起者,如果能从线程上下文取到事务id说明是参与者,如果取不到则是事务管理者。
String trx_id = TrxContextHolder.getTrxId();
boolean isManager = false;
if (StringUtils.isEmpty(trx_id)) {
UUID uuid = UUID.randomUUID();
TrxContextHolder.setTrxId(uuid.toString());
isManager=true;
}
Map map=new HashMap(2);
map.put(TransactionConst.TRX_ID,trx_id);
try {
joinPoint.proceed();
map.put(trx_id, TransactionConst.COMMIT);
} catch (Throwable throwable) {
map.put(trx_id, TransactionConst.ROLLBACK);
throw throwable;
}finally {
//方法执行后需发送消息告知所有事务参与者是提交还是回滚
if(isManager) {
Message msg = MessageBuilder.withPayload(map).build();
rocketMQTemplate.send(TransactionConst.TRX_TOPIC, msg);
}
}
}
}
事务常量定义
package com.xxx.mq.trx.config;
/**
* @Description TODO
* @Author 姚仲杰
* @Date 2021/1/4 9:28
*/
public interface TransactionConst {
int COMMIT=1;
int ROLLBACK=0;
String TRX_ID="trx_id";
String TRX_TOPIC="global_trx_topic";
String TRX_GROUP="global_trx_group";
}
package com.xxx.mq.trx.aspect;
import com.xxx.mq.trx.core.ConnectionProxy;
import com.xxx.mq.trx.core.TrxContextHolder;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
/**
* @Description 拦截getConnection调用用于处理事务手动提交
* @Author 姚仲杰
* @Date 2021/01/04 11:46
*/
@Aspect
@Component
public class DataSourceAspect {
@Autowired
ObjectFactory<ConnectionProxy> bean;
ReentrantLock lock = new ReentrantLock();
@Around("execution(* javax.sql.DataSource.getConnection(..))")
public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
Connection conn = (Connection)joinPoint.proceed();
String trxId=TrxContextHolder.getTrxId();
if (StringUtils.isNotBlank(trxId)) {
ConnectionProxy connectionProxy = bean.getObject();
connectionProxy.setConnection(conn);
lock.lock();
try {
List<ConnectionProxy> list = TrxContextHolder.getConnections(trxId);
if (list == null) {
list = new ArrayList<>();
}
list.add(connectionProxy);
TrxContextHolder.setConnections(trxId,list);
} finally {
lock.unlock();
}
return connectionProxy;
}
return conn;
}
}
连接代理让Transactional注解的事务提交执行个寂寞,然后转交由我们自己mq通知提交。
package com.xxx.mq.trx.core;
import java.sql.Array;
import java.sql.Blob;
import java.sql.CallableStatement;
import java.sql.Clob;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.NClob;
import java.sql.PreparedStatement;
import java.sql.SQLClientInfoException;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.SQLXML;
import java.sql.Savepoint;
import java.sql.Statement;
import java.sql.Struct;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.Executor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;
/**
* @Description TODO
* @Author 姚仲杰
* @Date 2021/1/4 10:48
*/
@Component
@Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE)
public class ConnectionProxy implements Connection {
private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionProxy.class);
private Connection connection;
//mq收到事务通知之后调用此方法执行提交或回滚
public void notify(int state) {
try {
if (state == 1) {
connection.commit();
} else {
connection.rollback();
}
connection.close();
} catch (Exception e) {
LOGGER.error(e.getLocalizedMessage(), e);
}
}
@Override
public void setAutoCommit(boolean autoCommit) throws SQLException {
connection.setAutoCommit(false);
}
@Override
public void commit() throws SQLException {
// connection.commit();
}
@Override
public void rollback() throws SQLException {
// connection.rollback();
}
@Override
public void close() throws SQLException {
// connection.close();
}
@Override
public boolean getAutoCommit() throws SQLException {
return connection.getAutoCommit();
}
@Override
public Statement createStatement() throws SQLException {
return connection.createStatement();
}
@Override
public PreparedStatement prepareStatement(String sql) throws SQLException {
return connection.prepareStatement(sql);
}
@Override
public CallableStatement prepareCall(String sql) throws SQLException {
return connection.prepareCall(sql);
}
@Override
public String nativeSQL(String sql) throws SQLException {
return connection.nativeSQL(sql);
}
@Override
public boolean isClosed() throws SQLException {
return connection.isClosed();
}
@Override
public DatabaseMetaData getMetaData() throws SQLException {
return connection.getMetaData();
}
@Override
public void setReadOnly(boolean readOnly) throws SQLException {
connection.setReadOnly(readOnly);
}
@Override
public boolean isReadOnly() throws SQLException {
return connection.isReadOnly();
}
@Override
public void setCatalog(String catalog) throws SQLException {
connection.setCatalog(catalog);
}
@Override
public String getCatalog() throws SQLException {
return connection.getCatalog();
}
@Override
public void setTransactionIsolation(int level) throws SQLException {
connection.setTransactionIsolation(level);
}
@Override
public int getTransactionIsolation() throws SQLException {
return connection.getTransactionIsolation();
}
@Override
public SQLWarning getWarnings() throws SQLException {
return connection.getWarnings();
}
@Override
public void clearWarnings() throws SQLException {
connection.clearWarnings();
}
@Override
public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
return connection.createStatement(resultSetType, resultSetConcurrency);
}
@Override
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency)
throws SQLException {
return connection.prepareStatement(sql, resultSetType, resultSetConcurrency);
}
@Override
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
return connection.prepareCall(sql, resultSetType, resultSetConcurrency);
}
@Override
public Map<String, Class<?>> getTypeMap() throws SQLException {
return connection.getTypeMap();
}
@Override
public void setTypeMap(Map<String, Class<?>> map) throws SQLException {
connection.setTypeMap(map);
}
@Override
public void setHoldability(int holdability) throws SQLException {
connection.setHoldability(holdability);
}
@Override
public int getHoldability() throws SQLException {
return connection.getHoldability();
}
@Override
public Savepoint setSavepoint() throws SQLException {
return connection.setSavepoint();
}
@Override
public Savepoint setSavepoint(String name) throws SQLException {
return connection.setSavepoint(name);
}
@Override
public void rollback(Savepoint savepoint) throws SQLException {
connection.rollback(savepoint);
}
@Override
public void releaseSavepoint(Savepoint savepoint) throws SQLException {
connection.releaseSavepoint(savepoint);
}
@Override
public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability)
throws SQLException {
return connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
}
@Override
public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency,
int resultSetHoldability) throws SQLException {
return connection.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
}
@Override
public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency,
int resultSetHoldability) throws SQLException {
return connection.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
}
@Override
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
return connection.prepareStatement(sql, autoGeneratedKeys);
}
@Override
public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
return connection.prepareStatement(sql, columnIndexes);
}
@Override
public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
return connection.prepareStatement(sql, columnNames);
}
@Override
public Clob createClob() throws SQLException {
return connection.createClob();
}
@Override
public Blob createBlob() throws SQLException {
return connection.createBlob();
}
@Override
public NClob createNClob() throws SQLException {
return connection.createNClob();
}
@Override
public SQLXML createSQLXML() throws SQLException {
return connection.createSQLXML();
}
@Override
public boolean isValid(int timeout) throws SQLException {
return connection.isValid(timeout);
}
@Override
public void setClientInfo(String name, String value) throws SQLClientInfoException {
connection.setClientInfo(name, value);
}
@Override
public void setClientInfo(Properties properties) throws SQLClientInfoException {
connection.setClientInfo(properties);
}
@Override
public String getClientInfo(String name) throws SQLException {
return connection.getClientInfo(name);
}
@Override
public Properties getClientInfo() throws SQLException {
return connection.getClientInfo();
}
@Override
public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
return connection.createArrayOf(typeName, elements);
}
@Override
public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
return connection.createStruct(typeName, attributes);
}
@Override
public void setSchema(String schema) throws SQLException {
connection.setSchema(schema);
}
@Override
public String getSchema() throws SQLException {
return connection.getSchema();
}
@Override
public void abort(Executor executor) throws SQLException {
connection.abort(executor);
}
@Override
public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
connection.setNetworkTimeout(executor, milliseconds);
}
@Override
public int getNetworkTimeout() throws SQLException {
return connection.getNetworkTimeout();
}
@Override
public <T> T unwrap(Class<T> iface) throws SQLException {
return connection.unwrap(iface);
}
@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return connection.isWrapperFor(iface);
}
public Connection getConnection() {
return connection;
}
public void setConnection(Connection connection) {
this.connection = connection;
}
}
事务上下文
package com.xxx.mq.trx.core;
import java.util.HashMap;
import java.util.Map;
/**
* @Description TODO
* @Author 姚仲杰
* @Date 2020/12/28 11:42
*/
public class TrxContext {
private ThreadLocal<Map<String,String>> threadLocal=new ThreadLocal<Map<String,String>>(){
@Override
protected Map<String, String> initialValue() {
return new HashMap<String, String>();
}
};
public String put(String key, String value) {
return threadLocal.get().put(key, value);
}
public String get(String key) {
return threadLocal.get().get(key);
}
public String remove(String key) {
return threadLocal.get().remove(key);
}
public Map<String, String> entries() {
return threadLocal.get();
}
}
事务上下文持有者缓存了trxId,以及全局事务连接列表等属性。
package com.xxx.mq.trx.core;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/**
* @Description
* @Author 姚仲杰
* @Date 2020/12/28 11:46
*/
public class TrxContextHolder {
private static final Logger LOGGER = LoggerFactory.getLogger(TrxContextHolder.class);
public static final TrxContext TRX_CONTEXT_HOLDER=new TrxContext();
private static volatile ConcurrentHashMap<String, List<ConnectionProxy>> connectionsMap =
new ConcurrentHashMap<>();
public static final String TRX_ID="TRX_ID";
public static String getTrxId(){
return TRX_CONTEXT_HOLDER.get(TRX_ID);
}
public static void setTrxId(String trxId){
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("set trx_id:[{}]", trxId);
}
TRX_CONTEXT_HOLDER.put(TRX_ID, trxId);
}
public static String removeTrxId() {
String trxId = TRX_CONTEXT_HOLDER.remove(TRX_ID);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("remove trx_id:[{}] ", trxId);
}
return trxId;
}
public static List<ConnectionProxy> getConnections(String trxId){
if (StringUtils.isEmpty(trxId)){
LOGGER.error("trx_id can not be empty");
throw new IllegalArgumentException();
}
return connectionsMap.get(trxId);
}
public static void setConnections(String trxId,List<ConnectionProxy> connections){
if (StringUtils.isEmpty(trxId)){
LOGGER.error("trx_id can not be empty");
throw new IllegalArgumentException();
}
if (CollectionUtils.isEmpty(connections)){
LOGGER.error("connections can not be empty,require at least one connection");
throw new IllegalArgumentException();
}
connectionsMap.put(trxId,connections);
}
public static void removeConnections(String trxId){
if (StringUtils.isEmpty(trxId)){
LOGGER.error("trx_id can not be empty");
throw new IllegalArgumentException();
}
connectionsMap.remove(trxId);
}
}
二阶段提交mq监听器
package com.xxx.mq.trx.core;
import com.alibaba.fastjson.JSON;
import com.xxx.mq.trx.config.TransactionConst;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.CollectionUtils;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.spring.annotation.ConsumeMode;
import org.apache.rocketmq.spring.annotation.RocketMQMessageListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @Description TODO
* @Author 姚仲杰
* @Date 2021/1/4 11:18
*/
@RocketMQMessageListener(consumeMode = ConsumeMode.CONCURRENTLY,topic = TransactionConst.TRX_TOPIC,consumerGroup = TransactionConst.TRX_GROUP)
public class TransactionMassageListener implements MessageListenerConcurrently {
public static final Logger LOGGER= LoggerFactory.getLogger(TransactionMassageListener.class);
@Override
public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs,
ConsumeConcurrentlyContext context) {
LOGGER.info("receive global transaction message: {}",msgs);
MessageExt messageExt = msgs.get(0);
//如果本地获取不到事务等待连接直接返回消费成功,因为这是广播模式。
try {
String s = new String(messageExt.getBody(), "utf-8");
Map map = JSON.parseObject(s, HashMap.class);
String trxId= (String) map.get(TransactionConst.TRX_ID);
int state= (int) map.get(trxId);
List<ConnectionProxy> connections = TrxContextHolder.getConnections(trxId);
if (!CollectionUtils.isEmpty(connections)){
try {
connections.forEach(cp -> cp.notify(state));
}finally {
TrxContextHolder.removeConnections(trxId);
}
}
}catch (Throwable e){
return ConsumeConcurrentlyStatus.RECONSUME_LATER;
}
return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
}
}
dubbo事务id传播过滤器
package com.xxx.mq.trx.integration.dubbo;
import com.xxx.mq.trx.config.TransactionConst;
import com.xxx.mq.trx.core.TrxContextHolder;
import org.apache.dubbo.common.extension.Activate;
import org.apache.dubbo.rpc.Filter;
import org.apache.dubbo.rpc.Invocation;
import org.apache.dubbo.rpc.Invoker;
import org.apache.dubbo.rpc.Result;
import org.apache.dubbo.rpc.RpcContext;
import org.apache.dubbo.rpc.RpcException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @Description 用户传递trx_id给下游服务,并将事务id绑定给本地线程变量
* @Author 姚仲杰
* @Date 2021/01/04 11:46
*/
@Activate(group = {"provider", "consumer"}, order = 100)
public class DubboTrxPropagationFilter implements Filter {
private static final Logger LOGGER = LoggerFactory.getLogger(DubboTrxPropagationFilter.class);
@Override
public Result invoke(Invoker<?> invoker, Invocation invocation) throws RpcException {
String trxId = TrxContextHolder.getTrxId();
String rpcXid = RpcContext.getContext().getAttachment(TransactionConst.TRX_ID);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("trxId in TrxContext[{}] trxId in RpcContext[{}]", trxId, rpcXid);
}
boolean bind = false;
if (trxId != null) {
RpcContext.getContext().setAttachment(TransactionConst.TRX_ID, trxId);
} else {
if (rpcXid != null) {
TrxContextHolder.setTrxId(rpcXid);
bind = true;
}
}
try {
return invoker.invoke(invocation);
} finally {
if (bind) {
TrxContextHolder.removeTrxId();
}
}
}
}