Spark SQL 2.自定义函数


1. SparkSQL 中自定义函数类型

在 Spark SQL 中,用户自定义函数(User-Defined Function,简称 UDF)是一种特殊的函数,允许用户定义自己的逻辑来处理数据。这些函数可以直接在 Spark SQL 查询中使用,就像使用内置函数一样。在Spark中,用户自定义函数可以分为以下三种类型:

  1. UDF(User-Defined Function):这是最常见的用户自定义函数类型。UDF接收一行输入,返回一个结果。这种类型的函数在实现上就是普通的Scala函数。例如,你可能有一个UDF,它接收一个字符串,然后返回该字符串的长度。
  2. UDTF(User-Defined Table-Generating Functions):这种类型的函数接收一行输入,返回多行(也就是一个表)。在SparkSQL中,这种类型的函数并没有明确的定义,因为Spark中的flatMap函数已经可以实现这个功能。例如,你可能有一个UDTF,它接收一个字符串,然后返回一个包含该字符串中每个字符的表。
  3. UDAF(User-Defined Aggregate Functions):这种类型的函数接收多行输入,返回一行结果。这里的”A”代表的是”aggregate”,也就是聚合的意思。如果业务逻辑复杂,可能需要自己实现聚合函数。例如,你可能有一个UDAF,它接收一个包含多个数字的表,然后返回这些数字的平均值。

实质上讲,例如说 UDF 会被 Spark SQL 中的 Catalyst 封装成为 Expression,最终会通过 eval 方法来计算输入的数据 Row(此处的Row和 DataFrame 中的 Row 没有任何关系)

2. UDF 介绍

Spark SQL 中的用户自定义函数(User-Defined Functions,简称 UDF)是一种特殊类型的函数,允许用户定义自己的逻辑来处理数据。这些函数可以直接在 Spark SQL 查询中使用,就像使用内置函数一样。

UDF 是一种接收一行输入并返回一个结果的函数。这种类型的函数在实现上就是普通的 Scala 或者 Python 函数。例如,你可能有一个 UDF,它接收一个字符串,然后返回该字符串的长度。

创建 Spark SQL 的 UDF 通常包含以下步骤:

  1. 定义一个函数:这个函数包含你想要应用到数据上的逻辑。这个函数可以是任何接收适当类型的输入并返回一个结果的函数。
  2. 注册函数:在你可以在 SQL 查询中使用你的 UDF 之前,你需要将其注册到 Spark SQL 的上下文中。你可以通过调用 spark.udf.register 方法来完成这个步骤,其中 spark 是你的 SparkSession 对象,register 方法接收两个参数:你的 UDF 的名称(一个字符串),以及你的函数。
  3. 在 SQL 查询中使用 UDF:一旦你的 UDF 被注册,你就可以在 SQL 查询中像使用其他 SQL 函数一样使用它了。

使用 UDF 可以让你在处理数据时有更大的灵活性,因为你可以定义自己的数据处理逻辑。这在处理复杂的数据转换或计算时特别有用。

代码示例

下面是一个 UDF代码示例,主要功能是读取CSV文件中的数据,然后使用一个用户自定义函数(UDF)对数据进行处理。这个UDF会检查输入的字符串是否全为数字,如果是,则返回输入本身,否则返回”0”。

import java.util.regex.{Matcher, Pattern}
import org.apache.spark.SparkConf
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.{DataFrame, SparkSession}

object SparkUDF {

  def main(args: Array[String]): Unit = {
    // 创建一个SparkConf对象,设置Spark的运行模式和应用名称
    val sparkConf: SparkConf = new SparkConf().setMaster("local[8]").setAppName("sparkCSV")

    // 创建一个SparkSession对象,使用SparkConf对象的配置
    val session: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
    // 设置Spark的日志级别为WARN
    session.sparkContext.setLogLevel("WARN")
    // 读取CSV文件,将其加载为一个DataFrame对象
    val frame: DataFrame = session
      .read
      .format("csv")
      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")
      .option("header", "true")
      .option("multiLine", true)
      .load("file:///D:\\datas\\datafiles")

    // 创建或替换一个临时视图,视图名为"house_sale"
    frame.createOrReplaceTempView("house_sale")

    // 创建一个UDF1对象,用于定义一个用户自定义函数,然后将其注册为udf
    session.udf.register("house_udf",new UDF1[String,String] {

      // 定义一个正则表达式模式,用于匹配全数字的字符串
      val pattern: Pattern = Pattern.compile("^[0-9]*$")
      // 定义UDF1对象的call方法,该方法接收一个字符串输入,返回一个字符串输出
      override def call(input: String): String = {
        // 使用定义的正则表达式模式对输入进行匹配 
        val matcher: Matcher = pattern.matcher(input)
        // 如果输入匹配全数字的模式,则返回输入本身,否则返回 "0"
        if(matcher.matches()){
          input
        }else{
          "0"
        }
      }
    },DataTypes.StringType)

    // 使用注册的udf对"house_sale"视图进行查询,并显示查询结果的前200行
    session.sql("select house_udf(house_age) from house_sale  limit 200").show()
    // 停止SparkSession
    session.stop()
  }
}

3. UDTF 介绍

3.1 UDTF 介绍

UDTF (User-Defined Table-Generating Functions,简称 )是用户自定义表生成函数,它可以接收一行输入并返回多行数据。

在 Spark SQL 中,UDTF 的概念并没有明确的定义,因为 Spark 中的 flatMap 函数已经可以实现这个功能。flatMap 函数接收一个函数作为参数,这个函数应该接收一个输入并返回一个迭代器。flatMap 函数将这个函数应用到数据集中的每个元素,然后将返回的所有迭代器连接成一个新的数据集。

例如,你可能有一个 UDTF,它接收一个字符串,然后返回一个包含该字符串中每个字符的表。你可以使用 flatMap 函数来实现这个功能。

val words = Seq("hello", "world").toDF("word")
val characters = words.flatMap(_.getString(0).toSeq)

在这个例子中,flatMap 函数接收一个函数,这个函数将一个字符串转换为一个字符序列。然后,flatMap 函数将这个函数应用到 “word” 列中的每个值,生成一个新的 DataFrame,其中包含每个单词中的每个字符。

3.2 代码示例

下面代码中首先创建了一个简单的 DataFrame,包含一个字符串列。然后,我们定义了一个 UDTF,它将一个字符串转换为一个字符数组。最后,我们使用 flatMap 函数将这个 UDTF 应用到 DataFrame 中的每一行,生成一个新的 DataFrame,其中包含每个单词中的每个字符。

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}

object UDTFExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("UDTF Example").getOrCreate()

    // 创建一个简单的 DataFrame,包含一个字符串列
    val wordsDataFrame = spark.sparkContext.parallelize(Seq("hello", "world")).toDF("word")

    // 定义一个 UDTF,将一个字符串转换为一个字符数组
    val explode = (s: String) => s.toCharArray.map(_.toString)

    // 使用 flatMap 函数应用 UDTF
    val charactersDataFrame = wordsDataFrame.flatMap(row => explode(row.getAs[String]("word")))

    charactersDataFrame.show()

    spark.stop()
  }
}

4. UDAF 介绍

4.1 UserDefinedAggregateFunction

在 Spark 编写自定义的 UDAF 时需要继承 UserDefinedAggregateFunction 类,UserDefinedAggregateFunction 是 Spark SQL 中的一个抽象类,用于创建用户自定义的聚合函数(User-Defined Aggregate Functions,简称 UDAF)。UDAF 是一种特殊类型的函数,可以处理多行输入并返回一个聚合的输出结果。

当你创建一个 UDAF 时,你需要继承 UserDefinedAggregateFunction 类,并实现以下方法:

方法名 描述
inputSchema 返回一个 StructType 对象,表示输入数据的模式。这个模式应该匹配你的 UDAF 所期望的输入
bufferSchema 返回一个 StructType 对象,表示中间缓冲区的模式。
在聚合过程中,你的 UDAF 将使用这个模式来存储中间结果。
dataType 返回一个 DataType 对象,表示你的 UDAF 的返回类型。
deterministic 返回一个布尔值,表示你的 UDAF 是否是确定性的。
如果对于相同的输入,你的 UDAF 总是产生相同的输出,那么这个方法应该返回 true
initialize 接收一个 MutableAggregationBuffer 对象,并将其初始化为你的 UDAF 的初始值。
MutableAggregationBuffer 是一个可变的行对象,可以用来存储中间结果。
update 方法接收一个 MutableAggregationBuffer 对象和一个 Row 对象。MutableAggregationBuffer 对象用来存储中间结果,Row 对象表示输入的数据。这个方法应该根据输入的 Row 更新 MutableAggregationBuffer
merge 这个方法接收两个 MutableAggregationBuffer 对象,表示两个中间结果。这个方法应该将这两个中间结果合并。
evaluate 这个方法接收一个 Row 对象,表示最终的中间结果。这个方法应该返回你的 UDAF 的最终结果。

