Pivoting a table with Spark

Pivoting a table with Spark

A pivot table is an old friend of Business Intelligence: it allows to summarize the data of another table, usually by performing aggregations (sum, mean, max, ...) on aggregated data. They are found in multiple reporting tools (i.e. Qlik, Tableau and Excel) and analytical RDBMS (i.e. Oracle) also implement them. Finally open source languages and libraries also have the option to pivot a table: pandas (Python) and reshape (R).

In the Hadoop environment, query engines like Hive and Impala don't have the option to do so. You have to rely on user defined functions or using a GROUP BY and CASE combination. However, spark 1.6 changed this situation, adding the method pivot() to their DataFrame API.

Transposition and pivoting

Transposition: a permutation of a set of elements that interchanges two elements and leaves the remaining elements in their original positions.
Dictionary.com

In a mathematical space, transposition is widely used in linear algebra. The operation changes rows for columns and columns for rows.

$$[A^T]_{ij} = A_{ji}$$

In the following example we can se how \( \beta \) and \( \gamma \) change positions.

$$\begin{pmatrix}\alpha & \beta \\ \gamma & \delta \end{pmatrix}^T =
\begin{pmatrix}\alpha & \gamma \\ \beta & \delta \end{pmatrix} $$

This exact operation is not that useful when working with tables of data, however, a similar transformation is widely used: pivoting. Pivot tables are a common way to summarize the contents of a table, shifting the data only over one column. A simple example could be a log-like table of sales in a small company with 4 stores.

shop_id product_sold quantity
3 tv 2
2 pc 1
3 tv 1
3 keyboard 1
1 tv 2
2 keyboard 2
3 pc 1
2 tv 3
1 pc 2

A way to summarize the data could be a groupBy(shop_id, product_sold).agg(sum), but it doesn't give a result where a simple glance is enough in order to have a rough idea of the statistics. You better use another type of table, for example considering the products as a column, mantaining shop_idin place, converting the content of product_sold to columns and then performing aggregations on the column quantity. The resulting table whould be the following:

shop_id keyboard pc tv
1 2 2
2 2 1 3
3 1 1 3

Here oy can easily see that shop 1 didn't sell any keboards, and it is a good format to create plots and graphs. In spark, you need to folow this order:

ds.groupBy("shop_id").
    pivot("product_sold").
    agg(sum)

Custom aggregations

While the previous code is useful for working with numeric values, sometimes is not enough due to nature of the data we are working with. If we are using strings or a case class as the type of a column then is is mandatory to create an user defined aggregate function.

If you are used to work with spark RDDs, the concept is what you have done every work day: aggregate a key value RDD and then process in some way the Seq(...) that resulted using map and a custom function (it could be a sum, a concatenation...).

However, it won't be that easy with spark SQL due to the complex typing and optimization that Spark applies under the hood. We will need to extend the class UserDefinedAggregateFunction, and implement its methods.

In the following example that you can execute directly in the spark-shell, I have developed a function that only takes the first element that finds in the content to aggregate. All the methods implemented are the required ones, and it works at least in Spark version 2.2.0

class AppendStringAgg extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction {
  import org.apache.spark.sql.Row
  import org.apache.spark.sql.expressions.MutableAggregationBuffer
  import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
  // Schema-related methods
  override def inputSchema: StructType = StructType(StructField("value", StringType)  :: Nil)
  // This will be the schema used internally in the UDAF (I almost wrote UASF, bitcoin folks will understand ;) )
  override def bufferSchema: StructType = StructType(StructField("v", StringType) :: Nil)
  //
  override def dataType: DataType = StringType
  override def deterministic: Boolean = true
  // Start the aggregation
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
  }
  // What happens when  it has merge a buffer with another aggregate
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, sumIfNotNull(buffer1.getString(0), buffer2.getString(0)))
  }
  // Iterate over all the entrances
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, sumIfNotNull(buffer.getString(0), input.getString(0)))
  }
  //This is the end, my friend
  override def evaluate(buffer: Row): Any = { buffer(0) }
}

Then, the only thing to do is to create a new variable with the class. and apply the function to a column.

val aggregator = new AppendStringAgg()

val pivoted = df.groupBy("A", "B").pivot("category").agg(aggregator($"value"))

Note: What we have done with the UDAF it is also possible to do with the method first, that is already implemented. I have done it just as a toy example

For more information about pivot tables, please go to this Databricks blog post or in this Office support page if you want to play a bit with pivot tables.