Design to Support Custom Callback Using Keras API

This document describes the design for supporting callback to customize the behavior of model during training, evaluation and inference in ElasticDL.

Motivation

In deep learning, we generally need to customize the behavior of model during training, evaluation or inference, including reading/changing the model. We may perform the behavior per batch, per epoch or at the start and end of job. tf.keras.callbacks.Callback is an abstract base class and has methods to perform the behavior at different call frequency, such as on_bath_end, on_epoch_end and so on. So, we adopt the interfaces of tf.keras.callbacks.Callback to customize the behavior of model in ElasticDL.

Now we have implemented some modules similar to callback in ElasticDL, such as LearningRateScheduler and PredictionOutputsProcessor. And users should write definitions in the model definition file for each module like:

  1. def custom_model():
  2. ...
  3. # Adjust the learning rate according to iteration steps
  4. def learning_rate_scheduler(model_version):
  5. if model_version < 5000:
  6. return 0.0003
  7. elif model_version < 12000:
  8. return 0.0002
  9. else:
  10. return 0.0001
  11. # Process the prediction outputs for each batch
  12. class PredictionOutputsProcessor(BasePredictionOutputsProcessor):
  13. ...

There will be different interface definitions for users to define different behaviors of model. The interfaces may be different. It is more convenient for users to define those behaviors using tf.keras.callbacks.Callback.

Some use cases we observed that we want to support are:

  • Case 1: Callback similar to PredictionOutputsProcessor that is executed after prediction outputs are made.
  • Case 2: Callback to modulate learning rate as we are currently doing in LearningRateScheduler.
  • Case 3: Callback to write additional summaries to TensorBoard after each evaluation job completes.
  • Case 4: Callback to perform early stopping when certain conditions are met. For example, the metrics are met after an evaluation job.
  • Case 5: Callback to export model using SavedModel after the training is completed.
  • Case 6: Callback to upload model to remote storage after the training is completed.

Next we will design how to define and implement callbacks to support those cases above.

Define Callbacks for the Model in ElasticDL

In the model definition file, users can add callbacks API to define callbacks like:

  1. From elasticdl.python.callbacks import (
  2. PredictionOutputsProcessor,
  3. LearningRateScheduler
  4. )
  5. def custom_model():
  6. ...
  7. def callbacks():
  8. prediction_output_processor = PredictionOutputsProcessor(process_fn)
  9. learning_rate_scheduler = LearningRateScheduler(schedule_fn)
  10. return [prediction_output_processor, learning_rate_scheduler]

Initialize Callbacks and Set Callback Attributes in ElasticDL

Use a Container for Callbacks Defined in the Model

We may define several callbacks for the model in a job. TensorFlow creates a container CallbackList to wrap the callbacks to conveniently call the methods in callbacks. For example:

  1. class CallbackList():
  2. def on_batch_end(self, batch, logs=None):
  3. for callback in self.callback_list:
  4. callback.on_batch_end(bach, logs)

So, we can also use CallbackList to wrap the callbacks in ElasticDL.

  1. from tensorflow.python.keras.callbacks import CallbackList
  2. callbacks = callbacks()
  3. callback_list = CallbackList(callbacks)
  4. callback_list.on_batch_end(batch)

Set Default Attributes for Callbacks

There are set_model and set_params in tf.keras.callbacks.Callback. set_model can set a model object to the attribute model of callback and set_params can set a dict object to the attribute params of callback. We also can use those methods of CallbackList to set model and params for the callbacks in CallbackList.

  1. model = custom_model()
  2. callbacks = callbacks()
  3. callback_list = CallbackList(callbacks)
  4. callback_list.set_model(model)
  5. callback_list.model.stop_training = False # used for early stop callback
  6. params = {
  7. 'batch_size': batch_size,
  8. 'epochs': epochs,
  9. 'saved_model_path': saved_model_path,
  10. 'checkpoint_path': checkpoint_path
  11. }
  12. callback_list.set_params(params)

