Working with State
In this section you will learn about the APIs that Flink provides for writing stateful programs. Please take a look at Stateful Stream Processing to learn about the concepts behind stateful stream processing.
Keyed DataStream
If you want to use keyed state, you first need to specify a key on a DataStream
that should be used to partition the state (and also the records in the stream themselves). You can specify a key using keyBy(KeySelector)
in Java/Scala API or key_by(KeySelector)
in Python API on a DataStream
. This will yield a KeyedStream
, which then allows operations that use keyed state.
A key selector function takes a single record as input and returns the key for that record. The key can be of any type and must be derived from deterministic computations.
The data model of Flink is not based on key-value pairs. Therefore, you do not need to physically pack the data set types into keys and values. Keys are “virtual”: they are defined as functions over the actual data to guide the grouping operator.
The following example shows a key selector function that simply returns the field of an object:
Java
// some ordinary POJO
public class WC {
public String word;
public int count;
public String getWord() { return word; }
}
DataStream<WC> words = // [...]
KeyedStream<WC> keyed = words
.keyBy(WC::getWord);
Scala
// some ordinary case class
case class WC(word: String, count: Int)
val words: DataStream[WC] = // [...]
val keyed = words.keyBy( _.word )
Python
words = # type: DataStream[Row]
keyed = words.key_by(lambda row: row[0])
Tuple Keys and Expression Keys
Flink also has two alternative ways of defining keys: tuple keys and expression keys in the Java/Scala API(still not supported in the Python API). With this you can specify keys using tuple field indices or expressions for selecting fields of objects. We don’t recommend using these today but you can refer to the Javadoc of DataStream to learn about them. Using a KeySelector function is strictly superior: with Java lambdas they are easy to use and they have potentially less overhead at runtime.
Using Keyed State
The keyed state interfaces provides access to different types of state that are all scoped to the key of the current input element. This means that this type of state can only be used on a KeyedStream
, which can be created via stream.keyBy(…)
in Java/Scala API or stream.key_by(…)
in Python API.
Now, we will first look at the different types of state available and then we will see how they can be used in a program. The available state primitives are:
ValueState<T>
: This keeps a value that can be updated and retrieved (scoped to key of the input element as mentioned above, so there will possibly be one value for each key that the operation sees). The value can be set usingupdate(T)
and retrieved usingT value()
.ListState<T>
: This keeps a list of elements. You can append elements and retrieve anIterable
over all currently stored elements. Elements are added usingadd(T)
oraddAll(List<T>)
, the Iterable can be retrieved usingIterable<T> get()
. You can also override the existing list withupdate(List<T>)
ReducingState<T>
: This keeps a single value that represents the aggregation of all values added to the state. The interface is similar toListState
but elements added usingadd(T)
are reduced to an aggregate using a specifiedReduceFunction
.AggregatingState<IN, OUT>
: This keeps a single value that represents the aggregation of all values added to the state. Contrary toReducingState
, the aggregate type may be different from the type of elements that are added to the state. The interface is the same as forListState
but elements added usingadd(IN)
are aggregated using a specifiedAggregateFunction
.MapState<UK, UV>
: This keeps a list of mappings. You can put key-value pairs into the state and retrieve anIterable
over all currently stored mappings. Mappings are added usingput(UK, UV)
orputAll(Map<UK, UV>)
. The value associated with a user key can be retrieved usingget(UK)
. The iterable views for mappings, keys and values can be retrieved usingentries()
,keys()
andvalues()
respectively. You can also useisEmpty()
to check whether this map contains any key-value mappings.
All types of state also have a method clear()
that clears the state for the currently active key, i.e. the key of the input element.
It is important to keep in mind that these state objects are only used for interfacing with state. The state is not necessarily stored inside but might reside on disk or somewhere else. The second thing to keep in mind is that the value you get from the state depends on the key of the input element. So the value you get in one invocation of your user function can differ from the value in another invocation if the keys involved are different.
To get a state handle, you have to create a StateDescriptor
. This holds the name of the state (as we will see later, you can create several states, and they have to have unique names so that you can reference them), the type of the values that the state holds, and possibly a user-specified function, such as a ReduceFunction
. Depending on what type of state you want to retrieve, you create either a ValueStateDescriptor
, a ListStateDescriptor
, an AggregatingStateDescriptor
, a ReducingStateDescriptor
, or a MapStateDescriptor
.
State is accessed using the RuntimeContext
, so it is only possible in rich functions. Please see here for information about that, but we will also see an example shortly. The RuntimeContext
that is available in a RichFunction
has these methods for accessing state:
ValueState<T> getState(ValueStateDescriptor<T>)
ReducingState<T> getReducingState(ReducingStateDescriptor<T>)
ListState<T> getListState(ListStateDescriptor<T>)
AggregatingState<IN, OUT> getAggregatingState(AggregatingStateDescriptor<IN, ACC, OUT>)
MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV>)
This is an example FlatMapFunction
that shows how all of the parts fit together:
Java
public class CountWindowAverage extends RichFlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> {
/**
* The ValueState handle. The first field is the count, the second field a running sum.
*/
private transient ValueState<Tuple2<Long, Long>> sum;
@Override
public void flatMap(Tuple2<Long, Long> input, Collector<Tuple2<Long, Long>> out) throws Exception {
// access the state value
Tuple2<Long, Long> currentSum = sum.value();
// update the count
currentSum.f0 += 1;
// add the second field of the input value
currentSum.f1 += input.f1;
// update the state
sum.update(currentSum);
// if the count reaches 2, emit the average and clear the state
if (currentSum.f0 >= 2) {
out.collect(new Tuple2<>(input.f0, currentSum.f1 / currentSum.f0));
sum.clear();
}
}
@Override
public void open(Configuration config) {
ValueStateDescriptor<Tuple2<Long, Long>> descriptor =
new ValueStateDescriptor<>(
"average", // the state name
TypeInformation.of(new TypeHint<Tuple2<Long, Long>>() {}), // type information
Tuple2.of(0L, 0L)); // default value of the state, if nothing was set
sum = getRuntimeContext().getState(descriptor);
}
}
// this can be used in a streaming program like this (assuming we have a StreamExecutionEnvironment env)
env.fromElements(Tuple2.of(1L, 3L), Tuple2.of(1L, 5L), Tuple2.of(1L, 7L), Tuple2.of(1L, 4L), Tuple2.of(1L, 2L))
.keyBy(value -> value.f0)
.flatMap(new CountWindowAverage())
.print();
// the printed output will be (1,4) and (1,5)
Scala
class CountWindowAverage extends RichFlatMapFunction[(Long, Long), (Long, Long)] {
private var sum: ValueState[(Long, Long)] = _
override def flatMap(input: (Long, Long), out: Collector[(Long, Long)]): Unit = {
// access the state value
val tmpCurrentSum = sum.value
// If it hasn't been used before, it will be null
val currentSum = if (tmpCurrentSum != null) {
tmpCurrentSum
} else {
(0L, 0L)
}
// update the count
val newSum = (currentSum._1 + 1, currentSum._2 + input._2)
// update the state
sum.update(newSum)
// if the count reaches 2, emit the average and clear the state
if (newSum._1 >= 2) {
out.collect((input._1, newSum._2 / newSum._1))
sum.clear()
}
}
override def open(parameters: Configuration): Unit = {
sum = getRuntimeContext.getState(
new ValueStateDescriptor[(Long, Long)]("average", createTypeInformation[(Long, Long)])
)
}
}
object ExampleCountWindowAverage extends App {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.fromCollection(List(
(1L, 3L),
(1L, 5L),
(1L, 7L),
(1L, 4L),
(1L, 2L)
)).keyBy(_._1)
.flatMap(new CountWindowAverage())
.print()
// the printed output will be (1,4) and (1,5)
env.execute("ExampleKeyedState")
}
Python
from pyflink.common.typeinfo import Types
from pyflink.datastream import StreamExecutionEnvironment, FlatMapFunction, RuntimeContext
from pyflink.datastream.state import ValueStateDescriptor
class CountWindowAverage(FlatMapFunction):
def __init__(self):
self.sum = None
def open(self, runtime_context: RuntimeContext):
descriptor = ValueStateDescriptor(
"average", # the state name
Types.PICKLED_BYTE_ARRAY() # type information
)
self.sum = runtime_context.get_state(descriptor)
def flat_map(self, value):
# access the state value
current_sum = self.sum.value()
if current_sum is None:
current_sum = (0, 0)
# update the count
current_sum = (current_sum[0] + 1, current_sum[1] + value[1])
# update the state
self.sum.update(current_sum)
# if the count reaches 2, emit the average and clear the state
if current_sum[0] >= 2:
self.sum.clear()
yield value[0], int(current_sum[1] / current_sum[0])
env = StreamExecutionEnvironment.get_execution_environment()
env.from_collection([(1, 3), (1, 5), (1, 7), (1, 4), (1, 2)]) \
.key_by(lambda row: row[0]) \
.flat_map(CountWindowAverage()) \
.print()
env.execute()
# the printed output will be (1,4) and (1,5)
This example implements a poor man’s counting window. We key the tuples by the first field (in the example all have the same key 1
). The function stores the count and a running sum in a ValueState
. Once the count reaches 2 it will emit the average and clear the state so that we start over from 0
. Note that this would keep a different state value for each different input key if we had tuples with different values in the first field.
State Time-To-Live (TTL)
A time-to-live (TTL) can be assigned to the keyed state of any type. If a TTL is configured and a state value has expired, the stored value will be cleaned up on a best effort basis which is discussed in more detail below.
All state collection types support per-entry TTLs. This means that list elements and map entries expire independently.
In order to use state TTL one must first build a StateTtlConfig
configuration object. The TTL functionality can then be enabled in any state descriptor by passing the configuration:
Java
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite)
.setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
.build();
ValueStateDescriptor<String> stateDescriptor = new ValueStateDescriptor<>("text state", String.class);
stateDescriptor.enableTimeToLive(ttlConfig);
Scala
import org.apache.flink.api.common.state.StateTtlConfig
import org.apache.flink.api.common.state.ValueStateDescriptor
import org.apache.flink.api.common.time.Time
val ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite)
.setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
.build
val stateDescriptor = new ValueStateDescriptor[String]("text state", classOf[String])
stateDescriptor.enableTimeToLive(ttlConfig)
Python
from pyflink.common.time import Time
from pyflink.common.typeinfo import Types
from pyflink.datastream.state import ValueStateDescriptor, StateTtlConfig
ttl_config = StateTtlConfig \
.new_builder(Time.seconds(1)) \
.set_update_type(StateTtlConfig.UpdateType.OnCreateAndWrite) \
.set_state_visibility(StateTtlConfig.StateVisibility.NeverReturnExpired) \
.build()
state_descriptor = ValueStateDescriptor("text state", Types.STRING())
state_descriptor.enable_time_to_live(ttl_config)
The configuration has several options to consider:
The first parameter of the newBuilder
method is mandatory, it is the time-to-live value.
The update type configures when the state TTL is refreshed (by default OnCreateAndWrite
):
StateTtlConfig.UpdateType.OnCreateAndWrite
- only on creation and write accessStateTtlConfig.UpdateType.OnReadAndWrite
- also on read access(Notes: If you set the state visibility to
StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp
at the same time, the state read cache will be disabled, which will cause some performance loss in PyFlink)
The state visibility configures whether the expired value is returned on read access if it is not cleaned up yet (by default NeverReturnExpired
):
StateTtlConfig.StateVisibility.NeverReturnExpired
- expired value is never returned(Notes: The state read/write cache will be disabled, which will cause some performance loss in PyFlink)
StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp
- returned if still available
In case of NeverReturnExpired
, the expired state behaves as if it does not exist anymore, even if it still has to be removed. The option can be useful for use cases where data has to become unavailable for read access strictly after TTL, e.g. application working with privacy sensitive data.
Another option ReturnExpiredIfNotCleanedUp
allows to return the expired state before its cleanup.
Notes:
The state backends store the timestamp of the last modification along with the user value, which means that enabling this feature increases consumption of state storage. Heap state backend stores an additional Java object with a reference to the user state object and a primitive long value in memory. The RocksDB state backend adds 8 bytes per stored value, list entry or map entry.
Only TTLs in reference to processing time are currently supported.
Trying to restore state, which was previously configured without TTL, using TTL enabled descriptor or vice versa will lead to compatibility failure and
StateMigrationException
.The TTL configuration is not part of check- or savepoints but rather a way of how Flink treats it in the currently running job.
The map state with TTL currently supports null user values only if the user value serializer can handle null values. If the serializer does not support null values, it can be wrapped with
NullableSerializer
at the cost of an extra byte in the serialized form.With TTL enabled configuration, the
defaultValue
inStateDescriptor
, which is atucally already deprecated, will no longer take an effect. This aims to make the semantics more clear and let user manually manage the default value if the contents of the state is null or expired.
Cleanup of Expired State
By default, expired values are explicitly removed on read, such as ValueState#value
, and periodically garbage collected in the background if supported by the configured state backend. Background cleanup can be disabled in the StateTtlConfig
:
Java
import org.apache.flink.api.common.state.StateTtlConfig;
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.disableCleanupInBackground()
.build();
Scala
import org.apache.flink.api.common.state.StateTtlConfig
val ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.disableCleanupInBackground
.build
Python
from pyflink.common.time import Time
from pyflink.datastream.state import StateTtlConfig
ttl_config = StateTtlConfig \
.new_builder(Time.seconds(1)) \
.disable_cleanup_in_background() \
.build()
For more fine-grained control over some special cleanup in background, you can configure it separately as described below. Currently, heap state backend relies on incremental cleanup and RocksDB backend uses compaction filter for background cleanup.
Cleanup in full snapshot
Additionally, you can activate the cleanup at the moment of taking the full state snapshot which will reduce its size. The local state is not cleaned up under the current implementation but it will not include the removed expired state in case of restoration from the previous snapshot. It can be configured in StateTtlConfig
:
Java
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.time.Time;
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.cleanupFullSnapshot()
.build();
Scala
import org.apache.flink.api.common.state.StateTtlConfig
import org.apache.flink.api.common.time.Time
val ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.cleanupFullSnapshot
.build
Python
from pyflink.common.time import Time
from pyflink.datastream.state import StateTtlConfig
ttl_config = StateTtlConfig \
.new_builder(Time.seconds(1)) \
.cleanup_full_snapshot() \
.build()
This option is not applicable for the incremental checkpointing in the RocksDB state backend.
For existing jobs, this cleanup strategy can be activated or deactivated anytime in
StateTtlConfig
, e.g. after restart from savepoint.
Incremental cleanup
Another option is to trigger cleanup of some state entries incrementally. The trigger can be a callback from each state access or/and each record processing. If this cleanup strategy is active for certain state, The storage backend keeps a lazy global iterator for this state over all its entries. Every time incremental cleanup is triggered, the iterator is advanced. The traversed state entries are checked and expired ones are cleaned up.
This feature can be configured in StateTtlConfig
:
Java
import org.apache.flink.api.common.state.StateTtlConfig;
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.cleanupIncrementally(10, true)
.build();
Scala
import org.apache.flink.api.common.state.StateTtlConfig
val ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.cleanupIncrementally(10, true)
.build
Python
from pyflink.common.time import Time
from pyflink.datastream.state import StateTtlConfig
ttl_config = StateTtlConfig \
.new_builder(Time.seconds(1)) \
.cleanup_incrementally(10, True) \
.build()
This strategy has two parameters. The first one is number of checked state entries per each cleanup triggering. It is always triggered per each state access. The second parameter defines whether to trigger cleanup additionally per each record processing. The default background cleanup for heap backend checks 5 entries without cleanup per record processing.
Notes:
- If no access happens to the state or no records are processed, expired state will persist.
- Time spent for the incremental cleanup increases record processing latency.
- At the moment incremental cleanup is implemented only for Heap state backend. Setting it for RocksDB will have no effect.
- If heap state backend is used with synchronous snapshotting, the global iterator keeps a copy of all keys while iterating because of its specific implementation which does not support concurrent modifications. Enabling of this feature will increase memory consumption then. Asynchronous snapshotting does not have this problem.
- For existing jobs, this cleanup strategy can be activated or deactivated anytime in
StateTtlConfig
, e.g. after restart from savepoint.
Cleanup during RocksDB compaction
If the RocksDB state backend is used, a Flink specific compaction filter will be called for the background cleanup. RocksDB periodically runs asynchronous compactions to merge state updates and reduce storage. Flink compaction filter checks expiration timestamp of state entries with TTL and excludes expired values.
This feature can be configured in StateTtlConfig
:
Java
import org.apache.flink.api.common.state.StateTtlConfig;
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.cleanupInRocksdbCompactFilter(1000)
.build();
Scala
import org.apache.flink.api.common.state.StateTtlConfig
val ttlConfig = StateTtlConfig
.newBuilder(Time.seconds(1))
.cleanupInRocksdbCompactFilter(1000)
.build
Python
from pyflink.common.time import Time
from pyflink.datastream.state import StateTtlConfig
ttl_config = StateTtlConfig \
.new_builder(Time.seconds(1)) \
.cleanup_in_rocksdb_compact_filter(1000) \
.build()
RocksDB compaction filter will query current timestamp, used to check expiration, from Flink every time after processing certain number of state entries. You can change it and pass a custom value to StateTtlConfig.newBuilder(...).cleanupInRocksdbCompactFilter(long queryTimeAfterNumEntries)
method. Updating the timestamp more often can improve cleanup speed but it decreases compaction performance because it uses JNI call from native code. The default background cleanup for RocksDB backend queries the current timestamp each time 1000 entries have been processed.
You can activate debug logs from the native code of RocksDB filter by activating debug level for FlinkCompactionFilter
:
log4j.logger.org.rocksdb.FlinkCompactionFilter=DEBUG
Notes:
- Calling of TTL filter during compaction slows it down. The TTL filter has to parse timestamp of last access and check its expiration for every stored state entry per key which is being compacted. In case of collection state type (list or map) the check is also invoked per stored element.
- If this feature is used with a list state which has elements with non-fixed byte length, the native TTL filter has to call additionally a Flink java type serializer of the element over JNI per each state entry where at least the first element has expired to determine the offset of the next unexpired element.
- For existing jobs, this cleanup strategy can be activated or deactivated anytime in
StateTtlConfig
, e.g. after restart from savepoint.
State in the Scala DataStream API
In addition to the interface described above, the Scala API has shortcuts for stateful map()
or flatMap()
functions with a single ValueState
on KeyedStream
. The user function gets the current value of the ValueState
in an Option
and must return an updated value that will be used to update the state.
val stream: DataStream[(String, Int)] = ...
val counts: DataStream[(String, Int)] = stream
.keyBy(_._1)
.mapWithState((in: (String, Int), count: Option[Int]) =>
count match {
case Some(c) => ( (in._1, c), Some(c + in._2) )
case None => ( (in._1, 0), Some(in._2) )
})
Operator State
Operator State (or non-keyed state) is state that is bound to one parallel operator instance. The Kafka Connector is a good motivating example for the use of Operator State in Flink. Each parallel instance of the Kafka consumer maintains a map of topic partitions and offsets as its Operator State.
The Operator State interfaces support redistributing state among parallel operator instances when the parallelism is changed. There are different schemes for doing this redistribution.
In a typical stateful Flink Application you don’t need operators state. It is mostly a special type of state that is used in source/sink implementations and scenarios where you don’t have a key by which state can be partitioned.
Notes: Operator state is still not supported in Python DataStream API.
Broadcast State
Broadcast State is a special type of Operator State. It was introduced to support use cases where records of one stream need to be broadcasted to all downstream tasks, where they are used to maintain the same state among all subtasks. This state can then be accessed while processing records of a second stream. As an example where broadcast state can emerge as a natural fit, one can imagine a low-throughput stream containing a set of rules which we want to evaluate against all elements coming from another stream. Having the above type of use cases in mind, broadcast state differs from the rest of operator states in that:
- it has a map format,
- it is only available to specific operators that have as inputs a broadcasted stream and a non-broadcasted one, and
- such an operator can have multiple broadcast states with different names.
Notes: Broadcast state is still not supported in Python DataStream API.
Using Operator State
To use operator state, a stateful function can implement the CheckpointedFunction
interface.
CheckpointedFunction
The CheckpointedFunction
interface provides access to non-keyed state with different redistribution schemes. It requires the implementation of two methods:
void snapshotState(FunctionSnapshotContext context) throws Exception;
void initializeState(FunctionInitializationContext context) throws Exception;
Whenever a checkpoint has to be performed, snapshotState()
is called. The counterpart, initializeState()
, is called every time the user-defined function is initialized, be that when the function is first initialized or be that when the function is actually recovering from an earlier checkpoint. Given this, initializeState()
is not only the place where different types of state are initialized, but also where state recovery logic is included.
Currently, list-style operator state is supported. The state is expected to be a List
of serializable objects, independent from each other, thus eligible for redistribution upon rescaling. In other words, these objects are the finest granularity at which non-keyed state can be redistributed. Depending on the state accessing method, the following redistribution schemes are defined:
Even-split redistribution: Each operator returns a List of state elements. The whole state is logically a concatenation of all lists. On restore/redistribution, the list is evenly divided into as many sublists as there are parallel operators. Each operator gets a sublist, which can be empty, or contain one or more elements. As an example, if with parallelism 1 the checkpointed state of an operator contains elements
element1
andelement2
, when increasing the parallelism to 2,element1
may end up in operator instance 0, whileelement2
will go to operator instance 1.Union redistribution: Each operator returns a List of state elements. The whole state is logically a concatenation of all lists. On restore/redistribution, each operator gets the complete list of state elements. Do not use this feature if your list may have high cardinality. Checkpoint metadata will store an offset to each list entry, which could lead to RPC framesize or out-of-memory errors.
Below is an example of a stateful SinkFunction
that uses CheckpointedFunction
to buffer elements before sending them to the outside world. It demonstrates the basic even-split redistribution list state:
Java
public class BufferingSink
implements SinkFunction<Tuple2<String, Integer>>,
CheckpointedFunction {
private final int threshold;
private transient ListState<Tuple2<String, Integer>> checkpointedState;
private List<Tuple2<String, Integer>> bufferedElements;
public BufferingSink(int threshold) {
this.threshold = threshold;
this.bufferedElements = new ArrayList<>();
}
@Override
public void invoke(Tuple2<String, Integer> value, Context contex) throws Exception {
bufferedElements.add(value);
if (bufferedElements.size() >= threshold) {
for (Tuple2<String, Integer> element: bufferedElements) {
// send it to the sink
}
bufferedElements.clear();
}
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
checkpointedState.clear();
for (Tuple2<String, Integer> element : bufferedElements) {
checkpointedState.add(element);
}
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
ListStateDescriptor<Tuple2<String, Integer>> descriptor =
new ListStateDescriptor<>(
"buffered-elements",
TypeInformation.of(new TypeHint<Tuple2<String, Integer>>() {}));
checkpointedState = context.getOperatorStateStore().getListState(descriptor);
if (context.isRestored()) {
for (Tuple2<String, Integer> element : checkpointedState.get()) {
bufferedElements.add(element);
}
}
}
}
Scala
class BufferingSink(threshold: Int = 0)
extends SinkFunction[(String, Int)]
with CheckpointedFunction {
@transient
private var checkpointedState: ListState[(String, Int)] = _
private val bufferedElements = ListBuffer[(String, Int)]()
override def invoke(value: (String, Int), context: Context): Unit = {
bufferedElements += value
if (bufferedElements.size >= threshold) {
for (element <- bufferedElements) {
// send it to the sink
}
bufferedElements.clear()
}
}
override def snapshotState(context: FunctionSnapshotContext): Unit = {
checkpointedState.clear()
for (element <- bufferedElements) {
checkpointedState.add(element)
}
}
override def initializeState(context: FunctionInitializationContext): Unit = {
val descriptor = new ListStateDescriptor[(String, Int)](
"buffered-elements",
TypeInformation.of(new TypeHint[(String, Int)]() {})
)
checkpointedState = context.getOperatorStateStore.getListState(descriptor)
if(context.isRestored) {
for(element <- checkpointedState.get().asScala) {
bufferedElements += element
}
}
}
}
The initializeState
method takes as argument a FunctionInitializationContext
. This is used to initialize the non-keyed state “containers”. These are a container of type ListState
where the non-keyed state objects are going to be stored upon checkpointing.
Note how the state is initialized, similar to keyed state, with a StateDescriptor
that contains the state name and information about the type of the value that the state holds:
Java
ListStateDescriptor<Tuple2<String, Integer>> descriptor =
new ListStateDescriptor<>(
"buffered-elements",
TypeInformation.of(new TypeHint<Tuple2<String, Integer>>() {}));
checkpointedState = context.getOperatorStateStore().getListState(descriptor);
Scala
val descriptor = new ListStateDescriptor[(String, Long)](
"buffered-elements",
TypeInformation.of(new TypeHint[(String, Long)]() {})
)
checkpointedState = context.getOperatorStateStore.getListState(descriptor)
The naming convention of the state access methods contain its redistribution pattern followed by its state structure. For example, to use list state with the union redistribution scheme on restore, access the state by using getUnionListState(descriptor)
. If the method name does not contain the redistribution pattern, e.g. getListState(descriptor)
, it simply implies that the basic even-split redistribution scheme will be used.
After initializing the container, we use the isRestored()
method of the context to check if we are recovering after a failure. If this is true
, i.e. we are recovering, the restore logic is applied.
As shown in the code of the modified BufferingSink
, this ListState
recovered during state initialization is kept in a class variable for future use in snapshotState()
. There the ListState
is cleared of all objects included by the previous checkpoint, and is then filled with the new ones we want to checkpoint.
As a side note, the keyed state can also be initialized in the initializeState()
method. This can be done using the provided FunctionInitializationContext
.
Stateful Source Functions
Stateful sources require a bit more care as opposed to other operators. In order to make the updates to the state and output collection atomic (required for exactly-once semantics on failure/recovery), the user is required to get a lock from the source’s context.
Java
public static class CounterSource
extends RichParallelSourceFunction<Long>
implements CheckpointedFunction {
/** current offset for exactly once semantics */
private Long offset = 0L;
/** flag for job cancellation */
private volatile boolean isRunning = true;
/** Our state object. */
private ListState<Long> state;
@Override
public void run(SourceContext<Long> ctx) {
final Object lock = ctx.getCheckpointLock();
while (isRunning) {
// output and state update are atomic
synchronized (lock) {
ctx.collect(offset);
offset += 1;
}
}
}
@Override
public void cancel() {
isRunning = false;
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
state = context.getOperatorStateStore().getListState(new ListStateDescriptor<>(
"state",
LongSerializer.INSTANCE));
// restore any state that we might already have to our fields, initialize state
// is also called in case of restore.
for (Long l : state.get()) {
offset = l;
}
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
state.clear();
state.add(offset);
}
}
Scala
class CounterSource
extends RichParallelSourceFunction[Long]
with CheckpointedFunction {
@volatile
private var isRunning = true
private var offset = 0L
private var state: ListState[Long] = _
override def run(ctx: SourceFunction.SourceContext[Long]): Unit = {
val lock = ctx.getCheckpointLock
while (isRunning) {
// output and state update are atomic
lock.synchronized({
ctx.collect(offset)
offset += 1
})
}
}
override def cancel(): Unit = isRunning = false
override def initializeState(context: FunctionInitializationContext): Unit = {
state = context.getOperatorStateStore.getListState(
new ListStateDescriptor[Long]("state", classOf[Long]))
for (l <- state.get().asScala) {
offset = l
}
}
override def snapshotState(context: FunctionSnapshotContext): Unit = {
state.clear()
state.add(offset)
}
}
Some operators might need the information when a checkpoint is fully acknowledged by Flink to communicate that with the outside world. In this case see the org.apache.flink.api.common.state.CheckpointListener
interface.