分布式数据并行(DDP)入门

作者Shen Li

译者Hamish

校验Hamish

DistributedDataParallel(DDP)在模块级别实现数据并行性。它使用torch.distributed包中的通信集合体来同步梯度,参数和缓冲区。并行性在流程内和跨流程均可用。在一个过程中,DDP将输入模块复制到device_ids中指定的设备,相应地沿批处理维度分散输入,并将输出收集到output_device,这与DataParallel相似。在整个过程中,DDP在正向传递中插入必要的参数同步,在反向传递中插入梯度同步。用户可以将进程映射到可用资源,只要进程不共享GPU设备即可。推荐的方法(通常是最快的方法)是为每个模块副本创建一个过程,即在一个过程中不进行任何模块复制。本教程中的代码在8-GPU服务器上运行,但可以轻松地推广到其他环境。

DataParallelDistributedDataParallel之间的比较

在深入研究之前,让我们澄清一下为什么,尽管增加了复杂性,您还是会考虑使用DistributedDataParallel而不是DataParallel

  • 首先,回想一下之前的教程,如果模型太大,无法被单个GPU容纳,则必须使用模型并行化将其拆分至多个GPU。DistributedDataParallel可以与模型并行化一起工作;DataParallel此时不工作。
  • DataParallel是单进程、多线程的,并且只在一台机器上工作;而DistributedDataParallel是多进程的,可用于单机和多机训练。因此,即使对于单机训练,数据足够小,可以放在一台机器上,DistributedDataParallel也会比DataParallel更快。DistributedDataParallel还可以预先复制模型,而不是在每次迭代时复制模型,从而可以避免全局解释器锁定。
  • 如果您的数据太大,无法在一台机器上容纳,并且您的模型也太大,无法在单个GPU上容纳,则可以将模型并行化(跨多个GPU拆分单个模型)与DistributedDataParallel结合起来。在这种机制下,每个DistributedDataParallel进程都可以使用模型并行化,同时所有进程都可以使用数据并行。

基本用例

要创建DDP模块,请首先正确设置进程组。更多细节可以在使用PyTorch编写分布式应用程序中找到。

  1. import os
  2. import tempfile
  3. import torch
  4. import torch.distributed as dist
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.multiprocessing as mp
  8. from torch.nn.parallel import DistributedDataParallel as DDP
  9. def setup(rank, world_size):
  10. os.environ['MASTER_ADDR'] = 'localhost'
  11. os.environ['MASTER_PORT'] = '12355'
  12. # initialize the process group
  13. dist.init_process_group("gloo", rank=rank, world_size=world_size)
  14. # Explicitly setting seed to make sure that models created in two processes
  15. # start from same random weights and biases.
  16. torch.manual_seed(42)
  17. def cleanup():
  18. dist.destroy_process_group()

现在,让我们创建一个玩具模块,用DDP包装它,并用一些虚拟输入数据给它输入。请注意,如果训练是从随机参数开始的,您可能需要确保所有DDP进程使用相同的初始值。否则,全局梯度同步将没有意义。

  1. class ToyModel(nn.Module):
  2. def __init__(self):
  3. super(ToyModel, self).__init__()
  4. self.net1 = nn.Linear(10, 10)
  5. self.relu = nn.ReLU()
  6. self.net2 = nn.Linear(10, 5)
  7. def forward(self, x):
  8. return self.net2(self.relu(self.net1(x)))
  9. def demo_basic(rank, world_size):
  10. setup(rank, world_size)
  11. # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
  12. # rank 2 uses GPUs [4, 5, 6, 7].
  13. n = torch.cuda.device_count() // world_size
  14. device_ids = list(range(rank * n, (rank + 1) * n))
  15. # create model and move it to device_ids[0]
  16. model = ToyModel().to(device_ids[0])
  17. # output_device defaults to device_ids[0]
  18. ddp_model = DDP(model, device_ids=device_ids)
  19. loss_fn = nn.MSELoss()
  20. optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
  21. optimizer.zero_grad()
  22. outputs = ddp_model(torch.randn(20, 10))
  23. labels = torch.randn(20, 5).to(device_ids[0])
  24. loss_fn(outputs, labels).backward()
  25. optimizer.step()
  26. cleanup()
  27. def run_demo(demo_fn, world_size):
  28. mp.spawn(demo_fn,
  29. args=(world_size,),
  30. nprocs=world_size,
  31. join=True)

如您所见,DDP包装了较低级别的分布式通信细节,并提供了一个干净的API,就好像它是一个本地模型一样。对于基本用例,DDP只需要几个loc来设置流程组。在将DDP应用于更高级的用例时,需要注意一些注意事项。

不均衡的处理速度

