如何select每个组的第一行?
我有一个DataFrame生成如下:
df.groupBy($"Hour", $"Category") .agg(sum($"value") as "TotalValue") .sort($"Hour".asc, $"TotalValue".desc))
结果如下所示:
+----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| | ...| ....| ....| +----+--------+----------+
正如你所看到的,DataFrame按Hour
顺序sorting,然后由TotalValue
按降序排列。
我想select每个组的第一行,即
- 从小时数== 0组中select(0,cat26,30.9)
- 从小时数== 1组中select(1,cat67,28.5)
- 从小时数== 2组中select(2,cat56,39.6)
- 等等
所以期望的输出将是:
+----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 1| cat67| 28.5| | 2| cat56| 39.6| | 3| cat8| 35.6| | ...| ...| ...| +----+--------+----------+
能够select每个组的前N行也许是方便的。
任何帮助,高度赞赏。
窗口function :
像这样的事情应该做的伎俩:
import org.apache.spark.sql.functions.{row_number, max, broadcast} import org.apache.spark.sql.expressions.Window val df = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue") val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc) val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+
这种方法在数据歪斜严重的情况下效率不高。
简单的SQL聚合,然后join
:
或者,你可以join聚合数据框:
val dfMax = df.groupBy($"hour").agg(max($"TotalValue")) val dfTopByJoin = df.join(broadcast(dfMax), ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value")) .drop("max_hour") .drop("max_value") dfTopByJoin.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+
它会保持重复的值(如果每小时有多个类别的总值相同)。 您可以删除这些如下:
dfTopByJoin .groupBy($"hour") .agg( first("category").alias("category"), first("TotalValue").alias("TotalValue"))
使用顺序structs
:
整洁,虽然没有很好的testing,不需要连接或窗口function的技巧:
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs")) .groupBy($"hour") .agg(max("vs").alias("vs")) .select($"Hour", $"vs.Category", $"vs.TotalValue") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+
使用DataSet API (Spark 1.6+,2.0+):
Spark 1.6 :
case class Record(Hour: Integer, Category: String, TotalValue: Double) df.as[Record] .groupBy($"hour") .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y) .show // +---+--------------+ // | _1| _2| // +---+--------------+ // |[0]|[0,cat26,30.9]| // |[1]|[1,cat67,28.5]| // |[2]|[2,cat56,39.6]| // |[3]| [3,cat8,35.6]| // +---+--------------+
Spark 2.0或更高版本 :
df.as[Record] .groupByKey(_.Hour) .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
最后两种方法可以利用地图边合并,不需要完全洗牌,所以大部分时间应该比窗口函数和连接performance出更好的性能。
不要使用 :
df.orderBy(...).groupBy(...).agg(first(...), ...)
它可能似乎工作(特别是在local
模式),但它是不可靠的( SPARK-16207 )。 向Tzach Zohar提供连接相关JIRA问题的信息 。
同样的说明适用于
df.orderBy(...).dropDuplicates(...)
内部使用等效的执行计划。
对于由多列分组的Spark 2.0.2:
import org.apache.spark.sql.functions.row_number import org.apache.spark.sql.expressions.Window val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc) val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
如果dataframe必须按多列分组,则可以提供帮助
val keys = List("Hour", "Category"); val selectFirstValueOfNoneGroupedColumns = df.columns .filterNot(keys.toSet) .map(_ -> "first").toMap val grouped = df.groupBy(keys.head, keys.tail: _*) .agg(selectFirstValueOfNoneGroupedColumns)
希望这可以帮助有类似问题的人
对于spark> 2.0,我们可以简单地做:
groupBy($"Hour").agg(df_op.columns.map((_, "first")).toMap)
详细使用OP的设置:
val df_op = df.groupBy($"Hour", $"Category") .agg(sum($"value") as "TotalValue") .sort($"Hour".asc, $"TotalValue".desc)) df_op.groupBy($"Hour").agg(df_op.columns.map((_, "first")).toMap)
这是使用RelationalGroupedDataset
的agg
方法Compute aggregates by specifying a map from column name to aggregate methods.
来Compute aggregates by specifying a map from column name to aggregate methods.
。 first
是一个sql聚合函数。
我们可以使用rank()窗口函数(在这里你会selectrank = 1)rank只是为一个组的每一行添加一个数字(在这种情况下,它将是小时)
这里是一个例子。 (从https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )
val dataset = spark.range(9).withColumn("bucket", 'id % 3) import org.apache.spark.sql.expressions.Window val byBucket = Window.partitionBy('bucket).orderBy('id) scala> dataset.withColumn("rank", rank over byBucket).show +---+------+----+ | id|bucket|rank| +---+------+----+ | 0| 0| 1| | 3| 0| 2| | 6| 0| 3| | 1| 1| 1| | 4| 1| 2| | 7| 1| 3| | 2| 2| 1| | 5| 2| 2| | 8| 2| 3| +---+------+----+