自定义函数

User-defined functions are an important feature, because they significantly extend the expressiveness of queries.

Register User-Defined Functions

In most cases, a user-defined function must be registered before it can be used in an query. It is not necessary to register functions for the Scala Table API.

Functions are registered at the TableEnvironment by calling a registerFunction() method. When a user-defined function is registered, it is inserted into the function catalog of the TableEnvironment such that the Table API or SQL parser can recognize and properly translate it.

Please find detailed examples of how to register and how to call each type of user-defined function (ScalarFunction, TableFunction, and AggregateFunction) in the following sub-sessions.

Scalar Functions

If a required scalar function is not contained in the built-in functions, it is possible to define custom, user-defined scalar functions for both the Table API and SQL. A user-defined scalar functions maps zero, one, or multiple scalar values to a new scalar value.

In order to define a scalar function, one has to extend the base class ScalarFunction in org.apache.flink.table.functions and implement (one or more) evaluation methods. The behavior of a scalar function is determined by the evaluation method. An evaluation method must be declared publicly and named eval. The parameter types and return type of the evaluation method also determine the parameter and return types of the scalar function. Evaluation methods can also be overloaded by implementing multiple methods named eval. Evaluation methods can also support variable arguments, such as eval(String… strs).

The following example shows how to define your own hash code function, register it in the TableEnvironment, and call it in a query. Note that you can configure your scalar function via a constructor before it is registered:

  1. public class HashCode extends ScalarFunction {
  2. private int factor = 12;
  3. public HashCode(int factor) {
  4. this.factor = factor;
  5. }
  6. public int eval(String s) {
  7. return s.hashCode() * factor;
  8. }
  9. }
  10. BatchTableEnvironment tableEnv = BatchTableEnvironment.create(env);
  11. // register the function
  12. tableEnv.registerFunction("hashCode", new HashCode(10));
  13. // use the function in Java Table API
  14. myTable.select("string, string.hashCode(), hashCode(string)");
  15. // use the function in SQL API
  16. tableEnv.sqlQuery("SELECT string, hashCode(string) FROM MyTable");

By default the result type of an evaluation method is determined by Flink’s type extraction facilities. This is sufficient for basic types or simple POJOs but might be wrong for more complex, custom, or composite types. In these cases TypeInformation of the result type can be manually defined by overriding ScalarFunction#getResultType().

The following example shows an advanced example which takes the internal timestamp representation and also returns the internal timestamp representation as a long value. By overriding ScalarFunction#getResultType() we define that the returned long value should be interpreted as a Types.TIMESTAMP by the code generation.

  1. public static class TimestampModifier extends ScalarFunction {
  2. public long eval(long t) {
  3. return t % 1000;
  4. }
  5. public TypeInformation<?> getResultType(Class<?>[] signature) {
  6. return Types.SQL_TIMESTAMP;
  7. }
  8. }

In order to define a scalar function, one has to extend the base class ScalarFunction in org.apache.flink.table.functions and implement (one or more) evaluation methods. The behavior of a scalar function is determined by the evaluation method. An evaluation method must be declared publicly and named eval. The parameter types and return type of the evaluation method also determine the parameter and return types of the scalar function. Evaluation methods can also be overloaded by implementing multiple methods named eval. Evaluation methods can also support variable arguments, such as @varargs def eval(str: String*).

The following example shows how to define your own hash code function, register it in the TableEnvironment, and call it in a query. Note that you can configure your scalar function via a constructor before it is registered:

  1. // must be defined in static/object context
  2. class HashCode(factor: Int) extends ScalarFunction {
  3. def eval(s: String): Int = {
  4. s.hashCode() * factor
  5. }
  6. }
  7. val tableEnv = BatchTableEnvironment.create(env)
  8. // use the function in Scala Table API
  9. val hashCode = new HashCode(10)
  10. myTable.select('string, hashCode('string))
  11. // register and use the function in SQL
  12. tableEnv.registerFunction("hashCode", new HashCode(10))
  13. tableEnv.sqlQuery("SELECT string, hashCode(string) FROM MyTable")

By default the result type of an evaluation method is determined by Flink’s type extraction facilities. This is sufficient for basic types or simple POJOs but might be wrong for more complex, custom, or composite types. In these cases TypeInformation of the result type can be manually defined by overriding ScalarFunction#getResultType().

The following example shows an advanced example which takes the internal timestamp representation and also returns the internal timestamp representation as a long value. By overriding ScalarFunction#getResultType() we define that the returned long value should be interpreted as a Types.TIMESTAMP by the code generation.

  1. object TimestampModifier extends ScalarFunction {
  2. def eval(t: Long): Long = {
  3. t % 1000
  4. }
  5. override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {
  6. Types.TIMESTAMP
  7. }
  8. }

Note Python 3.5+ and apache-beam==2.15.0 are required to run the Python scalar function.

Note By default PyFlink uses the command “python” to run the python udf workers. Before starting cluster, run the following command to confirm that it meets the requirements:

  1. $ python --version
  2. # the version printed here must be 3.5+
  3. $ python -m pip install apache-beam==2.15.0

Note Currently, Python UDF is supported in Blink planner both under streaming and batch mode while is only supported under streaming mode in old planner.

It supports to use both Java/Scala scalar functions and Python scalar functions in Python Table API and SQL. In order to define a Python scalar function, one can extend the base class ScalarFunction in pyflink.table.udf and implement an evaluation method. The behavior of a Python scalar function is determined by the evaluation method. An evaluation method must be named eval. Evaluation method can also support variable arguments, such as eval(*args).

The following example shows how to define your own Java and Python hash code functions, register them in the TableEnvironment, and call them in a query. Note that you can configure your scalar function via a constructor before it is registered:

  1. '''
  2. Java code:
  3. // The Java class must have a public no-argument constructor and can be founded in current Java classloader.
  4. public class HashCode extends ScalarFunction {
  5. private int factor = 12;
  6. public int eval(String s) {
  7. return s.hashCode() * factor;
  8. }
  9. }
  10. '''
  11. class PyHashCode(ScalarFunction):
  12. def __init__(self):
  13. self.factor = 12
  14. def eval(self, s):
  15. return hash(s) * self.factor
  16. table_env = BatchTableEnvironment.create(env)
  17. # register the Java function
  18. table_env.register_java_function("hashCode", "my.java.function.HashCode")
  19. # register the Python function
  20. table_env.register_function("py_hash_code", udf(PyHashCode(), DataTypes.BIGINT(), DataTypes.BIGINT()))
  21. # use the function in Python Table API
  22. my_table.select("string, bigint, string.hashCode(), hashCode(string), bigint.py_hash_code(), py_hash_code(bigint)")
  23. # use the function in SQL API
  24. table_env.sql_query("SELECT string, bigint, hashCode(string), py_hash_code(bigint) FROM MyTable")

