コード例 #1
0
ファイル: train.py プロジェクト: samuellees/ms_models
def main_worker(local_rank, args):
    args.local_rank = local_rank
    # prepare dist environment
    dist.init_process_group(backend='nccl',
                            rank=args.local_rank,
                            world_size=args.world_size)
    torch.cuda.set_device(args.local_rank)
    network = Xception(num_classes=cfg.num_classes)
    network = network.cuda()
    network = torch.nn.parallel.DistributedDataParallel(
        network, device_ids=[args.local_rank])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(network.parameters(),
                          lr=cfg.lr_init,
                          momentum=cfg.SGD_momentum)
    dataloader_train = create_dataset_pytorch_imagenet_dist_train(
        data_path=args.data_path + 'train',
        local_rank=local_rank,
        n_workers=cfg.n_workers)
    dataloader_test = create_dataset_pytorch_imagenet(
        data_path=args.data_path + 'val',
        is_train=False,
        n_workers=cfg.n_workers)

    step_per_epoch = len(dataloader_train)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          gamma=cfg.lr_decay_rate,
                                          step_size=cfg.lr_decay_epoch *
                                          step_per_epoch)
    summary_writer = None
    if local_rank == 0:
        summary_writer = SummaryWriter(log_dir='./summary')
    trainer = Trainer(network=network,
                      criterion=criterion,
                      optimizer=optimizer,
                      scheduler=scheduler,
                      dataloader_train=dataloader_train,
                      dataloader_test=dataloader_test,
                      summary_writer=summary_writer,
                      epoch_size=cfg.epoch_size,
                      ckpt_path=args.ckpt_path,
                      local_rank=local_rank)

    for epoch_id in range(cfg.epoch_size):
        trainer.step()

    if local_rank == 0:
        summary_writer.close()
コード例 #2
0
                        help='path where the dataset is saved')
    parser.add_argument('--ckpt_path',
                        type=str,
                        default="./checkpoint",
                        help='path where the checkpoint to be saved')
    parser.add_argument('--device_id',
                        type=int,
                        default=0,
                        help='device id of GPU. (Default: 0)')
    args = parser.parse_args()
    args.local_rank = 0
    args.world_size = 1

    network = Xception(num_classes=cfg.num_classes)
    # network = nn.DataParallel(network)
    network = network.cuda()
    criterion = nn.CrossEntropyLoss()
    #     optimizer = optim.RMSprop(network.parameters(),
    #                                 lr=cfg.lr_init,
    #                                 eps=cfg.rmsprop_epsilon,
    #                                 momentum=cfg.rmsprop_momentum,
    #                                 alpha=cfg.rmsprop_decay)
    optimizer = optim.SGD(network.parameters(),
                          lr=cfg.lr_init,
                          momentum=cfg.SGD_momentum)
    # prepare data
    # dataloader = create_dataset_pytorch(args.data_path + "/train")
    pipe = HybridTrainPipe(batch_size=cfg.batch_size,
                           num_threads=cfg.n_workers,
                           device_id=args.local_rank,
                           data_dir=args.data_path,