Exemple #1
0
def training_loop_lars(model,
                       criterion,
                       train_loader,
                       valid_loader,
                       epochs,
                       device,
                       print_every=1):
    '''
    Function defining the entire training loop
    '''

    # set objects for storing metrics
    train_losses = []
    valid_losses = []
    train_accuracy = []
    valid_accuracy = []

    # Train model
    for epoch in range(0, epochs):

        #training
        if epoch < 10:
            optimizer = LARS(model.parameters(),
                             lr=0.1 * (epoch + 1) / 10,
                             momentum=0.9)
        elif epoch < 15:
            optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
        else:
            optimizer = LARS(model.parameters(),
                             lr=0.1 * (0.95**(epoch - 15)),
                             momentum=0.9)
        model, optimizer, train_loss = train(train_loader, model, criterion,
                                             optimizer, device)
        train_losses.append(train_loss)

        # validation
        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion,
                                         device)
            valid_losses.append(valid_loss)

        if epoch % print_every == (print_every - 1):
            train_acc = get_accuracy(model, train_loader, device=device)
            valid_acc = get_accuracy(model, valid_loader, device=device)
            train_accuracy.append(float(train_acc))
            valid_accuracy.append(float(valid_acc))
            print(f'{datetime.now().time().replace(microsecond=0)} --- '
                  f'Epoch: {epoch}\t'
                  f'Train loss: {train_loss:.4f}\t'
                  f'Valid loss: {valid_loss:.4f}\t'
                  f'Train accuracy: {100 * train_acc:.2f}\t'
                  f'Valid accuracy: {100 * valid_acc:.2f}\n')

    #plot_losses(train_losses, valid_losses)

    return model, [train_losses, valid_losses, train_accuracy, valid_accuracy]
Exemple #2
0
def build_byol_optimizer(hparams: AttrDict, model: nn.Module) -> Optimizer:
    """
    Build optimizer for BYOL self-supervised network, including backbone.
    """
    regular_parameters = []
    excluded_parameters = []
    for name, parameter in model.named_parameters():
        if parameter.requires_grad is False:
            continue
        if any(x in name for x in [".bn", ".bias"]):
            excluded_parameters.append(parameter)
        else:
            regular_parameters.append(parameter)
    param_groups = [
        {
            "params": regular_parameters,
            "use_lars": True
        },
        {
            "params": excluded_parameters,
            "use_lars": False,
            "weight_decay": 0,
        },
    ]
    return LARS(
        param_groups,
        lr=hparams.self_supervised.learning_rate.base,
        eta=hparams.self_supervised.lars_eta,
        momentum=hparams.self_supervised.momentum,
        weight_decay=hparams.self_supervised.weight_decay,
    )
Exemple #3
0
def load_model(args):
    model = SimCLR(backbone=args.backbone,
                   projection_dim=args.projection_dim,
                   pretrained=args.pretrained,
                   normalize=args.normalize)

    if args.inference:
        model.load_state_dict(
            torch.load("SimCLR_{}_epoch90.pth".format(args.backbone)))

    model = model.to(args.device)

    scheduler = None
    if args.optimizer == "Adam":
        optimizer = Adam(model.parameters(), lr=3e-4)  # TODO: LARS
    elif args.optimizer == "LARS":
        # optimized using LARS with linear learning rate scaling
        # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
        learning_rate = 0.3 * args.batch_size / 256
        optimizer = LARS(
            model.parameters(),
            lr=learning_rate,
            weight_decay=args.weight_decay,
            exclude_from_weight_decay=["batch_normalization", "bias"],
        )

        # "decay the learning rate with the cosine decay schedule without restarts"
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               args.epochs,
                                                               eta_min=0,
                                                               last_epoch=-1)
    else:
        raise NotImplementedError

    return model, optimizer, scheduler
    def configure_optimizers(self):
        #With this thing we get only params, which requires grad (weights needed to train)
        params = filter(lambda p: p.requires_grad, self.model.parameters())

        if self.hparams.optimizer == "SGD":
            self.optimizer = torch.optim.SGD(params, self.hparams.lr, momentum = self.hparams.momentum, weight_decay=self.hparams.wd)
        elif self.hparams.optimizer == "LARS":
            self.optimizer = LARS(params, lr=self.hparams.lr, momentum=self.hparams.momentum, weight_decay=self.hparams.wd, max_epoch=self.hparams.epochs)

        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, self.hparams.lr, epochs=self.hparams.epochs, steps_per_epoch=1, pct_start=self.hparams.pct_start)
        sched_dict = {'scheduler': self.scheduler}

        return [self.optimizer], [sched_dict]
Exemple #5
0
def do_training(args):
    trainloader, testloader = build_dataset(
        args.dataset,
        dataroot=args.dataroot,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        num_workers=2)
    model = build_model(args.arch, num_classes=num_classes(args.dataset))
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    # Calculate total number of model parameters
    num_params = sum(p.numel() for p in model.parameters())
    track.metric(iteration=0, num_params=num_params)

    num_chunks = max(1, args.batch_size // args.max_samples_per_gpu)

    optimizer = LARS(params=model.parameters(),
                     lr=args.lr,
                     momentum=args.momentum,
                     weight_decay=args.weight_decay,
                     eta=args.eta,
                     max_epoch=args.epochs)

    criterion = torch.nn.CrossEntropyLoss()

    best_acc = 0.0
    for epoch in range(args.epochs):
        track.debug("Starting epoch %d" % epoch)
        train_loss, train_acc = train(trainloader,
                                      model,
                                      criterion,
                                      optimizer,
                                      epoch,
                                      args.cuda,
                                      num_chunks=num_chunks)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   args.cuda)
        track.debug('Finished epoch %d... | train loss %.3f | train acc %.3f '
                    '| test loss %.3f | test acc %.3f' %
                    (epoch, train_loss, train_acc, test_loss, test_acc))
        # Save model
        model_fname = os.path.join(track.trial_dir(),
                                   "model{}.ckpt".format(epoch))
        torch.save(model, model_fname)
        if test_acc > best_acc:
            best_acc = test_acc
            best_fname = os.path.join(track.trial_dir(), "best.ckpt")
            track.debug("New best score! Saving model")
            torch.save(model, best_fname)
 def lars(self):
     print("Enable LARS Optimizer Algorithm")
     self.optimizer = LARS(self.optimizer)
def main(lr=0.1):
    global best_acc
    args.lr = lr
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='/tmp/cifar10',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='/tmp/cifar10',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # Model
    print('==> Building model..')
    # net = VGG('VGG19')
    # net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()
    # net = RegNetX_200MF()
    net = ResNet50()
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    ckpt = './checkpoint/' + args.optimizer + str(lr) + '_ckpt.pth'

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(ckpt)
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()
    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)
    if args.optimizer.lower() == 'sgdwm':
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'rmsprop':
        optimizer = optim.RMSprop(net.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adagrad':
        optimizer = optim.Adagrad(net.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'radam':
        from radam import RAdam
        optimizer = RAdam(net.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'lars':  #no tensorboardX
        from lars import LARS
        optimizer = LARS(net.parameters(),
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'lamb':
        from lamb import Lamb
        optimizer = Lamb(net.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'novograd':
        from novograd import NovoGrad
        optimizer = NovoGrad(net.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    # lrs = create_lr_scheduler(args.warmup_epochs, args.lr_decay)
    # lr_scheduler = LambdaLR(optimizer,lrs)
    # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay, gamma=0.1)
    train_acc = []
    valid_acc = []

    # Training
    def train(epoch):
        print('\nEpoch: %d' % epoch)
        net.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            print(batch_idx)
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            # lr_scheduler.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        print(100. * correct / total)
        train_acc.append(correct / total)

    def test(epoch):
        global best_acc
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        print('test')
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                print(batch_idx)
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        # Save checkpoint.
        acc = 100. * correct / total
        print(acc)
        valid_acc.append(correct / total)

        if acc > best_acc:
            print('Saving..')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, ckpt)
            best_acc = acc

    for epoch in range(200):
        if epoch in args.lr_decay:
            checkpoint = torch.load(ckpt)
            net.load_state_dict(checkpoint['net'])
            best_acc = checkpoint['acc']
            args.lr *= 0.1
            if args.optimizer.lower() == 'sgd':
                optimizer = optim.SGD(net.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)
            if args.optimizer.lower() == 'sgdwm':
                optimizer = optim.SGD(net.parameters(),
                                      lr=args.lr,
                                      momentum=args.momentum,
                                      weight_decay=args.weight_decay)
            elif args.optimizer.lower() == 'adam':
                optimizer = optim.Adam(net.parameters(),
                                       lr=args.lr,
                                       weight_decay=args.weight_decay)
            elif args.optimizer.lower() == 'rmsprop':
                optimizer = optim.RMSprop(net.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay)
            elif args.optimizer.lower() == 'adagrad':
                optimizer = optim.Adagrad(net.parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
            elif args.optimizer.lower() == 'radam':
                from radam import RAdam

                optimizer = RAdam(net.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
            elif args.optimizer.lower() == 'lars':  # no tensorboardX
                optimizer = LARS(net.parameters(),
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 dampening=args.damping)
            elif args.optimizer.lower() == 'lamb':
                optimizer = Lamb(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
            elif args.optimizer.lower() == 'novograd':
                optimizer = NovoGrad(net.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
            else:
                optimizer = optim.SGD(net.parameters(),
                                      lr=args.lr,
                                      momentum=args.momentum,
                                      weight_decay=args.weight_decay)
        train(epoch)
        test(epoch)
    file = open(args.optimizer + str(lr) + 'log.json', 'w+')
    json.dump([train_acc, valid_acc], file)
    return best_acc
Exemple #8
0
                              lr=args.base_lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'adagrad':
    optimizer = optim.Adagrad(model.parameters(),
                              lr=args.base_lr,
                              weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'radam':
    from radam import RAdam
    optimizer = RAdam(model.parameters(),
                      lr=args.base_lr,
                      weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'lars':  #no tensorboardX
    from lars import LARS
    optimizer = LARS(model.parameters(),
                     lr=args.base_lr,
                     momentum=args.momentum,
                     weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'lamb':
    from lamb import Lamb
    optimizer = Lamb(model.parameters(),
                     lr=args.base_lr,
                     weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'novograd':
    from novograd import NovoGrad
    optimizer = NovoGrad(model.parameters(),
                         lr=args.base_lr,
                         weight_decay=args.weight_decay)
    lr_scheduler = [
        optim.lr_scheduler.CosineAnnealingLR(optimizer, 3 * len(train_loader),
                                             1e-4)
    ]
Exemple #9
0
epochs = 400
stop_at_epoch = 100
batch_size = 64
image_size = (92, 92)

train_loader, mem_loader, test_loader = get_train_mem_test_dataloaders(
    batch_size=batch_size)
train_transform, test_transform = gpu_transformer(image_size)

loss_ls = []
acc_ls = []

model = BYOL().to(device)

optimizer = LARS(model.named_modules(),
                 lr=lr,
                 momentum=momentum,
                 weight_decay=weight_decay)

scheduler = LR_Scheduler(optimizer,
                         warmup_epochs,
                         warmup_lr * batch_size / 8,
                         epochs,
                         lr * batch_size / 8,
                         final_lr * batch_size / 8,
                         len(train_loader),
                         constant_predictor_lr=True)

min_loss = np.inf
accuracy = 0

# start training
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
                      weight_decay=args.weight_decay)
elif args.optimizer.lower()=='adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr,
                      weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'rmsprop':
    optimizer = optim.RMSprop(net.parameters(),lr=args.lr, momentum=args.momentum,
                      weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'adagrad':
    optimizer = optim.Adagrad(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'radam':
    from radam import RAdam
    optimizer = RAdam(net.parameters(),lr=args.lr,weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'lars':#no tensorboardX
    from lars import LARS
    optimizer = LARS(net.parameters(), lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'lamb':
    from lamb import Lamb
    optimizer  = Lamb(net.parameters(),lr=args.lr,weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'novograd':
    from novograd import NovoGrad
    optimizer = NovoGrad(net.parameters(), lr=args.lr,weight_decay=args.weight_decay)
else:
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
                          weight_decay=args.weight_decay)
# lrs = create_lr_scheduler(args.warmup_epochs, args.lr_decay)
# lr_scheduler = LambdaLR(optimizer,lrs)
# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay, gamma=0.1)
train_acc = []
valid_acc = []
Exemple #11
0
if len(sys.argv) == 1:
    optimizer = optim.SGD(model.parameters(), lr=0.01)
elif sys.argv[1] == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
elif sys.argv[1] == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=0.01)
elif sys.argv[1] == 'sgdwm':
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
elif sys.argv[1] == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=0.001, momentum=0.9)
elif sys.argv[1] == 'adagrad':
    optimizer = optim.Adagrad(model.parameters(), lr=0.01)
elif sys.argv[1] == 'radam':
    optimizer = RAdam(model.parameters())
elif sys.argv[1] == 'lars':  #no tensorboardX
    optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
elif sys.argv[1] == 'lamb':
    optimizer = Lamb(model.parameters())
elif sys.argv[1] == 'novograd':
    optimizer = NovoGrad(model.parameters(), lr=0.01, weight_decay=0.001)
    schedular = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     3 * len(train_loader),
                                                     1e-4)

    def train(train_loader, model, criterion, optimizer, schedular, device):
        '''
        Function for the training step of the training loop
        '''

        model.train()
        running_loss = 0
Exemple #12
0
 def lars(self):
     self.optimizer = LARS(self.optimizer)
Exemple #13
0
knn_k = 200

min_loss = np.inf  #ironic
accuracy = 0

train_loader, memory_loader, test_loader = get_train_mem_test_dataloaders(
    dataset="cifar10",
    data_dir="./dataset",
    batch_size=batch_size,
    num_workers=4,
    download=True)

train_transform, test_transform = gpu_transformer(image_size)

optimizer = LARS(model.named_modules(),
                 lr=lr * batch_size / 256,
                 momentum=momentum,
                 weight_decay=weight_decay)

scheduler = LR_Scheduler(optimizer,
                         warmup_epochs,
                         warmup_lr * batch_size / 256,
                         num_epochs,
                         base_lr * batch_size / 256,
                         final_lr * batch_size / 256,
                         len(train_loader),
                         constant_predictor_lr=True)

global_progress = tqdm(range(0, epochs), desc=f'Training')
data_dict = {"loss": 100}
for epoch in global_progress:
    model.train()
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
                      weight_decay=args.weight_decay)
elif args.optimizer.lower()=='adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr,
                      weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'rmsprop':
    optimizer = optim.RMSprop(net.parameters(),lr=args.lr, momentum=args.momentum,
                      weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'adagrad':
    optimizer = optim.Adagrad(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'radam':
    from radam import RAdam
    optimizer = RAdam(net.parameters(),lr=args.lr,weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'lars':#no tensorboardX
    from lars import LARS
    optimizer = LARS(net.parameters(), lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'lamb':
    from lamb import Lamb
    optimizer  = Lamb(net.parameters(),lr=args.lr,weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'novograd':
    from novograd import NovoGrad
    optimizer = NovoGrad(net.parameters(), lr=args.lr,weight_decay=args.weight_decay)
elif args.optimizer.lower() == 'dyna':
    from dyna import Dyna
    optimizer = Dyna(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
                          weight_decay=args.weight_decay)
# lrs = create_lr_scheduler(args.warmup_epochs, args.lr_decay)
# lr_scheduler = LambdaLR(optimizer,lrs)
# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay, gamma=0.1)
Exemple #15
0
def main():
    args = parse_args()

    if args.name is None:
        args.name = '%s_WideResNet%s-%s_%d' % (args.dataset, args.depth,
                                               args.width, args.batch_size)
    if args.linear_scaling:
        args.name += '_wLS'
    if args.lars:
        args.name += '_wLARS'

    if not os.path.exists('models/%s' % args.name):
        os.makedirs('models/%s' % args.name)

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' % (arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' % args.name, 'w') as f:
        for arg in vars(args):
            print('%s: %s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    criterion = nn.CrossEntropyLoss().cuda()

    cudnn.benchmark = True

    # data loading code
    if args.dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_set = datasets.CIFAR10(root='~/data',
                                     train=True,
                                     download=True,
                                     transform=transform_train)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=8)

        test_set = datasets.CIFAR10(root='~/data',
                                    train=False,
                                    download=True,
                                    transform=transform_test)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=128,
                                                  shuffle=False,
                                                  num_workers=8)

        num_classes = 10

    elif args.dataset == 'cifar100':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_set = datasets.CIFAR100(root='~/data',
                                      train=True,
                                      download=True,
                                      transform=transform_train)
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=128,
                                                   shuffle=True,
                                                   num_workers=8)

        test_set = datasets.CIFAR100(root='~/data',
                                     train=False,
                                     download=True,
                                     transform=transform_test)
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=128,
                                                  shuffle=False,
                                                  num_workers=8)

        num_classes = 100

    # create model
    model = WideResNet(args.depth, args.width, num_classes=num_classes)
    model = model.cuda()

    if args.lars:
        optimizer = LARS(filter(lambda p: p.requires_grad, model.parameters()),
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
    else:
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    if args.linear_scaling:
        scheduler = WarmupMultiStepLR(
            optimizer,
            milestones=[int(e) for e in args.milestones.split(',')],
            target_lr=args.lr * args.batch_size / base_batch_size,
            gamma=args.gamma)
    else:
        scheduler = lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(e) for e in args.milestones.split(',')],
            gamma=args.gamma)

    log = pd.DataFrame(
        index=[],
        columns=['epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc'])

    best_acc = 0
    for epoch in range(args.epochs):
        print('Epoch [%d/%d]' % (epoch + 1, args.epochs))

        scheduler.step()

        # train for one epoch
        train_log = train(args, train_loader, model, criterion, optimizer,
                          epoch)
        # evaluate on validation set
        val_log = validate(args, test_loader, model, criterion)

        print('loss %.4f - acc %.4f - val_loss %.4f - val_acc %.4f' %
              (train_log['loss'], train_log['acc'], val_log['loss'],
               val_log['acc']))

        tmp = pd.Series(
            [
                epoch,
                scheduler.get_lr()[0],
                train_log['loss'],
                train_log['acc'],
                val_log['loss'],
                val_log['acc'],
            ],
            index=['epoch', 'lr', 'loss', 'acc', 'val_loss', 'val_acc'])

        log = log.append(tmp, ignore_index=True)
        log.to_csv('models/%s/log.csv' % args.name, index=False)

        if val_log['acc'] > best_acc:
            torch.save(model.state_dict(), 'models/%s/model.pth' % args.name)
            best_acc = val_log['acc']
            print("=> saved best model")
Exemple #16
0
args = parser.parse_args()

if args.optimizer.lower() == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
elif args.optimizer.lower() == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=args.lr)
elif args.optimizer.lower() == 'sgdwm':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
elif args.optimizer.lower() == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr, momentum=0.9)
elif args.optimizer.lower() == 'adagrad':
    optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
elif args.optimizer.lower() == 'radam':
    optimizer = RAdam(model.parameters(), lr=args.lr)
elif args.optimizer.lower() == 'lars':  #no tensorboardX
    optimizer = LARS(model.parameters(), lr=args.lr, momentum=0.9)
elif args.optimizer.lower() == 'lamb':
    optimizer = Lamb(model.parameters(), lr=args.lr)
elif args.optimizer.lower() == 'novograd':
    optimizer = NovoGrad(model.parameters(), lr=args.lr, weight_decay=0.0001)
else:
    optimizer = optim.SGD(model.parameters(), lr=0.01)

optname = args.optimizer if len(sys.argv) >= 2 else 'sgd'

# log = open(optname+'log.txt','w+')

log = None

criterion = nn.CrossEntropyLoss()
Exemple #17
0
def main(is_distributed, rank, ip, sync_bn):
    world_size = 1
    if is_distributed:
        world_size = 2
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=ip,
                                             world_size=world_size,
                                             rank=rank)
    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
    print("Connect")
    # set hyper parameters
    batch_size = 128
    lr = 0.01  # base on batch size 256
    momentum = 0.9
    weight_decay = 0.0001
    epoch = 100

    # recompute lr
    lr = lr * world_size

    # create model
    model = AlexNet(10)

    # synchronization batch normal
    if sync_bn:
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    # define loss function
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer strategy
    optimizer = torch.optim.SGD(model.parameters(),
                                lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    model, optimizer = apex.amp.initialize(model, optimizer, opt_level='O0')
    optimizer = LARS(optimizer)

    if is_distributed:
        # for distribute training
        model = nn.parallel.DistributedDataParallel(model)
        # model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)

    # load train data
    data_path = '~/datasets/cifar10/train'
    train_set = LoadClassifyDataSets(data_path, 227)
    train_sampler = None
    if is_distributed:
        train_sampler = distributed.DistributedSampler(train_set)
    train_loader = DataLoader(train_set,
                              batch_size,
                              shuffle=(train_sampler is None),
                              num_workers=4,
                              pin_memory=True,
                              sampler=train_sampler,
                              collate_fn=collate_fn)

    for epoch in range(100):
        # for distribute
        if is_distributed:
            train_sampler.set_epoch(epoch)

        model.train()
        # train_iter = iter(train_loader)
        # inputs, target = next(train_iter)
        prefetcher = DataPrefetcher(train_loader)
        inputs, target = prefetcher.next()

        step = 0
        print("Epoch is {}".format(epoch))
        while inputs is not None:
            step += 1
            print("Step is {}".format(step))

            time_model_1 = time.time()
            output = model(inputs)
            time_model_2 = time.time()
            print("model time: {}".format(time_model_2 - time_model_1))
            time_loss_1 = time.time()
            loss = criterion(output, target.cuda(async=True))
            time_loss_2 = time.time()
            print("loss time: {}".format(time_loss_2 - time_loss_1))
            optimizer.zero_grad()
            time_back_1 = time.time()
            # loss.backward()
            with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            time_back_2 = time.time()
            print("back time: {}".format(time_back_2 - time_back_1))
            optimizer.step()
            # if step % 10 == 0:
            #     print("loss is : {}", loss.item())
            # inputs, target = next(train_iter, (None, None))
            inputs, target = prefetcher.next()