1. SparkSQL 中自定义函数类型
在 Spark SQL 中,用户自定义函数(User-Defined Function,简称 UDF)是一种特殊的函数,允许用户定义自己的逻辑来处理数据。这些函数可以直接在 Spark SQL 查询中使用,就像使用内置函数一样。在Spark中,用户自定义函数可以分为以下三种类型:
- UDF(User-Defined Function):这是最常见的用户自定义函数类型。UDF接收一行输入,返回一个结果。这种类型的函数在实现上就是普通的Scala函数。例如,你可能有一个UDF,它接收一个字符串,然后返回该字符串的长度。
- UDTF(User-Defined Table-Generating Functions):这种类型的函数接收一行输入,返回多行(也就是一个表)。在SparkSQL中,这种类型的函数并没有明确的定义,因为Spark中的flatMap函数已经可以实现这个功能。例如,你可能有一个UDTF,它接收一个字符串,然后返回一个包含该字符串中每个字符的表。
- UDAF(User-Defined Aggregate Functions):这种类型的函数接收多行输入,返回一行结果。这里的”A”代表的是”aggregate”,也就是聚合的意思。如果业务逻辑复杂,可能需要自己实现聚合函数。例如,你可能有一个UDAF,它接收一个包含多个数字的表,然后返回这些数字的平均值。
实质上讲,例如说 UDF 会被 Spark SQL 中的 Catalyst 封装成为 Expression,最终会通过 eval 方法来计算输入的数据 Row(此处的Row和 DataFrame 中的 Row 没有任何关系)
2. UDF 介绍
Spark SQL 中的用户自定义函数(User-Defined Functions,简称 UDF)是一种特殊类型的函数,允许用户定义自己的逻辑来处理数据。这些函数可以直接在 Spark SQL 查询中使用,就像使用内置函数一样。
UDF 是一种接收一行输入并返回一个结果的函数。这种类型的函数在实现上就是普通的 Scala 或者 Python 函数。例如,你可能有一个 UDF,它接收一个字符串,然后返回该字符串的长度。
创建 Spark SQL 的 UDF 通常包含以下步骤:
- 定义一个函数:这个函数包含你想要应用到数据上的逻辑。这个函数可以是任何接收适当类型的输入并返回一个结果的函数。
- 注册函数:在你可以在 SQL 查询中使用你的 UDF 之前,你需要将其注册到 Spark SQL 的上下文中。你可以通过调用
spark.udf.register
方法来完成这个步骤,其中spark
是你的SparkSession
对象,register
方法接收两个参数:你的 UDF 的名称(一个字符串),以及你的函数。 - 在 SQL 查询中使用 UDF:一旦你的 UDF 被注册,你就可以在 SQL 查询中像使用其他 SQL 函数一样使用它了。
使用 UDF 可以让你在处理数据时有更大的灵活性,因为你可以定义自己的数据处理逻辑。这在处理复杂的数据转换或计算时特别有用。
代码示例
下面是一个 UDF代码示例,主要功能是读取CSV文件中的数据,然后使用一个用户自定义函数(UDF)对数据进行处理。这个UDF会检查输入的字符串是否全为数字,如果是,则返回输入本身,否则返回”0”。
import java.util.regex.{Matcher, Pattern}
import org.apache.spark.SparkConf
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.{DataFrame, SparkSession}
object SparkUDF {
def main(args: Array[String]): Unit = {
// 创建一个SparkConf对象,设置Spark的运行模式和应用名称
val sparkConf: SparkConf = new SparkConf().setMaster("local[8]").setAppName("sparkCSV")
// 创建一个SparkSession对象,使用SparkConf对象的配置
val session: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
// 设置Spark的日志级别为WARN
session.sparkContext.setLogLevel("WARN")
// 读取CSV文件,将其加载为一个DataFrame对象
val frame: DataFrame = session
.read
.format("csv")
.option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")
.option("header", "true")
.option("multiLine", true)
.load("file:///D:\\datas\\datafiles")
// 创建或替换一个临时视图,视图名为"house_sale"
frame.createOrReplaceTempView("house_sale")
// 创建一个UDF1对象,用于定义一个用户自定义函数,然后将其注册为udf
session.udf.register("house_udf",new UDF1[String,String] {
// 定义一个正则表达式模式,用于匹配全数字的字符串
val pattern: Pattern = Pattern.compile("^[0-9]*$")
// 定义UDF1对象的call方法,该方法接收一个字符串输入,返回一个字符串输出
override def call(input: String): String = {
// 使用定义的正则表达式模式对输入进行匹配
val matcher: Matcher = pattern.matcher(input)
// 如果输入匹配全数字的模式,则返回输入本身,否则返回 "0"
if(matcher.matches()){
input
}else{
"0"
}
}
},DataTypes.StringType)
// 使用注册的udf对"house_sale"视图进行查询,并显示查询结果的前200行
session.sql("select house_udf(house_age) from house_sale limit 200").show()
// 停止SparkSession
session.stop()
}
}
3. UDTF 介绍
3.1 UDTF 介绍
UDTF (User-Defined Table-Generating Functions,简称 )是用户自定义表生成函数,它可以接收一行输入并返回多行数据。
在 Spark SQL 中,UDTF 的概念并没有明确的定义,因为 Spark 中的 flatMap
函数已经可以实现这个功能。flatMap
函数接收一个函数作为参数,这个函数应该接收一个输入并返回一个迭代器。flatMap
函数将这个函数应用到数据集中的每个元素,然后将返回的所有迭代器连接成一个新的数据集。
例如,你可能有一个 UDTF,它接收一个字符串,然后返回一个包含该字符串中每个字符的表。你可以使用 flatMap
函数来实现这个功能。
val words = Seq("hello", "world").toDF("word")
val characters = words.flatMap(_.getString(0).toSeq)
在这个例子中,flatMap
函数接收一个函数,这个函数将一个字符串转换为一个字符序列。然后,flatMap
函数将这个函数应用到 “word” 列中的每个值,生成一个新的 DataFrame,其中包含每个单词中的每个字符。
3.2 代码示例
下面代码中首先创建了一个简单的 DataFrame,包含一个字符串列。然后,我们定义了一个 UDTF,它将一个字符串转换为一个字符数组。最后,我们使用 flatMap
函数将这个 UDTF 应用到 DataFrame 中的每一行,生成一个新的 DataFrame,其中包含每个单词中的每个字符。
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
object UDTFExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("UDTF Example").getOrCreate()
// 创建一个简单的 DataFrame,包含一个字符串列
val wordsDataFrame = spark.sparkContext.parallelize(Seq("hello", "world")).toDF("word")
// 定义一个 UDTF,将一个字符串转换为一个字符数组
val explode = (s: String) => s.toCharArray.map(_.toString)
// 使用 flatMap 函数应用 UDTF
val charactersDataFrame = wordsDataFrame.flatMap(row => explode(row.getAs[String]("word")))
charactersDataFrame.show()
spark.stop()
}
}
4. UDAF 介绍
4.1 UserDefinedAggregateFunction
在 Spark 编写自定义的 UDAF 时需要继承 UserDefinedAggregateFunction
类,UserDefinedAggregateFunction
是 Spark SQL 中的一个抽象类,用于创建用户自定义的聚合函数(User-Defined Aggregate Functions,简称 UDAF)。UDAF 是一种特殊类型的函数,可以处理多行输入并返回一个聚合的输出结果。
当你创建一个 UDAF 时,你需要继承 UserDefinedAggregateFunction
类,并实现以下方法:
方法名 | 描述 |
---|---|
inputSchema | 返回一个 StructType 对象,表示输入数据的模式。这个模式应该匹配你的 UDAF 所期望的输入 |
bufferSchema | 返回一个 StructType 对象,表示中间缓冲区的模式。在聚合过程中,你的 UDAF 将使用这个模式来存储中间结果。 |
dataType | 返回一个 DataType 对象,表示你的 UDAF 的返回类型。 |
deterministic | 返回一个布尔值,表示你的 UDAF 是否是确定性的。 如果对于相同的输入,你的 UDAF 总是产生相同的输出,那么这个方法应该返回 true 。 |
initialize | 接收一个 MutableAggregationBuffer 对象,并将其初始化为你的 UDAF 的初始值。MutableAggregationBuffer 是一个可变的行对象,可以用来存储中间结果。 |
update | 方法接收一个 MutableAggregationBuffer 对象和一个 Row 对象。MutableAggregationBuffer 对象用来存储中间结果,Row 对象表示输入的数据。这个方法应该根据输入的 Row 更新 MutableAggregationBuffer 。 |
merge | 这个方法接收两个 MutableAggregationBuffer 对象,表示两个中间结果。这个方法应该将这两个中间结果合并。 |
evaluate | 这个方法接收一个 Row 对象,表示最终的中间结果。这个方法应该返回你的 UDAF 的最终结果。 |
4.2 代码示例
代码示例如下, 功能为计算输入数据中每个单词的出现次数。
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}
object SparkSQLUDFUDAF {
def main (args: Array[String]) {
/**
* 第1步:创建Spark的配置对象SparkConf *
*/
val conf = new SparkConf() //创建SparkConf对象
conf.setAppName("SparkSQLUDFUDAF") //设置应用程序的名称,在程序运行的监控界面可以看到名称
// conf.setMaster("spark://Master:7077") //此时,程序在Spark集群
conf.setMaster("local[4]")
/**
* 第2步:创建SparkContext对象
* SparkContext是Spark程序所有功能的唯一入口,无论是采用Scala、Java、Python、R等都必须有一个SparkContext
* SparkContext核心作用:初始化Spark应用程序运行所需要的核心组件,包括DAGScheduler、TaskScheduler、SchedulerBackend
* 同时还会负责Spark程序往Master注册程序等
* SparkContext是整个Spark应用程序中最为至关重要的一个对象
*/
val sc = new SparkContext(conf) //创建SparkContext对象,通过传入SparkConf实例来定制Spark运行的具体参数和配置信息
val sqlContext = new SQLContext(sc) //构建SQL上下文
//模拟实际使用的数据
val bigData = Array("Spark", "Spark", "Hadoop", "Spark", "Hadoop", "Spark", "Spark", "Hadoop", "Spark", "Hadoop")
/**
* 基于提供的数据创建DataFrame
*/
val bigDataRDD = sc.parallelize(bigData)
val bigDataRDDRow = bigDataRDD.map(item => Row(item))
val structType = StructType(Array(StructField("word", StringType, true)))
val bigDataDF = sqlContext.createDataFrame(bigDataRDDRow,structType)
bigDataDF.registerTempTable("bigDataTable") //注册成为临时表
/**
* 通过SQLContext注册UDF,在Scala 2.10.x版本UDF函数最多可以接受22个输入参数
*/
sqlContext.udf.register("computeLength", (input: String) => input.length)
//直接在SQL语句中使用UDF,就像使用SQL自动的内部函数一样
sqlContext.sql("select word, computeLength(word) as length from bigDataTable").show
sqlContext.udf.register("wordCount", new MyUDAF)
sqlContext.sql("select word,wordCount(word) as count,computeLength(word) as length" +
" from bigDataTable group by word").show()
while(true)()
}
}
/**
* 按照模板实现UDAF
*/
class MyUDAF extends UserDefinedAggregateFunction {
/**
* 该方法指定具体输入数据的类型
* @return
*/
override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))
/**
* 在进行聚合操作的时候所要处理的数据的结果的类型
* @return
*/
override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))
/**
* 指定UDAF函数计算后返回的结果类型
* @return
*/
override def dataType: DataType = IntegerType
override def deterministic: Boolean = true
/**
* 在Aggregate之前每组数据的初始化结果
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}
/**
* 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
* 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + 1
}
/**
* 最后在分布式节点进行Local Reduce完成后需要进行全局级别的Merge操作
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
/**
* 返回UDAF最后的计算结果
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}