Training a Text Classification Model Using SQLFlow

This is a tutorial on how to train a Text Classification Model Using SQLFlow. Note that the steps in this tutorial may be changed during the development of SQLFlow, we only provide a way that simply works for the current version.

To support custom models like CNN text classification, you may check out the current design for ongoing development.

In this tutorial we use two datasets both for english and chinese text classification. The case using chinese dataset is more complicated since Chinese sentences can not be segmented by spaces. You can download the full dataset from:

  1. IMDB-Movie-Reviews-Dataset
  2. chinese-text-classification-dataset

Steps to Process and Train With IMDB Dataset

  1. The imdb database is already loaded in our Docker image, or you can use this script to download, preprocess and insert data into your own MySQL database.
  2. Use the following statements to train and predict using SQLFlow:

    1. %%sqlflow
    2. SELECT content, class FROM imdb.train
    3. TO TRAIN DNNClassifier
    4. WITH
    5. model.n_classes = 2,
    6. model.hidden_units = [128, 64]
    7. LABEL class
    8. INTO sqlflow_models.my_text_model_en;
    1. %%sqlflow
    2. SELECT *
    3. FROM imdb.test
    4. TO PREDICT imdb.predict.class
    5. USING sqlflow_models.my_text_model_en;
  3. Then you can get predict result from table imdb.predict.

Train and Predict Using Custom Keras Model

If you want to train you own custom model written by keras you may need to follow the below steps:

  1. Checkout our “models” repo: git clone https://github.com/sql-machine-learning/models.git
  2. Put your custom model under sqlflow_models/ directory and add importing lines in sqlflow_models/__init__.py, we only support custom model using keras subclass model.
  3. Install models repo on your server you wish to run the training: python setup.py install.
  4. Modify above SQL statement to use custom model by simply change the model name to sqlflow_models.YourAwesomeModel like:

    1. %%sqlflow
    2. SELECT content, class
    3. FROM imdb.train limit 100
    4. TO TRAIN sqlflow_models.StackedBiLSTMClassifier
    5. WITH
    6. model.n_classes = 2,
    7. model.stack_units = [64,32],
    8. model.hidden_size = 64,
    9. train.epoch = 10,
    10. train.batch_size = 64
    11. column EMBEDDING(SEQ_CATEGORY_ID(content, 16000), 128, sum)
    12. LABEL class
    13. INTO sqlflow_models.my_custom_model;

Steps to Run Chinese Text Classification Dataset

  1. Download the dataset from the above link and unpack toutiao_cat_data.txt.zip.
  2. Copy toutiao_cat_data.txt to /var/lib/mysql-files/ on the server your MySQL located on, this is because MySQL may prevent importing data from an untrusted location.
  3. Login to MySQL command line like mysql -uroot -p and create a database and table to load the dataset, note the table must create with CHARSET=utf8 COLLATE=utf8_unicode_ci so that the Chinese texts can be correctly shown.

    1. %%sqlflow
    2. CREATE DATABASE toutiao;
    3. CREATE TABLE `train` (
    4. `id` bigint(20) NOT NULL,
    5. `class_id` int(3) NOT NULL,
    6. `class_name` varchar(100) NOT NULL,
    7. `news_title` varchar(255) NOT NULL,
    8. `news_keywords` varchar(255) NOT NULL)
    9. ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
    10. CREATE TABLE `train_processed` (
    11. `id` bigint(20) NOT NULL,
    12. `class_id` int(3) NOT NULL,
    13. `class_name` varchar(100) NOT NULL,
    14. `news_title` TEXT NOT NULL,
    15. `news_keywords` varchar(255) NOT NULL)
    16. ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
    17. CREATE TABLE `test_processed` (
    18. `id` bigint(20) NOT NULL,
    19. `class_id` int(3) NOT NULL,
    20. `class_name` varchar(100) NOT NULL,
    21. `news_title` TEXT NOT NULL,
    22. `news_keywords` varchar(255) NOT NULL)
    23. ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
    24. COMMIT;
  4. In the MySQL shell, type below line to load the dataset into created table:

    1. %%sqlflow
    2. LOAD DATA LOCAL
    3. INFILE '/var/lib/mysql-files/toutiao_cat_data.txt'
    4. INTO TABLE train
    5. CHARACTER SET utf8
    6. FIELDS TERMINATED by '_!_'
    7. LINES TERMINATED by "\n";
  5. Run this python script to generate a vocabulary, and process the raw news title texts to padded word ids. The max length of the segmented sentence is 92. Note that this python script also change the class_id column’s value to 0~17 which originally is 100~117 since we accept label start from 0.

  6. Split some of the data into a validation table, and remove the validation data from train data:

    1. %%sqlflow
    2. INSERT INTO `test_processed` (`id`, `class_id`, `class_name`, `news_title`, `news_keywords`)
    3. SELECT `id`, `class_id`, `class_name`, `news_title`, `news_keywords` FROM `train_processed`
    4. ORDER BY RAND()
    5. LIMIT 5000;
    6. DELETE FROM `train_processed` WHERE id IN (
    7. SELECT id FROM `test_processed` AS p
    8. )
  7. Then use the following statements to train and predict using SQLFlow:

    1. %%sqlflow
    2. SELECT news_title, class_id
    3. FROM toutiao.train_processed
    4. TO TRAIN DNNClassifier
    5. WITH
    6. model.n_classes = 17,
    7. model.hidden_units = [128, 512]
    8. LABEL class_id
    9. INTO sqlflow_models.my_text_model;
    1. %%sqlflow
    2. SELECT *
    3. FROM toutiao.test_processed
    4. TO PREDICT toutiao.predict.class_id
    5. USING sqlflow_models.my_text_model;
  8. Then you can get predict result from table toutiao.predict: