6-3 Model Training Using Single GPU

The training procedure of deep learning is usually time consuming. It even takes tens of days for training, and there is no need to mention those take days or hours.

The time is mainly consumpted by two stages, data preparation and parameter iteration.

We can increase the number of process to alleviate this issue if data preparation takes the majority of time.

However, if the majority of time is taken by parameter iteration, we need to use GPU or Google TPU for acceleration.

You may refer to this article for further details: “GPU acceleration for Keras Models - How to Use Free Colab GPUs (in Chinese)”

There is no need to modify source code for switching from CPU to GPU when using the pre-defined fit method or the customized training loops. When GPU is available and the device is not specified, TensorFlow automatically chooses GPU for tensor creating and computing.

However, for the case of using shared GPU with multiple users, sucha as using server of the company or the lab, we need to add following code to specify the GPU ID and the GPU memory size that we are going to use, in order to avoid the GPU resources to be occupied by a single user (actually TensorFlow acquires all GPU cors and all GPU memories by default) and allows multiple users perform training on it.

In Colab notebook, choose GPU in Edit -> Notebook Settings -> Hardware Accelerator

Note: the following code only executes on Colab.

You may use the following link for testing (tf_singleGPU, in Chinese)

https://colab.research.google.com/drive/1r5dLoeJq5z01sU72BX2M5UiNSkuxsEFe

  1. %tensorflow_version 2.x
  2. import tensorflow as tf
  3. print(tf.__version__)
  1. from tensorflow.keras import *
  2. # Time stamp
  3. @tf.function
  4. def printbar():
  5. ts = tf.timestamp()
  6. today_ts = ts%(24*60*60)
  7. hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
  8. minite = tf.cast((today_ts%3600)//60,tf.int32)
  9. second = tf.cast(tf.floor(today_ts%60),tf.int32)
  10. def timeformat(m):
  11. if tf.strings.length(tf.strings.format("{}",m))==1:
  12. return(tf.strings.format("0{}",m))
  13. else:
  14. return(tf.strings.format("{}",m))
  15. timestring = tf.strings.join([timeformat(hour),timeformat(minite),
  16. timeformat(second)],separator = ":")
  17. tf.print("=========="*8,end = "")
  18. tf.print(timestring)

1. GPU Configuration

  1. gpus = tf.config.list_physical_devices("GPU")
  2. if gpus:
  3. gpu0 = gpus[0] # Only use GPU 0 when existing multiple GPUs
  4. tf.config.experimental.set_memory_growth(gpu0, True) # Set the usage of GPU memory according to needs
  5. # The GPU memory usage could also be fixed (e.g. 4GB)
  6. #tf.config.experimental.set_virtual_device_configuration(gpu0,
  7. # [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
  8. tf.config.set_visible_devices([gpu0],"GPU")

Compare the computing speed between GPU and CPU.

  1. printbar()
  2. with tf.device("/gpu:0"):
  3. tf.random.set_seed(0)
  4. a = tf.random.uniform((10000,100),minval = 0,maxval = 3.0)
  5. b = tf.random.uniform((100,100000),minval = 0,maxval = 3.0)
  6. c = a@b
  7. tf.print(tf.reduce_sum(tf.reduce_sum(c,axis = 0),axis=0))
  8. printbar()
  1. ================================================================================17:37:01
  2. 2.24953778e+11
  3. ================================================================================17:37:01
  1. printbar()
  2. with tf.device("/cpu:0"):
  3. tf.random.set_seed(0)
  4. a = tf.random.uniform((10000,100),minval = 0,maxval = 3.0)
  5. b = tf.random.uniform((100,100000),minval = 0,maxval = 3.0)
  6. c = a@b
  7. tf.print(tf.reduce_sum(tf.reduce_sum(c,axis = 0),axis=0))
  8. printbar()
  1. ================================================================================17:37:34
  2. 2.24953795e+11
  3. ================================================================================17:37:40

2. Data Preparation

  1. MAX_LEN = 300
  2. BATCH_SIZE = 32
  3. (x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
  4. x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
  5. x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
  6. MAX_WORDS = x_train.max()+1
  7. CAT_NUM = y_train.max()+1
  8. ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
  9. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  10. .prefetch(tf.data.experimental.AUTOTUNE).cache()
  11. ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
  12. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  13. .prefetch(tf.data.experimental.AUTOTUNE).cache()

3. Model Defining

  1. tf.keras.backend.clear_session()
  2. def create_model():
  3. model = models.Sequential()
  4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
  5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
  6. model.add(layers.MaxPool1D(2))
  7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
  8. model.add(layers.MaxPool1D(2))
  9. model.add(layers.Flatten())
  10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
  11. return(model)
  12. model = create_model()
  13. model.summary()
  1. Model: "sequential"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. embedding (Embedding) (None, 300, 7) 216874
  6. _________________________________________________________________
  7. conv1d (Conv1D) (None, 296, 64) 2304
  8. _________________________________________________________________
  9. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
  10. _________________________________________________________________
  11. conv1d_1 (Conv1D) (None, 146, 32) 6176
  12. _________________________________________________________________
  13. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
  14. _________________________________________________________________
  15. flatten (Flatten) (None, 2336) 0
  16. _________________________________________________________________
  17. dense (Dense) (None, 46) 107502
  18. =================================================================
  19. Total params: 332,856
  20. Trainable params: 332,856
  21. Non-trainable params: 0
  22. _________________________________________________________________

4. Model Training

  1. optimizer = optimizers.Nadam()
  2. loss_func = losses.SparseCategoricalCrossentropy()
  3. train_loss = metrics.Mean(name='train_loss')
  4. train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')
  5. valid_loss = metrics.Mean(name='valid_loss')
  6. valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')
  7. @tf.function
  8. def train_step(model, features, labels):
  9. with tf.GradientTape() as tape:
  10. predictions = model(features,training = True)
  11. loss = loss_func(labels, predictions)
  12. gradients = tape.gradient(loss, model.trainable_variables)
  13. optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  14. train_loss.update_state(loss)
  15. train_metric.update_state(labels, predictions)
  16. @tf.function
  17. def valid_step(model, features, labels):
  18. predictions = model(features)
  19. batch_loss = loss_func(labels, predictions)
  20. valid_loss.update_state(batch_loss)
  21. valid_metric.update_state(labels, predictions)
  22. def train_model(model,ds_train,ds_valid,epochs):
  23. for epoch in tf.range(1,epochs+1):
  24. for features, labels in ds_train:
  25. train_step(model,features,labels)
  26. for features, labels in ds_valid:
  27. valid_step(model,features,labels)
  28. logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
  29. if epoch%1 ==0:
  30. printbar()
  31. tf.print(tf.strings.format(logs,
  32. (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
  33. tf.print("")
  34. train_loss.reset_states()
  35. valid_loss.reset_states()
  36. train_metric.reset_states()
  37. valid_metric.reset_states()
  38. train_model(model,ds_train,ds_test,10)
  1. ================================================================================17:13:26
  2. Epoch=1,Loss:1.96735072,Accuracy:0.489200622,Valid Loss:1.64124215,Valid Accuracy:0.582813919
  3. ================================================================================17:13:28
  4. Epoch=2,Loss:1.4640888,Accuracy:0.624805152,Valid Loss:1.5559175,Valid Accuracy:0.607747078
  5. ================================================================================17:13:30
  6. Epoch=3,Loss:1.20681274,Accuracy:0.68581605,Valid Loss:1.58494771,Valid Accuracy:0.622439921
  7. ================================================================================17:13:31
  8. Epoch=4,Loss:0.937500894,Accuracy:0.75361836,Valid Loss:1.77466083,Valid Accuracy:0.621994674
  9. ================================================================================17:13:33
  10. Epoch=5,Loss:0.693960547,Accuracy:0.822199941,Valid Loss:2.00267363,Valid Accuracy:0.6197685
  11. ================================================================================17:13:35
  12. Epoch=6,Loss:0.519614,Accuracy:0.870296121,Valid Loss:2.23463202,Valid Accuracy:0.613980412
  13. ================================================================================17:13:37
  14. Epoch=7,Loss:0.408562034,Accuracy:0.901246965,Valid Loss:2.46969271,Valid Accuracy:0.612199485
  15. ================================================================================17:13:39
  16. Epoch=8,Loss:0.339028627,Accuracy:0.920062363,Valid Loss:2.68585229,Valid Accuracy:0.615316093
  17. ================================================================================17:13:41
  18. Epoch=9,Loss:0.293798745,Accuracy:0.92930305,Valid Loss:2.88995624,Valid Accuracy:0.613535166
  19. ================================================================================17:13:43
  20. Epoch=10,Loss:0.263130337,Accuracy:0.936651051,Valid Loss:3.09705234,Valid Accuracy:0.612644672

Please leave comments in the WeChat official account “Python与算法之美” (Elegance of Python and Algorithms) if you want to communicate with the author about the content. The author will try best to reply given the limited time available.

You are also welcomed to join the group chat with the other readers through replying 加群 (join group) in the WeChat official account.

image.png