Example #1
0
def main():
    parser = argparse.ArgumentParser(description='Adversarial test')
    parser.add_argument('--expdir', type=str, default=None, required=True,
                        help='experiment directory containing model')
    args = parser.parse_args()
    progress = default_progress()
    experiment_dir = args.expdir
    perturbation1 = numpy.load('perturbation/VGG-19.npy')
    perturbation = numpy.load('perturbation/perturb_synth.npy')
    print('Original std %e new std %e' % (numpy.std(perturbation1),
        numpy.std(perturbation)))
    perturbation *= numpy.std(perturbation1) / numpy.std(perturbation)
    # To deaturate uncomment.
    # perturbation = numpy.repeat(perturbation[:,:,1:2], 3, axis=2)

    val_loader = torch.utils.data.DataLoader(
        CachedImageFolder('dataset/miniplaces/simple/val',
            transform=transforms.Compose([
                        transforms.Resize(128),
                        transforms.CenterCrop(112),
                        AddPerturbation(perturbation[48:160,48:160]),
                        transforms.ToTensor(),
                        transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV),
                        ])),
        batch_size=32, shuffle=False,
        num_workers=0, pin_memory=True)
    # Create a simplified ResNet with half resolution.
    model = CustomResNet(18, num_classes=100, halfsize=True)
    checkpoint_filename = 'best_miniplaces.pth.tar'
    best_checkpoint = os.path.join(experiment_dir, checkpoint_filename)
    checkpoint = torch.load(best_checkpoint)
    iter_num = checkpoint['iter']
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    val_loss, val_acc = AverageMeter(), AverageMeter()
    for input, target in progress(val_loader):
        # Load data
        input_var, target_var = [d.cuda() for d in [input, target]]
        # Evaluate model
        with torch.no_grad():
            output = model(input_var)
            loss = criterion(output, target_var)
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)
                    ).data.float().sum().item() / input.size(0)
        val_loss.update(loss.data.item(), input.size(0))
        val_acc.update(accuracy, input.size(0))
        # Check accuracy
        post_progress(l=val_loss.avg, a=val_acc.avg)
    print_progress('Loss %e, validation accuracy %.4f' %
            (val_loss.avg, val_acc.avg))
    with open(os.path.join(experiment_dir, 'adversarial_test.json'), 'w') as f:
        json.dump(dict(
            adversarial_acc=val_acc.avg,
            adversarial_loss=val_loss.avg), f)
Example #2
0
def main():
    progress = default_progress()
    experiment_dir = 'experiment/resnet'
    val_loader = torch.utils.data.DataLoader(
        CachedImageFolder(
            'dataset/miniplaces/simple/val',
            transform=transforms.Compose([
                transforms.Resize(128),
                # transforms.CenterCrop(112),
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV),
            ])),
        batch_size=32,
        shuffle=False,
        num_workers=24,
        pin_memory=True)
    # Create a simplified ResNet with half resolution.
    model = CustomResNet(18, num_classes=100, halfsize=True)
    checkpoint_filename = 'best_miniplaces.pth.tar'
    best_checkpoint = os.path.join(experiment_dir, checkpoint_filename)
    checkpoint = torch.load(best_checkpoint)
    iter_num = checkpoint['iter']
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    val_loss, val_acc = AverageMeter(), AverageMeter()
    for input, target in progress(val_loader):
        # Load data
        input_var, target_var = [d.cuda() for d in [input, target]]
        # Evaluate model
        with torch.no_grad():
            output = model(input_var)
            loss = criterion(output, target_var)
            _, pred = output.max(1)
            accuracy = (
                target_var.eq(pred)).data.float().sum().item() / input.size(0)
        val_loss.update(loss.data.item(), input.size(0))
        val_acc.update(accuracy, input.size(0))
        # Check accuracy
        post_progress(l=val_loss.avg, a=val_acc.avg)
    print_progress('Loss %e, validation accuracy %.4f' %
                   (val_loss.avg, val_acc.avg))
Example #3
0
def main():
    parser = argparse.ArgumentParser(description='Adversarial test')
    parser.add_argument('--expdir',
                        type=str,
                        default='experiment/resnet',
                        help='experiment directory containing model')
    args = parser.parse_args()
    progress = default_progress()
    experiment_dir = args.expdir
    perturbations = [
        numpy.zeros((224, 224, 3), dtype='float32'),
        numpy.load('perturbation/VGG-19.npy'),
        numpy.load('perturbation/perturb_synth.npy'),
        numpy.load('perturbation/perturb_synth_histmatch.npy'),
        numpy.load('perturbation/perturb_synth_newspectrum.npy'),
        numpy.load('perturbation/perturbation_rotated.npy'),
        numpy.load('perturbation/perturbation_rotated_averaged.npy'),
        numpy.load('perturbation/perturbation_noisefree.npy'),
        numpy.load('perturbation/perturbation_noisefree_nodc.npy'),
    ]
    perturbation_name = [
        "unattacked",
        "universal adversary",
        "synthetic",
        "histmatch",
        "newspectrum",
        "rotated",
        "rotated_averaged",
        "noisefree",
        "noisefree_nodc",
    ]
    print('Original std %e new std %e' %
          (numpy.std(perturbations[1]), numpy.std(perturbations[2])))
    perturbations[2] *= (numpy.std(perturbations[1]) /
                         numpy.std(perturbations[2]))
    # To deaturate uncomment.
    loaders = [
        torch.utils.data.DataLoader(
            CachedImageFolder(
                'dataset/miniplaces/simple/val',
                transform=transforms.Compose([
                    transforms.Resize(128),
                    # transforms.CenterCrop(112),
                    AddPerturbation(perturbation[40:168, 40:168]),
                    transforms.ToTensor(),
                    transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV),
                ])),
            batch_size=32,
            shuffle=False,
            num_workers=0,
            pin_memory=True) for perturbation in perturbations
    ]
    # Create a simplified ResNet with half resolution.
    model = CustomResNet(18, num_classes=100, halfsize=True)
    layernames = ['relu', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
    retain_layers(model, layernames)
    checkpoint_filename = 'best_miniplaces.pth.tar'
    best_checkpoint = os.path.join(experiment_dir, checkpoint_filename)
    checkpoint = torch.load(best_checkpoint)
    iter_num = checkpoint['iter']
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    val_acc = [AverageMeter() for _ in loaders]
    diffs, maxdiffs, signflips = [
        [defaultdict(AverageMeter) for _ in perturbations] for _ in [1, 2, 3]
    ]
    for all_batches in progress(zip(*loaders), total=len(loaders[0])):
        # Load data
        indata = [b[0].cuda() for b in all_batches]
        target = all_batches[0][1].cuda()
        # Evaluate model
        retained = defaultdict(list)
        with torch.no_grad():
            for i, inp in enumerate(indata):
                output = model(inp)
                _, pred = output.max(1)
                accuracy = (
                    target.eq(pred)).data.float().sum().item() / inp.size(0)
                val_acc[i].update(accuracy, inp.size(0))
                for layer, data in model.retained.items():
                    retained[layer].append(data)
        for layer, vals in retained.items():
            for i in range(1, len(indata)):
                diffs[i][layer].update(
                    (vals[i] - vals[0]).pow(2).mean().item(), len(target))
                maxdiffs[i][layer].update(
                    (vals[i] - vals[0]).view(len(target),
                                             -1).max(1)[0].mean().item(),
                    len(target))
                signflips[i][layer].update(
                    ((vals[i] > 0).float() -
                     (vals[0] > 0).float()).abs().mean().item(), len(target))
        # Check accuracy
        post_progress(a=val_acc[0].avg)
    # Report on findings
    for i, acc in enumerate(val_acc):
        print_progress('Test #%d (%s), validation accuracy %.4f' %
                       (i, perturbation_name[i], acc.avg))
        if i > 0:
            for layer in layernames:
                print_progress(
                    'Layer %s RMS diff %.3e maxdiff %.3e signflip %.3e' %
                    (layer, math.sqrt(diffs[i][layer].avg),
                     maxdiffs[i][layer].avg, signflips[i][layer].avg))
Example #4
0
def main():
    progress = default_progress()
    experiment_dir = 'experiment/filt4_resnet'
    # Here's our data
    train_loader = torch.utils.data.DataLoader(CachedImageFolder(
        'dataset/miniplaces/simple/train',
        transform=transforms.Compose([
            transforms.Resize(128),
            transforms.RandomCrop(112),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV),
        ])),
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=24,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CachedImageFolder(
            'dataset/miniplaces/simple/val',
            transform=transforms.Compose([
                transforms.Resize(128),
                # transforms.CenterCrop(112),
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV),
            ])),
        batch_size=32,
        shuffle=False,
        num_workers=24,
        pin_memory=True)
    # Create a simplified ResNet with half resolution.
    model = CustomResNet(18,
                         num_classes=100,
                         halfsize=True,
                         extra_output=['maxpool'])  # right after conv1

    model.train()
    model.cuda()

    # An abbreviated training schedule: 40000 batches.
    # TODO: tune these hyperparameters.
    # init_lr = 0.002
    init_lr = 1e-4
    # max_iter = 40000 - 34.5% @1
    # max_iter = 50000 - 37% @1
    # max_iter = 80000 - 39.7% @1
    # max_iter = 100000 - 40.1% @1
    max_iter = 50000
    criterion = FiltDoubleBackpropLoss(1e4)
    optimizer = torch.optim.Adam(model.parameters())
    iter_num = 0
    best = dict(val_accuracy=0.0)
    model.train()
    # Oh, hold on.  Let's actually resume training if we already have a model.
    checkpoint_filename = 'miniplaces.pth.tar'
    best_filename = 'best_%s' % checkpoint_filename
    best_checkpoint = os.path.join(experiment_dir, best_filename)
    try_to_resume_training = False
    if try_to_resume_training and os.path.exists(best_checkpoint):
        checkpoint = torch.load(os.path.join(experiment_dir, best_filename))
        iter_num = checkpoint['iter']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best['val_accuracy'] = checkpoint['accuracy']

    def save_checkpoint(state, is_best):
        filename = os.path.join(experiment_dir, checkpoint_filename)
        ensure_dir_for(filename)
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            os.path.join(experiment_dir, best_filename))

    def validate_and_checkpoint():
        model.eval()
        # val_loss, val_acc = AverageMeter(), AverageMeter()
        val_acc = AverageMeter()
        for input, target in progress(val_loader):
            # Load data
            input_var, target_var = [d.cuda() for d in [input, target]]
            # Evaluate model
            with torch.no_grad():
                output = model(input_var)
                # loss, unreg_loss = criterion(output, target_var)
                _, pred = output[0].max(1)
                accuracy = (target_var.eq(pred)
                            ).data.float().sum().item() / input.size(0)
            # val_loss.update(loss.data.item(), input.size(0))
            val_acc.update(accuracy, input.size(0))
            # Check accuracy
            # post_progress(l=val_loss.avg, a=val_acc.avg*100.0)
            post_progress(a=val_acc.avg * 100.0)
        # Save checkpoint
        save_checkpoint(
            {
                'iter': iter_num,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'accuracy': val_acc.avg,
                # 'loss': val_loss.avg,
            },
            val_acc.avg > best['val_accuracy'])
        best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
        print_progress('Iteration %d val accuracy %.2f' %
                       (iter_num, val_acc.avg * 100.0))

    # Here is our training loop.
    while iter_num < max_iter:
        for input, target in progress(train_loader):
            # Track the average training loss/accuracy for each epoch.
            train_loss, train_acc = AverageMeter(), AverageMeter()
            train_loss_u = AverageMeter()
            train_loss_g = AverageMeter()
            # Load data
            input_var, target_var = [d.cuda() for d in [input, target]]
            # Evaluate model
            output = model(input_var)
            loss, unreg_loss, grad_loss = criterion(output, target_var)
            train_loss.update(loss.data.item(), input.size(0))
            train_loss_u.update(unreg_loss.data.item(), input.size(0))
            train_loss_g.update(grad_loss.data.item(), input.size(0))
            # Perform one step of SGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Also check training set accuracy
            _, pred = output[0].max(1)
            accuracy = (target_var.eq(pred)).data.float().sum().item() / (
                input.size(0))
            train_acc.update(accuracy)
            remaining = 1 - iter_num / float(max_iter)
            post_progress(g=train_loss_g.avg,
                          u=train_loss_u.avg,
                          a=train_acc.avg * 100.0)
            # Advance
            iter_num += 1
            if iter_num >= max_iter:
                break
            # Linear learning rate decay
            lr = init_lr * remaining
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            # Ocassionally check validation set accuracy and checkpoint
            if iter_num % 1000 == 0:
                validate_and_checkpoint()
                model.train()
Example #5
0
def main():
    progress = default_progress()
    experiment_dir = 'experiment/miniplaces'
    # Here's our data
    train_loader = torch.utils.data.DataLoader(CachedImageFolder(
        'dataset/miniplaces/simple/train',
        transform=transforms.Compose([
            transforms.Resize(128),
            transforms.RandomCrop(119),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)
        ])),
                                               batch_size=64,
                                               shuffle=True,
                                               num_workers=6,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CachedImageFolder(
        'dataset/miniplaces/simple/val',
        transform=transforms.Compose([
            transforms.Resize(128),
            transforms.CenterCrop(119),
            transforms.ToTensor(),
            transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)
        ])),
                                             batch_size=512,
                                             shuffle=False,
                                             num_workers=6,
                                             pin_memory=True)
    # Create a simplified AlexNet with half resolution.
    model = AlexNet(first_layer='conv1',
                    last_layer='fc8',
                    layer_sizes=dict(fc6=2048, fc7=2048),
                    output_channels=100,
                    half_resolution=True,
                    include_lrn=False,
                    split_groups=False).cuda()
    # Use Kaiming initialization for the weights
    for name, val in model.named_parameters():
        if 'weight' in name:
            init.kaiming_uniform_(val)
        else:
            # Init positive bias in many layers to avoid dead neurons.
            assert 'bias' in name
            init.constant_(
                val, 0 if any(
                    name.startswith(layer)
                    for layer in ['conv1', 'conv3', 'fc8']) else 1)
    # An abbreviated training schedule: 40000 batches.
    # TODO: tune these hyperparameters.
    # init_lr = 0.002
    init_lr = 0.002
    # max_iter = 40000 - 34.5% @1
    # max_iter = 50000 - 37% @1
    # max_iter = 80000 - 39.7% @1
    # max_iter = 100000 - 40.1% @1
    max_iter = 100000
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=init_lr,
        momentum=0.9,  # 0.9,
        # weight_decay=0.001)
        weight_decay=0.001)
    iter_num = 0
    best = dict(val_accuracy=0.0)
    model.train()
    # Oh, hold on.  Let's actually resume training if we already have a model.
    checkpoint_filename = 'miniplaces.pth.tar'
    best_filename = 'best_%s' % checkpoint_filename
    best_checkpoint = os.path.join(experiment_dir, best_filename)
    try_to_resume_training = False
    if try_to_resume_training and os.path.exists(best_checkpoint):
        checkpoint = torch.load(os.path.join(experiment_dir, best_filename))
        iter_num = checkpoint['iter']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best['val_accuracy'] = checkpoint['accuracy']

    def save_checkpoint(state, is_best):
        filename = os.path.join(experiment_dir, checkpoint_filename)
        ensure_dir_for(filename)
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            os.path.join(experiment_dir, best_filename))

    def validate_and_checkpoint():
        model.eval()
        val_loss, val_acc = AverageMeter(), AverageMeter()
        for input, target in progress(val_loader):
            # Load data
            input_var, target_var = [
                Variable(d.cuda(non_blocking=True)) for d in [input, target]
            ]
            # Evaluate model
            with torch.no_grad():
                output = model(input_var)
                loss = criterion(output, target_var)
                _, pred = output.max(1)
                accuracy = (target_var.eq(pred)
                            ).data.float().sum().item() / input.size(0)
            val_loss.update(loss.data.item(), input.size(0))
            val_acc.update(accuracy, input.size(0))
            # Check accuracy
            post_progress(l=val_loss.avg, a=val_acc.avg)
        # Save checkpoint
        save_checkpoint(
            {
                'iter': iter_num,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'accuracy': val_acc.avg,
                'loss': val_loss.avg,
            }, val_acc.avg > best['val_accuracy'])
        best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
        post_progress(v=val_acc.avg)

    # Here is our training loop.
    while iter_num < max_iter:
        for input, target in progress(train_loader):
            # Track the average training loss/accuracy for each epoch.
            train_loss, train_acc = AverageMeter(), AverageMeter()
            # Load data
            input_var, target_var = [
                Variable(d.cuda(non_blocking=True)) for d in [input, target]
            ]
            # Evaluate model
            output = model(input_var)
            loss = criterion(output, target_var)
            train_loss.update(loss.data.item(), input.size(0))
            # Perform one step of SGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Also check training set accuracy
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)).data.float().sum().item() / (
                input.size(0))
            train_acc.update(accuracy)
            remaining = 1 - iter_num / float(max_iter)
            post_progress(l=train_loss.avg,
                          a=train_acc.avg,
                          v=best['val_accuracy'])
            # Advance
            iter_num += 1
            if iter_num >= max_iter:
                break
            # Linear learning rate decay
            lr = init_lr * remaining
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            # Ocassionally check validation set accuracy and checkpoint
            if iter_num % 1000 == 0:
                validate_and_checkpoint()
                model.train()
Example #6
0
def main():
    progress = default_progress()
    experiment_dir = 'experiment/positive_resnet'
    # Here's our data
    train_loader = torch.utils.data.DataLoader(CachedImageFolder(
        'dataset/miniplaces/simple/train',
        transform=transforms.Compose([
            transforms.Resize(128),
            transforms.RandomCrop(112),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV),
        ])),
                                               batch_size=32,
                                               shuffle=True,
                                               num_workers=24,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CachedImageFolder(
            'dataset/miniplaces/simple/val',
            transform=transforms.Compose([
                transforms.Resize(128),
                # transforms.CenterCrop(112),
                transforms.ToTensor(),
                transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV),
            ])),
        batch_size=32,
        shuffle=False,
        num_workers=24,
        pin_memory=True)
    # Create a simplified ResNet with half resolution.
    model = CustomResNet(18, num_classes=100, halfsize=True)
    checkpoint_filename = 'best_miniplaces.pth.tar'
    best_checkpoint = os.path.join('experiment/resnet', checkpoint_filename)
    checkpoint = torch.load(best_checkpoint)
    model.load_state_dict(checkpoint['state_dict'])
    model.train()
    model.cuda()

    # An abbreviated training schedule: 40000 batches.
    # TODO: tune these hyperparameters.
    # init_lr = 0.002
    init_lr = 1e-4
    # max_iter = 40000 - 34.5% @1
    # max_iter = 50000 - 37% @1
    # max_iter = 80000 - 39.7% @1
    # max_iter = 100000 - 40.1% @1
    max_iter = 50000
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adam(model.parameters())
    iter_num = 0
    best = dict(val_accuracy=0.0)
    model.train()
    # Oh, hold on.  Let's actually resume training if we already have a model.
    checkpoint_filename = 'miniplaces.pth.tar'
    best_filename = 'best_%s' % checkpoint_filename
    best_checkpoint = os.path.join(experiment_dir, best_filename)
    try_to_resume_training = False
    if try_to_resume_training and os.path.exists(best_checkpoint):
        checkpoint = torch.load(os.path.join(experiment_dir, best_filename))
        iter_num = checkpoint['iter']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best['val_accuracy'] = checkpoint['accuracy']

    def save_checkpoint(state, is_best):
        filename = os.path.join(experiment_dir, checkpoint_filename)
        ensure_dir_for(filename)
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename,
                            os.path.join(experiment_dir, best_filename))

    def validate_and_checkpoint():
        model.eval()
        val_loss, val_acc = AverageMeter(), AverageMeter()
        for input, target in progress(val_loader):
            # Load data
            input_var, target_var = [d.cuda() for d in [input, target]]
            # Evaluate model
            with torch.no_grad():
                output = model(input_var)
                loss = criterion(output, target_var)
                _, pred = output.max(1)
                accuracy = (target_var.eq(pred)
                            ).data.float().sum().item() / input.size(0)
            val_loss.update(loss.data.item(), input.size(0))
            val_acc.update(accuracy, input.size(0))
            # Check accuracy
            post_progress(l=val_loss.avg, a=val_acc.avg)
        # Save checkpoint
        save_checkpoint(
            {
                'iter': iter_num,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'accuracy': val_acc.avg,
                'loss': val_loss.avg,
            }, val_acc.avg > best['val_accuracy'])
        best['val_accuracy'] = max(val_acc.avg, best['val_accuracy'])
        print_progress('Iteration %d val accuracy %.2f' %
                       (iter_num, val_acc.avg * 100.0))

    # Here is our training loop.
    while iter_num < max_iter:
        for input, target in progress(train_loader):
            if iter_num % 1000 == 0:
                # Every 1000 turns chop down the negative params
                neg_means = []
                pos_means = []
                neg_count = 0
                param_count = 0
                with torch.no_grad():
                    for name, param in model.named_parameters():
                        if all(n in name
                               for n in ['layer4', 'conv', 'weight']):
                            pc = param.numel()
                            neg = (param < 0)
                            nc = neg.int().sum().item()
                            param_count += pc
                            neg_count += nc
                            if nc > 0:
                                neg_means.append(param[neg].mean().item())
                            if nc < pc:
                                pos_means.append(param[~neg].mean().item())
                            param[neg] *= 0.5
                    print_progress(
                        '%d/%d neg, mean %e vs %e pos' %
                        (neg_count, param_count, sum(neg_means) /
                         len(neg_means), sum(pos_means) / len(pos_means)))
            # Track the average training loss/accuracy for each epoch.
            train_loss, train_acc = AverageMeter(), AverageMeter()
            # Load data
            input_var, target_var = [d.cuda() for d in [input, target]]
            # Evaluate model
            output = model(input_var)
            loss = criterion(output, target_var)
            train_loss.update(loss.data.item(), input.size(0))
            # Perform one step of SGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Also check training set accuracy
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)).data.float().sum().item() / (
                input.size(0))
            train_acc.update(accuracy)
            remaining = 1 - iter_num / float(max_iter)
            post_progress(l=train_loss.avg,
                          a=train_acc.avg,
                          v=best['val_accuracy'])
            # Advance
            iter_num += 1
            if iter_num >= max_iter:
                break
            # Linear learning rate decay
            lr = init_lr * remaining
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            # Ocassionally check validation set accuracy and checkpoint
            if iter_num % 1000 == 0:
                validate_and_checkpoint()
                model.train()