Spark SQL 2.自定义函数


1. SparkSQL 中自定义函数类型

用户自定义函数类别分为以下三种:

  1. UDF:输入一行,返回一个结果(一对一),实现上讲就是普通的Scala函数;

  2. UDTF:输入一行,返回多行(一对多),在SparkSQL中没有,因为Spark中使用flatMap即可实现这个功能

  3. UDAF:输入多行,返回一行,这里的A是aggregate,聚合的意思,如果业务复杂,需要自己实现聚合函数

实质上讲,例如说UDF会被Spark SQL中的Catalyst封装成为Expression,最终会通过eval方法来计算输入的数据Row(此处的Row和DataFrame中的Row没有任何关系)

2. UDF 代码示例


import java.util.regex.{Matcher, Pattern}
import org.apache.spark.SparkConf
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.{DataFrame, SparkSession}

object SparkUDF {

  def main(args: Array[String]): Unit = {
    val sparkConf: SparkConf = new SparkConf().setMaster("local[8]").setAppName("sparkCSV")

    val session: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
    session.sparkContext.setLogLevel("WARN")
    val frame: DataFrame = session
      .read
      .format("csv")
      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")
      .option("header", "true")
      .option("multiLine", true)
      .load("file:///D:\\datas\\datafiles")

    frame.createOrReplaceTempView("house_sale")


    // 创建 UDF1 对象并注册成 udf
    session.udf.register("house_udf",new UDF1[String,String] {

      val pattern: Pattern = Pattern.compile("^[0-9]*$")
      override def call(input: String): String = {
        val matcher: Matcher = pattern.matcher(input)
        if(matcher.matches()){
          input
        }else{
          "1990"
        }
      }
    },DataTypes.StringType)

    session.sql("select house_udf(house_age) from house_sale  limit 200").show()
    session.stop()
  }

}

3. UDAF 代码示例


package SparkSQLByScala
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}

object SparkSQLUDFUDAF {

  def main (args: Array[String]) {
    /**
      * 第1步:创建Spark的配置对象SparkConf     *
      */
    val conf = new SparkConf() //创建SparkConf对象
    conf.setAppName("SparkSQLUDFUDAF") //设置应用程序的名称,在程序运行的监控界面可以看到名称
    //    conf.setMaster("spark://Master:7077") //此时,程序在Spark集群
    conf.setMaster("local[4]")
    /**
      * 第2步:创建SparkContext对象
      * SparkContext是Spark程序所有功能的唯一入口,无论是采用Scala、Java、Python、R等都必须有一个SparkContext
      * SparkContext核心作用:初始化Spark应用程序运行所需要的核心组件,包括DAGScheduler、TaskScheduler、SchedulerBackend
      * 同时还会负责Spark程序往Master注册程序等
      * SparkContext是整个Spark应用程序中最为至关重要的一个对象
      */
    val sc = new SparkContext(conf) //创建SparkContext对象,通过传入SparkConf实例来定制Spark运行的具体参数和配置信息

    val sqlContext = new SQLContext(sc) //构建SQL上下文

    //模拟实际使用的数据
    val bigData = Array("Spark", "Spark", "Hadoop", "Spark", "Hadoop", "Spark", "Spark", "Hadoop", "Spark", "Hadoop")

    /**
      * 基于提供的数据创建DataFrame
      */
    val bigDataRDD =  sc.parallelize(bigData)
    val bigDataRDDRow = bigDataRDD.map(item => Row(item))
    val structType = StructType(Array(StructField("word", StringType, true)))
    val bigDataDF = sqlContext.createDataFrame(bigDataRDDRow,structType)

    bigDataDF.registerTempTable("bigDataTable") //注册成为临时表

    /**
      * 通过SQLContext注册UDF,在Scala 2.10.x版本UDF函数最多可以接受22个输入参数
      */
    sqlContext.udf.register("computeLength", (input: String) => input.length)

    //直接在SQL语句中使用UDF,就像使用SQL自动的内部函数一样
    sqlContext.sql("select word, computeLength(word) as length from bigDataTable").show

    sqlContext.udf.register("wordCount", new MyUDAF)

    sqlContext.sql("select word,wordCount(word) as count,computeLength(word) as length" +
      " from bigDataTable group by word").show()

    while(true)()
  }
}

/**
  * 按照模板实现UDAF
  */
class  MyUDAF extends UserDefinedAggregateFunction {
  /**
    * 该方法指定具体输入数据的类型
    * @return
    */
  override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))

  /**
    * 在进行聚合操作的时候所要处理的数据的结果的类型
    * @return
    */
  override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))

  /**
    * 指定UDAF函数计算后返回的结果类型
    * @return
    */
  override def dataType: DataType = IntegerType

  override def deterministic: Boolean = true

  /**
    * 在Aggregate之前每组数据的初始化结果
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}

  /**
    * 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
    * 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
    * @param buffer
    * @param input
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  /**
    * 最后在分布式节点进行Local Reduce完成后需要进行全局级别的Merge操作
    * @param buffer1
    * @param buffer2
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }


  /**
    * 返回UDAF最后的计算结果
    * @param buffer
    * @return
    */
  override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}

文章作者: hnbian
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 hnbian !
评论
 上一篇
倒排索引 倒排索引
1. 介绍倒排索引源于实际应用中需要根据属性的值来查找记录。 这种索引表中的每一项都包括一个属性值和具有该属性值的各记录的地址。 由于不是由记录来确定属性值, 而是由属性值来确定记录的位置, 因而称为倒排索引(inverted index)
2018-05-29
下一篇 
Spark SQL 1. 常见概念与基本操作 Spark SQL 1. 常见概念与基本操作
1. SparkSQL 概述1.1 Shark Shark 是 Databricks 开发出专门针对于spark的构建大规模数据仓库系统的一个框架 Shark 与 Hive 兼容,同时也依赖于Spark版本 Shark是把sql语句解析
2018-05-21
  目录