如何select每个组的第一行?

我有一个DataFrame生成如下:

df.groupBy($"Hour", $"Category") .agg(sum($"value") as "TotalValue") .sort($"Hour".asc, $"TotalValue".desc)) 

结果如下所示:

 +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| | ...| ....| ....| +----+--------+----------+ 

正如你所看到的,DataFrame按Hour顺序sorting,然后由TotalValue按降序排列。

我想select每个组的第一行,即

  • 从小时数== 0组中select(0,cat26,30.9)
  • 从小时数== 1组中select(1,cat67,28.5)
  • 从小时数== 2组中select(2,cat56,39.6)
  • 等等

所以期望的输出将是:

 +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 1| cat67| 28.5| | 2| cat56| 39.6| | 3| cat8| 35.6| | ...| ...| ...| +----+--------+----------+ 

能够select每个组的前N行也许是方便的。

任何帮助,高度赞赏。

窗口function

像这样的事情应该做的伎俩:

 import org.apache.spark.sql.functions.{row_number, max, broadcast} import org.apache.spark.sql.expressions.Window val df = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue") val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc) val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

这种方法在数据歪斜严重的情况下效率不高。

简单的SQL聚合,然后join

或者,你可以join聚合数据框:

 val dfMax = df.groupBy($"hour").agg(max($"TotalValue")) val dfTopByJoin = df.join(broadcast(dfMax), ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value")) .drop("max_hour") .drop("max_value") dfTopByJoin.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

它会保持重复的值(如果每小时有多个类别的总值相同)。 您可以删除这些如下:

 dfTopByJoin .groupBy($"hour") .agg( first("category").alias("category"), first("TotalValue").alias("TotalValue")) 

使用顺序structs

整洁,虽然没有很好的testing,不需要连接或窗口function的技巧:

 val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs")) .groupBy($"hour") .agg(max("vs").alias("vs")) .select($"Hour", $"vs.Category", $"vs.TotalValue") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

使用DataSet API (Spark 1.6+,2.0+):

Spark 1.6

 case class Record(Hour: Integer, Category: String, TotalValue: Double) df.as[Record] .groupBy($"hour") .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y) .show // +---+--------------+ // | _1| _2| // +---+--------------+ // |[0]|[0,cat26,30.9]| // |[1]|[1,cat67,28.5]| // |[2]|[2,cat56,39.6]| // |[3]| [3,cat8,35.6]| // +---+--------------+ 

Spark 2.0或更高版本

 df.as[Record] .groupByKey(_.Hour) .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y) 

最后两种方法可以利用地图边合并,不需要完全洗牌,所以大部分时间应该比窗口函数和连接performance出更好的性能。

不要使用

 df.orderBy(...).groupBy(...).agg(first(...), ...) 

它可能似乎工作(特别是在local模式),但它是不可靠的( SPARK-16207 )。 向Tzach Zohar提供连接相关JIRA问题的信息 。

同样的说明适用于

 df.orderBy(...).dropDuplicates(...) 

内部使用等效的执行计划。

对于由多列分组的Spark 2.0.2:

 import org.apache.spark.sql.functions.row_number import org.apache.spark.sql.expressions.Window val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc) val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") 

如果dataframe必须按多列分组,则可以提供帮助

 val keys = List("Hour", "Category"); val selectFirstValueOfNoneGroupedColumns = df.columns .filterNot(keys.toSet) .map(_ -> "first").toMap val grouped = df.groupBy(keys.head, keys.tail: _*) .agg(selectFirstValueOfNoneGroupedColumns) 

希望这可以帮助有类似问题的人

对于spark> 2.0,我们可以简单地做:
groupBy($"Hour").agg(df_op.columns.map((_, "first")).toMap)

详细使用OP的设置:

 val df_op = df.groupBy($"Hour", $"Category") .agg(sum($"value") as "TotalValue") .sort($"Hour".asc, $"TotalValue".desc)) df_op.groupBy($"Hour").agg(df_op.columns.map((_, "first")).toMap) 

这是使用RelationalGroupedDatasetagg方法Compute aggregates by specifying a map from column name to aggregate methods.Compute aggregates by specifying a map from column name to aggregate methods.first是一个sql聚合函数。

我们可以使用rank()窗口函数(在这里你会selectrank = 1)rank只是为一个组的每一行添加一个数字(在这种情况下,它将是小时)

这里是一个例子。 (从https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank

 val dataset = spark.range(9).withColumn("bucket", 'id % 3) import org.apache.spark.sql.expressions.Window val byBucket = Window.partitionBy('bucket).orderBy('id) scala> dataset.withColumn("rank", rank over byBucket).show +---+------+----+ | id|bucket|rank| +---+------+----+ | 0| 0| 1| | 3| 0| 2| | 6| 0| 3| | 1| 1| 1| | 4| 1| 2| | 7| 1| 3| | 2| 2| 1| | 5| 2| 2| | 8| 2| 3| +---+------+----+