Diagnostic Attribute Error

Motivation

SQLFlow extended SQL syntax to describe an end-to-end machine learning job, for a typical SQL program:

  1. SELECT * FROM iris.train
  2. TO TRAIN DNNClassifier
  3. WITH
  4. model.n_classes = 2,
  5. model.hidden_units = [128, 64]
  6. LABEL class
  7. INTO my_model;
  8. SELECT * FROM iris.test
  9. TO PREDICT iris.pred.class
  10. USING my_model;

SQLFlow compiles each statement in the program into an execution plan and executes them. As the TRAIN statement above, SQLFlow uses TO TRAIN clause to train a specific model called DNNClassifier, using WITH clause to configure the training arguments.

Sometimes users may make some configuration mistake on WITH clause, then the job would fail during execution and return some uncertain error message.

The model parameter documentation describes parameters and acceptable values of human reading. We want to enhance it for the reading by the SQLFlow compiler so to warn about wrongly set parameters, and this can active three advantages at least:

  1. Early testing, we can do early testing before running the job; users can wait less time and cluster save resources.
  2. More accurate diagnostic message.
  3. Model developers do not have to involve dependencies other than Keras or TensorFlow.

Design

We want to document the compiler-readable description of model parameters in the docstring of Python function or class that define a model.

A docstring contains multiple lines:

  • A line starting with # is the check rule in Python code.
  • A line starting with argument name and document followed by a colon :.

An example:

  1. class MyDNNClassifier(keras.Model)
  2. def __init__(self, n_classes=32, hidden_units=[32, 64]):
  3. """
  4. Args:
  5. # isintance(n_classes, int) && n_classes > 1
  6. n_classes: Number of label classes. Defaults to 2, namely binary
  7. classification. Must be > 1.
  8. # isintance(hidden_units, list) && all(isinstance(item, int) for item in hidden_units)
  9. hidden_units: Iterable of number hidden units per layer. All layers are
  10. fully connected. Ex. `[64, 32]` means first layer has 64 nodes and
  11. second one has 32.
  12. """

If a user set some invalid parameters as the following SQL statement:

  1. SELECT ... TO TRAIN sqlflow_models.MyDNNClassifier
  2. WITH
  3. model.n_classes=1,
  4. model.hidden_units=64
  5. LABEL class
  6. INTO my_dnn_model;

We expected the SQLFlow GUI show the error message as:

  1. SQLFLow received attribute error:
  2. > model.n_classes received unexpected value: 1, attribute usage:
  3. Number of label classes. Defaults to 2, namely binary classification. Must be > 1.
  4. > model.hidden_units received unexpected value: 64, attribute usage:
  5. Iterable of number hidden units per layer. All layers are fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second one has 32.

For the implementation, it’s easy to extract the check rule and argument documentation from the docstring, and check it on the compile phase.

  1. def attribute_check(estimator, **args):
  2. # extract argument name, documentation and contract from doc string
  3. contract = extract_symbol(estimator)
  4. # SQLFlowDiagnosticError message can be pipe to SQLFlow GUI via SQLFlow gRPC server
  5. diag_err = SQLFLowDiagnosError()
  6. for name, value in args:
  7. if !contract.check(name, value):
  8. # component received value and argument documentation
  9. diag_err.append_message(contract.diag_message(name, value))
  10. if !diag_err.empty():
  11. raise diag_err

Future

This documentation using native Python code to express the check rule, another PR designed a new Python library to make the code shorter and simpler, will make more discussion in the future.