def train_mnist_classifier(): """ Train a non-VCL classifier for MNIST to be used to compute the 'classifier uncertainty' evaluation metric in the generative tasks. """ # image transforms and model model = MnistResNet().to(device) transforms = Compose([Resize(size=(224, 224)), ToTensor(), Scale()]) loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adadelta(model.parameters()) # download dataset mnist_train = MNIST(root="data", train=True, download=True, transform=transforms) mnist_test = MNIST(root="data", train=False, download=True, transform=transforms) train_loader = DataLoader(mnist_train, batch_size=CLASSIFIER_BATCH_SIZE, shuffle=True) test_loader = DataLoader(mnist_test, batch_size=CLASSIFIER_BATCH_SIZE, shuffle=True) # train model.train() for epoch in tqdm(range(CLASSIFIER_EPOCHS), 'Epochs'): epoch_loss = 0 for batch in tqdm(train_loader): optimizer.zero_grad() x, y = batch[0].to(device), batch[1].to(device) predictions = model(x) loss = loss_fn(predictions, y) epoch_loss += len(x) * loss.item() loss.backward() optimizer.step() # evaluate model.eval() accuracies = [] for batch in test_loader: x, y = batch[0].to(device), batch[1].to(device) predictions = torch.argmax(model(x), dim=1) accuracies.append(class_accuracy(predictions, y)) accuracy = sum(accuracies) / len(accuracies) print('Classifier accuracy: ' + str(accuracy)) save_model(model, MNIST_CLASSIFIER_FILENAME)
def run_task(model, train_data, train_task_ids, test_data, test_task_ids, task_idx, coreset, epochs, batch_size, save_as, device, lr, y_transform=None, multiheaded=True, train_full_coreset=True, summary_writer=None): """ Trains a VCL model using online variational inference on a task, and performs a coreset training run as well as an evaluation after training. :param model: the VCL model to train :param train_data: the complete dataset to train on, such as MNIST :param train_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to :param test_data: the complete dataset to train on, such as MNIST :param test_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to :param task_idx: task being learned, maps to a specific head in the network :param coreset: coreset object to use in training :param epochs: number of training epochs :param batch_size: batch size used in training :param save_as: base directory to save into :param device: device to run the experiment on, either 'cpu' or 'cuda' :param lr: optimizer learning rate to use :param y_transform: transform to be applied to the dataset labels :param multiheaded: true if the network being trained is multi-headed :param summary_writer: tensorboard_x summary writer """ print('TASK ', task_idx) # separate optimizer for each task optimizer = optim.Adam(model.parameters(), lr=lr) head = task_idx if multiheaded else 0 # obtain correct subset of data for training, and set some aside for the coreset task_data = task_subset(train_data, train_task_ids, task_idx) non_coreset_data = coreset.select(task_data, task_id=task_idx) train_loader = DataLoader(non_coreset_data, batch_size) # train for epoch in tqdm(range(epochs), 'Epochs: '): epoch_loss = 0 for batch in train_loader: optimizer.zero_grad() x, y_true = batch x = x.to(device) y_true = y_true.to(device) if y_transform is not None: y_true = y_transform(y_true, task_idx) loss = model.vcl_loss(x, y_true, head, len(task_data)) epoch_loss += len(x) * loss.item() loss.backward() optimizer.step() if summary_writer is not None: summary_writer.add_scalars( "loss", {"TASK_" + str(task_idx): epoch_loss / len(task_data)}, epoch) # after training, prepare for new task by copying posteriors into priors model.reset_for_new_task(head) # train using full coreset if train_full_coreset: model_cs_trained = coreset.coreset_train(model, optimizer, list(range(task_idx + 1)), epochs, device, y_transform=y_transform, multiheaded=multiheaded) # test task_accuracies = [] tot_right = 0 tot_tested = 0 for test_task_idx in range(task_idx + 1): if not train_full_coreset: model_cs_trained = coreset.coreset_train(model, optimizer, test_task_idx, epochs, device, y_transform=y_transform, multiheaded=multiheaded) head = test_task_idx if multiheaded else 0 task_data = task_subset(test_data, test_task_ids, test_task_idx) x = torch.Tensor([x for x, _ in task_data]) y_true = torch.Tensor([y for _, y in task_data]) x = x.to(device) y_true = y_true.to(device) if y_transform is not None: y_true = y_transform(y_true, test_task_idx) y_pred = model_cs_trained.prediction(x, head) acc = class_accuracy(y_pred, y_true) print("After task {} perfomance on task {} is {}".format( task_idx, test_task_idx, acc)) tot_right += acc * len(task_data) tot_tested += len(task_data) task_accuracies.append(acc) mean_accuracy = tot_right / tot_tested print("Mean accuracy:", mean_accuracy) if summary_writer is not None: task_accuracies_dict = dict( zip(["TASK_" + str(i) for i in range(task_idx + 1)], task_accuracies)) summary_writer.add_scalars("test_accuracy", task_accuracies_dict, task_idx + 1) summary_writer.add_scalar("mean_posterior_variance", model._mean_posterior_variance(), task_idx + 1) summary_writer.add_scalar("mean_accuracy", mean_accuracy, task_idx + 1) write_as_json(save_as + '/accuracy.txt', task_accuracies) save_model(model, save_as + '_model_task_' + str(task_idx) + '.pth')
def run_generative_task(model, train_data, train_task_ids, test_data, test_task_ids, task_idx, coreset, epochs, batch_size, save_as, device, lr, evaluation_classifier, multiheaded=True, optimizer=None, summary_writer=None): """ Trains a VCL model using online variational inference on a task, and performs a coreset training run as well as an evaluation after training. :param model: the VCL model to train :param train_data: the complete dataset to train on, such as MNIST :param train_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to :param test_data: the complete dataset to train on, such as MNIST :param test_task_ids: the label-to-task mapping that defines which task examples in the dataset belong to :param task_idx: task being learned, maps to a specific head in the network :param coreset: coreset object to use in training :param epochs: number of training epochs :param batch_size: batch size used in training :param save_as: base directory to save into :param device: device to run the experiment on, either 'cpu' or 'cuda' :param lr: optimizer learning rate to use :param evaluation_classifier: classifier used for the 'classifier uncertainty' test metric :param optimizer: optionally, provide an existing optimizer instead of having the method create a new one :param multiheaded: true if the network being trained is multi-headed :param summary_writer: tensorboard_x summary writer """ print('TASK ', task_idx) # separate optimizer for each task optimizer = optimizer if optimizer is not None else optim.Adam( model.parameters(), lr=lr) head = task_idx if multiheaded else 0 # obtain correct subset of data for training, and set some aside for the coreset task_data = task_subset(train_data, train_task_ids, task_idx) non_coreset_data = coreset.select(task_data, task_id=task_idx) train_loader = DataLoader(non_coreset_data, batch_size) # train for epoch in tqdm(range(epochs), 'Epochs: '): epoch_loss = 0 for batch in train_loader: optimizer.zero_grad() x = batch[0].to(device) loss = model.vae_loss(x, head, len(task_data)) epoch_loss += len(x) * loss.item() loss.backward() optimizer.step() if summary_writer is not None: summary_writer.add_scalars( "loss", {"TASK_" + str(task_idx): epoch_loss / len(task_data)}, epoch) # after training, prepare for new task by copying posteriors into priors model.reset_for_new_task(head) # we're reset a lot of parameters in the model, so we refresh the optimizer as well # optimizer = torch.optim.Adam(model.parameters(), lr=lr) # coreset train model_cs_trained = coreset.coreset_train_generative( model, optimizer, task_idx, epochs, device, multiheaded=multiheaded) task_confusions = [] task_likelihoods = [] for test_task_idx in range(task_idx + 1): head = test_task_idx if multiheaded else 0 # first test using classifier confusion metric y_true = torch.zeros(size=(batch_size, 10)).to(device) y_true[:, task_idx] = 1 x_generated = model_cs_trained.generate(batch_size, head).view( batch_size, 1, 28, 28) y_pred = evaluation_classifier(x_generated) task_confusions.append(F.kl_div(torch.log(y_pred), y_true).item()) print("After task {} confusion on task {} is {}".format( task_idx, test_task_idx, task_confusions[-1])) # generate a sample of 10 images images = x_generated[0:10] for count, image in enumerate(images, 0): save_generated_image( torch.squeeze(image.detach()).cpu().numpy(), 'mnist_' + str(test_task_idx) + '_after_' + str(task_idx) + '_' + str(count) + '.png') # then test using log likelihood task_data = task_subset(test_data, test_task_ids, test_task_idx) x = torch.Tensor([x for x, _ in task_data]) x = x.to(device) x_reconstructed = model(x, head) task_likelihoods.append( torch.mean(bernoulli_log_likelihood(x, x_reconstructed)).item()) print("After task {} log likelihood on reconstruction task {} is {}". format(task_idx, test_task_idx, task_likelihoods[-1])) if summary_writer is not None: task_confusions_dict = dict( zip(["TASK_" + str(i) for i in range(task_idx + 1)], task_confusions)) test_likelihoods_dict = dict( zip(["TASK_" + str(i) for i in range(task_idx + 1)], task_likelihoods)) summary_writer.add_scalars("test_confusion", task_confusions_dict, task_idx + 1) summary_writer.add_scalars("test_likelihoods", test_likelihoods_dict, task_idx + 1) write_as_json(save_as + '/accuracy.txt', task_confusions) save_model(model, save_as + '_model_task_' + str(task_idx) + '.pth')