Exemplo n.º 1
0
def full_training(args):
    if not os.path.isdir(args.expdir):
        os.makedirs(args.expdir)
    elif os.path.exists(args.expdir + '/results.npy'):
        return

    if 'ae' in args.task:
        os.mkdir(args.expdir + '/figs/')

    train_batch_size = args.train_batch_size // 4 if args.task == 'rot' else args.train_batch_size
    test_batch_size = args.test_batch_size // 4 if args.task == 'rot' else args.test_batch_size
    yield_indices = (args.task == 'inst_disc')
    datadir = args.datadir + args.dataset
    trainloader, valloader, num_classes = general_dataset_loader.prepare_data_loaders(
        datadir,
        image_dim=args.image_dim,
        yield_indices=yield_indices,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
        train_on_10_percent=args.train_on_10,
        train_on_half_classes=args.train_on_half)
    _, testloader, _ = general_dataset_loader.prepare_data_loaders(
        datadir,
        image_dim=args.image_dim,
        yield_indices=yield_indices,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
    )

    args.num_classes = num_classes
    if args.task == 'rot':
        num_classes = 4
    elif args.task == 'inst_disc':
        num_classes = args.low_dim

    if args.task == 'ae':
        net = models.AE([args.code_dim], image_dim=args.image_dim)
    elif args.task == 'jigsaw':
        net = JigsawModel(num_perms=args.num_perms,
                          code_dim=args.code_dim,
                          gray_prob=args.gray_prob,
                          image_dim=args.image_dim)
    else:
        net = models.resnet26(num_classes,
                              mlp_depth=args.mlp_depth,
                              normalize=(args.task == 'inst_disc'))
    if args.task == 'inst_disc':
        train_lemniscate = LinearAverage(args.low_dim,
                                         trainloader.dataset.__len__(),
                                         args.nce_t, args.nce_m)
        train_lemniscate.cuda()
        args.train_lemniscate = train_lemniscate
        test_lemniscate = LinearAverage(args.low_dim,
                                        valloader.dataset.__len__(),
                                        args.nce_t, args.nce_m)
        test_lemniscate.cuda()
        args.test_lemniscate = test_lemniscate
    if args.source:
        try:
            old_net = torch.load(args.source)
        except:
            print("Falling back encoding")
            from functools import partial
            import pickle
            pickle.load = partial(pickle.load, encoding="latin1")
            pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
            old_net = torch.load(args.source,
                                 map_location=lambda storage, loc: storage,
                                 pickle_module=pickle)

        # net.load_state_dict(old_net['net'].state_dict())
        old_net = old_net['net']
        if hasattr(old_net, "module"):
            old_net = old_net.module
        old_state_dict = old_net.state_dict()
        new_state_dict = net.state_dict()
        for key, weight in old_state_dict.items():
            if 'linear' not in key:
                new_state_dict[key] = weight
            elif key == 'linears.0.weight' and weight.shape[0] == num_classes:
                new_state_dict['linears.0.0.weight'] = weight
            elif key == 'linears.0.bias' and weight.shape[0] == num_classes:
                new_state_dict['linears.0.0.bias'] = weight
        net.load_state_dict(new_state_dict)

        del old_net
    net = torch.nn.DataParallel(net).cuda()
    start_epoch = 0
    if args.task in ['ae', 'inst_disc']:
        best_acc = np.inf
    else:
        best_acc = -1
    results = np.zeros((4, start_epoch + args.nb_epochs))

    net.cuda()
    cudnn.benchmark = True

    if args.task in ['ae']:
        args.criterion = nn.MSELoss()
    else:
        args.criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       net.parameters()),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.wd)

    print("Start training")
    train_func = eval('utils_pytorch.train_' + args.task)
    test_func = eval('utils_pytorch.test_' + args.task)
    if args.test_first:
        with torch.no_grad():
            test_func(0, valloader, net, best_acc, args, optimizer)
    for epoch in range(start_epoch, start_epoch + args.nb_epochs):
        utils_pytorch.adjust_learning_rate(optimizer, epoch, args)
        st_time = time.time()

        # Training and validation
        train_acc, train_loss = train_func(epoch, trainloader, net, args,
                                           optimizer)
        test_acc, test_loss, best_acc = test_func(epoch, valloader, net,
                                                  best_acc, args, optimizer)

        # Record statistics
        results[0:2, epoch] = [train_loss, train_acc]
        results[2:4, epoch] = [test_loss, test_acc]
        np.save(args.expdir + '/results.npy', results)
        print('Epoch lasted {0}'.format(time.time() - st_time))
        sys.stdout.flush()
        if (args.task == 'rot') and (train_acc >= 98) and args.early_stopping:
            break
    if args.task == 'inst_disc':
        args.train_lemniscate = None
        args.test_lemniscate = None
    else:
        best_net = torch.load(args.expdir + 'checkpoint.t7')['net']
        if args.task in ['ae', 'inst_disc']:
            best_acc = np.inf
        else:
            best_acc = -1
        final_acc, final_loss, _ = test_func(0, testloader, best_net, best_acc,
                                             args, None)
Exemplo n.º 2
0
    lemniscate = checkpoint['lemniscate']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
else:
    print('==> Building model..')
    net = models.__dict__['ResNet50'](low_dim=args.low_dim)
    # define leminiscate
    lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum)

# define loss function
criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.targets))

if use_cuda:
    net.cuda()
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    lemniscate.cuda()
    criterion.cuda()
    cudnn.benchmark = True

if args.test_only:
    acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature)
    sys.exit(0)

optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 50))
    print(lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr