UtilBase

class paddle.distributed.fleet. UtilBase [源代码]

分布式训练工具类,主要提供集合通信、文件系统操作等接口。

all_reduce ( input, mode=’sum’, comm_world=’worker’ )

在指定的通信集合间进行归约操作,并将归约结果返回给集合中每个实例。

参数:

  • input (list|numpy.array) – 归约操作的输入。

  • mode (str) - 归约操作的模式,包含求和,取最大值和取最小值,默认为求和归约。

  • comm_world (str) - 归约操作的通信集合,包含: server集合(“server”),worker集合(“worker”)及所有节点集合(“all”),默认为worker集合。

返回:

  • Numpy.array|None: 一个和 input 形状一致的numpy数组或None.

代码示例:

  1. # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
  2. import paddle.distributed.fleet as fleet
  3. from paddle.distributed.fleet import PaddleCloudRoleMaker
  4. import sys
  5. import numpy as np
  6. import os
  7. os.environ["PADDLE_WITH_GLOO"] = "2"
  8. def train():
  9. role = PaddleCloudRoleMaker(
  10. is_collective=False,
  11. init_gloo=True,
  12. path="./tmp_gloo")
  13. fleet.init(role)
  14. if fleet.is_server():
  15. input = [1, 2]
  16. output = fleet.util.all_reduce(input, "sum", "server")
  17. print(output)
  18. # [2, 4]
  19. elif fleet.is_worker():
  20. input = np.array([3, 4])
  21. output = fleet.util.all_reduce(input, "sum", "worker")
  22. print(output)
  23. # [6, 8]
  24. output = fleet.util.all_reduce(input, "sum", "all")
  25. print(output)
  26. # [8, 12]
  27. if __name__ == "__main__":
  28. train()

barrier ( comm_world=’worker’ )

在指定的通信集合间进行阻塞操作,以实现集合间进度同步。

参数:

  • comm_world (str) - 阻塞操作的通信集合,包含: server集合(“server”),worker集合(“worker”)及所有节点集合(“all”),默认为worker集合。

代码示例:

  1. # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
  2. import paddle.distributed.fleet as fleet
  3. from paddle.distributed.fleet import PaddleCloudRoleMaker
  4. import sys
  5. import os
  6. os.environ["PADDLE_WITH_GLOO"] = "2"
  7. def train():
  8. role = PaddleCloudRoleMaker(
  9. is_collective=False,
  10. init_gloo=True,
  11. path="./tmp_gloo")
  12. fleet.init(role)
  13. if fleet.is_server():
  14. fleet.util.barrier("server")
  15. print("all server arrive here")
  16. elif fleet.is_worker():
  17. fleet.util.barrier("worker")
  18. print("all server arrive here")
  19. fleet.util.barrier("all")
  20. print("all servers and workers arrive here")
  21. if __name__ == "__main__":
  22. train()

all_gather ( input, comm_world=’worker’ )

在指定的通信集合间进行聚合操作,并将聚合的结果返回给集合中每个实例。

参数:

  • input (int|float) - 聚合操作的输入。

  • comm_world (str) - 聚合操作的通信集合,包含: server集合(“server”),worker集合(“worker”)及所有节点集合(“all”),默认为worker集合。

返回:

  • output (List): List格式的聚合结果。

代码示例:

  1. # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
  2. import paddle.distributed.fleet as fleet
  3. from paddle.distributed.fleet import PaddleCloudRoleMaker
  4. import sys
  5. import os
  6. os.environ["PADDLE_WITH_GLOO"] = "2"
  7. def train():
  8. role = PaddleCloudRoleMaker(
  9. is_collective=False,
  10. init_gloo=True,
  11. path="./tmp_gloo")
  12. fleet.init(role)
  13. if fleet.is_server():
  14. input = fleet.server_index()
  15. output = fleet.util.all_gather(input, "server")
  16. print(output)
  17. # output = [0, 1]
  18. elif fleet.is_worker():
  19. input = fleet.worker_index()
  20. output = fleet.util.all_gather(input, "worker")
  21. # output = [0, 1]
  22. print(output)
  23. output = fleet.util.all_gather(input, "all")
  24. print(output)
  25. # output = [0, 1, 0, 1]
  26. if __name__ == "__main__":
  27. train()

get_file_shard ( files )

在数据并行的分布式训练中,获取属于当前训练节点的文件列表。

  1. 示例 1: 原始所有文件列表 `files` = [a, b, c ,d, e],训练节点个数 `trainer_num` = 2,那么属于零号节点的训练文件为[a, b, c],属于1号节点的训练文件为[d, e]。
  2. 示例 2: 原始所有文件列表 `files` = [a, b],训练节点个数 `trainer_num` = 3,那么属于零号节点的训练文件为[a],属于1号节点的训练文件为[b],属于2号节点的训练文件为[]。

参数:

  • files (List):原始所有文件列表。

返回:

  • List: 属于当前训练节点的文件列表。

代码示例:

  1. import paddle.distributed.fleet as fleet
  2. import paddle.distributed.fleet.base.role_maker as role_maker
  3. role = role_maker.UserDefinedRoleMaker(
  4. is_collective=False,
  5. init_gloo=False,
  6. current_id=0,
  7. role=role_maker.Role.WORKER,
  8. worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
  9. server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
  10. fleet.init(role)
  11. files = fleet.util.get_file_shard(["file1", "file2", "file3"])
  12. print(files)
  13. # files = ["file1", "file2"]

print_on_rank ( message, rank_id )

在编号为 rank_id 的节点上打印指定信息。

参数:

  • message (str) – 打印内容。

  • rank_id (int) - 节点编号。

代码示例:

  1. import paddle.distributed.fleet as fleet
  2. import paddle.distributed.fleet.base.role_maker as role_maker
  3. role = role_maker.UserDefinedRoleMaker(
  4. is_collective=False,
  5. init_gloo=False,
  6. current_id=0,
  7. role=role_maker.Role.WORKER,
  8. worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
  9. server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
  10. fleet.init(role)
  11. fleet.util.print_on_rank("I'm worker 0", 0)