1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > Spark SQL 用户自定义函数UDF 用户自定义聚合函数UDAF 教程(Java踩坑教学版)

Spark SQL 用户自定义函数UDF 用户自定义聚合函数UDAF 教程(Java踩坑教学版)

时间:2023-03-18 12:51:06

相关推荐

Spark SQL 用户自定义函数UDF 用户自定义聚合函数UDAF 教程(Java踩坑教学版)

在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:

UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap

本篇就手把手教你如何编写UDF和UDAF

先来个简单的UDF

场景:

我们有这样一个文本文件:

1^^d2^b^d3^c^d4^^d

在读取数据的时候,第二列的数据如果为空,需要显示'null',不为空就直接输出它的值。定义完成后,就可以直接在SparkSQL中使用了。

代码为:

package test;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.sql.DataFrame;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SQLContext;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/** * Created by xinghailong on /2/23. */public class test3 {public static void main(String[] args) {//创建spark的运行环境SparkConf sparkConf = new SparkConf();sparkConf.setMaster("local[2]");sparkConf.setAppName("test-udf");JavaSparkContext sc = new JavaSparkContext(sparkConf);SQLContext sqlContext = new SQLContext(sc);//注册自定义方法sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType);//读取文件JavaRDD<String> lines = sc.textFile( "C:\\test-udf.txt" );JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));List<StructField> structFields = new ArrayList<StructField>();structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true ));StructType structType = DataTypes.createStructType( structFields );DataFrame test = sqlContext.createDataFrame( rows, structType);test.registerTempTable("test");sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show();sc.stop();}}

输出内容为:

+---+----+---+| a| _c1| c|+---+----+---+| 1|null| d|| 2| b| d|| 3| c| d|| 4|null| d|+---+----+---+

其中比较关键的就是这句:

sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType);

这里我直接用的java8的语法写的,如果是java8之前的版本,需要使用Function2创建匿名函数。

再来个自定义的UDAF—求平均数

先来个最简单的UDAF,求平均数。类似这种的操作有很多,比如最大值,最小值,累加,拼接等等,都可以采用相同的思路来做。

首先是需要定义UDAF函数

package test;import org.apache.spark.sql.Row;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.DataType;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/*** Created by xinghailong on /2/23.*/public class MyAvg extends UserDefinedAggregateFunction {@Overridepublic StructType inputSchema() {List<StructField> structFields = new ArrayList<>();structFields.add(DataTypes.createStructField( "field1", DataTypes.StringType, true ));return DataTypes.createStructType( structFields );}@Overridepublic StructType bufferSchema() {List<StructField> structFields = new ArrayList<>();structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true ));structFields.add(DataTypes.createStructField( "field2", DataTypes.IntegerType, true ));return DataTypes.createStructType( structFields );}@Overridepublic DataType dataType() {return DataTypes.IntegerType;}@Overridepublic boolean deterministic() {return false;}@Overridepublic void initialize(MutableAggregationBuffer buffer) {buffer.update(0,0);buffer.update(1,0);}@Overridepublic void update(MutableAggregationBuffer buffer, Row input) {buffer.update(0,buffer.getInt(0)+1);buffer.update(1,buffer.getInt(1)+Integer.valueOf(input.getString(0)));}@Overridepublic void merge(MutableAggregationBuffer buffer1, Row buffer2) {buffer1.update(0,buffer1.getInt(0)+buffer2.getInt(0));buffer1.update(1,buffer1.getInt(1)+buffer2.getInt(1));}@Overridepublic Object evaluate(Row buffer) {return buffer.getInt(1)/buffer.getInt(0);}}

使用的时候,需要先注册,然后在spark sql里面就可以直接使用了:

package test;import com.tgou.standford.misdw.udf.MyAvg;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.sql.DataFrame;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SQLContext;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/*** Created by xinghailong on /2/23.*/public class test4 {public static void main(String[] args) {SparkConf sparkConf = new SparkConf();sparkConf.setMaster("local[2]");sparkConf.setAppName("test");JavaSparkContext sc = new JavaSparkContext(sparkConf);SQLContext sqlContext = new SQLContext(sc);sqlContext.udf().register("my_avg",new MyAvg());JavaRDD<String> lines = sc.textFile( "C:\\test4.txt" );JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));List<StructField> structFields = new ArrayList<StructField>();structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));StructType structType = DataTypes.createStructType( structFields );DataFrame test = sqlContext.createDataFrame( rows, structType);test.registerTempTable("test");sqlContext.sql("SELECT my_avg(b) FROM test GROUP BY a").show();sc.stop();}}

