Exemple #1
0
def generate_not_mnist():
    """
        Runs the generative MNIST experiment from the VCL paper, in which each task is
        a generative task for one of the digits in the MNIST dataset.
    """
    z_dim = 50
    h_dim = 500
    layer_width = 500
    n_tasks = 10
    multiheaded = True
    coreset_size = 40
    epochs = 400
    batch_size = 50

    transform = Compose([Flatten(), Scale()])

    # download dataset
    not_mnist_train = NOTMNIST(train=True,
                               overwrite=False,
                               transform=transform)
    not_mnist_test = NOTMNIST(train=False,
                              overwrite=False,
                              transform=transform)

    model = GenerativeVCL(z_dim=z_dim,
                          h_dim=h_dim,
                          x_dim=MNIST_FLATTENED_DIM,
                          n_heads=n_tasks,
                          encoder_h_dims=(layer_width, layer_width,
                                          layer_width),
                          decoder_head_h_dims=(layer_width, ),
                          decoder_shared_h_dims=(layer_width, ),
                          initial_posterior_variance=INITIAL_POSTERIOR_VAR,
                          mc_sampling_n=10,
                          device=device).to(device)
    evaluation_classifier = load_model(NOTMNIST_CLASSIFIER_FILENAME)
    # we are using ResNet, so need to call eval()
    evaluation_classifier.eval()

    # optimizer = Adam(model.parameters(), lr=LR)
    coreset = RandomCoreset(size=coreset_size)

    # each label is its own task, so no need to define a dictionary like in the discriminative experiments
    if isinstance(not_mnist_train[0][1], int):
        train_task_ids = torch.Tensor([y for _, y in not_mnist_train])
        test_task_ids = torch.Tensor([y for _, y in not_mnist_test])
    elif isinstance(not_mnist_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor([y.item() for _, y in not_mnist_train])
        test_task_ids = torch.Tensor([y.item() for _, y in not_mnist_test])

    summary_logdir = os.path.join("logs", "gen_n_mnist",
                                  datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    for task_idx in range(n_tasks):
        run_generative_task(model=model,
                            train_data=not_mnist_train,
                            train_task_ids=train_task_ids,
                            test_data=not_mnist_test,
                            test_task_ids=test_task_ids,
                            coreset=coreset,
                            task_idx=task_idx,
                            epochs=epochs,
                            batch_size=batch_size,
                            lr=LR,
                            save_as="gen_n_mnist",
                            device=device,
                            evaluation_classifier=evaluation_classifier,
                            multiheaded=multiheaded,
                            summary_writer=writer)

    writer.close()
def permuted_mnist():
    """
    Runs the 'Permuted MNIST' experiment from the VCL paper, in which each task
    is obtained by applying a fixed random permutation to the pixels of each image.
    """
    N_CLASSES = 10
    LAYER_WIDTH = 100
    N_HIDDEN_LAYERS = 2
    N_TASKS = 10
    MULTIHEADED = False
    CORESET_SIZE = 200
    EPOCHS = 100
    BATCH_SIZE = 256
    TRAIN_FULL_CORESET = False

    # permutation used for each task
    # transforms = [Compose([Scale(), Permute(torch.randperm(MNIST_FLATTENED_DIM))]) for _ in range(N_TASKS)]
    rng_permute = np.random.RandomState(92916)
    idx_permute = torch.from_numpy(rng_permute.permutation(784))
    permute = Compose([
        ToTensor(),
        Normalize((0.1307, ), (0.3081, )),
        Lambda(lambda x: x.view(-1)[idx_permute].view(1, 28, 28))
    ])
    transforms = [permute for _ in range(N_TASKS)]

    # create model, single-headed in permuted MNIST experiment
    model = Variationalize(
        MultiHeadCNN(n_heads=(N_TASKS if MULTIHEADED else 1),
                     split_mnist=False).to(device))
    coreset = RandomCoreset(size=CORESET_SIZE)

    mnist_train = ConcatDataset([
        MNIST(root="data", train=True, download=True, transform=t)
        for t in transforms
    ])
    task_size = len(mnist_train) // N_TASKS
    train_task_ids = torch.cat([
        torch.full((task_size, ), id, dtype=torch.long)
        for id in range(N_TASKS)
    ])

    mnist_test = ConcatDataset([
        MNIST(root="data", train=False, download=True, transform=t)
        for t in transforms
    ])
    task_size = len(mnist_test) // N_TASKS
    test_task_ids = torch.cat([
        torch.full((task_size, ), id, dtype=torch.long)
        for id in range(N_TASKS)
    ])

    summary_logdir = os.path.join("logs", "disc_p_mnist",
                                  datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)
    # run_point_estimate_initialisation(model=model, data=mnist_train,
    #                                   epochs=EPOCHS, batch_size=BATCH_SIZE,
    #                                   device=device, lr=LR,
    #                                   multiheaded=MULTIHEADED,
    #                                   task_ids=train_task_ids)

    # each task is classification of MNIST images with permuted pixels
    for task in range(N_TASKS):
        run_task(model=model,
                 train_data=mnist_train,
                 train_task_ids=train_task_ids,
                 test_data=mnist_test,
                 test_task_ids=test_task_ids,
                 task_idx=task,
                 coreset=coreset,
                 epochs=EPOCHS,
                 batch_size=BATCH_SIZE,
                 device=device,
                 lr=LR,
                 save_as="disc_p_mnist",
                 multiheaded=MULTIHEADED,
                 train_full_coreset=TRAIN_FULL_CORESET,
                 summary_writer=writer)

    writer.close()
def split_mnist():
    """
    Runs the 'Split MNIST' experiment from the VCL paper, in which each task is
    a binary classification task carried out on a subset of the MNIST dataset.
    """
    N_CLASSES = 2  # TODO does it make sense to do binary classification with out_size=2 ?
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 40
    EPOCHS = 1
    BATCH_SIZE = 50000
    TRAIN_FULL_CORESET = True

    transform = Compose([Scale()])

    # download dataset
    mnist_train = MNIST(root="data",
                        train=True,
                        download=True,
                        transform=transform)
    mnist_test = MNIST(root="data",
                       train=False,
                       download=True,
                       transform=transform)

    model = Variationalize(
        MultiHeadCNN(n_heads=(N_TASKS if MULTIHEADED else 1),
                     split_mnist=True).to(device))

    coreset = RandomCoreset(size=CORESET_SIZE)

    label_to_task_mapping = {
        0: 0,
        1: 0,
        2: 1,
        3: 1,
        4: 2,
        5: 2,
        6: 3,
        7: 3,
        8: 4,
        9: 4,
    }

    if isinstance(mnist_train[0][1], int):
        train_task_ids = torch.Tensor(
            [label_to_task_mapping[y] for _, y in mnist_train])
        test_task_ids = torch.Tensor(
            [label_to_task_mapping[y] for _, y in mnist_test])
    elif isinstance(mnist_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor(
            [label_to_task_mapping[y.item()] for _, y in mnist_train])
        test_task_ids = torch.Tensor(
            [label_to_task_mapping[y.item()] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist",
                                  datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    # each task is a binary classification task for a different pair of digits
    binarize_y = lambda y, task: (y == (2 * task + 1)).long()

    # run_point_estimate_initialisation(model=model, data=mnist_train,
    #                                   epochs=EPOCHS, batch_size=BATCH_SIZE,
    #                                   device=device, multiheaded=MULTIHEADED,
    #                                   lr=LR, task_ids=train_task_ids,
    #                                   y_transform=binarize_y)

    for task_idx in range(N_TASKS):
        run_task(model=model,
                 train_data=mnist_train,
                 train_task_ids=train_task_ids,
                 test_data=mnist_test,
                 test_task_ids=test_task_ids,
                 coreset=coreset,
                 task_idx=task_idx,
                 epochs=EPOCHS,
                 batch_size=BATCH_SIZE,
                 lr=LR,
                 save_as="disc_s_mnist",
                 device=device,
                 multiheaded=MULTIHEADED,
                 y_transform=binarize_y,
                 train_full_coreset=TRAIN_FULL_CORESET,
                 summary_writer=writer)

    writer.close()
def split_not_mnist():
    """
    Runs the 'Split not MNIST' experiment from the VCL paper, in which each task
    is a binary classification task carried out on a subset of the not MNIST
    character recognition dataset.
    """
    N_CLASSES = 2  # TODO does it make sense to do binary classification with out_size=2 ?
    LAYER_WIDTH = 150
    N_HIDDEN_LAYERS = 4
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 40
    EPOCHS = 120
    BATCH_SIZE = 400000
    TRAIN_FULL_CORESET = True

    transform = Compose([Scale()])

    not_mnist_train = NOTMNIST(train=True,
                               overwrite=False,
                               transform=transform)
    not_mnist_test = NOTMNIST(train=False,
                              overwrite=False,
                              transform=transform)

    model = Variationalize(
        MultiHeadCNN(n_heads=(N_TASKS if MULTIHEADED else 1),
                     split_mnist=True).to(device))

    optimizer = optim.Adam(model.parameters(), lr=LR)
    coreset = RandomCoreset(size=CORESET_SIZE)

    # The y classes are integers 0-9.
    label_to_task_mapping = {
        0: 0,
        1: 1,
        2: 2,
        3: 3,
        4: 4,
        5: 0,
        6: 1,
        7: 2,
        8: 3,
        9: 4,
    }

    train_task_ids = torch.from_numpy(
        np.array([label_to_task_mapping[y] for _, y in not_mnist_train]))
    test_task_ids = torch.from_numpy(
        np.array([label_to_task_mapping[y] for _, y in not_mnist_test]))

    summary_logdir = os.path.join("logs", "disc_s_n_mnist",
                                  datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    # each task is a binary classification task for a different pair of digits
    # binarize_y(c, n) is 1 when c is is the nth digit - A for task 0, B for task 1
    binarize_y = lambda y, task: (y == task).long()

    # run_point_estimate_initialisation(model=model, data=not_mnist_train,
    #                                   epochs=EPOCHS, batch_size=BATCH_SIZE,
    #                                   device=device, multiheaded=MULTIHEADED,
    #                                   task_ids=train_task_ids, lr=LR,
    #                                   y_transform=binarize_y)

    for task_idx in range(N_TASKS):
        run_task(model=model,
                 train_data=not_mnist_train,
                 train_task_ids=train_task_ids,
                 test_data=not_mnist_test,
                 test_task_ids=test_task_ids,
                 coreset=coreset,
                 task_idx=task_idx,
                 epochs=EPOCHS,
                 lr=LR,
                 batch_size=BATCH_SIZE,
                 save_as="disc_s_n_mnist",
                 device=device,
                 multiheaded=MULTIHEADED,
                 y_transform=binarize_y,
                 train_full_coreset=TRAIN_FULL_CORESET,
                 summary_writer=writer)

    writer.close()
def permuted_mnist():
    """
    Runs the 'Permuted MNIST' experiment from the VCL paper, in which each task
    is obtained by applying a fixed random permutation to the pixels of each image.
    """
    N_CLASSES = 10
    LAYER_WIDTH = 100
    N_HIDDEN_LAYERS = 2
    N_TASKS = 10
    MULTIHEADED = False
    CORESET_SIZE = 200
    EPOCHS = 100
    BATCH_SIZE = 256
    TRAIN_FULL_CORESET = True

    # flattening and permutation used for each task
    transforms = [
        Compose(
            [Flatten(),
             Scale(),
             Permute(torch.randperm(MNIST_FLATTENED_DIM))])
        for _ in range(N_TASKS)
    ]

    # create model, single-headed in permuted MNIST experiment
    model = DiscriminativeVCL(
        in_size=MNIST_FLATTENED_DIM,
        out_size=N_CLASSES,
        layer_width=LAYER_WIDTH,
        n_hidden_layers=N_HIDDEN_LAYERS,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_var=INITIAL_POSTERIOR_VAR).to(device)
    coreset = RandomCoreset(size=CORESET_SIZE)

    mnist_train = ConcatDataset([
        MNIST(root="data", train=True, download=True, transform=t)
        for t in transforms
    ])
    task_size = len(mnist_train) // N_TASKS
    train_task_ids = torch.cat(
        [torch.full((task_size, ), id) for id in range(N_TASKS)])

    mnist_test = ConcatDataset([
        MNIST(root="data", train=False, download=True, transform=t)
        for t in transforms
    ])
    task_size = len(mnist_test) // N_TASKS
    test_task_ids = torch.cat(
        [torch.full((task_size, ), id) for id in range(N_TASKS)])

    summary_logdir = os.path.join("logs", "disc_p_mnist",
                                  datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)
    run_point_estimate_initialisation(model=model,
                                      data=mnist_train,
                                      epochs=EPOCHS,
                                      batch_size=BATCH_SIZE,
                                      device=device,
                                      lr=LR,
                                      multiheaded=MULTIHEADED,
                                      task_ids=train_task_ids)

    # each task is classification of MNIST images with permuted pixels
    for task in range(N_TASKS):
        run_task(model=model,
                 train_data=mnist_train,
                 train_task_ids=train_task_ids,
                 test_data=mnist_test,
                 test_task_ids=test_task_ids,
                 task_idx=task,
                 coreset=coreset,
                 epochs=EPOCHS,
                 batch_size=BATCH_SIZE,
                 device=device,
                 lr=LR,
                 save_as="disc_p_mnist",
                 multiheaded=MULTIHEADED,
                 train_full_coreset=TRAIN_FULL_CORESET,
                 summary_writer=writer)

    writer.close()