1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 多线程如何实现事务回滚?一招帮你解决

多线程如何实现事务回滚?一招帮你解决

时间:2022-03-06 16:16:06

相关推荐

多线程如何实现事务回滚?一招帮你解决

特别说明CountDownLatch

**CountDownLatch是一个类springboot自带的类,可以直接用,**变量AtomicBoolean 也是可以直接使用

CountDownLatch的用法

CountDownLatch典型用法:

1、某一线程在开始运行前等待n个线程执行完毕。将CountDownLatch的计数器初始化为new CountDownLatch(n),每当一个任务线程执行完毕,就将计数器减1 countdownLatch.countDown(),当计数器的值变为0时,在CountDownLatch上await()的线程就会被唤醒。一个典型应用场景就是启动一个服务时,主线程需要等待多个组件加载完毕,之后再继续执行。

2、实现多个线程开始执行任务的最大并行性。注意是并行性,不是并发,强调的是多个线程在某一时刻同时开始执行。类似于赛跑,将多个线程放到起点,等待发令枪响,然后同时开跑。做法是初始化一个共享的CountDownLatch(1),将其计算器初始化为1,多个线程在开始执行任务前首先countdownlatch.await(),当主线程调用countDown()时,计数器变为0,多个线程同时被唤醒。

CountDownLatch(num) 简单说明

new 一个 CountDownLatch(num) 对象

建立对象的时候 num 代表的是需要等待 num 个线程

// 建立对象的时候 num 代表的是需要等待 num 个线程//主线程CountDownLatch mainThreadLatch = new CountDownLatch(num);//子线程CountDownLatch rollBackLatch = new CountDownLatch(1);

主线程:mainThreadLatch.await() 和mainThreadLatch.countDown()

新建对象

CountDownLatch mainThreadLatch = new CountDownLatch(num);

卡住主线程,让其等待子线程,代码mainThreadLatch.await(),放在主线程里

mainThreadLatch.await();

代码mainThreadLatch.countDown(),放在子线程里,每一个子线程运行一到这个代码,意味着CountDownLatch(num),里面的num-1(自动减一)

mainThreadLatch.countDown();

CountDownLatch(num)里面的num减到0,也就是CountDownLatch(0),被卡住的主线程mainThreadLatch.await(),就会往下执行

子线程:rollBackLatch.await() 和rollBackLatch.countDown()

新建对象,特别注意:子线程这个num就是1(关于只能为1的解答在后面)

CountDownLatch rollBackLatch = new CountDownLatch(1);

卡住子线程,阻止每一个子线程的事务提交和回滚

rollBackLatch.await();

代码rollBackLatch.countDown();放在主线程里,而且是放在主线程的等待代码mainThreadLatch.await();后面。

rollBackLatch.countDown();

为什么所有的子线程会在一瞬间就被所有都释放了?

事务的回滚是怎么结合进去的?

假设总共20个子线程,那么其中一个线程报错了怎么实现所有线程回滚。

引入变量

AtomicBoolean rollbackFlag = new AtomicBoolean(false)

和字面意思是一样的:根据 rollbackFlag 的true或者false 判断子线程里面,是否回滚。

首先我们确定的一点:rollbackFlag 是所有的子线程都用着这一个判断

主线程类Entry

