三天半的清明假期开始啦!先来总结下PyTorch分布式训练相关的内容。
如GETTING STARTED WITH DISTRIBUTED DATA PARALLEL,DistributedDataParallel(DDP)可以跨多台机器运行,实现数据并行。DDP使用torch.distributed包中的集体通信来同步梯度和buffer。
研究生的时候一直使用的是DataParallel,但是DistributedDataParallel的效率更高,主要区别如下:
- DataParallel是单进程、多线程,并且只在单机上工作,而DistributedDataParallel是多进程的,适用于单机和多机训练。
- 由于跨线程的GIL争用、每次迭代的复制模型以及分散输入和收集输出引入的额外开销,即使在单台机器上,DataParallel通常也比DistributedDataParallel慢。
- DistributedDataParallel可以和model parallel一起使用,而DataParallel不可以。
因此使用DistributedDataParallel是非常有必要的!
下面就来介绍一下使用方法(TCP初始化方式)
输入参数
1 | def myargs(): |
batch_size
指一台机子上所有GPU的总batch_size;dist_url
为主机的ip:port,port未被占用,当只有一台机子时,使用本机ip 127.0.0.1;dist_backend
一般设置为nccl;world_size
为使用的机器总数;rank
为机器序号;gpu
不需要指定,表示当前机子当前进程的局部rank号,torch.multiprocessing.spawn会自动将其传入main_worker函数。
运行样例
1 | # 单机4GPU运行 |
训练代码
- 主函数,通过spawn开启多个进程,一个进程一个GPU,传入main_worker函数的参数为(当前局部rank号可以视为GPU号,args)。
1 | if __name__ == '__main__': |
- main函数,每个进程都会执行
1 | def main_worker(gpu, ngpus_per_node, args): |