如何在Spark SQL中定义和使用用户定义的聚合函数?

我知道如何在Spark SQL中编写UDF:

def belowThreshold(power: Int): Boolean = { return power < -40 } sqlContext.udf.register("belowThreshold", belowThreshold _) 

我可以做类似的定义一个聚合函数吗? 这是怎么做的?

对于上下文,我想运行下面的SQL查询:

 val aggDF = sqlContext.sql("""SELECT span, belowThreshold(opticalReceivePower), timestamp FROM ifDF WHERE opticalReceivePower IS NOT null GROUP BY span, timestamp ORDER BY span""") 

它应该返回类似的东西

Row(span1, false, T0)

我想要聚合函数来告诉我,如果在由spantimestamp定义的组中的opticalReceivePower有任何值在阈值以下。 我是否需要以不同的方式将我的UDAF写入上面粘贴的UDF?

支持的方法

Spark 2.0+ (可选1.6+,但具有略微不同的API):

可以在types化Datasets上使用Aggregators

 import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{Encoder, Encoders} class BelowThreshold[I](f: I => Boolean) extends Aggregator[I, Boolean, Boolean] with Serializable { def zero = false def reduce(acc: Boolean, x: I) = acc | f(x) def merge(acc1: Boolean, acc2: Boolean) = acc1 | acc2 def finish(acc: Boolean) = acc def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean } val belowThreshold = new BelowThreshold[(String, Int)](_._2 < - 40).toColumn df.as[(String, Int)].groupByKey(_._1).agg(belowThreshold) 

Spark> = 1.5

在Spark 1.5中,你可以像这样创buildUDAF,尽pipe这很可能是一种矫枉过正的行为:

 import org.apache.spark.sql.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.Row object belowThreshold extends UserDefinedAggregateFunction { // Schema you get as an input def inputSchema = new StructType().add("power", IntegerType) // Schema of the row which is used for aggregation def bufferSchema = new StructType().add("ind", BooleanType) // Returned type def dataType = BooleanType // Self-explaining def deterministic = true // zero value def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, false) // Similar to seqOp in aggregate def update(buffer: MutableAggregationBuffer, input: Row) = { if (!input.isNullAt(0)) buffer.update(0, buffer.getBoolean(0) | input.getInt(0) < -40) } // Similar to combOp in aggregate def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { buffer1.update(0, buffer1.getBoolean(0) | buffer2.getBoolean(0)) } // Called on exit to get return value def evaluate(buffer: Row) = buffer.getBoolean(0) } 

用法示例:

 df .groupBy($"group") .agg(belowThreshold($"power").alias("belowThreshold")) .show // +-----+--------------+ // |group|belowThreshold| // +-----+--------------+ // | a| false| // | b| true| // +-----+--------------+ 

Spark 1.4解决方法

我不确定我是否正确理解了你的要求,但据我所知,在这里,简单的旧聚合应该足够了:

 val df = sc.parallelize(Seq( ("a", 0), ("a", 1), ("b", 30), ("b", -50))).toDF("group", "power") df .withColumn("belowThreshold", ($"power".lt(-40)).cast(IntegerType)) .groupBy($"group") .agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold")) .show // +-----+--------------+ // |group|belowThreshold| // +-----+--------------+ // | a| false| // | b| true| // +-----+--------------+ 

Spark <= 1.4

据我所知,在这个时候(Spark 1.4.1),除了Hive之外,并不支持UDAF。 应该可以使用Spark 1.5(参见SPARK-3947 )。

不支持/内部方法

内部Spark使用一些类,包括ImperativeAggregatesDeclarativeAggregates

这是为内部使用而devise的,可能会在没有进一步通知的情况下进行更改,因此可能不是您想在生产代码中使用的东西,而是为了完整性,使用DeclarativeAggregate BelowThreshold可以像这样执行(使用Spark 2.2-SNAPSHOT进行testing):

 import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ case class BelowThreshold(child: Expression, threshold: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = Seq(child, threshold) override def nullable: Boolean = false override def dataType: DataType = BooleanType private lazy val belowThreshold = AttributeReference( "belowThreshold", BooleanType, nullable = false )() // Used to derive schema override lazy val aggBufferAttributes = belowThreshold :: Nil override lazy val initialValues = Seq( Literal(false) ) override lazy val updateExpressions = Seq(Or( belowThreshold, If(IsNull(child), Literal(false), LessThan(child, threshold)) )) override lazy val mergeExpressions = Seq( Or(belowThreshold.left, belowThreshold.right) ) override lazy val evaluateExpression = belowThreshold override def defaultResult: Option[Literal] = Option(Literal(false)) } 

它应该进一步包装与相当于与withAggregateFunction