Design Doc: Synchronous SGD

  1. # Basic design philosophy:
  2. #
  3. # - Workers are gRPC client, the master is the gRPC server.
  4. # - The master keeps the three task queues: todo, doing, and done.
  5. # - The master keeps the current global model and the model ID is the hash of the model parameters.
  6. #------- worker.py -------#
  7. model_params = NULL
  8. model_version = NULL
  9. module = parse_command_line_for_class_name()(
  10. parse_command_line_for_ctor_params())
  11. master = grpc.create_client(getenv("MASTER_ADDR"))
  12. while True:
  13. # Claim a task from the master. A task consists of a data segment and the
  14. # model id. The meaning of a task is "to update the specified model with
  15. # the given data segment".
  16. task, err = master.GetTask()
  17. if err == NO_MORE_TASK:
  18. break # Training completed.
  19. if err != NULL:
  20. continue # Retry to get task.
  21. task_status = SUCCEED
  22. for minibatch in read_data(task):
  23. accepted = False
  24. report_count = 0
  25. while not accepted:
  26. try:
  27. # If the current model_version on the worker is older than the model
  28. # on the master, this call updates model_version and
  29. # model_params; otherwise, it leaves these two variables unchanged.
  30. master.UpdateModelIfOutOfDate(&model_version, &model_params)
  31. cost = module.forward(data, model_params)
  32. gradients = module.backward(cost, model_params)
  33. except:
  34. task_status = FAILED
  35. break
  36. else:
  37. # If the reported gradients are not accepted by the master due to old model_version,
  38. # try the minibatch again with the updated model in the next while loop.
  39. # Fail the task if the minibatch report count exceeds a predefined threshold.
  40. accepted = master.ReportGradients(model_version, gradients)
  41. if not accepted:
  42. report_count += 1
  43. if report_count == PREDEFINED_MAX_REPORT_COUNT:
  44. task_status = FAILED
  45. break
  46. if task_status == FAILED:
  47. break
  48. master.ReportTask(task, task_status)
  49. #------- master.py -------#
  50. inputs = parse_command_line().recordio_files() # No online learning in this version.
  51. todo = partition_data_into_tasks(inputs)
  52. doing = Queue()
  53. done = Queue()
  54. module = parse_command_line_for_class_name()(
  55. parse_command_line_for_ctor_params())
  56. model_params = module.create_parameters().random_initialize()
  57. model_version = 0
  58. gradients = []
  59. @grpc
  60. def UpdateModelIfOutOfDate(mv, mp):
  61. if model_version != mv:
  62. copy(*mv, model_version)
  63. copy(*mp, model_params)
  64. @grpc
  65. def GetTask():
  66. task = todo.pop()
  67. doing.push(task)
  68. return task
  69. @grpc
  70. def ReportGradients(mv, grads):
  71. accepted = False
  72. if mv == model_version:
  73. gradients += grads
  74. accepted = True
  75. if len(gradients) >= num_gradients_sufficient_to_update_model():
  76. model_params = optimize_model(model_params, gradients)
  77. model_version = model_version + 1
  78. gradients = [] # Clear out the buffer.
  79. return accepted
  80. @grpc
  81. def ReportTask(task, status):
  82. if status == FAILED:
  83. move_task(task, doing, todo) # Move the task from doing back to todo
  84. else:
  85. move_task(task, doing, done) # Move the task from doing to done