Type-safe UDAF in Spark can be implemented by extending the abstract class Aggregator and applying it to datasets. It is loosely based on Aggregator from algebird: https://github.com/twitter/algebird. The following methods needs to be defined after extending the Aggregator class:
- public Encoder<Average> bufferEncoder(): Encoder is for the intermediate value type and is similar to the untyped UDAF function bufferSchema().
- public Encoder<Double> outputEncoder(): Specifies the encoder for the final output value type, that is, the output of the UDAF function and similar to untyped the UDAF function dataType().
- public Average zero(): The zero() method should satisfy the property that any b + zero = b. It is similar to the untyped UDAF function initialize().
- public Average reduce(Average buffer, Employee employee): Aggregates two values of the same kind and returns the updated value. It is similar to the untyped UDAF function update().
- public Average merge(Average b1, Average b2): Merges two intermediate values and is similar to the untyped UDAF function merge().
- public Double finish(Average reduction): Calculates the final output of UDAF and is similar to the untyped UDAF function evaluate().
Drawing a similarity between the previous untyped UDAF to calculate the average, let's follow the same objective of calculating the average but using a type-safe UDAF.
Create a class with variables whose final values for a given iteration will determine the output of the UDAF, such as, in the case of average, the two variables are sum and count:
publicclass Average implements Serializable {
private static final long serialVersionUID = 1L;
private double sumVal;
privatelong countVal;
public Average() {
}
public Average(long sumVal, long countVal) {
super();
this.sumVal = sumVal;
this.countVal = countVal;
}
publicdouble getSumVal() {
return sumVal;
}
publicvoid setSumVal(double sumVal) {
this.sumVal = sumVal;
}
publiclong getCountVal() {
return countVal;
}
publicvoid setCountVal(long countVal) {
this.countVal = countVal;
}
}
Create a UDAF class extending the Aggregator class of package org.apache.spark.sql.expressions.Aggregator:
publicclass Average implements Serializable {
private static final long serialVersionUID = 1L;
private double sumVal;
privatelong countVal;
public Average() {
}
public Average(long sumVal, long countVal) {
super();
this.sumVal = sumVal;
this.countVal = countVal;
}
publicdouble getSumVal() {
return sumVal;
}
publicvoid setSumVal(double sumVal) {
this.sumVal = sumVal;
}
publiclong getCountVal() {
return countVal;
}
publicvoid setCountVal(long countVal) {
this.countVal = countVal;
}
}publicclass TypeSafeUDAF extends Aggregator<Employee, Average, Double> implements Serializable{
private static final long serialVersionUID = 1L;
public Average zero() {
returnnew Average(0L, 0L);
}
public Average reduce(Average buffer, Employee employee) {
double newSum = buffer.getSumVal() + employee.getSalary();
long newCount = buffer.getCountVal() + 1;
buffer.setSumVal(newSum);
buffer.setCountVal(newCount);
return buffer;
}
public Average merge(Average b1, Average b2) {
double mergedSum = b1.getSumVal() + b2.getSumVal();
long mergedCount = b1.getCountVal() + b2.getCountVal();
b1.setSumVal(mergedSum);
b1.setCountVal(mergedCount);
return b1;
}
public Double finish(Average reduction) {
return ((double) reduction.getSumVal()) / reduction.getCountVal();
}
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
Instantiate the UDAF class, convert the instance to a typed column giving it an alias as the column name, and then pass it to the dataset:
TypeSafeUDAF typeSafeUDAF=new TypeSafeUDAF();
Dataset<Employee> emf = emp_ds.as(Encoders.bean(Employee.class));
TypedColumn<Employee, Double> averageSalary = typeSafeUDAF.toColumn().name("averageTypeSafe");
Dataset<Double> result = emf.select(averageSalary);