package org.apache.dolphinscheduler.api.utils;import com.alibaba.fastjson.JSONArray;import com.alibaba.fastjson.JSONObject;import org.apache.dolphinscheduler.api.controller.WorkThread;import org.mon.enums.DbType;import org.springframework.web.bind.annotation.*;import java.text.SimpleDateFormat;import java.util.ArrayList;import java.util.Date;import java.util.List;import java.util.TimeZone;import java.util.concurrent.CountDownLatch;import java.util.concurrent.atomic.AtomicBoolean;@RestController@RequestMapping("importDatabase")public class Entry {/*** @param dbid 数据库的id* @param tablename 表名* @param sftpFileName 文件名称* @param head 是否有头文件* @param splitSign 分隔符* @param type 数据库类型*/private static String SFTP_HOST = "192.168.1.92";private static int SFTP_PORT = 22;private static String SFTP_USERNAME = "root";private static String SFTP_PASSWORD = "rootroot";private static String SFTP_BASEPATH = "/opt/testSFTP/";@PostMapping("/thread")@ResponseBodypublic static JSONObject importDatabase(@RequestParam("dbid") int dbid,@RequestParam("tablename") String tablename,@RequestParam("sftpFileName") String sftpFileName,@RequestParam("head") String head,@RequestParam("splitSign") String splitSign,@RequestParam("type") DbType type,@RequestParam("heads") String heads,@RequestParam("scolumns") String scolumns,@RequestParam("tcolumns") String tcolumns ) throws Exception {JSONObject obForRetrun = new JSONObject();try {JSONArray jsonArray = JSONArray.parseArray(tcolumns);JSONArray scolumnArray = JSONArray.parseArray(scolumns);JSONArray headsArray = JSONArray.parseArray(heads);List<Integer> listInteger = getRrightDataNum(headsArray,scolumnArray);JSONArray bodys = SFTPUtils.getSftpContent(SFTP_HOST,SFTP_PORT,SFTP_USERNAME,SFTP_PASSWORD,SFTP_BASEPATH,sftpFileName,head,splitSign);int total = bodys.size();int num = 20; //一个批次的数据有多少int count = total/num;//周期int lastNum =total- count*num;//余数List<Thread> list = new ArrayList<Thread>();SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");TimeZone t = sdf.getTimeZone();t.setRawOffset(0);sdf.setTimeZone(t);Long startTime=System.currentTimeMillis();int countForCountDownLatch = 0;if(lastNum==0){//整除countForCountDownLatch= count;}else{countForCountDownLatch= count + 1;}//子线程CountDownLatch rollBackLatch = new CountDownLatch(1);//主线程CountDownLatch mainThreadLatch = new CountDownLatch(countForCountDownLatch);AtomicBoolean rollbackFlag = new AtomicBoolean(false);StringBuffer message = new StringBuffer();message.append("报错信息:");//子线程for(int i=0;i<count;i++) {//这里的count代表有几个线程Thread g = new Thread(new WorkThread(i,num,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));g.start();list.add(g);}if(lastNum!=0){//有小数的情况下Thread g = new Thread(new WorkThread(0,lastNum,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));g.start();list.add(g);}// for(Thread thread:list){//System.out.println(thread.getState());//thread.join();//是等待这个线程结束;// }mainThreadLatch.await();//所有等待的子线程全部放开rollBackLatch.countDown();//是主线程等待子线程的终止。也就是说主线程的代码块中,如果碰到了t.join()方法,此时主线程需要等待(阻塞),等待子线程结束了(Waits for this thread to die.),才能继续执行t.join()之后的代码块。Long endTime=System.currentTimeMillis();System.out.println("总共用时: "+sdf.format(new Date(endTime-startTime)));if(rollbackFlag.get()){obForRetrun.put("code",500);obForRetrun.put("msg",message);}else{obForRetrun.put("code",200);obForRetrun.put("msg","提交成功!");}obForRetrun.put("data",null);}catch (InterruptedException e){e.printStackTrace();obForRetrun.put("code",500);obForRetrun.put("msg",e.getMessage());obForRetrun.put("data",null);}return obForRetrun;}/*** 文件里第几列被作为导出列* @param headsArray* @param scolumnArray* @return*/public static List<Integer> getRrightDataNum(JSONArray headsArray, JSONArray scolumnArray){List<Integer> list = new ArrayList<Integer>();String arrayA [] = new String[headsArray.size()];for(int i=0;i<headsArray.size();i++){JSONObject ob = (JSONObject)headsArray.get(i);arrayA[i] =String.valueOf(ob.get("title"));}String arrayB [] = new String[scolumnArray.size()];for(int i=0;i<scolumnArray.size();i++){JSONObject ob = (JSONObject)scolumnArray.get(i);arrayB[i] =String.valueOf(ob.get("columnName"));}for(int i =0;i<arrayA.length;i++){for(int j=0;j<arrayB.length;j++){if(arrayA[i].equals(arrayB[j])){list.add(i);break;}}}return list;}}

子线程类WorkThread

package org.apache.dolphinscheduler.api.controller;import com.alibaba.fastjson.JSONArray;import com.alibaba.fastjson.JSONObject;import org.apache.dolphinscheduler.api.service.DataSourceService;import org.mon.enums.DbType;import org.apache.dolphinscheduler.dao.entity.DataSource;import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper;import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;import org.springframework.transaction.PlatformTransactionManager;import java.sql.Connection;import java.sql.PreparedStatement;import java.sql.SQLException;import java.text.ParseException;import java.text.SimpleDateFormat;import java.util.Date;import java.util.List;import java.util.TimeZone;import java.util.concurrent.CountDownLatch;import java.util.concurrent.atomic.AtomicBoolean;/*** 多线程*/public class WorkThread implements Runnable{//建立线程的两种方法 1 实现Runnable 接口 2 继承 Thread 类private DataSourceService dataSourceService;private DataSourceMapper dataSourceMapper;private Integer begin;private Integer end;private String tableName;private JSONArray columnArray;private Integer dbid;private DbType type;private JSONArray bodys;private List<Integer> listInteger;private PlatformTransactionManager transactionManager;private CountDownLatch mainThreadLatch;private CountDownLatch rollBackLatch;private AtomicBoolean rollbackFlag;private StringBuffer message;/*** @param i* @param num* @param tableFrom* @param columnArrayFrom* @param dbidFrom* @param typeFrom*/public WorkThread(int i, int num, String tableFrom, JSONArray columnArrayFrom, int dbidFrom, DbType typeFrom, JSONArray bodysFrom, List<Integer> listIntegerFrom,CountDownLatch mainThreadLatch,CountDownLatch rollBackLatch,AtomicBoolean rollbackFlag,StringBuffer messageFrom) {begin=i*num;end=begin+num;tableName = tableFrom;columnArray = columnArrayFrom;dbid = dbidFrom;type = typeFrom;bodys = bodysFrom;listInteger = listIntegerFrom;this.dataSourceMapper = SpringApplicationContext.getBean(DataSourceMapper.class);this.dataSourceService = SpringApplicationContext.getBean(DataSourceService.class);this.transactionManager = SpringApplicationContext.getBean(PlatformTransactionManager.class);this.mainThreadLatch = mainThreadLatch;this.rollBackLatch = rollBackLatch;this.rollbackFlag = rollbackFlag;this.message = messageFrom;}public void run() {DataSource dataSource = dataSourceMapper.queryDataSourceByID(dbid);String cp = dataSource.getConnectionParams();Connection con=null;con = dataSourceService.getConnection(type,cp);if(con!=null){SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");TimeZone t = sdf.getTimeZone();t.setRawOffset(0);sdf.setTimeZone(t);Long startTime = System.currentTimeMillis();try {con.setAutoCommit(false);//---------------------------- 获取字段和类型String columnString = null;//活动的字段int intForType = 0;String type[] = new String[columnArray.size()];//类型集合for(int i=0;i<columnArray.size();i++){JSONObject ob = (JSONObject)columnArray.get(i);if(columnString==null){columnString = String.valueOf(ob.get("name"));}else{columnString = columnString + ","+ String.valueOf(ob.get("name"));}type[intForType] = String.valueOf(ob.get("type"));intForType = intForType + 1;}intForType = 0;//这一步是为了形成 insert into "+tableName+"(id,name,age) values (?,?,?);String dataString = null;for(int i=0;i<columnArray.size();i++){if(dataString==null){dataString = "?";}else{dataString = dataString +","+"?";}}//--------------------------------StringBuffer sql = new StringBuffer();sql = sql.append("insert into "+tableName+"("+columnString+") values ("+dataString+")") ;PreparedStatement pst= (PreparedStatement)con.prepareStatement(sql.toString());for(int i=begin;i<end;i++) {JSONObject ob = (JSONObject)bodys.get(i);if(ob!=null){String [] array = ob.get(i).toString().split("\\,");String [] arrayFinal = getFinalData(listInteger,array);for(int j=0;j<type.length;j++){String typeString = type[j].toLowerCase();int z = j+1;if("string".equals(typeString)||"varchar".equals(typeString)){pst.setString(z,arrayFinal[j]);//这里的第一个参数 是指 替换第几个?}else if("int".equals(typeString)||"bigint".equals(typeString)){pst.setInt(z,Integer.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个?}else if("long".equals(typeString)){pst.setLong(z,Long.valueOf(arrayFinal[j]));//这里的第一个参数 是指 替换第几个?}else if("double".equals(typeString)){pst.setDouble(z,Double.parseDouble(arrayFinal[j]));}else if("date".equals(typeString)||"datetime".equals(typeString)){pst.setDate(z, setDateback(arrayFinal[j]));}else if("Timestamp".equals(typeString)){pst.setTimestamp(z, setTimestampback(arrayFinal[j]));}}}pst.addBatch();}pst.executeBatch();mainThreadLatch.countDown();rollBackLatch.await();if(rollbackFlag.get()){con.rollback();//表示回滚事务;}else{mit();//事务提交}con.close();} catch (Exception e) {System.out.println(e.getMessage());message = message.append(e.getMessage());rollbackFlag.set(true);mainThreadLatch.countDown();try {con.close();} catch (SQLException throwables) {throwables.printStackTrace();}}Long endTime = System.currentTimeMillis();System.out.println(Thread.currentThread().getName()+":startTime= "+sdf.format(new Date(startTime))+",endTime= "+sdf.format(new Date(endTime))+" 用时:"+sdf.format(new Date(endTime - startTime)));}}public java.sql.Date setDateback(String dateString) throws ParseException {SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );java.util.Date date = sdf.parse( "-5-6 10:30:00" );long lg = date.getTime();// 日期 转 时间戳return new java.sql.Date( lg );}public java.sql.Timestamp setTimestampback(String dateString) throws ParseException {SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );java.util.Date date = sdf.parse( "-5-6 10:30:00" );long lg = date.getTime();// 日期 转 时间戳return new java.sql.Timestamp( lg );}public String [] getFinalData(List<Integer> listInteger,String[] array){String [] arrayFinal = new String [listInteger.size()];for(int i=0;i<listInteger.size();i++){int a = listInteger.get(i);arrayFinal[i] = array[a];}return arrayFinal;}}

代码实际运用踩坑!!!!

还记得这里有个一批次处理多少数据么,我这边设置了20,实际到运用中的时候客户给了个20W的数据,我批次设置为20,那就有1W个子线程!!!!

这还不是最糟糕的,最糟糕的是每个子线程都会创建一个数据库连接,数据库直接被我搞炸了

所以这里需要把:

int num = 20; //一个批次的数据有多少

改成:

int num = 20000; //一个批次的数据有多少

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。