def main():

    # Hyperparameters

    device = 'cpu'

    # Dataset

    transform = transforms.ToTensor()
    test_dataset = CIFAR10(root='./examples/data/',
                           train=False,
                           transform=transform)

    # Model

    model = BCNN(3, 10).to(device)
    model.load_state_dict(
        torch.load('./examples/CIFAR10/cifar10_pretrained.pth',
                   map_location=device))
    model.eval()

    # Metrics

    entropy_function = Entropy(dim=-1)

    # Uncertainty measurements

    n_images = 5
    fig, axes = plt.subplots(n_images, 2)
    random_indices = torch.randint(0, len(test_dataset), (n_images, ))

    for axs, index in zip(axes, random_indices):
        image, _ = test_dataset[index]
        image = image.to(device).unsqueeze(0)
        transformation = transforms.Compose(
            [transforms.ToPILImage(),
             transforms.ColorJitter(0, .1, .1, .2)])
        transformed = transformation(image[0])
        transformed = transforms.ToTensor()(transformed).unsqueeze(0)

        for ax, name, x in zip(axs, ['original', 'transformed'],
                               [image, transformed]):
            preds = model(x)

            agg_preds = torch.stack(preds, dim=0).mean(dim=0)
            logits = agg_preds.argmax(dim=-1).item()
            entropy = entropy_function(agg_preds).item()

            ax.axis('off')
            ax.imshow(x.cpu().numpy()[0, ...].transpose(1, 2, 0))
            ax.set_title(f'{name}, E: {entropy:.3f}, C: {logits}')

    plt.show()
Ejemplo n.º 2
0
def main():

    # Hyperparameters

    batch_size = 1024
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Dataset

    transform = transforms.ToTensor()
    test_dataset = CIFAR10(root='./examples/data/',
                           train=False,
                           transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    # Model

    model = BCNN(3, 10).to(device)
    model.load_state_dict(
        torch.load('./examples/CIFAR10/cifar10_pretrained.pth',
                   map_location=device))
    model.eval()

    # Pruning

    pruner = PruneNormal()

    for drop_percentage in torch.linspace(0, .3, 7):
        pruner(model, drop_percentage)

        count = 0
        correct = 0
        with torch.no_grad():
            for test_x, test_y in test_loader:
                test_x, test_y = test_x.to(device), test_y.to(device)

                test_preds = model(test_x)
                test_logits = torch.stack(test_preds,
                                          dim=-1).mean(dim=-1).argmax(dim=-1)

                count += len(test_y)
                correct += (test_logits == test_y).to(torch.float).sum()

        test_accuracy = correct / count

        print(f'dropped {100 * drop_percentage:.2f}% of weights',
              f'accuracy: {100 * test_accuracy:.2f}%',
              sep=', ')

        print_unique_percentage(model)
        print()
Ejemplo n.º 3
0
def main():

    # Hyperparameters

    epochs = 10
    batch_size = 1024
    learning_rate = 5e-4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Dataset

    transform = transforms.ToTensor()
    train_dataset = FashionMNIST(root='./examples/data/',
                                 train=True,
                                 transform=transform)
    test_dataset = FashionMNIST(root='./examples/data/',
                                train=False,
                                transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    # Model

    model = BCNN(1, 10).to(device)
    summary(model, (1, 28, 28), device=device)

    # Loss, Metrics and Optimizer

    kld_function = KLDivergence(number_of_batches=len(train_loader))
    loss_function = torch.nn.CrossEntropyLoss()
    entropy_function = Entropy(dim=-1)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training

    epochs_logger = tqdm(range(1, epochs + 1), desc='epoch')
    for epoch in epochs_logger:
        running_acc = []
        running_ent = []
        running_loss = []
        steps_logger = tqdm(train_loader, total=len(train_loader), desc='step')
        for x, y in steps_logger:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            preds = model(x)
            divergence = kld_function(model)
            likelihood = torch.stack(
                [loss_function(pred, y) for pred in preds]).mean()

            loss = likelihood + divergence
            loss.backward()
            optimizer.step()

            agg_preds = torch.stack(preds, dim=0).mean(dim=0)
            logits = agg_preds.argmax(dim=-1)
            accuracy = (logits == y).float().mean()
            entropy = entropy_function(agg_preds)

            running_acc.append(accuracy.item())
            running_ent.append(entropy.item())
            running_loss.append(loss.item())

            log_str = f'L: {mean(running_loss):.4f}, A: {mean(running_acc):.4f}, E: {mean(running_ent):.4f}'
            steps_logger.set_postfix_str(log_str)

        count = 0
        correct = 0
        with torch.no_grad():
            for test_x, test_y in test_loader:
                test_x, test_y = test_x.to(device), test_y.to(device)

                test_preds = model(test_x)
                test_logits = torch.stack(test_preds,
                                          dim=-1).mean(dim=-1).argmax(dim=-1)

                count += len(test_y)
                correct += (test_logits == test_y).to(torch.float).sum()

        test_accuracy = correct / count

        log_str += f', TA: {test_accuracy:.4f}'
        epochs_logger.set_postfix_str(log_str)

    torch.save(model.state_dict(),
               './examples/FashionMNIST/fmnist_pretrained.pth')