There are many ways to define a Python scalar function besides extending the base class ScalarFunction. The following example shows the different ways to define a Python scalar function which takes two columns of bigint as input parameters and returns the sum of them as the result.

  1. # option 1: extending the base class `ScalarFunction`
  2. class Add(ScalarFunction):
  3. def eval(self, i, j):
  4. return i + j
  5. add = udf(Add(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
  6. # option 2: Python function
  7. @udf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()], result_type=DataTypes.BIGINT())
  8. def add(i, j):
  9. return i + j
  10. # option 3: lambda function
  11. add = udf(lambda i, j: i + j, [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
  12. # option 4: callable function
  13. class CallableAdd(object):
  14. def __call__(self, i, j):
  15. return i + j
  16. add = udf(CallableAdd(), [DataTypes.BIGINT(), DataTypes.BIGINT()], DataTypes.BIGINT())
  17. # option 5: partial function
  18. def partial_add(i, j, k):
  19. return i + j + k
  20. add = udf(functools.partial(partial_add, k=1), [DataTypes.BIGINT(), DataTypes.BIGINT()],
  21. DataTypes.BIGINT())
  22. # register the Python function
  23. table_env.register_function("add", add)
  24. # use the function in Python Table API
  25. my_table.select("add(a, b)")

If the python scalar function depends on third-party dependencies, you can specify the dependencies with the following table APIs or through command line directly when submitting the job.

APIsDescription
add_python_file(file_path)Adds python file dependencies which could be python files, python packages or local directories. They will be added to the PYTHONPATH of the python UDF worker.
  1. table_env.add_python_file(file_path)
set_python_requirements(requirements_file_path, requirements_cache_dir=None)Specifies a requirements.txt file which defines the third-party dependencies. These dependencies will be installed to a temporary directory and added to the PYTHONPATH of the python UDF worker. For the dependencies which could not be accessed in the cluster, a directory which contains the installation packages of these dependencies could be specified using the parameter "requirements_cached_dir". It will be uploaded to the cluster to support offline installation.
  1. # commands executed in shellecho numpy==1.16.5 > requirements.txtpip download -d cached_dir -r requirements.txt no-binary :all:# python codetable_env.set_python_requirements("requirements.txt", "cached_dir")
Please make sure the installation packages matches the platform of the cluster and the python version used. These packages will be installed using pip, so also make sure the version of Pip (version >= 7.1.0) and the version of SetupTools (version >= 37.0.0).
add_python_archive(archive_path, target_dir=None)Adds a python archive file dependency. The file will be extracted to the working directory of python UDF worker. If the parameter "target_dir" is specified, the archive file will be extracted to a directory named "target_dir". Otherwise, the archive file will be extracted to a directory with the same name of the archive file.
  1. # command executed in shell# assert the relative path of python interpreter is py_env/bin/pythonzip -r py_env.zip py_env# python codetable_env.add_python_archive("py_env.zip")# ortable_env.add_python_archive("py_env.zip", "myenv")# the files contained in the archive file can be accessed in UDFdef my_udf(): with open("myenv/py_env/data/data.txt") as f:
Please make sure the uploaded python environment matches the platform that the cluster is running on. Currently only zip-format is supported. i.e. zip, jar, whl, egg, etc.
set_python_executable(python_exec)Sets the path of the python interpreter which is used to execute the python udf workers, e.g., "/usr/local/bin/python3".
  1. table_env.add_python_archive("py_env.zip")table_env.get_config().set_python_executable("py_env.zip/py_env/bin/python")
Please make sure that the specified environment matches the platform that the cluster is running on.

Table Functions

Similar to a user-defined scalar function, a user-defined table function takes zero, one, or multiple scalar values as input parameters. However in contrast to a scalar function, it can return an arbitrary number of rows as output instead of a single value. The returned rows may consist of one or more columns.

In order to define a table function one has to extend the base class TableFunction in org.apache.flink.table.functions and implement (one or more) evaluation methods. The behavior of a table function is determined by its evaluation methods. An evaluation method must be declared public and named eval. The TableFunction can be overloaded by implementing multiple methods named eval. The parameter types of the evaluation methods determine all valid parameters of the table function. Evaluation methods can also support variable arguments, such as eval(String… strs). The type of the returned table is determined by the generic type of TableFunction. Evaluation methods emit output rows using the protected collect(T) method.

In the Table API, a table function is used with .joinLateral or .leftOuterJoinLateral. The joinLateral operator (cross) joins each row from the outer table (table on the left of the operator) with all rows produced by the table-valued function (which is on the right side of the operator). The leftOuterJoinLateral operator joins each row from the outer table (table on the left of the operator) with all rows produced by the table-valued function (which is on the right side of the operator) and preserves outer rows for which the table function returns an empty table. In SQL use LATERAL TABLE(<TableFunction>) with CROSS JOIN and LEFT JOIN with an ON TRUE join condition (see examples below).

The following example shows how to define table-valued function, register it in the TableEnvironment, and call it in a query. Note that you can configure your table function via a constructor before it is registered:

  1. // The generic type "Tuple2<String, Integer>" determines the schema of the returned table as (String, Integer).
  2. public class Split extends TableFunction<Tuple2<String, Integer>> {
  3. private String separator = " ";
  4. public Split(String separator) {
  5. this.separator = separator;
  6. }
  7. public void eval(String str) {
  8. for (String s : str.split(separator)) {
  9. // use collect(...) to emit a row
  10. collect(new Tuple2<String, Integer>(s, s.length()));
  11. }
  12. }
  13. }
  14. BatchTableEnvironment tableEnv = BatchTableEnvironment.create(env);
  15. Table myTable = ... // table schema: [a: String]
  16. // Register the function.
  17. tableEnv.registerFunction("split", new Split("#"));
  18. // Use the table function in the Java Table API. "as" specifies the field names of the table.
  19. myTable.joinLateral("split(a) as (word, length)")
  20. .select("a, word, length");
  21. myTable.leftOuterJoinLateral("split(a) as (word, length)")
  22. .select("a, word, length");
  23. // Use the table function in SQL with LATERAL and TABLE keywords.
  24. // CROSS JOIN a table function (equivalent to "join" in Table API).
  25. tableEnv.sqlQuery("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)");
  26. // LEFT JOIN a table function (equivalent to "leftOuterJoin" in Table API).
  27. tableEnv.sqlQuery("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE");
  1. // The generic type "(String, Int)" determines the schema of the returned table as (String, Integer).
  2. class Split(separator: String) extends TableFunction[(String, Int)] {
  3. def eval(str: String): Unit = {
  4. // use collect(...) to emit a row.
  5. str.split(separator).foreach(x => collect((x, x.length)))
  6. }
  7. }
  8. val tableEnv = BatchTableEnvironment.create(env)
  9. val myTable = ... // table schema: [a: String]
  10. // Use the table function in the Scala Table API (Note: No registration required in Scala Table API).
  11. val split = new Split("#")
  12. // "as" specifies the field names of the generated table.
  13. myTable.joinLateral(split('a) as ('word, 'length)).select('a, 'word, 'length)
  14. myTable.leftOuterJoinLateral(split('a) as ('word, 'length)).select('a, 'word, 'length)
  15. // Register the table function to use it in SQL queries.
  16. tableEnv.registerFunction("split", new Split("#"))
  17. // Use the table function in SQL with LATERAL and TABLE keywords.
  18. // CROSS JOIN a table function (equivalent to "join" in Table API)
  19. tableEnv.sqlQuery("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
  20. // LEFT JOIN a table function (equivalent to "leftOuterJoin" in Table API)
  21. tableEnv.sqlQuery("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")

IMPORTANT: Do not implement TableFunction as a Scala object. Scala object is a singleton and will cause concurrency issues.

  1. '''
  2. Java code:
  3. // The generic type "Tuple2<String, Integer>" determines the schema of the returned table as (String, Integer).
  4. // The java class must have a public no-argument constructor and can be founded in current java classloader.
  5. public class Split extends TableFunction<Tuple2<String, Integer>> {
  6. private String separator = " ";
  7. public void eval(String str) {
  8. for (String s : str.split(separator)) {
  9. // use collect(...) to emit a row
  10. collect(new Tuple2<String, Integer>(s, s.length()));
  11. }
  12. }
  13. }
  14. '''
  15. table_env = BatchTableEnvironment.create(env)
  16. my_table = ... # type: Table, table schema: [a: String]
  17. # Register the java function.
  18. table_env.register_java_function("split", "my.java.function.Split")
  19. # Use the table function in the Python Table API. "as" specifies the field names of the table.
  20. my_table.join_lateral("split(a) as (word, length)").select("a, word, length")
  21. my_table.left_outer_join_lateral("split(a) as (word, length)").select("a, word, length")
  22. # Register the python function.
  23. # Use the table function in SQL with LATERAL and TABLE keywords.
  24. # CROSS JOIN a table function (equivalent to "join" in Table API).
  25. table_env.sql_query("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
  26. # LEFT JOIN a table function (equivalent to "left_outer_join" in Table API).
  27. table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")

Please note that POJO types do not have a deterministic field order. Therefore, you cannot rename the fields of POJO returned by a table function using AS.

By default the result type of a TableFunction is determined by Flink’s automatic type extraction facilities. This works well for basic types and simple POJOs but might be wrong for more complex, custom, or composite types. In such a case, the type of the result can be manually specified by overriding TableFunction#getResultType() which returns its TypeInformation.

The following example shows an example of a TableFunction that returns a Row type which requires explicit type information. We define that the returned table type should be RowTypeInfo(String, Integer) by overriding TableFunction#getResultType().

  1. public class CustomTypeSplit extends TableFunction<Row> {
  2. public void eval(String str) {
  3. for (String s : str.split(" ")) {
  4. Row row = new Row(2);
  5. row.setField(0, s);
  6. row.setField(1, s.length());
  7. collect(row);
  8. }
  9. }
  10. @Override
  11. public TypeInformation<Row> getResultType() {
  12. return Types.ROW(Types.STRING(), Types.INT());
  13. }
  14. }
  1. class CustomTypeSplit extends TableFunction[Row] {
  2. def eval(str: String): Unit = {
  3. str.split(" ").foreach({ s =>
  4. val row = new Row(2)
  5. row.setField(0, s)
  6. row.setField(1, s.length)
  7. collect(row)
  8. })
  9. }
  10. override def getResultType: TypeInformation[Row] = {
  11. Types.ROW(Types.STRING, Types.INT)
  12. }
  13. }

Aggregation Functions

User-Defined Aggregate Functions (UDAGGs) aggregate a table (one or more rows with one or more attributes) to a scalar value.

UDAGG mechanism

The above figure shows an example of an aggregation. Assume you have a table that contains data about beverages. The table consists of three columns, id, name and price and 5 rows. Imagine you need to find the highest price of all beverages in the table, i.e., perform a max() aggregation. You would need to check each of the 5 rows and the result would be a single numeric value.

User-defined aggregation functions are implemented by extending the AggregateFunction class. An AggregateFunction works as follows. First, it needs an accumulator, which is the data structure that holds the intermediate result of the aggregation. An empty accumulator is created by calling the createAccumulator() method of the AggregateFunction. Subsequently, the accumulate() method of the function is called for each input row to update the accumulator. Once all rows have been processed, the getValue() method of the function is called to compute and return the final result.

The following methods are mandatory for each AggregateFunction:

  • createAccumulator()
  • accumulate()
  • getValue()

Flink’s type extraction facilities can fail to identify complex data types, e.g., if they are not basic types or simple POJOs. So similar to ScalarFunction and TableFunction, AggregateFunction provides methods to specify the TypeInformation of the result type (through AggregateFunction#getResultType()) and the type of the accumulator (through AggregateFunction#getAccumulatorType()).

Besides the above methods, there are a few contracted methods that can be optionally implemented. While some of these methods allow the system more efficient query execution, others are mandatory for certain use cases. For instance, the merge() method is mandatory if the aggregation function should be applied in the context of a session group window (the accumulators of two session windows need to be joined when a row is observed that “connects” them).

The following methods of AggregateFunction are required depending on the use case:

  • retract() is required for aggregations on bounded OVER windows.
  • merge() is required for many batch aggregations and session window aggregations.
  • resetAccumulator() is required for many batch aggregations.

All methods of AggregateFunction must be declared as public, not static and named exactly as the names mentioned above. The methods createAccumulator, getValue, getResultType, and getAccumulatorType are defined in the AggregateFunction abstract class, while others are contracted methods. In order to define a aggregate function, one has to extend the base class org.apache.flink.table.functions.AggregateFunction and implement one (or more) accumulate methods. The method accumulate can be overloaded with different parameter types and supports variable arguments.

Detailed documentation for all methods of AggregateFunction is given below.

  1. /**
  2. * Base class for user-defined aggregates and table aggregates.
  3. *
  4. * @param <T> the type of the aggregation result.
  5. * @param <ACC> the type of the aggregation accumulator. The accumulator is used to keep the
  6. * aggregated values which are needed to compute an aggregation result.
  7. */
  8. public abstract class UserDefinedAggregateFunction<T, ACC> extends UserDefinedFunction {
  9. /**
  10. * Creates and init the Accumulator for this (table)aggregate function.
  11. *
  12. * @return the accumulator with the initial value
  13. */
  14. public ACC createAccumulator(); // MANDATORY
  15. /**
  16. * Returns the TypeInformation of the (table)aggregate function's result.
  17. *
  18. * @return The TypeInformation of the (table)aggregate function's result or null if the result
  19. * type should be automatically inferred.
  20. */
  21. public TypeInformation<T> getResultType = null; // PRE-DEFINED
  22. /**
  23. * Returns the TypeInformation of the (table)aggregate function's accumulator.
  24. *
  25. * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
  26. * accumulator type should be automatically inferred.
  27. */
  28. public TypeInformation<ACC> getAccumulatorType = null; // PRE-DEFINED
  29. }
  30. /**
  31. * Base class for aggregation functions.
  32. *
  33. * @param <T> the type of the aggregation result
  34. * @param <ACC> the type of the aggregation accumulator. The accumulator is used to keep the
  35. * aggregated values which are needed to compute an aggregation result.
  36. * AggregateFunction represents its state using accumulator, thereby the state of the
  37. * AggregateFunction must be put into the accumulator.
  38. */
  39. public abstract class AggregateFunction<T, ACC> extends UserDefinedAggregateFunction<T, ACC> {
  40. /** Processes the input values and update the provided accumulator instance. The method
  41. * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
  42. * requires at least one accumulate() method.
  43. *
  44. * @param accumulator the accumulator which contains the current aggregated results
  45. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  46. */
  47. public void accumulate(ACC accumulator, [user defined inputs]); // MANDATORY
  48. /**
  49. * Retracts the input values from the accumulator instance. The current design assumes the
  50. * inputs are the values that have been previously accumulated. The method retract can be
  51. * overloaded with different custom types and arguments. This function must be implemented for
  52. * datastream bounded over aggregate.
  53. *
  54. * @param accumulator the accumulator which contains the current aggregated results
  55. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  56. */
  57. public void retract(ACC accumulator, [user defined inputs]); // OPTIONAL
  58. /**
  59. * Merges a group of accumulator instances into one accumulator instance. This function must be
  60. * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
  61. *
  62. * @param accumulator the accumulator which will keep the merged aggregate results. It should
  63. * be noted that the accumulator may contain the previous aggregated
  64. * results. Therefore user should not replace or clean this instance in the
  65. * custom merge method.
  66. * @param its an {@link java.lang.Iterable} pointed to a group of accumulators that will be
  67. * merged.
  68. */
  69. public void merge(ACC accumulator, java.lang.Iterable<ACC> its); // OPTIONAL
  70. /**
  71. * Called every time when an aggregation result should be materialized.
  72. * The returned value could be either an early and incomplete result
  73. * (periodically emitted as data arrive) or the final result of the
  74. * aggregation.
  75. *
  76. * @param accumulator the accumulator which contains the current
  77. * aggregated results
  78. * @return the aggregation result
  79. */
  80. public T getValue(ACC accumulator); // MANDATORY
  81. /**
  82. * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
  83. * dataset grouping aggregate.
  84. *
  85. * @param accumulator the accumulator which needs to be reset
  86. */
  87. public void resetAccumulator(ACC accumulator); // OPTIONAL
  88. /**
  89. * Returns true if this AggregateFunction can only be applied in an OVER window.
  90. *
  91. * @return true if the AggregateFunction requires an OVER window, false otherwise.
  92. */
  93. public Boolean requiresOver = false; // PRE-DEFINED
  94. }
  1. /**
  2. * Base class for user-defined aggregates and table aggregates.
  3. *
  4. * @tparam T the type of the aggregation result.
  5. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  6. * aggregated values which are needed to compute an aggregation result.
  7. */
  8. abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction {
  9. /**
  10. * Creates and init the Accumulator for this (table)aggregate function.
  11. *
  12. * @return the accumulator with the initial value
  13. */
  14. def createAccumulator(): ACC // MANDATORY
  15. /**
  16. * Returns the TypeInformation of the (table)aggregate function's result.
  17. *
  18. * @return The TypeInformation of the (table)aggregate function's result or null if the result
  19. * type should be automatically inferred.
  20. */
  21. def getResultType: TypeInformation[T] = null // PRE-DEFINED
  22. /**
  23. * Returns the TypeInformation of the (table)aggregate function's accumulator.
  24. *
  25. * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
  26. * accumulator type should be automatically inferred.
  27. */
  28. def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
  29. }
  30. /**
  31. * Base class for aggregation functions.
  32. *
  33. * @tparam T the type of the aggregation result
  34. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  35. * aggregated values which are needed to compute an aggregation result.
  36. * AggregateFunction represents its state using accumulator, thereby the state of the
  37. * AggregateFunction must be put into the accumulator.
  38. */
  39. abstract class AggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] {
  40. /**
  41. * Processes the input values and update the provided accumulator instance. The method
  42. * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
  43. * requires at least one accumulate() method.
  44. *
  45. * @param accumulator the accumulator which contains the current aggregated results
  46. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  47. */
  48. def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY
  49. /**
  50. * Retracts the input values from the accumulator instance. The current design assumes the
  51. * inputs are the values that have been previously accumulated. The method retract can be
  52. * overloaded with different custom types and arguments. This function must be implemented for
  53. * datastream bounded over aggregate.
  54. *
  55. * @param accumulator the accumulator which contains the current aggregated results
  56. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  57. */
  58. def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL
  59. /**
  60. * Merges a group of accumulator instances into one accumulator instance. This function must be
  61. * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
  62. *
  63. * @param accumulator the accumulator which will keep the merged aggregate results. It should
  64. * be noted that the accumulator may contain the previous aggregated
  65. * results. Therefore user should not replace or clean this instance in the
  66. * custom merge method.
  67. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be
  68. * merged.
  69. */
  70. def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
  71. /**
  72. * Called every time when an aggregation result should be materialized.
  73. * The returned value could be either an early and incomplete result
  74. * (periodically emitted as data arrive) or the final result of the
  75. * aggregation.
  76. *
  77. * @param accumulator the accumulator which contains the current
  78. * aggregated results
  79. * @return the aggregation result
  80. */
  81. def getValue(accumulator: ACC): T // MANDATORY
  82. /**
  83. * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
  84. * dataset grouping aggregate.
  85. *
  86. * @param accumulator the accumulator which needs to be reset
  87. */
  88. def resetAccumulator(accumulator: ACC): Unit // OPTIONAL
  89. /**
  90. * Returns true if this AggregateFunction can only be applied in an OVER window.
  91. *
  92. * @return true if the AggregateFunction requires an OVER window, false otherwise.
  93. */
  94. def requiresOver: Boolean = false // PRE-DEFINED
  95. }

The following example shows how to

  • define an AggregateFunction that calculates the weighted average on a given column,
  • register the function in the TableEnvironment, and
  • use the function in a query.

To calculate an weighted average value, the accumulator needs to store the weighted sum and count of all the data that has been accumulated. In our example we define a class WeightedAvgAccum to be the accumulator. Accumulators are automatically backup-ed by Flink’s checkpointing mechanism and restored in case of a failure to ensure exactly-once semantics.

The accumulate() method of our WeightedAvg AggregateFunction has three inputs. The first one is the WeightedAvgAccum accumulator, the other two are user-defined inputs: input value ivalue and weight of the input iweight. Although the retract(), merge(), and resetAccumulator() methods are not mandatory for most aggregation types, we provide them below as examples. Please note that we used Java primitive types and defined getResultType() and getAccumulatorType() methods in the Scala example because Flink type extraction does not work very well for Scala types.

  1. /**
  2. * Accumulator for WeightedAvg.
  3. */
  4. public static class WeightedAvgAccum {
  5. public long sum = 0;
  6. public int count = 0;
  7. }
  8. /**
  9. * Weighted Average user-defined aggregate function.
  10. */
  11. public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {
  12. @Override
  13. public WeightedAvgAccum createAccumulator() {
  14. return new WeightedAvgAccum();
  15. }
  16. @Override
  17. public Long getValue(WeightedAvgAccum acc) {
  18. if (acc.count == 0) {
  19. return null;
  20. } else {
  21. return acc.sum / acc.count;
  22. }
  23. }
  24. public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
  25. acc.sum += iValue * iWeight;
  26. acc.count += iWeight;
  27. }
  28. public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
  29. acc.sum -= iValue * iWeight;
  30. acc.count -= iWeight;
  31. }
  32. public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
  33. Iterator<WeightedAvgAccum> iter = it.iterator();
  34. while (iter.hasNext()) {
  35. WeightedAvgAccum a = iter.next();
  36. acc.count += a.count;
  37. acc.sum += a.sum;
  38. }
  39. }
  40. public void resetAccumulator(WeightedAvgAccum acc) {
  41. acc.count = 0;
  42. acc.sum = 0L;
  43. }
  44. }
  45. // register function
  46. StreamTableEnvironment tEnv = ...
  47. tEnv.registerFunction("wAvg", new WeightedAvg());
  48. // use function
  49. tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
  1. import java.lang.{Long => JLong, Integer => JInteger}
  2. import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
  3. import org.apache.flink.api.java.typeutils.TupleTypeInfo
  4. import org.apache.flink.table.api.Types
  5. import org.apache.flink.table.functions.AggregateFunction
  6. /**
  7. * Accumulator for WeightedAvg.
  8. */
  9. class WeightedAvgAccum extends JTuple1[JLong, JInteger] {
  10. sum = 0L
  11. count = 0
  12. }
  13. /**
  14. * Weighted Average user-defined aggregate function.
  15. */
  16. class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] {
  17. override def createAccumulator(): WeightedAvgAccum = {
  18. new WeightedAvgAccum
  19. }
  20. override def getValue(acc: WeightedAvgAccum): JLong = {
  21. if (acc.count == 0) {
  22. null
  23. } else {
  24. acc.sum / acc.count
  25. }
  26. }
  27. def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
  28. acc.sum += iValue * iWeight
  29. acc.count += iWeight
  30. }
  31. def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
  32. acc.sum -= iValue * iWeight
  33. acc.count -= iWeight
  34. }
  35. def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {
  36. val iter = it.iterator()
  37. while (iter.hasNext) {
  38. val a = iter.next()
  39. acc.count += a.count
  40. acc.sum += a.sum
  41. }
  42. }
  43. def resetAccumulator(acc: WeightedAvgAccum): Unit = {
  44. acc.count = 0
  45. acc.sum = 0L
  46. }
  47. override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = {
  48. new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT)
  49. }
  50. override def getResultType: TypeInformation[JLong] = Types.LONG
  51. }
  52. // register function
  53. val tEnv: StreamTableEnvironment = ???
  54. tEnv.registerFunction("wAvg", new WeightedAvg())
  55. // use function
  56. tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")
  1. '''
  2. Java code:
  3. /**
  4. * Accumulator for WeightedAvg.
  5. */
  6. public static class WeightedAvgAccum {
  7. public long sum = 0;
  8. public int count = 0;
  9. }
  10. // The java class must have a public no-argument constructor and can be founded in current java classloader.
  11. /**
  12. * Weighted Average user-defined aggregate function.
  13. */
  14. public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {
  15. @Override
  16. public WeightedAvgAccum createAccumulator() {
  17. return new WeightedAvgAccum();
  18. }
  19. @Override
  20. public Long getValue(WeightedAvgAccum acc) {
  21. if (acc.count == 0) {
  22. return null;
  23. } else {
  24. return acc.sum / acc.count;
  25. }
  26. }
  27. public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
  28. acc.sum += iValue * iWeight;
  29. acc.count += iWeight;
  30. }
  31. public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
  32. acc.sum -= iValue * iWeight;
  33. acc.count -= iWeight;
  34. }
  35. public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
  36. Iterator<WeightedAvgAccum> iter = it.iterator();
  37. while (iter.hasNext()) {
  38. WeightedAvgAccum a = iter.next();
  39. acc.count += a.count;
  40. acc.sum += a.sum;
  41. }
  42. }
  43. public void resetAccumulator(WeightedAvgAccum acc) {
  44. acc.count = 0;
  45. acc.sum = 0L;
  46. }
  47. }
  48. '''
  49. # register function
  50. t_env = ... # type: StreamTableEnvironment
  51. t_env.register_java_function("wAvg", "my.java.function.WeightedAvg")
  52. # use function
  53. t_env.sql_query("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")

