1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| def myargs(): parser = argparse.ArgumentParser(description='training example') parser.add_argument('--batch_size', default=256, type=int, help='this is the total batch size of all GPUs on the current node when using Distributed Data Parallel') parser.add_argument('--workers', default=8, type=int, help='number of data loading workers') parser.add_argument('--epochs', default=100, type=int, help='number of total epochs to run') parser.add_argument('--seed', default=20, type=int, help='seed for initializing training.') parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training') parser.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend') parser.add_argument('--world_size', default=1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=0, type=int, help='node rank for distributed training') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') myargs = parser.parse_args() return myargs
def main_worker(gpu, ngpus_per_node, args): random.seed(myargs.random_seed) np.random.seed(myargs.random_seed) torch.manual_seed(myargs.random_seed) torch.cuda.manual_seed_all(myargs.random_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False args.gpu = gpu print("Use GPU: {} for distributed training".format(args.gpu)) args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
model = resnet50() torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) torch_scaler = torch.cuda.amp.GradScaler()
for epoch in range(args.epochs): train_sampler.set_epoch(epoch)
model.train() for i, (images, target) in enumerate(train_loader): optimizer.zero_grad() images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) with torch.cuda.amp.autocast(): output = model(images) loss = criterion(output, target) torch_scaler.scale(loss).backward() torch_scaler.step(optimizer) torch_scaler.update()
scheduler.step() if args.rank % ngpus_per_node == 0: torch.save(model.module.state_dict(), 'checkpoint.pth.tar')
torch.cuda.empty_cache()
if __name__ == '__main__': args = myargs()
ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|