Esempio n. 1
0
def train_mnist_classifier():
    """
    Train a non-VCL classifier for MNIST to be used to compute the 'classifier uncertainty'
    evaluation metric in the generative tasks.
    """
    # image transforms and model
    model = MnistResNet().to(device)
    transforms = Compose([Resize(size=(224, 224)), ToTensor(), Scale()])
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adadelta(model.parameters())

    # download dataset
    mnist_train = MNIST(root="data",
                        train=True,
                        download=True,
                        transform=transforms)
    mnist_test = MNIST(root="data",
                       train=False,
                       download=True,
                       transform=transforms)
    train_loader = DataLoader(mnist_train,
                              batch_size=CLASSIFIER_BATCH_SIZE,
                              shuffle=True)
    test_loader = DataLoader(mnist_test,
                             batch_size=CLASSIFIER_BATCH_SIZE,
                             shuffle=True)

    # train
    model.train()
    for epoch in tqdm(range(CLASSIFIER_EPOCHS), 'Epochs'):
        epoch_loss = 0
        for batch in tqdm(train_loader):
            optimizer.zero_grad()
            x, y = batch[0].to(device), batch[1].to(device)

            predictions = model(x)
            loss = loss_fn(predictions, y)
            epoch_loss += len(x) * loss.item()

            loss.backward()
            optimizer.step()

    # evaluate
    model.eval()
    accuracies = []
    for batch in test_loader:
        x, y = batch[0].to(device), batch[1].to(device)

        predictions = torch.argmax(model(x), dim=1)
        accuracies.append(class_accuracy(predictions, y))

    accuracy = sum(accuracies) / len(accuracies)

    print('Classifier accuracy: ' + str(accuracy))
    save_model(model, MNIST_CLASSIFIER_FILENAME)
def run_task(model,
             train_data,
             train_task_ids,
             test_data,
             test_task_ids,
             task_idx,
             coreset,
             epochs,
             batch_size,
             save_as,
             device,
             lr,
             y_transform=None,
             multiheaded=True,
             train_full_coreset=True,
             summary_writer=None):
    """
        Trains a VCL model using online variational inference on a task, and performs a coreset
        training run as well as an evaluation after training.

        :param model: the VCL model to train
        :param train_data: the complete dataset to train on, such as MNIST
        :param train_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to
        :param test_data: the complete dataset to train on, such as MNIST
        :param test_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to
        :param task_idx: task being learned, maps to a specific head in the network
        :param coreset: coreset object to use in training
        :param epochs: number of training epochs
        :param batch_size: batch size used in training
        :param save_as: base directory to save into
        :param device: device to run the experiment on, either 'cpu' or 'cuda'
        :param lr: optimizer learning rate to use
        :param y_transform: transform to be applied to the dataset labels
        :param multiheaded: true if the network being trained is multi-headed
        :param summary_writer: tensorboard_x summary writer
        """

    print('TASK ', task_idx)

    # separate optimizer for each task
    optimizer = optim.Adam(model.parameters(), lr=lr)

    head = task_idx if multiheaded else 0

    # obtain correct subset of data for training, and set some aside for the coreset
    task_data = task_subset(train_data, train_task_ids, task_idx)
    non_coreset_data = coreset.select(task_data, task_id=task_idx)
    train_loader = DataLoader(non_coreset_data, batch_size)

    # train
    for epoch in tqdm(range(epochs), 'Epochs: '):
        epoch_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            x, y_true = batch
            x = x.to(device)
            y_true = y_true.to(device)

            if y_transform is not None:
                y_true = y_transform(y_true, task_idx)

            loss = model.vcl_loss(x, y_true, head, len(task_data))
            epoch_loss += len(x) * loss.item()

            loss.backward()
            optimizer.step()

        if summary_writer is not None:
            summary_writer.add_scalars(
                "loss", {"TASK_" + str(task_idx): epoch_loss / len(task_data)},
                epoch)

    # after training, prepare for new task by copying posteriors into priors
    model.reset_for_new_task(head)

    # train using full coreset
    if train_full_coreset:
        model_cs_trained = coreset.coreset_train(model,
                                                 optimizer,
                                                 list(range(task_idx + 1)),
                                                 epochs,
                                                 device,
                                                 y_transform=y_transform,
                                                 multiheaded=multiheaded)

    # test
    task_accuracies = []
    tot_right = 0
    tot_tested = 0

    for test_task_idx in range(task_idx + 1):
        if not train_full_coreset:
            model_cs_trained = coreset.coreset_train(model,
                                                     optimizer,
                                                     test_task_idx,
                                                     epochs,
                                                     device,
                                                     y_transform=y_transform,
                                                     multiheaded=multiheaded)

        head = test_task_idx if multiheaded else 0

        task_data = task_subset(test_data, test_task_ids, test_task_idx)

        x = torch.Tensor([x for x, _ in task_data])
        y_true = torch.Tensor([y for _, y in task_data])
        x = x.to(device)
        y_true = y_true.to(device)

        if y_transform is not None:
            y_true = y_transform(y_true, test_task_idx)

        y_pred = model_cs_trained.prediction(x, head)

        acc = class_accuracy(y_pred, y_true)
        print("After task {} perfomance on task {} is {}".format(
            task_idx, test_task_idx, acc))

        tot_right += acc * len(task_data)
        tot_tested += len(task_data)
        task_accuracies.append(acc)

    mean_accuracy = tot_right / tot_tested
    print("Mean accuracy:", mean_accuracy)

    if summary_writer is not None:
        task_accuracies_dict = dict(
            zip(["TASK_" + str(i) for i in range(task_idx + 1)],
                task_accuracies))
        summary_writer.add_scalars("test_accuracy", task_accuracies_dict,
                                   task_idx + 1)
        summary_writer.add_scalar("mean_posterior_variance",
                                  model._mean_posterior_variance(),
                                  task_idx + 1)
        summary_writer.add_scalar("mean_accuracy", mean_accuracy, task_idx + 1)

    write_as_json(save_as + '/accuracy.txt', task_accuracies)
    save_model(model, save_as + '_model_task_' + str(task_idx) + '.pth')