Table Aggregation Functions

User-Defined Table Aggregate Functions (UDTAGGs) aggregate a table (one or more rows with one or more attributes) to a result table with multi rows and columns.

UDAGG mechanism

The above figure shows an example of a table aggregation. Assume you have a table that contains data about beverages. The table consists of three columns, id, name and price and 5 rows. Imagine you need to find the top 2 highest prices of all beverages in the table, i.e., perform a top2() table aggregation. You would need to check each of the 5 rows and the result would be a table with the top 2 values.

User-defined table aggregation functions are implemented by extending the TableAggregateFunction class. A TableAggregateFunction works as follows. First, it needs an accumulator, which is the data structure that holds the intermediate result of the aggregation. An empty accumulator is created by calling the createAccumulator() method of the TableAggregateFunction. Subsequently, the accumulate() method of the function is called for each input row to update the accumulator. Once all rows have been processed, the emitValue() method of the function is called to compute and return the final results.

The following methods are mandatory for each TableAggregateFunction:

  • createAccumulator()
  • accumulate()

Flink’s type extraction facilities can fail to identify complex data types, e.g., if they are not basic types or simple POJOs. So similar to ScalarFunction and TableFunction, TableAggregateFunction provides methods to specify the TypeInformation of the result type (through TableAggregateFunction#getResultType()) and the type of the accumulator (through TableAggregateFunction#getAccumulatorType()).

Besides the above methods, there are a few contracted methods that can be optionally implemented. While some of these methods allow the system more efficient query execution, others are mandatory for certain use cases. For instance, the merge() method is mandatory if the aggregation function should be applied in the context of a session group window (the accumulators of two session windows need to be joined when a row is observed that “connects” them).

The following methods of TableAggregateFunction are required depending on the use case:

  • retract() is required for aggregations on bounded OVER windows.
  • merge() is required for many batch aggregations and session window aggregations.
  • resetAccumulator() is required for many batch aggregations.
  • emitValue() is required for batch and window aggregations.

The following methods of TableAggregateFunction are used to improve the performance of streaming jobs:

  • emitUpdateWithRetract() is used to emit values that have been updated under retract mode.

For emitValue method, it emits full data according to the accumulator. Take TopN as an example, emitValue emit all top n values each time. This may bring performance problems for streaming jobs. To improve the performance, a user can also implement emitUpdateWithRetract method to improve the performance. The method outputs data incrementally in retract mode, i.e., once there is an update, we have to retract old records before sending new updated ones. The method will be used in preference to the emitValue method if they are all defined in the table aggregate function, because emitUpdateWithRetract is treated to be more efficient than emitValue as it can output values incrementally.

All methods of TableAggregateFunction must be declared as public, not static and named exactly as the names mentioned above. The methods createAccumulator, getResultType, and getAccumulatorType are defined in the parent abstract class of TableAggregateFunction, while others are contracted methods. In order to define a table aggregate function, one has to extend the base class org.apache.flink.table.functions.TableAggregateFunction and implement one (or more) accumulate methods. The method accumulate can be overloaded with different parameter types and supports variable arguments.

Detailed documentation for all methods of TableAggregateFunction is given below.

  1. /**
  2. * Base class for user-defined aggregates and table aggregates.
  3. *
  4. * @param <T> the type of the aggregation result.
  5. * @param <ACC> the type of the aggregation accumulator. The accumulator is used to keep the
  6. * aggregated values which are needed to compute an aggregation result.
  7. */
  8. public abstract class UserDefinedAggregateFunction<T, ACC> extends UserDefinedFunction {
  9. /**
  10. * Creates and init the Accumulator for this (table)aggregate function.
  11. *
  12. * @return the accumulator with the initial value
  13. */
  14. public ACC createAccumulator(); // MANDATORY
  15. /**
  16. * Returns the TypeInformation of the (table)aggregate function's result.
  17. *
  18. * @return The TypeInformation of the (table)aggregate function's result or null if the result
  19. * type should be automatically inferred.
  20. */
  21. public TypeInformation<T> getResultType = null; // PRE-DEFINED
  22. /**
  23. * Returns the TypeInformation of the (table)aggregate function's accumulator.
  24. *
  25. * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
  26. * accumulator type should be automatically inferred.
  27. */
  28. public TypeInformation<ACC> getAccumulatorType = null; // PRE-DEFINED
  29. }
  30. /**
  31. * Base class for table aggregation functions.
  32. *
  33. * @param <T> the type of the aggregation result
  34. * @param <ACC> the type of the aggregation accumulator. The accumulator is used to keep the
  35. * aggregated values which are needed to compute a table aggregation result.
  36. * TableAggregateFunction represents its state using accumulator, thereby the state of
  37. * the TableAggregateFunction must be put into the accumulator.
  38. */
  39. public abstract class TableAggregateFunction<T, ACC> extends UserDefinedAggregateFunction<T, ACC> {
  40. /** Processes the input values and update the provided accumulator instance. The method
  41. * accumulate can be overloaded with different custom types and arguments. A TableAggregateFunction
  42. * requires at least one accumulate() method.
  43. *
  44. * @param accumulator the accumulator which contains the current aggregated results
  45. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  46. */
  47. public void accumulate(ACC accumulator, [user defined inputs]); // MANDATORY
  48. /**
  49. * Retracts the input values from the accumulator instance. The current design assumes the
  50. * inputs are the values that have been previously accumulated. The method retract can be
  51. * overloaded with different custom types and arguments. This function must be implemented for
  52. * datastream bounded over aggregate.
  53. *
  54. * @param accumulator the accumulator which contains the current aggregated results
  55. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  56. */
  57. public void retract(ACC accumulator, [user defined inputs]); // OPTIONAL
  58. /**
  59. * Merges a group of accumulator instances into one accumulator instance. This function must be
  60. * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
  61. *
  62. * @param accumulator the accumulator which will keep the merged aggregate results. It should
  63. * be noted that the accumulator may contain the previous aggregated
  64. * results. Therefore user should not replace or clean this instance in the
  65. * custom merge method.
  66. * @param its an {@link java.lang.Iterable} pointed to a group of accumulators that will be
  67. * merged.
  68. */
  69. public void merge(ACC accumulator, java.lang.Iterable<ACC> its); // OPTIONAL
  70. /**
  71. * Called every time when an aggregation result should be materialized. The returned value
  72. * could be either an early and incomplete result (periodically emitted as data arrive) or
  73. * the final result of the aggregation.
  74. *
  75. * @param accumulator the accumulator which contains the current
  76. * aggregated results
  77. * @param out the collector used to output data
  78. */
  79. public void emitValue(ACC accumulator, Collector<T> out); // OPTIONAL
  80. /**
  81. * Called every time when an aggregation result should be materialized. The returned value
  82. * could be either an early and incomplete result (periodically emitted as data arrive) or
  83. * the final result of the aggregation.
  84. *
  85. * Different from emitValue, emitUpdateWithRetract is used to emit values that have been updated.
  86. * This method outputs data incrementally in retract mode, i.e., once there is an update, we
  87. * have to retract old records before sending new updated ones. The emitUpdateWithRetract
  88. * method will be used in preference to the emitValue method if both methods are defined in the
  89. * table aggregate function, because the method is treated to be more efficient than emitValue
  90. * as it can outputvalues incrementally.
  91. *
  92. * @param accumulator the accumulator which contains the current
  93. * aggregated results
  94. * @param out the retractable collector used to output data. Use collect method
  95. * to output(add) records and use retract method to retract(delete)
  96. * records.
  97. */
  98. public void emitUpdateWithRetract(ACC accumulator, RetractableCollector<T> out); // OPTIONAL
  99. /**
  100. * Collects a record and forwards it. The collector can output retract messages with the retract
  101. * method. Note: only use it in {@code emitRetractValueIncrementally}.
  102. */
  103. public interface RetractableCollector<T> extends Collector<T> {
  104. /**
  105. * Retract a record.
  106. *
  107. * @param record The record to retract.
  108. */
  109. void retract(T record);
  110. }
  111. }
  1. /**
  2. * Base class for user-defined aggregates and table aggregates.
  3. *
  4. * @tparam T the type of the aggregation result.
  5. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  6. * aggregated values which are needed to compute an aggregation result.
  7. */
  8. abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction {
  9. /**
  10. * Creates and init the Accumulator for this (table)aggregate function.
  11. *
  12. * @return the accumulator with the initial value
  13. */
  14. def createAccumulator(): ACC // MANDATORY
  15. /**
  16. * Returns the TypeInformation of the (table)aggregate function's result.
  17. *
  18. * @return The TypeInformation of the (table)aggregate function's result or null if the result
  19. * type should be automatically inferred.
  20. */
  21. def getResultType: TypeInformation[T] = null // PRE-DEFINED
  22. /**
  23. * Returns the TypeInformation of the (table)aggregate function's accumulator.
  24. *
  25. * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
  26. * accumulator type should be automatically inferred.
  27. */
  28. def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
  29. }
  30. /**
  31. * Base class for table aggregation functions.
  32. *
  33. * @tparam T the type of the aggregation result
  34. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  35. * aggregated values which are needed to compute an aggregation result.
  36. * TableAggregateFunction represents its state using accumulator, thereby the state of
  37. * the TableAggregateFunction must be put into the accumulator.
  38. */
  39. abstract class TableAggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] {
  40. /**
  41. * Processes the input values and update the provided accumulator instance. The method
  42. * accumulate can be overloaded with different custom types and arguments. A TableAggregateFunction
  43. * requires at least one accumulate() method.
  44. *
  45. * @param accumulator the accumulator which contains the current aggregated results
  46. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  47. */
  48. def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY
  49. /**
  50. * Retracts the input values from the accumulator instance. The current design assumes the
  51. * inputs are the values that have been previously accumulated. The method retract can be
  52. * overloaded with different custom types and arguments. This function must be implemented for
  53. * datastream bounded over aggregate.
  54. *
  55. * @param accumulator the accumulator which contains the current aggregated results
  56. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  57. */
  58. def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL
  59. /**
  60. * Merges a group of accumulator instances into one accumulator instance. This function must be
  61. * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
  62. *
  63. * @param accumulator the accumulator which will keep the merged aggregate results. It should
  64. * be noted that the accumulator may contain the previous aggregated
  65. * results. Therefore user should not replace or clean this instance in the
  66. * custom merge method.
  67. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be
  68. * merged.
  69. */
  70. def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
  71. /**
  72. * Called every time when an aggregation result should be materialized. The returned value
  73. * could be either an early and incomplete result (periodically emitted as data arrive) or
  74. * the final result of the aggregation.
  75. *
  76. * @param accumulator the accumulator which contains the current
  77. * aggregated results
  78. * @param out the collector used to output data
  79. */
  80. def emitValue(accumulator: ACC, out: Collector[T]): Unit // OPTIONAL
  81. /**
  82. * Called every time when an aggregation result should be materialized. The returned value
  83. * could be either an early and incomplete result (periodically emitted as data arrive) or
  84. * the final result of the aggregation.
  85. *
  86. * Different from emitValue, emitUpdateWithRetract is used to emit values that have been updated.
  87. * This method outputs data incrementally in retract mode, i.e., once there is an update, we
  88. * have to retract old records before sending new updated ones. The emitUpdateWithRetract
  89. * method will be used in preference to the emitValue method if both methods are defined in the
  90. * table aggregate function, because the method is treated to be more efficient than emitValue
  91. * as it can outputvalues incrementally.
  92. *
  93. * @param accumulator the accumulator which contains the current
  94. * aggregated results
  95. * @param out the retractable collector used to output data. Use collect method
  96. * to output(add) records and use retract method to retract(delete)
  97. * records.
  98. */
  99. def emitUpdateWithRetract(accumulator: ACC, out: RetractableCollector[T]): Unit // OPTIONAL
  100. /**
  101. * Collects a record and forwards it. The collector can output retract messages with the retract
  102. * method. Note: only use it in `emitRetractValueIncrementally`.
  103. */
  104. trait RetractableCollector[T] extends Collector[T] {
  105. /**
  106. * Retract a record.
  107. *
  108. * @param record The record to retract.
  109. */
  110. def retract(record: T): Unit
  111. }
  112. }

