def main_worker(local_rank, args): args.local_rank = local_rank # prepare dist environment dist.init_process_group(backend='nccl', rank=args.local_rank, world_size=args.world_size) torch.cuda.set_device(args.local_rank) network = Xception(num_classes=cfg.num_classes) network = network.cuda() network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[args.local_rank]) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(network.parameters(), lr=cfg.lr_init, momentum=cfg.SGD_momentum) dataloader_train = create_dataset_pytorch_imagenet_dist_train( data_path=args.data_path + 'train', local_rank=local_rank, n_workers=cfg.n_workers) dataloader_test = create_dataset_pytorch_imagenet( data_path=args.data_path + 'val', is_train=False, n_workers=cfg.n_workers) step_per_epoch = len(dataloader_train) scheduler = optim.lr_scheduler.StepLR(optimizer, gamma=cfg.lr_decay_rate, step_size=cfg.lr_decay_epoch * step_per_epoch) summary_writer = None if local_rank == 0: summary_writer = SummaryWriter(log_dir='./summary') trainer = Trainer(network=network, criterion=criterion, optimizer=optimizer, scheduler=scheduler, dataloader_train=dataloader_train, dataloader_test=dataloader_test, summary_writer=summary_writer, epoch_size=cfg.epoch_size, ckpt_path=args.ckpt_path, local_rank=local_rank) for epoch_id in range(cfg.epoch_size): trainer.step() if local_rank == 0: summary_writer.close()
help='path where the dataset is saved') parser.add_argument('--ckpt_path', type=str, default="./checkpoint", help='path where the checkpoint to be saved') parser.add_argument('--device_id', type=int, default=0, help='device id of GPU. (Default: 0)') args = parser.parse_args() args.local_rank = 0 args.world_size = 1 network = Xception(num_classes=cfg.num_classes) # network = nn.DataParallel(network) network = network.cuda() criterion = nn.CrossEntropyLoss() # optimizer = optim.RMSprop(network.parameters(), # lr=cfg.lr_init, # eps=cfg.rmsprop_epsilon, # momentum=cfg.rmsprop_momentum, # alpha=cfg.rmsprop_decay) optimizer = optim.SGD(network.parameters(), lr=cfg.lr_init, momentum=cfg.SGD_momentum) # prepare data # dataloader = create_dataset_pytorch(args.data_path + "/train") pipe = HybridTrainPipe(batch_size=cfg.batch_size, num_threads=cfg.n_workers, device_id=args.local_rank, data_dir=args.data_path,