在DDP中,构造函数、前向方法和输出的微分是分布式同步点。不同的进程将以相同的顺序到达同步点,并在大致相同的时间进入每个同步点。否则,快速进程可能会提前到达,并在等待散乱的进程时超时。因此,用户需要负责跨进程平衡工作负载的分配。有时,由于网络延迟、资源竞争、不可预测的工作量高峰,不均衡的处理速度是不可避免的。要避免在这些情况下超时,请确保在调用init_process_group时传递足够大的timeout值。

保存和载入检查点

在训练过程中,经常使用torch.savetorch.load为模块创建检查点,以及从检查点恢复。有关的详细信息,请参见保存和加载模型。在使用DDP时,一种优化方法是只在一个进程中保存模型,然后将其加载到所有进程中,从而减少写开销。这是正确的,因为所有进程都是从相同的参数开始的,并且梯度在反向过程中是同步的,因此优化器应该将参数设置为相同的值。如果使用这种优化方法,请确保在保存完成之前,所有进程都不会开始加载。此外,加载模块时,需要提供适当的map_location参数,以防止进程进入其他设备。如果缺少map_locationtorch.load将首先将模块加载到CPU,然后将每个参数复制到其保存的位置,这将导致同一台计算机上的所有进程使用同一组设备。

  1. def demo_checkpoint(rank, world_size):
  2. setup(rank, world_size)
  3. # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
  4. # rank 2 uses GPUs [4, 5, 6, 7].
  5. n = torch.cuda.device_count() // world_size
  6. device_ids = list(range(rank * n, (rank + 1) * n))
  7. model = ToyModel().to(device_ids[0])
  8. # output_device defaults to device_ids[0]
  9. ddp_model = DDP(model, device_ids=device_ids)
  10. loss_fn = nn.MSELoss()
  11. optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
  12. CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
  13. if rank == 0:
  14. # All processes should see same parameters as they all start from same
  15. # random parameters and gradients are synchronized in backward passes.
  16. # Therefore, saving it in one process is sufficient.
  17. torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
  18. # Use a barrier() to make sure that process 1 loads the model after process
  19. # 0 saves it.
  20. dist.barrier()
  21. # configure map_location properly
  22. rank0_devices = [x - rank * len(device_ids) for x in device_ids]
  23. device_pairs = zip(rank0_devices, device_ids)
  24. map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
  25. ddp_model.load_state_dict(
  26. torch.load(CHECKPOINT_PATH, map_location=map_location))
  27. optimizer.zero_grad()
  28. outputs = ddp_model(torch.randn(20, 10))
  29. labels = torch.randn(20, 5).to(device_ids[0])
  30. loss_fn = nn.MSELoss()
  31. loss_fn(outputs, labels).backward()
  32. optimizer.step()
  33. # Use a barrier() to make sure that all processes have finished reading the
  34. # checkpoint
  35. dist.barrier()
  36. if rank == 0:
  37. os.remove(CHECKPOINT_PATH)
  38. cleanup()

结合DDP与模型并行化

DDP也适用于多GPU模型,但不支持进程内的复制。您需要为每个模块副本创建一个进程,这通常会比每个进程创建多个副本带来更好的性能。当使用大量数据训练大型模型时,DDP包装多GPU模型尤其有用。使用此功能时,需要小心地实现多GPU模型,以避免硬编码设备,因为不同的模型副本将被放置到不同的设备上。

  1. class ToyMpModel(nn.Module):
  2. def __init__(self, dev0, dev1):
  3. super(ToyMpModel, self).__init__()
  4. self.dev0 = dev0
  5. self.dev1 = dev1
  6. self.net1 = torch.nn.Linear(10, 10).to(dev0)
  7. self.relu = torch.nn.ReLU()
  8. self.net2 = torch.nn.Linear(10, 5).to(dev1)
  9. def forward(self, x):
  10. x = x.to(self.dev0)
  11. x = self.relu(self.net1(x))
  12. x = x.to(self.dev1)
  13. return self.net2(x)

将多GPU模型传递给DDP时,不能设置device_idsoutput_device。输入和输出数据将由应用程序或模型forward()方法放置在适当的设备中。

  1. def demo_model_parallel(rank, world_size):
  2. setup(rank, world_size)
  3. # setup mp_model and devices for this process
  4. dev0 = rank * 2
  5. dev1 = rank * 2 + 1
  6. mp_model = ToyMpModel(dev0, dev1)
  7. ddp_mp_model = DDP(mp_model)
  8. loss_fn = nn.MSELoss()
  9. optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
  10. optimizer.zero_grad()
  11. # outputs will be on dev1
  12. outputs = ddp_mp_model(torch.randn(20, 10))
  13. labels = torch.randn(20, 5).to(dev1)
  14. loss_fn(outputs, labels).backward()
  15. optimizer.step()
  16. cleanup()
  17. if __name__ == "__main__":
  18. run_demo(demo_basic, 2)
  19. run_demo(demo_checkpoint, 2)
  20. if torch.cuda.device_count() >= 8:
  21. run_demo(demo_model_parallel, 4)