def run(): """Run the script. Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_bbb' config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Generate networks. use_hnet = not config.mnet_only mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers # Mean and variance of prior that is used for variational inference. if config.mean_only: # No prior-matching can be performed. shared.prior_mean = None shared.prior_logvar = None shared.prior_std = None else: plogvar = np.log(config.prior_variance) pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some main networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint mnets as well! shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, 'bbb', num_tasks, mnet, hnet=hnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={ **vars(config), **{ 'num_weights_main': shared.summary['aa_num_weights_main'], 'num_weights_hyper': shared.summary['aa_num_weights_hyper'], 'num_weights_ratio': shared.summary['aa_num_weights_ratio'], } }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. train(i, data, mnet, hnet, device, config, shared, logger, writer) ### Test networks. test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. pmutils.checkpoint_nets(config, shared, i, mnet, hnet) mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse
def test(data_handlers, mnet, hnet, device, config, shared, logger, writer, hhnet=None, save_fig=True): """Test the performance of all tasks. Tasks are assumed to be regression tasks. Args: (....): See docstring of method :func:`probabilistic.train_vi.train`. data_handlers: A list of data handlers, each representing a task. save_fig: Whether the figures should be saved in the output folder. """ logger.info('### Testing all trained tasks ... ###') if hasattr(config, 'mean_only') and config.mean_only: warn('Task inference calculated in test method doesn\'t make any ' + 'sense, as the deterministic main network has no notion of ' + 'uncertainty.') pcutils.set_train_mode(False, mnet, hnet, hhnet, None) n = len(data_handlers) disable_lrt_test = config.disable_lrt_test if \ hasattr(config, 'disable_lrt_test') else None if not hasattr(shared, 'current_mse'): shared.current_mse = np.ones(n) * -1. elif shared.current_mse.size < n: tmp = shared.current_mse shared.current_mse = np.ones(n) * -1. shared.current_mse[:tmp.size] = tmp # Current MSE value on test set. test_mse = np.ones(n) * -1. # Current MSE value using the mean prediction of the inferred embedding. inferred_val_mse = np.ones(n) * -1. # Task inference accuracies. task_infer_val_accs = np.ones(n) * -1. with torch.no_grad(): # We need to keep data for plotting results on all tasks later on. val_inputs = [] val_targets = [] # Needed to compute MSE values. val_preds_mean = [] val_preds_std = [] # Which uncertainties have been measured per sample and task. The argmax # over all tasks gives the predicted task. val_task_preds = [] test_inputs = [] test_preds_mean = [] test_preds_std = [] normal_post = None if 'ewc' in shared.experiment_type: assert hnet is None normal_post = ewcutil.build_ewc_posterior(data_handlers, mnet, device, config, shared, logger, writer, n, task_id=n - 1) if config.train_from_scratch and n > 1: # We need to iterate over different networks when we want to # measure the uncertainty of dataset i on task j. # Note, we will always load the corresponding checkpoint of task j # before using these networks. if 'avb' in shared.experiment_type \ or 'ssge' in shared.experiment_type\ or 'ewc' in shared.experiment_type: mnet_other, hnet_other, hhnet_other, _ = \ pcutils.generate_networks(config, shared, logger, shared.all_dhandlers, device, create_dis=False) else: assert hhnet is None hhnet_other = None non_gaussian = config.mean_only \ if hasattr(config, 'mean_only') else True mnet_other, hnet_other = train_utils.generate_gauss_networks( \ config, logger, shared.all_dhandlers, device, create_hnet=hnet is not None, non_gaussian=non_gaussian) pcutils.set_train_mode(False, mnet_other, hnet_other, hhnet_other, None) task_n_mnet = mnet task_n_hnet = hnet task_n_hhnet = hhnet task_n_normal_post = normal_post # This renaming is just a protection against myself, that I don't use # any of those networks (`mnet`, `hnet`, `hhnet`) in the future # inside the loop when training from scratch. if config.train_from_scratch: mnet = None hnet = None hhnet = None normal_post = None ### For each data set (i.e., for each task). for i in range(n): data = data_handlers[i] ### We want to measure MSE values within the training range only! split_type = 'val' num_val_samples = data.num_val_samples if num_val_samples == 0: split_type = 'train' num_val_samples = data.num_train_samples logger.debug('Test: Task %d - Using training set as no ' % i + 'validation set is available.') ### Task inference. # We need to iterate over each task embedding and measure the # predictive uncertainty in order to decide which embedding to use. data_preds = np.empty((num_val_samples, config.val_sample_size, n)) data_preds_mean = np.empty((num_val_samples, n)) data_preds_std = np.empty((num_val_samples, n)) for j in range(n): ckpt_score_j = None if config.train_from_scratch and j == (n - 1): # Note, the networks trained on dataset (n-1) haven't been # checkpointed yet. mnet_j = task_n_mnet hnet_j = task_n_hnet hhnet_j = task_n_hhnet normal_post_j = task_n_normal_post elif config.train_from_scratch: ckpt_score_j = pmutils.load_networks(shared, j, device, logger, mnet_other, hnet_other, hhnet=hhnet_other, dis=None) mnet_j = mnet_other hnet_j = hnet_other hhnet_j = hhnet_other if 'ewc' in shared.experiment_type: normal_post_j = ewcutil.build_ewc_posterior( \ data_handlers, mnet_j, device, config, shared, logger, writer, n, task_id=j) else: mnet_j = mnet hnet_j = hnet hhnet_j = hhnet normal_post_j = normal_post mse_val, val_struct = train_utils.compute_mse( j, data, mnet_j, hnet_j, device, config, shared, hhnet=hhnet_j, split_type=split_type, return_dataset=i == j, return_predictions=True, disable_lrt=disable_lrt_test, normal_post=normal_post_j) if i == j: # I.e., we used the correct embedding. # This sanity check is likely to fail as we don't # deterministically sample the models. #if ckpt_score_j is not None: # assert np.allclose(-mse_val, ckpt_score_j) val_inputs.append(val_struct.inputs) val_targets.append(val_struct.targets) val_preds_mean.append(val_struct.predictions.mean(axis=1)) val_preds_std.append(val_struct.predictions.std(axis=1)) shared.current_mse[i] = mse_val logger.debug('Test: Task %d - Mean MSE on %s set: %f ' % (i, split_type, mse_val) + '(std: %g).' % (val_struct.mse_vals.std())) writer.add_scalar('test/task_%d/val_mse' % i, shared.current_mse[i], n) # The test set spans into the OOD range and can be used to # visualize how uncertainty behaves outside the # in-distribution range. mse_test, test_struct = train_utils.compute_mse( i, data, mnet_j, hnet_j, device, config, shared, hhnet=hhnet_j, split_type='test', return_dataset=True, return_predictions=True, disable_lrt=disable_lrt_test, normal_post=normal_post_j) data_preds[:, :, j] = val_struct.predictions data_preds_mean[:, j] = val_struct.predictions.mean(axis=1) ### We interpret this value as the certainty of the prediction. # I.e., how certain is our system that each of the samples # belong to task j? data_preds_std[:, j] = val_struct.predictions.std(axis=1) val_task_preds.append(data_preds_std) ### Compute task inference accuracy. inferred_task_ids = data_preds_std.argmin(axis=1) num_correct = np.sum(inferred_task_ids == i) accuracy = 100. * num_correct / num_val_samples task_infer_val_accs[i] = accuracy logger.debug('Test: Task %d - Accuracy of task inference ' % i + 'on %s set: %.2f%%.' % (split_type, accuracy)) writer.add_scalar('test/task_%d/accuracy' % i, accuracy, n) ### Compute MSE based on inferred embedding. # Note, this (commented) way of computing the mean does not take # into account the variance of the predictive distribution, which is # why we don't use it (see docstring of `compute_mse`). #means_of_inferred_preds = data_preds_mean[np.arange( \ # data_preds_mean.shape[0]), inferred_task_ids] #inferred_val_mse[i] = np.power(means_of_inferred_preds - # val_targets[-1].squeeze(), 2).mean() inferred_preds = data_preds[np.arange(data_preds.shape[0]), :, inferred_task_ids] inferred_val_mse[i] = np.power(inferred_preds - \ val_targets[-1].squeeze()[:, np.newaxis], 2).mean() logger.debug('Test: Task %d - Mean MSE on %s set using inferred '\ % (i, split_type) + 'embeddings: %f.' % (inferred_val_mse[i])) writer.add_scalar('test/task_%d/inferred_val_mse' % i, inferred_val_mse[i], n) ### We are interested in the predictive uncertainty across the ### whole test range! test_mse[i] = mse_test writer.add_scalar('test/task_%d/test_mse' % i, test_mse[i], n) test_inputs.append(test_struct.inputs.squeeze()) test_preds_mean.append(test_struct.predictions.mean(axis=1). \ squeeze()) test_preds_std.append( test_struct.predictions.std(axis=1).squeeze()) if hasattr(shared, 'during_mse') and \ shared.during_mse[i] == -1: shared.during_mse[i] = shared.current_mse[i] if test_struct.w_hnet is not None or test_struct.w_mean is not None: assert hasattr(shared, 'during_weights') if test_struct.w_hnet is not None: # We have a hyper-hypernetwork. In this case, the CL # regularizer is applied to its output and therefore, these # are the during weights whose Euclidean distance we want to # track. assert task_n_hhnet is not None w_all = test_struct.w_hnet else: assert test_struct.w_mean is not None # We will be here whenever the hnet is deterministic (i.e., # doesn't represent an implicit distribution). w_all = list(test_struct.w_mean) if test_struct.w_std is not None: w_all += list(test_struct.w_std) W_curr = torch.cat([d.clone().view(-1) for d in w_all]) if type(shared.during_weights[i]) == int: assert (shared.during_weights[i] == -1) shared.during_weights[i] = W_curr else: W_during = shared.during_weights[i] W_dis = torch.norm(W_curr - W_during, 2) logger.info('Euclidean distance between hypernet output ' + 'for task %d: %g' % (i, W_dis)) ### Compute overall task inference accuracy. num_correct = 0 num_samples = 0 for i, uncertainties in enumerate(val_task_preds): pred_task_ids = uncertainties.argmin(axis=1) num_correct += np.sum(pred_task_ids == i) num_samples += pred_task_ids.size accuracy = 100. * num_correct / num_samples logger.info('Task inference accuracy: %.2f%%.' % accuracy) # TODO Compute overall MSE on all tasks using inferred embeddings. ### Plot the mean predictions on all tasks. # (Using the validation set and the correct embedding per dataset) plot_x_ranges = [] for i in range(n): plot_x_ranges.append(data_handlers[i].train_x_range) fig_fn = None if save_fig: fig_fn = os.path.join(config.out_dir, 'val_predictions_%d' % n) data_inputs = val_inputs mean_preds = val_preds_mean data_handlers[0].plot_datasets(data_handlers, data_inputs, mean_preds, fun_xranges=plot_x_ranges, filename=fig_fn, show=False, publication_style=config.publication_style) writer.add_figure('test/val_predictions', plt.gcf(), n, close=not config.show_plots) if config.show_plots: utils.repair_canvas_and_show_fig(plt.gcf()) ### Scatter plot showing MSE per task (original + current one). during_mse = None if hasattr(shared, 'during_mse'): during_mse = shared.during_mse[:n] train_utils.plot_mse(config, writer, n, shared.current_mse[:n], during_mse, save_fig=save_fig) additional_plots = { 'Current Inferred Val MSE': inferred_val_mse, #'Current Test MSE': test_mse } train_utils.plot_mse(config, writer, n, shared.current_mse[:n], during_mse, baselines=additional_plots, save_fig=False, summary_label='test/mse_detailed') ### Plot predictive distributions over test range for all tasks. data_inputs = test_inputs mean_preds = test_preds_mean std_preds = test_preds_std train_utils.plot_predictive_distributions( config, writer, data_handlers, data_inputs, mean_preds, std_preds, save_fig=save_fig, publication_style=config.publication_style) logger.info('Mean task MSE: %f (std: %d)' % (shared.current_mse[:n].mean(), shared.current_mse[:n].std())) ### Update performance summary. s = shared.summary s['aa_mse_during'][:n] = shared.during_mse[:n].tolist() s['aa_mse_during_mean'] = shared.during_mse[:n].mean() s['aa_mse_final'][:n] = shared.current_mse[:n].tolist() s['aa_mse_final_mean'] = shared.current_mse[:n].mean() s['aa_task_inference'][:n] = task_infer_val_accs.tolist() s['aa_task_inference_mean'] = task_infer_val_accs.mean() s['aa_mse_during_inferred'][n - 1] = inferred_val_mse[n - 1] s['aa_mse_during_inferred_mean'] = np.mean(s['aa_mse_during_inferred'][:n]) s['aa_mse_final_inferred'] = inferred_val_mse[:n].tolist() s['aa_mse_final_inferred_mean'] = inferred_val_mse[:n].mean() train_utils.save_summary_dict(config, shared) logger.info('### Testing all trained tasks ... Done ###')
def run(method='avb'): """Run the script. Args: method (str, optional): The VI algorithm. Possible values are: - ``'avb'`` - ``'ssge'`` Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_' + method use_dis = False # whether a discriminator network is used if method == 'avb': use_dis = True config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) if config.prior_focused: logger.info('Running a prior-focused CL experiment ...') else: logger.info('Learning task-specific posteriors sequentially ...') ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers shared.prior_focused = config.prior_focused ### Generate networks and environment mnet, hnet, hhnet, dnet = pcu.generate_networks(config, shared, logger, dhandlers, device, create_dis=use_dis) # Mean and variance of prior that is used for variational inference. # For a prior-focused training, this prior will only be used for the # first task. pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint all networks, even if they # where constructed without weights. shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d') #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, method, num_tasks, mnet, hnet=hnet, hhnet=hhnet, dis=dnet) logger.info('Ratio num hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hm_ratio']) if hhnet is not None: logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hhm_ratio']) if mode == 'regression_avb' and dnet is not None: # A discriminator only exists for AVB. logger.info('Ratio num dis weights / num mnet weights: %f.' % shared.summary['aa_num_weights_dm_ratio']) # Add hparams to tensorboard, such that the identification of runs is # easier. hparams_extra_dict = { 'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'], 'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio'] } if mode == 'regression_avb': hparams_extra_dict['num_weights_dm_ratio'] = \ shared.summary['aa_num_weights_dm_ratio'] writer.add_hparams(hparam_dict={ **vars(config), **hparams_extra_dict }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. tvi.train(i, data, mnet, hnet, hhnet, dnet, device, config, shared, logger, writer, method=method) # Test networks. train_bbb.test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer, hhnet=hhnet) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. # Note, we only need the discriminator as helper for training, # so we don't checkpoint it. pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) mnet, hnet, hhnet, dnet = pcu.generate_networks( config, shared, logger, dhandlers, device) elif config.store_during_models: logger.info('Checkpointing current model ...') pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) if config.store_final_model: logger.info('Checkpointing final model ...') pmu.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet, hhnet=hhnet, dis=None) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse
def run(config, experiment='regression_ewc'): """Run the EWC training for the given experiment. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - ``'regression_ewc'``: Regression tasks with EWC - ``'gmm_ewc'``: GMM Data with EWC - ``'split_mnist_ewc'``: SplitMNIST with EWC - ``'perm_mnist_ewc'``: PermutedMNIST with EWC - ``'cifar_resnet_ewc'``: CIFAR-10/100 with EWC """ assert experiment in ['regression_ewc', 'gmm_ewc', 'split_mnist_ewc', 'perm_mnist_ewc', 'cifar_resnet_ewc'] script_start = time() device, writer, logger = sutils.setup_environment(config, logger_name=experiment) rutils.backup_cli_command(config) is_classification = True if 'regression' in experiment: is_classification = False ### Create tasks. if is_classification: dhandlers = pmutils.load_datasets(config, logger, experiment, writer) else: dhandlers, num_tasks = rutils.generate_tasks(config, writer) config.num_tasks = num_tasks ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.num_trained = 0 ### Generate network. mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if not is_classification: shared.during_mse = np.ones(config.num_tasks) * -1. # MSE achieved after most recent call of test method. shared.current_mse = np.ones(config.num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). shared.during_weights = [-1] * config.num_tasks # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. if is_classification: shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] ### Initialize summary. if is_classification: pcutils.setup_summary_dict(config, shared, experiment, mnet) else: rutils.setup_summary_dict(config, shared, 'ewc', config.num_tasks, mnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={**vars(config), **{ 'num_weights_main': shared.summary['num_weights_main'] \ if is_classification else shared.summary['aa_num_weights_main'] }}, metric_dict={}) if is_classification: during_acc_criterion = pmutils.parse_performance_criterion(config, shared, logger) ### Train on tasks sequentially. for i in range(config.num_tasks): logger.info('### Training on task %d ###' % (i+1)) data = dhandlers[i] # Train the network. shared.num_trained += 1 train(i, data, mnet, device, config, shared, logger, writer) ### Test networks. test_ids = None if hasattr(config, 'full_test_interval') and \ config.full_test_interval != -1: if i == config.num_tasks-1 or \ (i > 0 and i % config.full_test_interval == 0): test_ids = None # Test on all tasks. else: test_ids = [i] # Only test on current task. test(dhandlers[:(i+1)], mnet, device, config, shared, logger, writer, test_ids=test_ids) ### Check if last task got "acceptable" accuracy ### if is_classification: curr_dur_acc = shared.summary['acc_task_given_during'][i] if is_classification and i < config.num_tasks-1 \ and during_acc_criterion[i] != -1 \ and during_acc_criterion[i] > curr_dur_acc: logger.error('During accuracy of task %d too small (%f < %f).' % \ (i+1, curr_dur_acc, during_acc_criterion[i])) logger.error('Training of future tasks will be skipped') writer.close() exit(1) if config.train_from_scratch and i < config.num_tasks-1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. pmutils.checkpoint_nets(config, shared, i, mnet, None) mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks-1, mnet, None) ### Plot final classification scores. if is_classification: logger.info('During accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given_during']), precision=2, separator=','), shared.summary['acc_avg_task_given_during'])) logger.info('Final accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) if is_classification and config.cl_scenario == 3 and config.split_head_cl3: logger.info('During accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array( \ shared.summary['acc_task_inferred_ent_during']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent_during'])) logger.info('Final accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_inferred_ent']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent'])) ### Plot final regression scores. if not is_classification: logger.info('During MSE values after training each task: %s' % \ np.array2string(np.array(shared.summary['aa_mse_during']), precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(np.array(shared.summary['aa_mse_final']), precision=5, separator=',')) logger.info('Final MSE mean %.4f.' % \ (shared.summary['aa_mse_during_mean'])) ### Write final summary. shared.summary['finished'] = 1 if is_classification: pmutils.save_summary_dict(config, shared, experiment) else: rutils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))
def run(config, experiment='regression_mt'): """Run the Multitask training for the given experiment. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - ``'regression_mt'``: Regression tasks with multitask training - ``'gmm_mt'``: GMM Data with multitask training - ``'split_mnist_mt'``: SplitMNIST with multitask training - ``'perm_mnist_mt'``: PermutedMNIST with multitask training - ``'cifar_resnet_mt'``: CIFAR-10/100 with multitask training """ assert experiment in [ 'regression_mt', 'gmm_mt', 'split_mnist_mt', 'perm_mnist_mt', 'cifar_resnet_mt' ] script_start = time() device, writer, logger = sutils.setup_environment(config, logger_name=experiment) rutils.backup_cli_command(config) is_classification = True if 'regression' in experiment: is_classification = False ### Create tasks. if is_classification: dhandlers = pmutils.load_datasets(config, logger, experiment, writer) else: dhandlers, num_tasks = rutils.generate_tasks(config, writer) config.num_tasks = num_tasks ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.num_trained = 0 ### Generate network. mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if not is_classification: shared.during_mse = np.ones(config.num_tasks) * -1. # MSE achieved after most recent call of test method. shared.current_mse = np.ones(config.num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. if is_classification: shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] ### Initialize summary. if is_classification: pcutils.setup_summary_dict(config, shared, experiment, mnet) else: rutils.setup_summary_dict(config, shared, 'mt', config.num_tasks, mnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={**vars(config), **{ 'num_weights_main': shared.summary['num_weights_main'] \ if is_classification else shared.summary['aa_num_weights_main'] }}, metric_dict={}) ### Train on all tasks. logger.info('### Training ###') # Note, since we are training on all tasks; all output heads can at all # times be considered as trained! shared.num_trained = config.num_tasks train(dhandlers, mnet, device, config, shared, logger, writer) logger.info('### Testing ###') test(dhandlers, mnet, device, config, shared, logger, writer, test_ids=None) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet, None) ### Plot final classification scores. if is_classification: logger.info('Final accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) ### Plot final regression scores. if not is_classification: logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(np.array(shared.summary['aa_mse_final']), precision=5, separator=',')) logger.info('Final MSE mean %.4f.' % \ (shared.summary['aa_mse_during_mean'])) ### Write final summary. shared.summary['finished'] = 1 if is_classification: pmutils.save_summary_dict(config, shared, experiment) else: rutils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))