def run_generative_task(model,
                        train_data,
                        train_task_ids,
                        test_data,
                        test_task_ids,
                        task_idx,
                        coreset,
                        epochs,
                        batch_size,
                        save_as,
                        device,
                        lr,
                        evaluation_classifier,
                        multiheaded=True,
                        optimizer=None,
                        summary_writer=None):
    """
        Trains a VCL model using online variational inference on a task, and performs a coreset
        training run as well as an evaluation after training.

        :param model: the VCL model to train
        :param train_data: the complete dataset to train on, such as MNIST
        :param train_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to
        :param test_data: the complete dataset to train on, such as MNIST
        :param test_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to
        :param task_idx: task being learned, maps to a specific head in the network
        :param coreset: coreset object to use in training
        :param epochs: number of training epochs
        :param batch_size: batch size used in training
        :param save_as: base directory to save into
        :param device: device to run the experiment on, either 'cpu' or 'cuda'
        :param lr: optimizer learning rate to use
        :param evaluation_classifier: classifier used for the 'classifier uncertainty' test metric
        :param optimizer: optionally, provide an existing optimizer instead of having the method create a new one
        :param multiheaded: true if the network being trained is multi-headed
        :param summary_writer: tensorboard_x summary writer
        """

    print('TASK ', task_idx)

    # separate optimizer for each task
    optimizer = optimizer if optimizer is not None else optim.Adam(
        model.parameters(), lr=lr)
    head = task_idx if multiheaded else 0

    # obtain correct subset of data for training, and set some aside for the coreset
    task_data = task_subset(train_data, train_task_ids, task_idx)
    non_coreset_data = coreset.select(task_data, task_id=task_idx)
    train_loader = DataLoader(non_coreset_data, batch_size)

    # train
    for epoch in tqdm(range(epochs), 'Epochs: '):
        epoch_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            x = batch[0].to(device)

            loss = model.vae_loss(x, head, len(task_data))
            epoch_loss += len(x) * loss.item()

            loss.backward()
            optimizer.step()

        if summary_writer is not None:
            summary_writer.add_scalars(
                "loss", {"TASK_" + str(task_idx): epoch_loss / len(task_data)},
                epoch)

    # after training, prepare for new task by copying posteriors into priors
    model.reset_for_new_task(head)
    # we're reset a lot of parameters in the model, so we refresh the optimizer as well
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # coreset train
    model_cs_trained = coreset.coreset_train_generative(
        model, optimizer, task_idx, epochs, device, multiheaded=multiheaded)

    task_confusions = []
    task_likelihoods = []

    for test_task_idx in range(task_idx + 1):
        head = test_task_idx if multiheaded else 0

        # first test using classifier confusion metric
        y_true = torch.zeros(size=(batch_size, 10)).to(device)
        y_true[:, task_idx] = 1

        x_generated = model_cs_trained.generate(batch_size, head).view(
            batch_size, 1, 28, 28)
        y_pred = evaluation_classifier(x_generated)
        task_confusions.append(F.kl_div(torch.log(y_pred), y_true).item())

        print("After task {} confusion on task {} is {}".format(
            task_idx, test_task_idx, task_confusions[-1]))

        # generate a sample of 10 images
        images = x_generated[0:10]
        for count, image in enumerate(images, 0):
            save_generated_image(
                torch.squeeze(image.detach()).cpu().numpy(),
                'mnist_' + str(test_task_idx) + '_after_' + str(task_idx) +
                '_' + str(count) + '.png')

        # then test using log likelihood
        task_data = task_subset(test_data, test_task_ids, test_task_idx)
        x = torch.Tensor([x for x, _ in task_data])
        x = x.to(device)
        x_reconstructed = model(x, head)
        task_likelihoods.append(
            torch.mean(bernoulli_log_likelihood(x, x_reconstructed)).item())

        print("After task {} log likelihood on reconstruction task {} is {}".
              format(task_idx, test_task_idx, task_likelihoods[-1]))

    if summary_writer is not None:
        task_confusions_dict = dict(
            zip(["TASK_" + str(i) for i in range(task_idx + 1)],
                task_confusions))
        test_likelihoods_dict = dict(
            zip(["TASK_" + str(i) for i in range(task_idx + 1)],
                task_likelihoods))
        summary_writer.add_scalars("test_confusion", task_confusions_dict,
                                   task_idx + 1)
        summary_writer.add_scalars("test_likelihoods", test_likelihoods_dict,
                                   task_idx + 1)

    write_as_json(save_as + '/accuracy.txt', task_confusions)
    save_model(model, save_as + '_model_task_' + str(task_idx) + '.pth')