The following example shows how to

  • define a TableAggregateFunction that calculates the top 2 values on a given column,
  • register the function in the TableEnvironment, and
  • use the function in a Table API query(TableAggregateFunction is only supported by Table API).

To calculate the top 2 values, the accumulator needs to store the biggest 2 values of all the data that has been accumulated. In our example we define a class Top2Accum to be the accumulator. Accumulators are automatically backup-ed by Flink’s checkpointing mechanism and restored in case of a failure to ensure exactly-once semantics.

The accumulate() method of our Top2 TableAggregateFunction has two inputs. The first one is the Top2Accum accumulator, the other one is the user-defined input: input value v. Although the merge() method is not mandatory for most table aggregation types, we provide it below as examples. Please note that we used Java primitive types and defined getResultType() and getAccumulatorType() methods in the Scala example because Flink type extraction does not work very well for Scala types.

  1. /**
  2. * Accumulator for Top2.
  3. */
  4. public class Top2Accum {
  5. public Integer first;
  6. public Integer second;
  7. }
  8. /**
  9. * The top2 user-defined table aggregate function.
  10. */
  11. public static class Top2 extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accum> {
  12. @Override
  13. public Top2Accum createAccumulator() {
  14. Top2Accum acc = new Top2Accum();
  15. acc.first = Integer.MIN_VALUE;
  16. acc.second = Integer.MIN_VALUE;
  17. return acc;
  18. }
  19. public void accumulate(Top2Accum acc, Integer v) {
  20. if (v > acc.first) {
  21. acc.second = acc.first;
  22. acc.first = v;
  23. } else if (v > acc.second) {
  24. acc.second = v;
  25. }
  26. }
  27. public void merge(Top2Accum acc, java.lang.Iterable<Top2Accum> iterable) {
  28. for (Top2Accum otherAcc : iterable) {
  29. accumulate(acc, otherAcc.first);
  30. accumulate(acc, otherAcc.second);
  31. }
  32. }
  33. public void emitValue(Top2Accum acc, Collector<Tuple2<Integer, Integer>> out) {
  34. // emit the value and rank
  35. if (acc.first != Integer.MIN_VALUE) {
  36. out.collect(Tuple2.of(acc.first, 1));
  37. }
  38. if (acc.second != Integer.MIN_VALUE) {
  39. out.collect(Tuple2.of(acc.second, 2));
  40. }
  41. }
  42. }
  43. // register function
  44. StreamTableEnvironment tEnv = ...
  45. tEnv.registerFunction("top2", new Top2());
  46. // init table
  47. Table tab = ...;
  48. // use function
  49. tab.groupBy("key")
  50. .flatAggregate("top2(a) as (v, rank)")
  51. .select("key, v, rank");
  1. import java.lang.{Integer => JInteger}
  2. import org.apache.flink.table.api.Types
  3. import org.apache.flink.table.functions.TableAggregateFunction
  4. /**
  5. * Accumulator for top2.
  6. */
  7. class Top2Accum {
  8. var first: JInteger = _
  9. var second: JInteger = _
  10. }
  11. /**
  12. * The top2 user-defined table aggregate function.
  13. */
  14. class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] {
  15. override def createAccumulator(): Top2Accum = {
  16. val acc = new Top2Accum
  17. acc.first = Int.MinValue
  18. acc.second = Int.MinValue
  19. acc
  20. }
  21. def accumulate(acc: Top2Accum, v: Int) {
  22. if (v > acc.first) {
  23. acc.second = acc.first
  24. acc.first = v
  25. } else if (v > acc.second) {
  26. acc.second = v
  27. }
  28. }
  29. def merge(acc: Top2Accum, its: JIterable[Top2Accum]): Unit = {
  30. val iter = its.iterator()
  31. while (iter.hasNext) {
  32. val top2 = iter.next()
  33. accumulate(acc, top2.first)
  34. accumulate(acc, top2.second)
  35. }
  36. }
  37. def emitValue(acc: Top2Accum, out: Collector[JTuple2[JInteger, JInteger]]): Unit = {
  38. // emit the value and rank
  39. if (acc.first != Int.MinValue) {
  40. out.collect(JTuple2.of(acc.first, 1))
  41. }
  42. if (acc.second != Int.MinValue) {
  43. out.collect(JTuple2.of(acc.second, 2))
  44. }
  45. }
  46. }
  47. // init table
  48. val tab = ...
  49. // use function
  50. tab
  51. .groupBy('key)
  52. .flatAggregate(top2('a) as ('v, 'rank))
  53. .select('key, 'v, 'rank)

The following example shows how to use emitUpdateWithRetract method to emit only updates. To emit only updates, in our example, the accumulator keeps both old and new top 2 values. Note: if the N of topN is big, it may inefficient to keep both old and new values. One way to solve this case is to store the input record into the accumulator in accumulate method and then perform calculation in emitUpdateWithRetract.

  1. /**
  2. * Accumulator for Top2.
  3. */
  4. public class Top2Accum {
  5. public Integer first;
  6. public Integer second;
  7. public Integer oldFirst;
  8. public Integer oldSecond;
  9. }
  10. /**
  11. * The top2 user-defined table aggregate function.
  12. */
  13. public static class Top2 extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accum> {
  14. @Override
  15. public Top2Accum createAccumulator() {
  16. Top2Accum acc = new Top2Accum();
  17. acc.first = Integer.MIN_VALUE;
  18. acc.second = Integer.MIN_VALUE;
  19. acc.oldFirst = Integer.MIN_VALUE;
  20. acc.oldSecond = Integer.MIN_VALUE;
  21. return acc;
  22. }
  23. public void accumulate(Top2Accum acc, Integer v) {
  24. if (v > acc.first) {
  25. acc.second = acc.first;
  26. acc.first = v;
  27. } else if (v > acc.second) {
  28. acc.second = v;
  29. }
  30. }
  31. public void emitUpdateWithRetract(Top2Accum acc, RetractableCollector<Tuple2<Integer, Integer>> out) {
  32. if (!acc.first.equals(acc.oldFirst)) {
  33. // if there is an update, retract old value then emit new value.
  34. if (acc.oldFirst != Integer.MIN_VALUE) {
  35. out.retract(Tuple2.of(acc.oldFirst, 1));
  36. }
  37. out.collect(Tuple2.of(acc.first, 1));
  38. acc.oldFirst = acc.first;
  39. }
  40. if (!acc.second.equals(acc.oldSecond)) {
  41. // if there is an update, retract old value then emit new value.
  42. if (acc.oldSecond != Integer.MIN_VALUE) {
  43. out.retract(Tuple2.of(acc.oldSecond, 2));
  44. }
  45. out.collect(Tuple2.of(acc.second, 2));
  46. acc.oldSecond = acc.second;
  47. }
  48. }
  49. }
  50. // register function
  51. StreamTableEnvironment tEnv = ...
  52. tEnv.registerFunction("top2", new Top2());
  53. // init table
  54. Table tab = ...;
  55. // use function
  56. tab.groupBy("key")
  57. .flatAggregate("top2(a) as (v, rank)")
  58. .select("key, v, rank");
  1. import java.lang.{Integer => JInteger}
  2. import org.apache.flink.table.api.Types
  3. import org.apache.flink.table.functions.TableAggregateFunction
  4. /**
  5. * Accumulator for top2.
  6. */
  7. class Top2Accum {
  8. var first: JInteger = _
  9. var second: JInteger = _
  10. var oldFirst: JInteger = _
  11. var oldSecond: JInteger = _
  12. }
  13. /**
  14. * The top2 user-defined table aggregate function.
  15. */
  16. class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] {
  17. override def createAccumulator(): Top2Accum = {
  18. val acc = new Top2Accum
  19. acc.first = Int.MinValue
  20. acc.second = Int.MinValue
  21. acc.oldFirst = Int.MinValue
  22. acc.oldSecond = Int.MinValue
  23. acc
  24. }
  25. def accumulate(acc: Top2Accum, v: Int) {
  26. if (v > acc.first) {
  27. acc.second = acc.first
  28. acc.first = v
  29. } else if (v > acc.second) {
  30. acc.second = v
  31. }
  32. }
  33. def emitUpdateWithRetract(
  34. acc: Top2Accum,
  35. out: RetractableCollector[JTuple2[JInteger, JInteger]])
  36. : Unit = {
  37. if (acc.first != acc.oldFirst) {
  38. // if there is an update, retract old value then emit new value.
  39. if (acc.oldFirst != Int.MinValue) {
  40. out.retract(JTuple2.of(acc.oldFirst, 1))
  41. }
  42. out.collect(JTuple2.of(acc.first, 1))
  43. acc.oldFirst = acc.first
  44. }
  45. if (acc.second != acc.oldSecond) {
  46. // if there is an update, retract old value then emit new value.
  47. if (acc.oldSecond != Int.MinValue) {
  48. out.retract(JTuple2.of(acc.oldSecond, 2))
  49. }
  50. out.collect(JTuple2.of(acc.second, 2))
  51. acc.oldSecond = acc.second
  52. }
  53. }
  54. }
  55. // init table
  56. val tab = ...
  57. // use function
  58. tab
  59. .groupBy('key)
  60. .flatAggregate(top2('a) as ('v, 'rank))
  61. .select('key, 'v, 'rank)

Best Practices for Implementing UDFs

The Table API and SQL code generation internally tries to work with primitive values as much as possible. A user-defined function can introduce much overhead through object creation, casting, and (un)boxing. Therefore, it is highly recommended to declare parameters and result types as primitive types instead of their boxed classes. Types.DATE and Types.TIME can also be represented as int. Types.TIMESTAMP can be represented as long.

We recommended that user-defined functions should be written by Java instead of Scala as Scala types pose a challenge for Flink’s type extractor.

Integrating UDFs with the Runtime

Sometimes it might be necessary for a user-defined function to get global runtime information or do some setup/clean-up work before the actual work. User-defined functions provide open() and close() methods that can be overridden and provide similar functionality as the methods in RichFunction of DataSet or DataStream API.

The open() method is called once before the evaluation method. The close() method after the last call to the evaluation method.

The open() method provides a FunctionContext that contains information about the context in which user-defined functions are executed, such as the metric group, the distributed cache files, or the global job parameters.

The following information can be obtained by calling the corresponding methods of FunctionContext:

MethodDescription
getMetricGroup()Metric group for this parallel subtask.
getCachedFile(name)Local temporary file copy of a distributed cache file.
getJobParameter(name, defaultValue)Global job parameter value associated with given key.

The following example snippet shows how to use FunctionContext in a scalar function for accessing a global job parameter:

  1. public class HashCode extends ScalarFunction {
  2. private int factor = 0;
  3. @Override
  4. public void open(FunctionContext context) throws Exception {
  5. // access "hashcode_factor" parameter
  6. // "12" would be the default value if parameter does not exist
  7. factor = Integer.valueOf(context.getJobParameter("hashcode_factor", "12"));
  8. }
  9. public int eval(String s) {
  10. return s.hashCode() * factor;
  11. }
  12. }
  13. ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
  14. BatchTableEnvironment tableEnv = BatchTableEnvironment.create(env);
  15. // set job parameter
  16. Configuration conf = new Configuration();
  17. conf.setString("hashcode_factor", "31");
  18. env.getConfig().setGlobalJobParameters(conf);
  19. // register the function
  20. tableEnv.registerFunction("hashCode", new HashCode());
  21. // use the function in Java Table API
  22. myTable.select("string, string.hashCode(), hashCode(string)");
  23. // use the function in SQL
  24. tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable");
  1. object hashCode extends ScalarFunction {
  2. var hashcode_factor = 12
  3. override def open(context: FunctionContext): Unit = {
  4. // access "hashcode_factor" parameter
  5. // "12" would be the default value if parameter does not exist
  6. hashcode_factor = context.getJobParameter("hashcode_factor", "12").toInt
  7. }
  8. def eval(s: String): Int = {
  9. s.hashCode() * hashcode_factor
  10. }
  11. }
  12. val tableEnv = BatchTableEnvironment.create(env)
  13. // use the function in Scala Table API
  14. myTable.select('string, hashCode('string))
  15. // register and use the function in SQL
  16. tableEnv.registerFunction("hashCode", hashCode)
  17. tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable")