Example #1
0
def test(args, model, epoch=None, grad=False):
    model.eval()
    if args.dataset == 'mnist':
        _, test_loader = datagen.load_mnist(args)
    elif args.dataset == 'fashion_mnist':
        _, test_loader = datagen.load_fashion_mnist(args)
    test_loss = 0
    correct = 0.
    criterion = nn.CrossEntropyLoss()
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        output = model(data)
        if grad is False:
            test_loss += criterion(output, target).item()  # sum up batch loss
        else:
            test_loss += criterion(output, target)

        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    acc = (correct.float() / len(test_loader.dataset)).item()
    print(acc)

    if epoch:
        print('Average loss: {}, Accuracy: {}/{} ({}%)'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    return acc, test_loss
Example #2
0
def train(args, model, grad=False):
    if args.dataset == 'mnist':
        train_loader, _ = datagen.load_mnist(args)
    elif args.dataset == 'fashion_mnist':
        train_loader, _ = datagen.load_fashion_mnist(args)
    train_loss, train_acc = 0., 0.
    criterion = nn.CrossEntropyLoss()
    if args.ft:
        for child in list(model.children())[:2]:
            print('removing {}'.format(child))
            for param in child.parameters():
                param.requires_grad = False
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=1e-3)
    for epoch in range(args.epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        acc, loss = test(args, model, epoch)
    return acc, loss
Example #3
0
def load_data(args):
    if args.dataset == 'mnist':
        return datagen.load_mnist(args)
    if args.dataset == 'cifar':
        return datagen.load_cifar(args)
    if args.dataset == 'fmnist':
        return datagen.load_fashion_mnist(args)
    if args.dataset == 'cifar_hidden':
        class_list = [0] ## just load class 0
        return datagen.load_cifar_hidden(args, class_list)
    else:
        print ('Dataset not specified correctly')
        print ('choose --dataset <mnist, fmnist, cifar, cifar_hidden>')