예제 #1
0
criterion = criterion.cuda()
criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
criterion_smooth = criterion_smooth.cuda()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(),
                      lr=args.base_lr * hvd.size(),
                      momentum=args.momentum,
                      weight_decay=args.wd)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
                                     named_parameters=model.named_parameters(),
                                     compression=compression)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
if resume_from_epoch > 0 and hvd.rank() == 0:
    filepath = args.checkpoint_format.format(exp=args.save,
                                             epoch=resume_from_epoch)
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
예제 #2
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    print(torch.cuda.device_count())
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    if hvd.rank() == 0:
        logging.info('gpu device = %d' % args.gpu)
        logging.info("args = %s", args)

    best_acc_top1 = 0
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(os.path.join(args.save, 'checkpoint.pth.tar'))
        best_checkpoint = torch.load(
            os.path.join(args.save, 'model_best.pth.tar'))
        start_epoch = checkpoint['epoch']
        best_acc_top1 = best_checkpoint['best_acc_top1']
        start_epoch = hvd.broadcast(torch.tensor(start_epoch),
                                    root_rank=0,
                                    name='start_epoch').item()
        best_acc_top1 = hvd.broadcast(torch.tensor(best_acc_top1),
                                      root_rank=0,
                                      name='best_acc_top1').item()

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary,
                    genotype)

    if args.parallel:
        model = nn.DataParallel(model).cuda()
    else:
        model = model.cuda()

    if hvd.rank() == 0:
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate * hvd.size(),
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # ***************** horovod *******************
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    # ***************** horovod *******************

    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_data = dset.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4,
                                   hue=0.2),
            transforms.ToTensor(),
            normalize,
        ]))
    valid_data = dset.ImageFolder(
        validdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_data, num_replicas=hvd.size(), rank=hvd.rank())
    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              pin_memory=True,
                                              num_workers=args.num_workers,
                                              sampler=train_sampler)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=args.num_workers)

    if start_epoch > 0 and hvd.rank() == 0:
        checkpoint = torch.load(os.path.join(args.save, 'checkpoint.pth.tar'))
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("checkpoint {}, model, optimizer was loaded".format(start_epoch))

    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    if not args.resume:
        set_lr(0, 0, len(train_queue), optimizer, args.scheduler)

    for epoch in range(start_epoch, args.epochs + args.warmup_epochs):
        if hvd.rank() == 0:
            lr = optimizer.param_groups[0]['lr']
            logging.info('epoch %d lr %e', epoch, lr)
            with open(os.path.join(args.save, 'learning_rate.txt'),
                      mode='a') as f:
                f.write(str(lr) + '\n')

        if args.parallel:
            model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
        else:
            model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        hvd.broadcast_parameters(model.state_dict(), root_rank=0)

        train_acc, train_obj = train(train_queue, train_sampler, model,
                                     criterion_smooth, optimizer, epoch)
        if hvd.rank() == 0:
            logging.info('train_acc %f', train_acc)
            with open(os.path.join(args.save, "train_acc.txt"), mode='a') as f:
                f.write(str(train_acc) + '\n')
            with open(os.path.join(args.save, "train_loss.txt"),
                      mode='a') as f:
                f.write(str(train_obj) + '\n')

        valid_acc_top1, valid_acc_top5, valid_obj = infer(
            valid_queue, model, criterion)
        if hvd.rank() == 0:
            logging.info('valid_acc_top1 %f', valid_acc_top1)
            logging.info('valid_acc_top5 %f', valid_acc_top5)
            with open(os.path.join(args.save, "test_acc_1.txt"),
                      mode='a') as f:
                f.write(str(valid_acc_top1) + '\n')
            with open(os.path.join(args.save, "test_acc_5.txt"),
                      mode='a') as f:
                f.write(str(valid_acc_top5) + '\n')
            with open(os.path.join(args.save, "test_loss.txt"), mode='a') as f:
                f.write(str(valid_obj) + '\n')

        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True

        if hvd.rank() == 0:
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc_top1': best_acc_top1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save)