计算的文本内容为:

a^3a^6b^2b^4b^6

再来个无所不能的UDAF

真正的业务场景里面,总会有千奇百怪的需求,比如:

想要按照某个字段分组,取其中的一个最大值想要按照某个字段分组,对分组内容的数据按照特定字段统计累加想要按照某个字段分组,针对特定的条件,拼接字符串

再比如一个场景,需要按照某个字段分组,然后分组内的数据,又需要按照某一列进行去重,最后再计算值

1 按照某个字段分组2 分组校验条件3 然后处理字段

如果不用UDAF,你要是写spark可能需要这样做:

rdd.groupBy(r->r.xxx).map(t2->{HashSet<String> set = new HashSet<>();for(Object p : t2._2){if(p.getBs() > 0 ){map.put(xx,yyy)}}return StringUtils.join(set.toArray(),",");});

上面是一段伪码,不保证正常运行哈。

这样写,其实也能应付需求了,但是代码显得略有点丑陋。还是不如SparkSQL看的清晰明了...

所以我们再尝试用SparkSql中的UDAF来一版!

首先需要创建UDAF类

import mons.lang.StringUtils;import org.apache.spark.sql.Row;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.*;import java.util.*;/**** Created by xinghailong on /2/23.*/public class ConditionJoinUDAF extends UserDefinedAggregateFunction {@Overridepublic StructType inputSchema() {List<StructField> structFields = new ArrayList<>();structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true ));structFields.add(DataTypes.createStructField( "field2", DataTypes.StringType, true ));return DataTypes.createStructType( structFields );}@Overridepublic StructType bufferSchema() {List<StructField> structFields = new ArrayList<>();structFields.add(DataTypes.createStructField( "field", DataTypes.StringType, true ));return DataTypes.createStructType( structFields );}@Overridepublic DataType dataType() {return DataTypes.StringType;}@Overridepublic boolean deterministic() {//是否强制每次执行的结果相同return false;}@Overridepublic void initialize(MutableAggregationBuffer buffer) {//初始化buffer.update(0,"");}@Overridepublic void update(MutableAggregationBuffer buffer, Row input) {//相同的executor间的数据合并Integer bs = input.getInt(0);String field = buffer.getString(0);String in = input.getString(1);if(bs > 0 && !"".equals(in) && !field.contains(in)){field += ","+in;}buffer.update(0,field);}@Overridepublic void merge(MutableAggregationBuffer buffer1, Row buffer2) {//不同excutor间的数据合并String field1 = buffer1.getString(0);String field2 = buffer2.getString(0);if(!"".equals(field2)){field1 += ","+field2;}buffer1.update(0,field1);}@Overridepublic Object evaluate(Row buffer) {//根据Buffer计算结果return StringUtils.join(Arrays.stream(buffer.getString(0).split(",")).filter(line->!line.equals("")).toArray(),",");}}

拿一个例子坐下实验:

a^1111^2a^1111^2a^1111^2a^1111^2a^1111^2a^2222^0a^3333^1b^4444^0b^5555^3c^6666^0

按照第一列进行分组,不同的第三列值,进行拼接。

package test;import test.ConditionJoinUDAF;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.sql.DataFrame;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SQLContext;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import java.util.ArrayList;import java.util.List;/*** Created by xinghailong on /2/23.*/public class test2 {public static void main(String[] args) {SparkConf sparkConf = new SparkConf();sparkConf.setMaster("local[2]");sparkConf.setAppName("test");JavaSparkContext sc = new JavaSparkContext(sparkConf);SQLContext sqlContext = new SQLContext(sc);sqlContext.udf().register("con_join",new ConditionJoinUDAF());JavaRDD<String> lines = sc.textFile( "C:\\test2.txt" );JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));List<StructField> structFields = new ArrayList<StructField>();structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true ));StructType structType = DataTypes.createStructType( structFields );DataFrame test = sqlContext.createDataFrame( rows, structType);test.registerTempTable("test");sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show();sc.stop();}}

这样SQL简洁明了,就能表达意思了。

参考

Spark Multiple Input/Output User Defined Aggregate Function (UDAF) using Java李震的UDAF·scala版本Spark Sql官方文档Scala菜鸟教程spark1.5 自定义聚合函数UDAF 本文转自博客园xingoo的博客,原文链接:Spark SQL 用户自定义函数UDF、用户自定义聚合函数UDAF 教程(Java踩坑教学版),如需转载请自行联系原博主。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。