Then, we can call model and params in the methods like:

  1. class CustomCallback(tf.keras.callbacks.Callback):
  2. def on_train_batch_begin(self, batch, logs=None):
  3. lr = self.model.optimizer.learning_rate,
  4. saved_model_path = self.params.get("save_model_path")

The Execution of Supported Methods in ElasticDL Using ParameterServerStrategy

Now, we only support 4 methods of tf.keras.callbacks.Callbackto implement those cases above.

  1. on_train_batch_begin
  2. on_predict_batch_end
  3. on_train_end
  4. on_test_end

The worker will execute on_train_batch_begin and on_predict_batch_end, because the worker processes each batch. Then, the worker also will execute on_train_end because the worker will exports model using SavedModel after training. The details is in Model Serving Design. However, only the master knows when the training is completed. So, the master can create a training end task for the worker. The worker call on_train_end in callbacks after receiving the task.

The master will execute on_test_end because the master receives all evaluation metrics from the worker.

Implement Callbacks to Support the Cases in the Motivation

We split the callbacks to two parts. One is the callbacks which is automatically configured in ElasticDL. And another is pre-made callback which users can configure in the model definition if needed.

Callbacks Configured by ElasticDL

Those callbacks is automatically configured for each job by ElasticDL.

  • Case 5: Callback to export model using SavedModel after the training is completed.
  1. class SavedModelExporter(tf.keras.callbacks.Callback):
  2. """Export model using SavedModel after training.
  3. Args:
  4. task_data_service: TaskDataService to process data according the task
  5. dataset_fn: function to process dataset
  6. model_handler: to transform the trained model with ElasticDL embedding layer to Keras native model.
  7. """
  8. def __init__(self, task_data_service, dataset_fn, model_handler):
  9. self._model_handler = model_handler
  10. self._task_data_service
  11. self._dataset_fn = dataset_fn
  12. def on_train_end(self, logs=None):
  13. """Call on the train job end
  14. Args:
  15. logs: dict. Currently no data is passed to this argument for this method but that may change in the future.
  16. """
  17. saved_model_path = self.params.get("saved_model_path", None)
  18. batch_size = self.params.get("batch_size")
  19. (task,dataset) = self._task_data_service.get_save_model_task_and_dataset()
  20. if task is not None and dataset is not None:
  21. dataset = self._dataset_fn(
  22. dataset,
  23. Mode.PREDICTION,
  24. self._task_data_service.data_reader.metadata,
  25. )
  26. dataset = dataset.batch(batch_size)
  27. model = self._model_handler.get_model_to_export(
  28. self.model, dataset
  29. )
  30. tf.saved_model.save(model, saved_model_path)

Pre-made Callbacks for Users to Configure in Model Definition

We provide users with pre-made callbacks to define the process logic in pre-made callbacks. And users need to define the callback in the model definition. For example:

  1. from elasticdl.python.callbacks import (
  2. LearningRateScheduler,
  3. EarlyStopper
  4. )
  5. def callbacks():
  6. learning_rate_scheduler = LearningRateScheduler(schedule)
  7. early_stopper = EarlyStopper(stop_train)
  8. def schedule(batch):
  9. return 0.003 if batch < 1000 else 0.001
  10. def stop_train(metrics_list):
  11. latest_metrics = metrics_list[-1]
  12. return True if latest_metrics['auc'] > 0.8 else False
  • Case 1: Callback to process prediction outputs after batch prediction outputs are made.
  1. def process_fn(predictions):
  2. """The function is defined by users
  3. Args:
  4. predictions: prediction outputs of the model
  5. """
  6. print(len(predictions))
  7. class PredictionOutputsProcessor(tf.keras.callbacks.Callback):
  8. def __init__(self, process_fn):
  9. self.process_fn = process_fn
  10. super(PredictionOutputsProcessor, self).__init__()
  11. def on_predict_batch_end(self, batch, logs=None):
  12. """Call on the prediction end of each batch
  13. Args:
  14. batch: integer, index of batch in the current worker
  15. logs: dict. Has keys predictions representing
  16. the prediction outputs of a batch.
  17. """
  18. predictions = logs["predictions"]
  19. process(predictions)
  • Case 2: Callback to modulate learning rate.
  1. def schedule_fn(version):
  2. """The function is defined by users
  3. Args:
  4. version: model iteration version
  5. """
  6. return 0.003 if batch < 1000 else 0.001
  7. class LearningRateScheduler(tf.keras.callbacks.Callback):
  8. def __init__(self, schedule_fn):
  9. super(LearningRateScheduler, self).__init__()
  10. self.schedule_fn = schedule_fn
  11. def on_train_batch_begin(self, batch, logs=None):
  12. """
  13. Args:
  14. batch: integer, the model version requested from PS.
  15. logs: dict. Has keys batch and size representing the current batch number and the size of the batch.
  16. """
  17. if not hasattr(self.model.optimizer, 'lr'):
  18. raise ValueError('Optimizer must have a "lr" attribute.')
  19. lr = self.schedule_fn(batch)
  20. K.set_value(self.model.optimizer.lr, lr)

Using ParameterServerStrategy, the worker calculates the batch gradients and set gradients to PS. PS updates weights using optimizer after receiving gradients. Although the worker can call on_train_batch_begin in LearningRateScheduler to adjust the learning rate of optimizer in its model, we should send the learning rate with gradients to PS by GRPC and PS updates weights using the received learning rate.

  • Case 3: Callback to write additional summaries to TensorBoard after each evaluation job completes.
  1. class SummaryWriter(tf.keras.callbacks.Callback):
  2. def on_test_end(self, logs=None):
  3. """Call on the test job end
  4. Args:
  5. logs: dict. Has key metrics representing the evaluation metrics on the validation data.
  6. """
  7. metrics = logs.get("metrics", None)
  8. if metrics is None:
  9. return
  10. write(metrics)

The master determine whether or not the evaluation job is completed by EvaluationService.complete_task(). So, the master call on_test_end After EvaluationService.complete_task() returns evaluation metrics.

  1. if evaluation_task_completed:
  2. eval_metrics = self._evaluation_service.complete_task()
  3. if eval_metrics is not None:
  4. logs = {"metrics": eval_metrics}
  5. self._callbacks_list.on_test_end(logs)
  • Case 4: Callback to perform early stopping when the metrics are met after an evaluation job.
  1. def stop_fn(metrics_list):
  2. """Function to determine whether or not to stop training
  3. Args:
  4. metrics_list: List to contain metrics of each evaluation job.
  5. Retrun:
  6. boolean: Stop training if true.
  7. """
  8. latest_metrics = metrics_list[-1]
  9. return True if latest_metrics['auc'] > 0.8 else False
  10. class EarlyStopper(tf.keras.callbacks.Callback):
  11. def __init__(self, stop_fn):
  12. self.stop_train = stop_train
  13. self.metrics_list = []
  14. super(EarlyStopper, self).__init__()
  15. def on_test_end(self, logs=None):
  16. """Call on the test job end
  17. Args:
  18. logs: dict. Has key metrics representing the evaluation metrics on the validation data.
  19. """
  20. metrics = logs.get("metrics", None)
  21. if metrics is None:
  22. return
  23. self.metrics_list.append(metrics)
  24. self.model.stop_training = stop_train(self.metrics_list)

The same as SummaryWriter, the master call on_test_end of EarlyStopper After an evaluation job is completed.

  • Case 6: Callback to upload model to remote storage after the training is completed.
  1. class ModelUploader(tf.keras.callbacks.Callback):
  2. """Upload model to remote storage
  3. """
  4. def __init__(self, remote_url):
  5. self.remote_url = remote_url
  6. super(ModelUploader, self).__init__()
  7. def on_train_end(self, logs=None):
  8. """Call on the train job end
  9. Args:
  10. logs: dict. Currently no data is passed to this argument for this method but that may change in the future.
  11. """
  12. saved_model_path = self.params["saved_model_path"]
  13. upload(save_model_path, remote_url)

The same as SaveModelExporter, the worker will call on_train_end of ModelUploader after receiving a train end task from the master.