def _generate_1d_tasks(show_plots=True, data_random_seed=42, writer=None): """Generate a set of tasks for 1D regression. Args: show_plots: Visualize the generated datasets. data_random_seed: Random seed that should be applied to the synthetic data generation. writer: Tensorboard writer, in case plots should be logged. Returns: data_handlers: A data handler for each task (instance of class "ToyRegression"). num_tasks: Number of generated tasks. """ # FIXME task generation currently not controlled by user via command-line. if False: # Set of random polynomials. num_tasks = 20 x_domains = [[-10, 10]] * num_tasks # Disjoint x domains. # tmp = np.linspace(-10, 10, num_tasks+1) # x_domains = list(zip(tmp[:-1], tmp[1:])) max_degree = 6 pcoeffs = np.random.uniform(-1, 1, size=(num_tasks, max_degree + 1)) map_funcs = [] for i in range(num_tasks): d = np.random.randint(0, max_degree) # Ignore highest degrees. c = pcoeffs[i, d:] # Decrease the magnitute of higher order coefficients. f = .5 for j in range(c.size - 1, -1, -1): c[j] *= f f *= f # We want the border points of all polynomials to not exceed a # certain absolute magnitude. bp = np.polyval(c, x_domains[i]) s = np.max(np.abs(bp)) + 1e-5 c = c / s * 10. map_funcs.append(lambda x, c=c: np.polyval(c, x)) std = .1 else: # Manually selected tasks. """ tmp = np.linspace(-15, 15, num_tasks + 1) x_domains = list(zip(tmp[:-1], tmp[1:])) map_funcs = [lambda x: 2. * (x+10.), lambda x: np.power(x, 2) * 2./2.5 - 10, lambda x: np.power(x-10., 3) * 1./12.5] std = 1. """ """ map_funcs = [lambda x : 0.1 * x, lambda x : np.power(x, 2) * 1e-2, lambda x : np.power(x, 3) * 1e-3] num_tasks = len(map_funcs) x_domains = [[-10, 10]] * num_tasks std = .1 """ map_funcs = [ lambda x: (x + 3.), lambda x: 2. * np.power(x, 2) - 1, lambda x: np.power(x - 3., 3) ] num_tasks = len(map_funcs) x_domains = [[-4, -2], [-1, 1], [2, 4]] std = .05 """ map_funcs = [lambda x : (x+30.), lambda x : .2 * np.power(x, 2) - 10, lambda x : 1e-2 * np.power(x-30., 3)] num_tasks = len(map_funcs) x_domains = [[-40,-20], [-10,10], [20,40]] std = .5 """ dhandlers = [] for i in range(num_tasks): print('Generating %d-th task.' % (i)) dhandlers.append( ToyRegression(train_inter=x_domains[i], num_train=100, test_inter=x_domains[i], num_test=50, map_function=map_funcs[i], std=std, rseed=data_random_seed)) if writer is not None: dhandlers[-1].plot_dataset(show=False) writer.add_figure('task_%d/dataset' % i, plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) elif show_plots: dhandlers[-1].plot_dataset() return dhandlers, num_tasks
def evaluate(task_id, data, mnet, hnet, device, config, shared, logger, writer, train_iter=None): """Evaluate the training progress. Evaluate the performance of the network on a single task (that is currently being trained) on the validation set. Note, if no validation set is available, the test set will be used instead. Args: (....): See docstring of method :func:`train`. Note, `hnet` can be passed as :code:`None`. In this case, no weights are passed to the `forward` method of the main network. train_iter: The current training iteration. If not given, the `writer` will not be used. """ if train_iter is None: logger.info('# Evaluating training ...') else: logger.info('# Evaluating network on task %d ' % (task_id + 1) + 'before running training step %d ...' % (train_iter)) # TODO: write histograms of weight samples to tensorboard. mnet.eval() if hnet is not None: hnet.eval() with torch.no_grad(): # Note, if no validation set exists, we use the training data to compute # the MSE (note, test data may contain out-of-distribution data in our # setup). split_type = 'train' if data.num_val_samples == 0 else 'val' if split_type == 'train': logger.debug('Eval - Using training set as no validation set is ' + 'available.') mse_val, val_struct = train_utils.compute_mse(task_id, data, mnet, hnet, device, config, shared, split_type=split_type) ident = 'training' if split_type == 'train' else 'validation' logger.info('Eval - Mean MSE on %s set: %f (std: %g).' % (ident, mse_val, val_struct.mse_vals.std())) # In contrast, we visualize uncertainty using the test set. mse_test, test_struct = train_utils.compute_mse( task_id, data, mnet, hnet, device, config, shared, split_type='test', return_dataset=True, return_predictions=True) logger.debug('Eval - Mean MSE on test set: %f (std: %g).' % (mse_test, test_struct.mse_vals.std())) if config.show_plots or train_iter is not None: train_utils.plot_predictive_distribution(data, test_struct.inputs, test_struct.predictions, show_raw_pred=True, figsize=(10, 4), show=train_iter is None) if train_iter is not None: writer.add_figure('task_%d/predictions' % task_id, plt.gcf(), train_iter, close=not config.show_plots) if config.show_plots: utils.repair_canvas_and_show_fig(plt.gcf()) writer.add_scalar('eval/task_%d/val_mse' % task_id, mse_val, train_iter) writer.add_scalar('eval/task_%d/test_mse' % task_id, mse_test, train_iter) logger.info('# Evaluating training ... Done')
def test(data_handlers, mnet, hnet, device, config, shared, logger, writer, hhnet=None, save_fig=True): """Test the performance of all tasks. Tasks are assumed to be regression tasks. Args: (....): See docstring of method :func:`probabilistic.train_vi.train`. data_handlers: A list of data handlers, each representing a task. save_fig: Whether the figures should be saved in the output folder. """ logger.info('### Testing all trained tasks ... ###') if hasattr(config, 'mean_only') and config.mean_only: warn('Task inference calculated in test method doesn\'t make any ' + 'sense, as the deterministic main network has no notion of ' + 'uncertainty.') pcutils.set_train_mode(False, mnet, hnet, hhnet, None) n = len(data_handlers) disable_lrt_test = config.disable_lrt_test if \ hasattr(config, 'disable_lrt_test') else None if not hasattr(shared, 'current_mse'): shared.current_mse = np.ones(n) * -1. elif shared.current_mse.size < n: tmp = shared.current_mse shared.current_mse = np.ones(n) * -1. shared.current_mse[:tmp.size] = tmp # Current MSE value on test set. test_mse = np.ones(n) * -1. # Current MSE value using the mean prediction of the inferred embedding. inferred_val_mse = np.ones(n) * -1. # Task inference accuracies. task_infer_val_accs = np.ones(n) * -1. with torch.no_grad(): # We need to keep data for plotting results on all tasks later on. val_inputs = [] val_targets = [] # Needed to compute MSE values. val_preds_mean = [] val_preds_std = [] # Which uncertainties have been measured per sample and task. The argmax # over all tasks gives the predicted task. val_task_preds = [] test_inputs = [] test_preds_mean = [] test_preds_std = [] normal_post = None if 'ewc' in shared.experiment_type: assert hnet is None normal_post = ewcutil.build_ewc_posterior(data_handlers, mnet, device, config, shared, logger, writer, n, task_id=n - 1) if config.train_from_scratch and n > 1: # We need to iterate over different networks when we want to # measure the uncertainty of dataset i on task j. # Note, we will always load the corresponding checkpoint of task j # before using these networks. if 'avb' in shared.experiment_type \ or 'ssge' in shared.experiment_type\ or 'ewc' in shared.experiment_type: mnet_other, hnet_other, hhnet_other, _ = \ pcutils.generate_networks(config, shared, logger, shared.all_dhandlers, device, create_dis=False) else: assert hhnet is None hhnet_other = None non_gaussian = config.mean_only \ if hasattr(config, 'mean_only') else True mnet_other, hnet_other = train_utils.generate_gauss_networks( \ config, logger, shared.all_dhandlers, device, create_hnet=hnet is not None, non_gaussian=non_gaussian) pcutils.set_train_mode(False, mnet_other, hnet_other, hhnet_other, None) task_n_mnet = mnet task_n_hnet = hnet task_n_hhnet = hhnet task_n_normal_post = normal_post # This renaming is just a protection against myself, that I don't use # any of those networks (`mnet`, `hnet`, `hhnet`) in the future # inside the loop when training from scratch. if config.train_from_scratch: mnet = None hnet = None hhnet = None normal_post = None ### For each data set (i.e., for each task). for i in range(n): data = data_handlers[i] ### We want to measure MSE values within the training range only! split_type = 'val' num_val_samples = data.num_val_samples if num_val_samples == 0: split_type = 'train' num_val_samples = data.num_train_samples logger.debug('Test: Task %d - Using training set as no ' % i + 'validation set is available.') ### Task inference. # We need to iterate over each task embedding and measure the # predictive uncertainty in order to decide which embedding to use. data_preds = np.empty((num_val_samples, config.val_sample_size, n)) data_preds_mean = np.empty((num_val_samples, n)) data_preds_std = np.empty((num_val_samples, n)) for j in range(n): ckpt_score_j = None if config.train_from_scratch and j == (n - 1): # Note, the networks trained on dataset (n-1) haven't been # checkpointed yet. mnet_j = task_n_mnet hnet_j = task_n_hnet hhnet_j = task_n_hhnet normal_post_j = task_n_normal_post elif config.train_from_scratch: ckpt_score_j = pmutils.load_networks(shared, j, device, logger, mnet_other, hnet_other, hhnet=hhnet_other, dis=None) mnet_j = mnet_other hnet_j = hnet_other hhnet_j = hhnet_other if 'ewc' in shared.experiment_type: normal_post_j = ewcutil.build_ewc_posterior( \ data_handlers, mnet_j, device, config, shared, logger, writer, n, task_id=j) else: mnet_j = mnet hnet_j = hnet hhnet_j = hhnet normal_post_j = normal_post mse_val, val_struct = train_utils.compute_mse( j, data, mnet_j, hnet_j, device, config, shared, hhnet=hhnet_j, split_type=split_type, return_dataset=i == j, return_predictions=True, disable_lrt=disable_lrt_test, normal_post=normal_post_j) if i == j: # I.e., we used the correct embedding. # This sanity check is likely to fail as we don't # deterministically sample the models. #if ckpt_score_j is not None: # assert np.allclose(-mse_val, ckpt_score_j) val_inputs.append(val_struct.inputs) val_targets.append(val_struct.targets) val_preds_mean.append(val_struct.predictions.mean(axis=1)) val_preds_std.append(val_struct.predictions.std(axis=1)) shared.current_mse[i] = mse_val logger.debug('Test: Task %d - Mean MSE on %s set: %f ' % (i, split_type, mse_val) + '(std: %g).' % (val_struct.mse_vals.std())) writer.add_scalar('test/task_%d/val_mse' % i, shared.current_mse[i], n) # The test set spans into the OOD range and can be used to # visualize how uncertainty behaves outside the # in-distribution range. mse_test, test_struct = train_utils.compute_mse( i, data, mnet_j, hnet_j, device, config, shared, hhnet=hhnet_j, split_type='test', return_dataset=True, return_predictions=True, disable_lrt=disable_lrt_test, normal_post=normal_post_j) data_preds[:, :, j] = val_struct.predictions data_preds_mean[:, j] = val_struct.predictions.mean(axis=1) ### We interpret this value as the certainty of the prediction. # I.e., how certain is our system that each of the samples # belong to task j? data_preds_std[:, j] = val_struct.predictions.std(axis=1) val_task_preds.append(data_preds_std) ### Compute task inference accuracy. inferred_task_ids = data_preds_std.argmin(axis=1) num_correct = np.sum(inferred_task_ids == i) accuracy = 100. * num_correct / num_val_samples task_infer_val_accs[i] = accuracy logger.debug('Test: Task %d - Accuracy of task inference ' % i + 'on %s set: %.2f%%.' % (split_type, accuracy)) writer.add_scalar('test/task_%d/accuracy' % i, accuracy, n) ### Compute MSE based on inferred embedding. # Note, this (commented) way of computing the mean does not take # into account the variance of the predictive distribution, which is # why we don't use it (see docstring of `compute_mse`). #means_of_inferred_preds = data_preds_mean[np.arange( \ # data_preds_mean.shape[0]), inferred_task_ids] #inferred_val_mse[i] = np.power(means_of_inferred_preds - # val_targets[-1].squeeze(), 2).mean() inferred_preds = data_preds[np.arange(data_preds.shape[0]), :, inferred_task_ids] inferred_val_mse[i] = np.power(inferred_preds - \ val_targets[-1].squeeze()[:, np.newaxis], 2).mean() logger.debug('Test: Task %d - Mean MSE on %s set using inferred '\ % (i, split_type) + 'embeddings: %f.' % (inferred_val_mse[i])) writer.add_scalar('test/task_%d/inferred_val_mse' % i, inferred_val_mse[i], n) ### We are interested in the predictive uncertainty across the ### whole test range! test_mse[i] = mse_test writer.add_scalar('test/task_%d/test_mse' % i, test_mse[i], n) test_inputs.append(test_struct.inputs.squeeze()) test_preds_mean.append(test_struct.predictions.mean(axis=1). \ squeeze()) test_preds_std.append( test_struct.predictions.std(axis=1).squeeze()) if hasattr(shared, 'during_mse') and \ shared.during_mse[i] == -1: shared.during_mse[i] = shared.current_mse[i] if test_struct.w_hnet is not None or test_struct.w_mean is not None: assert hasattr(shared, 'during_weights') if test_struct.w_hnet is not None: # We have a hyper-hypernetwork. In this case, the CL # regularizer is applied to its output and therefore, these # are the during weights whose Euclidean distance we want to # track. assert task_n_hhnet is not None w_all = test_struct.w_hnet else: assert test_struct.w_mean is not None # We will be here whenever the hnet is deterministic (i.e., # doesn't represent an implicit distribution). w_all = list(test_struct.w_mean) if test_struct.w_std is not None: w_all += list(test_struct.w_std) W_curr = torch.cat([d.clone().view(-1) for d in w_all]) if type(shared.during_weights[i]) == int: assert (shared.during_weights[i] == -1) shared.during_weights[i] = W_curr else: W_during = shared.during_weights[i] W_dis = torch.norm(W_curr - W_during, 2) logger.info('Euclidean distance between hypernet output ' + 'for task %d: %g' % (i, W_dis)) ### Compute overall task inference accuracy. num_correct = 0 num_samples = 0 for i, uncertainties in enumerate(val_task_preds): pred_task_ids = uncertainties.argmin(axis=1) num_correct += np.sum(pred_task_ids == i) num_samples += pred_task_ids.size accuracy = 100. * num_correct / num_samples logger.info('Task inference accuracy: %.2f%%.' % accuracy) # TODO Compute overall MSE on all tasks using inferred embeddings. ### Plot the mean predictions on all tasks. # (Using the validation set and the correct embedding per dataset) plot_x_ranges = [] for i in range(n): plot_x_ranges.append(data_handlers[i].train_x_range) fig_fn = None if save_fig: fig_fn = os.path.join(config.out_dir, 'val_predictions_%d' % n) data_inputs = val_inputs mean_preds = val_preds_mean data_handlers[0].plot_datasets(data_handlers, data_inputs, mean_preds, fun_xranges=plot_x_ranges, filename=fig_fn, show=False, publication_style=config.publication_style) writer.add_figure('test/val_predictions', plt.gcf(), n, close=not config.show_plots) if config.show_plots: utils.repair_canvas_and_show_fig(plt.gcf()) ### Scatter plot showing MSE per task (original + current one). during_mse = None if hasattr(shared, 'during_mse'): during_mse = shared.during_mse[:n] train_utils.plot_mse(config, writer, n, shared.current_mse[:n], during_mse, save_fig=save_fig) additional_plots = { 'Current Inferred Val MSE': inferred_val_mse, #'Current Test MSE': test_mse } train_utils.plot_mse(config, writer, n, shared.current_mse[:n], during_mse, baselines=additional_plots, save_fig=False, summary_label='test/mse_detailed') ### Plot predictive distributions over test range for all tasks. data_inputs = test_inputs mean_preds = test_preds_mean std_preds = test_preds_std train_utils.plot_predictive_distributions( config, writer, data_handlers, data_inputs, mean_preds, std_preds, save_fig=save_fig, publication_style=config.publication_style) logger.info('Mean task MSE: %f (std: %d)' % (shared.current_mse[:n].mean(), shared.current_mse[:n].std())) ### Update performance summary. s = shared.summary s['aa_mse_during'][:n] = shared.during_mse[:n].tolist() s['aa_mse_during_mean'] = shared.during_mse[:n].mean() s['aa_mse_final'][:n] = shared.current_mse[:n].tolist() s['aa_mse_final_mean'] = shared.current_mse[:n].mean() s['aa_task_inference'][:n] = task_infer_val_accs.tolist() s['aa_task_inference_mean'] = task_infer_val_accs.mean() s['aa_mse_during_inferred'][n - 1] = inferred_val_mse[n - 1] s['aa_mse_during_inferred_mean'] = np.mean(s['aa_mse_during_inferred'][:n]) s['aa_mse_final_inferred'] = inferred_val_mse[:n].tolist() s['aa_mse_final_inferred_mean'] = inferred_val_mse[:n].mean() train_utils.save_summary_dict(config, shared) logger.info('### Testing all trained tasks ... Done ###')
def generate_datasets(config, logger, writer): """Create a data handler per task. Note: The datasets are hard-coded in this function. Args: config (argparse.Namespace): Command-line arguments. This function will add the key ``num_tasks`` to this namespace if not existing. Note, this function will also add the keys ``gmm_grid_size``, ``gmm_grid_range_1``, ``gmm_grid_range_2``, which are used for plotting. logger: Logger object. writer (tensorboardX.SummaryWriter): Tensorboard logger. Returns: (list): A list of data handlers. """ NUM_TRAIN = 10 NUM_TEST = 100 config.gmm_grid_size = 250 dhandlers = [] TASK_SET = 3 if TASK_SET == 0: config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-1, 1] means = [[np.array([0, 1]), np.array([0, -1])]] variances = [[0.05**2 * np.eye(len(mean)) for mean in means[0]]] elif TASK_SET == 1: config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-6, 6] means = [gauss_mod.CHE_MEANS[i:i + 2] for i in range(0, 6, 2)] variances = [gauss_mod.CHE_VARIANCES[i:i + 2] for i in range(0, 6, 2)] elif TASK_SET == 2: config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-9, 9] means = [gauss_mod.CHE_MEANS[i:i + 2] for i in range(0, 6, 2)] variances = [[1.**2 * np.eye(len(m)) for m in mm] for mm in means] elif TASK_SET == 3: config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-9, 9] means = [gauss_mod.CHE_MEANS[i:i + 2] for i in range(0, 6, 2)] variances = [[.2**2 * np.eye(len(m)) for m in mm] for mm in means] else: raise NotImplementedError() # Note, this is a synthetic dataset where the number of tasks and the # number of classes per tasks is hard-coded inside this function. if hasattr(config, 'num_tasks') and config.num_tasks > len(means): raise ValueError('Command-line argument "num_tasks" has impossible ' + 'value %d (maximum value would be %d).' % (config.num_tasks, len(means))) elif not hasattr(config, 'num_tasks'): config.num_tasks = len(means) else: means = means[:config.num_tasks] variances = variances[:config.num_tasks] if hasattr(config, 'num_classes_per_task'): raise ValueError('Command-line argument "num_classes_per_task" ' + 'cannot be considered by this function.') if hasattr(config, 'val_set_size') and config.val_set_size > 0: raise ValueError('GMM Dataset does not support a validation set!') show_plots = False if hasattr(config, 'show_plots'): show_plots = config.show_plots # For multiple tasks, generate a combined dataset just to create some plots. gauss_bumps_all = get_gmm_tasks(means=list(itertools.chain(*means)), covs=list(itertools.chain(*variances)), num_train=NUM_TRAIN, num_test=NUM_TEST, map_functions=None, rseed=config.data_random_seed) if config.num_tasks > 1: full_data = GMMData(gauss_bumps_all, classification=True, use_one_hot=True, mixing_coefficients=None) input_mesh = full_data.get_input_mesh(x1_range=config.gmm_grid_range_1, x2_range=config.gmm_grid_range_2, grid_size=config.gmm_grid_size) # Plot data distribution. if writer is not None: full_data.plot_uncertainty_map(title='All Data', input_mesh=input_mesh, use_generative_uncertainty=True, sketch_components=False, show=False) writer.add_figure('all_tasks/data_dist', plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot ground-truth conditional uncertainty. if writer is not None: full_data.plot_uncertainty_map(title='Conditional Uncertainty', input_mesh=input_mesh, sketch_components=True, show=False) writer.add_figure('all_tasks/cond_entropy', plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot ground-truth class boundaries. if writer is not None: full_data.plot_optimal_classification(title='Class-Boundaries', input_mesh=input_mesh, sketch_components=True, show=False) writer.add_figure('all_tasks/class_boundaries', plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot ground-truth class boundaries together with all training data. # Note, that might visualize training points the would even be # misclassified by the true underlying model (due to the stochastic # drawing of samples). if writer is not None: full_data.plot_optimal_classification( title='Class-Boundaries - Training Data', input_mesh=input_mesh, sketch_components=True, show=False, sample_inputs=full_data.get_train_inputs(), sample_modes=np.argmax(full_data.get_train_outputs(), axis=1)) writer.add_figure('all_tasks/class_boundaries_train', plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot ground-truth class boundaries together with all test data. if writer is not None: full_data.plot_optimal_classification( title='Class-Boundaries - Test Data', input_mesh=input_mesh, sketch_components=True, show=False, sample_inputs=full_data.get_test_inputs(), sample_modes=np.argmax(full_data.get_test_outputs(), axis=1)) writer.add_figure('all_tasks/class_boundaries_test', plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Create individual task datasets. ii = 0 for i in range(len(means)): gauss_bumps = gauss_bumps_all[ii:ii + len(means[i])] ii += len(means[i]) dhandlers.append( GMMData(gauss_bumps, classification=True, use_one_hot=True, mixing_coefficients=None)) input_mesh = dhandlers[-1].get_input_mesh( \ x1_range=config.gmm_grid_range_1, x2_range=config.gmm_grid_range_2, grid_size=config.gmm_grid_size) # Plot training data. if writer is not None: dhandlers[-1].plot_uncertainty_map(title='Training Data', input_mesh=input_mesh, use_generative_uncertainty=True, sample_inputs=dhandlers[-1].get_train_inputs(), sample_modes=np.argmax(dhandlers[-1].get_train_outputs(), \ axis=1), #sample_label='Training data', sketch_components=True, show=False) writer.add_figure('task_%d/train_data' % i, plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot test data. if writer is not None: dhandlers[-1].plot_uncertainty_map(title='Test Data', input_mesh=input_mesh, use_generative_uncertainty=True, sample_inputs=dhandlers[-1].get_test_inputs(), sample_modes=np.argmax(dhandlers[-1].get_test_outputs(), \ axis=1), #sample_label='Training data', sketch_components=True, show=False) writer.add_figure('task_%d/test_data' % i, plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot ground-truth conditional uncertainty. if writer is not None: dhandlers[-1].plot_uncertainty_map(title='Conditional Uncertainty', input_mesh=input_mesh, sketch_components=True, show=False) writer.add_figure('task_%d/cond_entropy' % i, plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot ground-truth class boundaries. if writer is not None: dhandlers[-1].plot_optimal_classification(title='Class-Boundaries', input_mesh=input_mesh, sketch_components=True, show=False) writer.add_figure('task_%d/class_boundaries' % i, plt.gcf(), close=not show_plots) if show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) return dhandlers
def plot_gmm_prior_preds(task_id, data, mnet, hnet, hhnet, device, config, shared, logger, writer, prior_mean, prior_std, prior_theta=None): """Visualize the prior predictive entropy over the whole input space. Similar to function :func:`plot_gmm_preds`, but rather than sampling from the approximate posterior, samples are drawn from a given prior distribution. Args: (....): See docstring of function :func:`plot_gmm_preds`. prior_mean (list): A list of tensors that represent the mean of an explicit prior. Is expected to be ``None`` if ``prior_theta`` is specified. prior_std (list): A list of tensors that represent the std of an explicit prior. See``prior_mean`` for more details. prior_theta (list): The weights passed to ``hnet`` when drawing samples from the current implicit distribution, which represents the prior. """ # FIXME Code in this function is almost identical to the one in function # `plot_gmm_preds`. assert prior_mean is None and prior_std is None or \ prior_mean is not None and prior_std is not None assert (prior_theta is None or prior_mean is None) and \ (prior_theta is not None or prior_std is not None) # Gather prior samples. prior_samples = [] for i in range(config.val_sample_size): if prior_theta is not None: z = torch.normal(torch.zeros(1, shared.noise_dim), config.latent_std).to(device) prior_samples.append( hnet.forward(uncond_input=z, weights=prior_theta)) else: prior_samples.append( putils.sample_diag_gauss(prior_mean, prior_std)) input_mesh = data.get_input_mesh(x1_range=config.gmm_grid_range_1, x2_range=config.gmm_grid_range_2, grid_size=config.gmm_grid_size) if 'bbb' in shared.experiment_type: _, ret_fig = pmutils.compute_acc(task_id, data, mnet, hnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=True, disable_lrt=config.disable_lrt_test, in_samples=input_mesh[2], w_samples=prior_samples) else: assert 'avb' in shared.experiment_type or 'ssge' in \ shared.experiment_type _, ret_fig = pcutils.compute_acc(task_id, data, mnet, hnet, hhnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=True, in_samples=input_mesh[2], w_samples=prior_samples) # The means of other tasks. other_means = np.concatenate([dh.means for dh in shared.all_dhandlers], axis=0) # Plot entropies over whole input space (according to `input_mesh`). data.plot_uncertainty_map( \ title='Entropy of prior predictive distribution', input_mesh=input_mesh, uncertainties=ret_fig.entropies.reshape(-1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure('prior/task_%d/pred_entropies' % task_id, plt.gcf(), close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot entropies over whole input space (according to `input_mesh`). data.plot_optimal_classification(title='Prior Predicted Class-Boundaries', input_mesh=input_mesh, mesh_modes=ret_fig.pred_labels.reshape( -1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure('prior/task_%d/pred_class_boundaries' % task_id, plt.gcf(), close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plots for single samples from the prior. assert 'bbb' in shared.experiment_type and not config.mean_only or \ 'avb' in shared.experiment_type and hnet is not None or \ 'ssge' in shared.experiment_type and hnet is not None for ii in range(min(10, len(prior_samples))): if 'bbb' in shared.experiment_type: _, ret_fig = pmutils.compute_acc( task_id, data, mnet, hnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=False, disable_lrt=config.disable_lrt_test, in_samples=input_mesh[2], w_samples=[prior_samples[ii]]) else: assert 'avb' in shared.experiment_type or 'ssge' in \ shared.experiment_type _, ret_fig = pcutils.compute_acc(task_id, data, mnet, hnet, hhnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=False, in_samples=input_mesh[2], w_samples=[prior_samples[ii]]) # Plot entropies over whole input space (according to `input_mesh`). data.plot_uncertainty_map( \ title='Entropy of prior predictive distribution', input_mesh=input_mesh, uncertainties=ret_fig.entropies.reshape(-1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure('prior/task_%d/single_samples_pred_entropies' \ % task_id, plt.gcf(), ii, close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot entropies over whole input space (according to `input_mesh`). data.plot_optimal_classification( \ title='Prior Predicted Class-Boundaries', input_mesh=input_mesh, mesh_modes=ret_fig.pred_labels.reshape(-1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure(\ 'prior/task_%d/single_samples_pred_class_boundaries' % task_id, plt.gcf(), ii, close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf())
def plot_gmm_preds(task_id, data, mnet, hnet, hhnet, device, config, shared, logger, writer, tb_step, draw_samples=False, normal_post=None): """Visualize the predictive entropy over the whole input space. The advantage of the GMM toy example is, that we can visualize quantities such as predictions and predictive entropies over an arbitrary large part of the 2D input space. Here, we use the current model associated with ``task_id``. All plots are logged to tensorboard. Args: (....): See docstring of function :func:`probabilistic.prob_cifar.train_avb.test`. task_id (int): ID of current task. tb_step (int): Tensorboard step for plots to be logged. draw_samples (bool): If ``True``, the method will also draw plots for single samples (if model is non-deterministic). normal_post (tuple, optional): See docstring of function :func:`probabilistic.regression.train_utils.compute_mse` """ input_mesh = data.get_input_mesh(x1_range=config.gmm_grid_range_1, x2_range=config.gmm_grid_range_2, grid_size=config.gmm_grid_size) if 'bbb' in shared.experiment_type or 'ewc' in shared.experiment_type: disable_lrt = config.disable_lrt_test if \ hasattr(config, 'disable_lrt_test') else False _, ret_fig = pmutils.compute_acc(task_id, data, mnet, hnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=True, disable_lrt=disable_lrt, in_samples=input_mesh[2], normal_post=normal_post) else: assert 'avb' in shared.experiment_type or 'ssge' in \ shared.experiment_type _, ret_fig = pcutils.compute_acc(task_id, data, mnet, hnet, hhnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=True, in_samples=input_mesh[2]) # The means of other tasks. other_means = np.concatenate([dh.means for dh in shared.all_dhandlers], axis=0) # Plot entropies over whole input space (according to `input_mesh`). data.plot_uncertainty_map( \ title='Entropy of predictive distribution', input_mesh=input_mesh, uncertainties=ret_fig.entropies.reshape(-1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure('task_%d/pred_entropies' % task_id, plt.gcf(), tb_step, close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot entropies over whole input space (according to `input_mesh`). data.plot_optimal_classification(title='Predicted Class-Boundaries', input_mesh=input_mesh, mesh_modes=ret_fig.pred_labels.reshape( -1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure('task_%d/pred_class_boundaries' % task_id, plt.gcf(), tb_step, close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # If not deterministic, plot single weight samples. # TODO We could also plot them for EWC if 'bbb' in shared.experiment_type and not config.mean_only or \ ('avb' in shared.experiment_type or \ 'ssge' in shared.experiment_type) and hnet is not None and \ draw_samples: for ii in range(10): if 'bbb' in shared.experiment_type: _, ret_fig = pmutils.compute_acc( task_id, data, mnet, hnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=False, disable_lrt=config.disable_lrt_test, in_samples=input_mesh[2], num_w_samples=1) else: assert 'avb' in shared.experiment_type or \ 'ssge' in shared.experiment_type _, ret_fig = pcutils.compute_acc(task_id, data, mnet, hnet, hhnet, device, config, shared, split_type=None, return_entropies=True, return_pred_labels=True, deterministic_sampling=False, in_samples=input_mesh[2], num_w_samples=1) # Plot entropies over whole input space (according to `input_mesh`). data.plot_uncertainty_map( \ title='Entropy of predictive distribution', input_mesh=input_mesh, uncertainties=ret_fig.entropies.reshape(-1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure('task_%d/single_samples_pred_entropies' \ % task_id, plt.gcf(), ii, close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf()) # Plot entropies over whole input space (according to `input_mesh`). data.plot_optimal_classification( title='Predicted Class-Boundaries', input_mesh=input_mesh, mesh_modes=ret_fig.pred_labels.reshape(-1, 1), sample_inputs=other_means, sketch_components=True, show=False) writer.add_figure(\ 'task_%d/single_samples_pred_class_boundaries' % task_id, plt.gcf(), ii, close=not config.show_plots) if config.show_plots: misc.repair_canvas_and_show_fig(plt.gcf())
def evaluate(task_id, data, mnet, hnet, hhnet, dis, device, config, shared, logger, writer, train_iter=None): """Evaluate the training progress. Evaluate the performance of the network on a single task (that is currently being trained) on the validation set. Note, if no validation set is available, the test set will be used instead. Args: (....): See docstring of method :func:`probabilistic.prob_cifar.train_avb.evaluate`. """ # FIXME Code below almost identical to # `probabilistic.regression.train_bbb.evaluate`. if train_iter is None: logger.info('# Evaluating training ...') else: logger.info('# Evaluating network on task %d ' % (task_id + 1) + 'before running training step %d ...' % train_iter) pcu.set_train_mode(False, mnet, hnet, hhnet, dis) with torch.no_grad(): # Note, if no validation set exists, we use the training data to compute # the MSE (note, test data may contain out-of-distribution data in our # setup). split_type = 'train' if data.num_val_samples == 0 else 'val' if split_type == 'train': logger.debug('Eval - Using training set as no validation set is ' + 'available.') mse_val, val_struct = train_utils.compute_mse(task_id, data, mnet, hnet, device, config, shared, hhnet=hhnet, split_type=split_type) ident = 'training' if split_type == 'train' else 'validation' logger.info('Eval - Mean MSE on %s set: %f (std: %g).' % (ident, mse_val, val_struct.mse_vals.std())) # In contrast, we visualize uncertainty using the test set. mse_test, test_struct = train_utils.compute_mse( task_id, data, mnet, hnet, device, config, shared, hhnet=hhnet, split_type='test', return_dataset=True, return_predictions=True) logger.debug('Eval - Mean MSE on test set: %f (std: %g).' % (mse_test.mean(), mse_test.std())) if config.show_plots or train_iter is not None: train_utils.plot_predictive_distribution(data, test_struct.inputs, test_struct.predictions, show_raw_pred=True, figsize=(10, 4), show=train_iter is None) if train_iter is not None: writer.add_figure('task_%d/predictions' % task_id, plt.gcf(), train_iter, close=not config.show_plots) if config.show_plots: utils.repair_canvas_and_show_fig(plt.gcf()) writer.add_scalar('eval/task_%d/val_mse' % task_id, mse_val.mean(), train_iter) writer.add_scalar('eval/task_%d/test_mse' % task_id, mse_test.mean(), train_iter) # FIXME Code below copied from # `probabilistic.prob_cifar.train_avb.evaluate`. ### Compute discriminator accuracy. if dis is not None and hnet is not None: hnet_theta = None if hhnet is not None: hnet_theta = hhnet.forward(cond_id=task_id) # FIXME Is it ok if I only look at how samples from the current # implicit distribution are classified? dis_out, dis_inputs = pcu.process_dis_batch(config, shared, config.val_sample_size, device, dis, hnet, hnet_theta, dist=None) dis_acc = (dis_out > 0).sum().detach().cpu().numpy() / \ config.val_sample_size * 100. logger.debug('Eval - Discriminator accuracy: %.2f%%.' % (dis_acc)) writer.add_scalar('eval/task_%d/dis_acc' % task_id, dis_acc, train_iter) # FIXME Summary results should be written in the test method after # training on a task has finished (note, eval is no guaranteed to be # called after or even during training). But I just want to get an # overview. s = shared.summary s['aa_acc_dis'][task_id] = dis_acc s['aa_acc_avg_dis'] = np.mean(s['aa_acc_dis'][:(task_id + 1)]) # Visualize weight samples. # FIXME A bit hacky. w_samples = dis_inputs if config.use_batchstats: w_samples = dis_inputs[:, (dis_inputs.shape[1] // 2):] pcu.visualize_implicit_dist(config, task_id, writer, train_iter, w_samples, figsize=(10, 6)) logger.info('# Evaluating training ... Done')
def plot_predictive_distributions(config, writer, data_handlers, inputs, preds_mean, preds_std, save_fig=True, publication_style=False): """Plot the predictive distribution of several tasks into one plot. Args: config: Command-line arguments. writer: Tensorboard summary writer. data_handlers: A set of data loaders. inputs: A list of arrays containing the x values per task. preds_mean: The mean predictions corresponding to each task in `inputs`. preds_std: The std of all predictions in `preds_mean`. save_fig: Whether the figure should be saved in the output folder. publication_style: whether plots should be made in publication style. """ num_tasks = len(data_handlers) assert (len(inputs) == num_tasks) assert (len(preds_mean) == num_tasks) assert (len(preds_std) == num_tasks) colors = utils.get_colorbrewer2_colors(family='Dark2') if num_tasks > len(colors): warn('Changing to automatic color scheme as we don\'t have ' + 'as many manual colors as tasks.') colors = cm.rainbow(np.linspace(0, 1, num_tasks)) fig, axes = plt.subplots(figsize=(12, 6)) if publication_style: ts, lw, ms = 60, 15, 10 # text fontsize, line width, marker size else: ts, lw, ms = 12, 5, 8 # The default matplotlib setting is usually too high for most plots. plt.locator_params(axis='y', nbins=2) plt.locator_params(axis='x', nbins=6) for i, data in enumerate(data_handlers): assert (isinstance(data, ToyRegression)) # In what range to plot the real function values? train_range = data.train_x_range range_offset = (train_range[1] - train_range[0]) * 0.05 sample_x, sample_y = data._get_function_vals( \ x_range=[train_range[0]-range_offset, train_range[1]+range_offset]) #sample_x, sample_y = data._get_function_vals() #plt.plot(sample_x, sample_y, color='k', label='f(x)', # linestyle='dashed', linewidth=.5) plt.plot(sample_x, sample_y, color='k', linestyle='dashed', linewidth=lw / 7.) train_x = data.get_train_inputs().squeeze() train_y = data.get_train_outputs().squeeze() if i == 0: plt.plot(train_x, train_y, 'o', color='k', label='Training Data', markersize=ms) else: plt.plot(train_x, train_y, 'o', color='k', markersize=ms) plt.plot(inputs[i], preds_mean[i], color=colors[i], label='Task %d' % (i + 1), lw=lw / 3.) plt.fill_between(inputs[i], preds_mean[i] + preds_std[i], preds_mean[i] - preds_std[i], color=colors[i], alpha=0.3) plt.fill_between(inputs[i], preds_mean[i] + 2. * preds_std[i], preds_mean[i] - 2. * preds_std[i], color=colors[i], alpha=0.2) plt.fill_between(inputs[i], preds_mean[i] + 3. * preds_std[i], preds_mean[i] - 3. * preds_std[i], color=colors[i], alpha=0.1) if publication_style: axes.grid(False) axes.set_facecolor('w') axes.set_ylim([-2.5, 3]) axes.axhline(y=axes.get_ylim()[0], color='k', lw=lw) axes.axvline(x=axes.get_xlim()[0], color='k', lw=lw) if len(data_handlers) == 3: plt.yticks([-2, 0, 2], fontsize=ts) plt.xticks([-3, 0, 3], fontsize=ts) else: for tick in axes.yaxis.get_major_ticks(): tick.label.set_fontsize(ts) for tick in axes.xaxis.get_major_ticks(): tick.label.set_fontsize(ts) axes.tick_params(axis='both', length=lw, direction='out', width=lw / 2.) if config.train_from_scratch: plt.title('training from scratch', fontsize=ts, pad=ts) elif config.beta == 0: plt.title('fine-tuning', fontsize=ts, pad=ts) else: plt.title('CL with prob. hnet reg.', fontsize=ts, pad=ts) else: plt.legend() plt.title('Predictive distributions', fontsize=ts, pad=ts) plt.xlabel('$x$', fontsize=ts) plt.ylabel('$y$', fontsize=ts) if save_fig: plt.savefig(os.path.join(config.out_dir, 'pred_dists_%d' % num_tasks), bbox_inches='tight') writer.add_figure('test/pred_dists', plt.gcf(), num_tasks, close=not config.show_plots) if config.show_plots: utils.repair_canvas_and_show_fig(plt.gcf())
def generate_1d_tasks(show_plots=True, data_random_seed=42, writer=None, task_set=1): """Generate a set of tasks for 1D regression. Args: show_plots: Visualize the generated datasets. data_random_seed: Random seed that should be applied to the synthetic data generation. writer: Tensorboard writer, in case plots should be logged. task_set (int): The set of tasks to be used. All sets are hard-coded inside this function. Returns: (tuple): Tuple containing: - **data_handlers**: A data handler for each task (instance of class :class:`data.special.regression1d_data.ToyRegression`). - **num_tasks**: Number of generated tasks. """ if task_set == 0: # Here, we define as benchmark a typical dataset used in the # uncertainty literature: # Regression task y = x**3 + eps, where eps ~ N(0, 9*I). # For instance, see here: # https://arxiv.org/pdf/1703.01961.pdf # How far outside the regime of the training data do we wanna predict # samples? test_offset = 1.5 map_funcs = [lambda x: (x**3.)] num_tasks = len(map_funcs) x_domains = [[-4, 4]] std = 3 # 3**2 == 9 num_train = 20 # Range of pred. dist. plots. test_domains = [[ x_domains[0][0] - test_offset, x_domains[0][1] + test_offset ]] elif task_set == 1: #test_offset = 1 map_funcs = [ lambda x: (x + 3.), lambda x: 2. * np.power(x, 2) - 1, lambda x: np.power(x - 3., 3) ] num_tasks = len(map_funcs) x_domains = [[-4, -2], [-1, 1], [2, 4]] std = .05 #test_domains = [[-4.1, -0.5], [-2.5,2.5], [.5, 4.1]] test_domains = [[-4.1, 4.1], [-4.1, 4.1], [-4.1, 4.1]] #m = 32 # magnitude #s = -.25 * np.pi #map_funcs = [lambda x : m*np.pi * np.sin(x + s), # lambda x : m*np.pi * np.sin(x + s), # lambda x : m*np.pi * np.sin(x + s)] #x_domains = [[-2*np.pi, -1*np.pi], [-0.5*np.pi, 0.5*np.pi], # [1*np.pi, 2*np.pi]] #x_domains_test = [[-2*np.pi, 2*np.pi], [-2*np.pi, 2*np.pi], # [-2*np.pi, 2*np.pi]] #std = 3 num_tasks = len(map_funcs) num_train = 20 elif task_set == 2: map_funcs = [ lambda x: np.power(x + 3., 3), lambda x: 2. * np.power(x, 2) - 1, lambda x: -np.power(x - 3., 3) ] num_tasks = len(map_funcs) x_domains = [[-4, -2], [-1, 1], [2, 4]] std = .1 test_domains = [[-4.1, 4.1], [-4.1, 4.1], [-4.1, 4.1]] num_tasks = len(map_funcs) num_train = 20 elif task_set == 3: # Same as task set 2, but less aleatoric uncertainty. map_funcs = [ lambda x: np.power(x + 3., 3), lambda x: 2. * np.power(x, 2) - 1, lambda x: -np.power(x - 3., 3) ] num_tasks = len(map_funcs) x_domains = [[-4, -2], [-1, 1], [2, 4]] std = .05 test_domains = [[-4.1, 4.1], [-4.1, 4.1], [-4.1, 4.1]] num_tasks = len(map_funcs) num_train = 20 else: raise NotImplementedError('Set of tasks "%d" unknown!' % task_set) dhandlers = [] for i in range(num_tasks): print('Generating %d-th task.' % (i)) #test_inter = [x_domains[i][0] - test_offset, # x_domains[i][1] + test_offset] dhandlers.append( ToyRegression(train_inter=x_domains[i], num_train=num_train, test_inter=test_domains[i], num_test=50, val_inter=x_domains[i], num_val=50, map_function=map_funcs[i], std=std, rseed=data_random_seed)) if writer is not None: dhandlers[-1].plot_dataset(show=False) writer.add_figure('task_%d/dataset' % i, plt.gcf(), close=not show_plots) if show_plots: utils.repair_canvas_and_show_fig(plt.gcf()) elif show_plots: dhandlers[-1].plot_dataset() return dhandlers, num_tasks
def plot_mse(config, writer, num_tasks, current_mse, during_mse=None, baselines=None, save_fig=True, summary_label='test/mse'): """Produce a scatter plot that shows the current and immediate mse values (and maybe a set of baselines) of each task. This visualization helps to understand the impact of forgetting. Args: config: Command-line arguments. writer: Tensorboard summary writer. num_tasks: Number of tasks. current_mse: Array of MSE values currently achieved. during_mse (optional): Array of MSE values achieved right after training on the corresponding task. baselines (optional): A dictionary of label names mapping onto arrays. Can be used to plot additional baselines. save_fig: Whether the figure should be saved in the output folder. summary_label: Label used for the figure when writing it to tensorboard. """ x_vals = np.arange(1, num_tasks + 1) num_plots = 1 if during_mse is not None: num_plots += 1 if baselines is not None: num_plots += len(baselines.keys()) colors = utils.get_colorbrewer2_colors(family='Dark2') if num_plots > len(colors): warn('Changing to automatic color scheme as we don\'t have ' + 'as many manual colors as tasks.') colors = cm.rainbow(np.linspace(0, 1, num_plots)) fig, axes = plt.subplots(figsize=(10, 6)) plt.title('Current MSE on each task') plt.scatter(x_vals, current_mse, color=colors[0], label='Current Val MSE') if during_mse is not None: plt.scatter(x_vals, during_mse, label='During MSE', color=colors[1], marker='*') if baselines is not None: for i, (label, vals) in enumerate(baselines.items()): plt.scatter(x_vals, vals, label=label, color=colors[2 + i], marker='x') plt.ylabel('MSE') plt.xlabel('Task') plt.xticks(x_vals) plt.legend() if save_fig: plt.savefig(os.path.join(config.out_dir, 'mse_%d' % num_tasks), bbox_inches='tight') writer.add_figure(summary_label, plt.gcf(), num_tasks, close=not config.show_plots) if config.show_plots: utils.repair_canvas_and_show_fig(plt.gcf())