4.2 代码示例

代码示例如下, 功能为计算输入数据中每个单词的出现次数。

import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}

object SparkSQLUDFUDAF {

  def main (args: Array[String]) {
    /**
      * 第1步:创建Spark的配置对象SparkConf     *
      */
    val conf = new SparkConf() //创建SparkConf对象
    conf.setAppName("SparkSQLUDFUDAF") //设置应用程序的名称,在程序运行的监控界面可以看到名称
    //    conf.setMaster("spark://Master:7077") //此时,程序在Spark集群
    conf.setMaster("local[4]")
    /**
      * 第2步:创建SparkContext对象
      * SparkContext是Spark程序所有功能的唯一入口,无论是采用Scala、Java、Python、R等都必须有一个SparkContext
      * SparkContext核心作用:初始化Spark应用程序运行所需要的核心组件,包括DAGScheduler、TaskScheduler、SchedulerBackend
      * 同时还会负责Spark程序往Master注册程序等
      * SparkContext是整个Spark应用程序中最为至关重要的一个对象
      */
    val sc = new SparkContext(conf) //创建SparkContext对象,通过传入SparkConf实例来定制Spark运行的具体参数和配置信息

    val sqlContext = new SQLContext(sc) //构建SQL上下文

    //模拟实际使用的数据
    val bigData = Array("Spark", "Spark", "Hadoop", "Spark", "Hadoop", "Spark", "Spark", "Hadoop", "Spark", "Hadoop")

    /**
      * 基于提供的数据创建DataFrame
      */
    val bigDataRDD =  sc.parallelize(bigData)
    val bigDataRDDRow = bigDataRDD.map(item => Row(item))
    val structType = StructType(Array(StructField("word", StringType, true)))
    val bigDataDF = sqlContext.createDataFrame(bigDataRDDRow,structType)

    bigDataDF.registerTempTable("bigDataTable") //注册成为临时表

    /**
      * 通过SQLContext注册UDF,在Scala 2.10.x版本UDF函数最多可以接受22个输入参数
      */
    sqlContext.udf.register("computeLength", (input: String) => input.length)

    //直接在SQL语句中使用UDF,就像使用SQL自动的内部函数一样
    sqlContext.sql("select word, computeLength(word) as length from bigDataTable").show

    sqlContext.udf.register("wordCount", new MyUDAF)

    sqlContext.sql("select word,wordCount(word) as count,computeLength(word) as length" +
      " from bigDataTable group by word").show()

    while(true)()
  }
}

/**
  * 按照模板实现UDAF
  */
class  MyUDAF extends UserDefinedAggregateFunction {
  /**
    * 该方法指定具体输入数据的类型
    * @return
    */
  override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))

  /**
    * 在进行聚合操作的时候所要处理的数据的结果的类型
    * @return
    */
  override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))

  /**
    * 指定UDAF函数计算后返回的结果类型
    * @return
    */
  override def dataType: DataType = IntegerType

  override def deterministic: Boolean = true

  /**
    * 在Aggregate之前每组数据的初始化结果
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}

  /**
    * 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
    * 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
    * @param buffer
    * @param input
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  /**
    * 最后在分布式节点进行Local Reduce完成后需要进行全局级别的Merge操作
    * @param buffer1
    * @param buffer2
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }

  /**
    * 返回UDAF最后的计算结果
    * @param buffer
    * @return
    */
  override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}

文章作者: hnbian
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 hnbian !
评论
 上一篇
倒排索引 倒排索引
1. 介绍倒排索引源于实际应用中需要根据属性的值来查找记录。 这种索引表中的每一项都包括一个属性值和具有该属性值的各记录的地址。 由于不是由记录来确定属性值, 而是由属性值来确定记录的位置, 因而称为倒排索引(inverted index)
2018-05-29
下一篇 
Spark SQL 1. 常见概念与基本操作 Spark SQL 1. 常见概念与基本操作
1. SparkSQL 概述1.1 Shark Shark 是 Databricks 开发出专门针对于spark的构建大规模数据仓库系统的一个框架 Shark 与 Hive 兼容,同时也依赖于Spark版本 Shark是把sql语句解析
2018-05-21
  目录