def run(): """ Run the script""" ############# ### Setup ### ############# config = train_args_seq_smnist.parse_cmd_arguments() device, writer, logger = sutils.setup_environment(config) dhandlers = ctu._generate_tasks(config, logger) # We will use the namespace below to share miscellaneous information between # functions. shared = Namespace() shared.feature_size = dhandlers[0].in_shape[0] # Plot images. if config.show_plots: figure_dir = os.path.join(config.out_dir, 'figures') if not os.path.exists(figure_dir): os.makedirs(figure_dir) for t, dh in enumerate(dhandlers): dh.plot_samples('Test Samples - Task %d' % t, dh.get_train_inputs()[:8], outputs=dh.get_train_outputs()[:8], show=True, filename=os.path.join(figure_dir, 'test_samples_task_%d.png' % t)) target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers, device) # generate masks if needed ctx_masks = None if config.use_masks: ctx_masks = stu.generate_binary_masks(config, device, target_net) # We store the target network weights (excluding potential context-mod # weights after every task). In this way, we can quantify changes and # observe the "stiffness" of EWC. shared.tnet_weights = [] # We store the context-mod weights (or all weights) coming from the hypernet # after every task, in order to quantify "forgetting". Note, the hnet # regularizer should keep them fix. shared.hnet_out = [] # Get the task-specific functions for loss and accuracy. task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False) accuracy_func = ctu.get_accuracy_func(config) ewc_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=True) \ if config.use_ewc else None replay_fcts = None if config.use_replay: replay_fcts = dict() replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func() replay_fcts['distill_loss'] = ctu.get_distill_loss_func() replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func() if config.multitask: summary_keywords=hpsearch_mt._SUMMARY_KEYWORDS summary_filename=hpsearch_mt._SUMMARY_FILENAME else: summary_keywords=hpsearch_cl._SUMMARY_KEYWORDS summary_filename=hpsearch_cl._SUMMARY_FILENAME ######################## ### Train classifier ### ######################## # Train the network task by task. Testing on all tasks is run after # finishing training on each task. ret, train_loss, test_loss, test_acc = sts.train_tasks(dhandlers, target_net, hnet, dnet, device, config, shared, logger, writer, ctx_masks, summary_keywords, summary_filename, task_loss_func=task_loss_func, accuracy_func=accuracy_func, ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts) stu.log_results(test_acc, config, logger) writer.close() if ret == -1: logger.info('Program finished successfully.') if config.show_plots: plt.show() else: logger.error('Only %d tasks have completed training.' % (ret+1))
def run(config, *args, dataset='copy'): """Run the script. Args: config: The default config for the current dataset. dataset (str, optional): The dataset being analysed. """ fig_size = [2.8, 2.55] configure_matplotlib_params(fig_size=fig_size) ### Parse the command-line arguments. parser = argparse.ArgumentParser(description= \ 'Studying the dimensionality of the hidden space.') parser.add_argument('out_dir', type=str, default='./out/hyperparam_search', help='The output directory of the runs. ' + 'Default: %(default)s.') parser.add_argument('--redo_analyses', action='store_true', help='If enabled, all analyses will be done even ' + 'if some previous results had been stored and ' + 'could have been loaded.') parser.add_argument('--do_kernel_pca', action='store_true', help='If enabled, kernel PCA will also be used to ' + 'compute the number of hidden dimensions.') parser.add_argument('--do_supervised_dimred', action='store_true', help='If enabled, supervised dimensionality reduction '+ 'will be performed to compute the number of '+ 'task-relevant hidden dimensions.') parser.add_argument('--p_var', type=int, default=0.75, help='The amount of variance that needs to be ' + 'explained to determine the number of ' + 'dimensions. Default: %(default)s.') parser.add_argument('--n_pcs', type=int, default=10, help='The number of principal components to be taken '+ 'into account for the plots and analyses. ' + 'Default: %(default)s.') parser.add_argument('--n_samples', type=int, default=1000, help='The number of samples to be used for the PCA ' + 'analyses in the different settings. This ' + 'ensures that comparisons when pulling data ' + 'from different tasks are made with an equal ' + 'amount of data. Default: %(default)s.') parser.add_argument('--timesteps_for_analysis', type=str, default='output', choices=['all', 'input', 'output', 'stop', 'last', \ 'stop_plus_one'], help='The timesteps to be used for the PCA analyses. '+ 'Options are: all timesteps, only during input ' + 'presentation, only during output presentation, '+ 'only upon stop flag, only one timestep after ' + 'the flag or only the very last output timestep.'+ ' Default: %(default)s.') parser.add_argument('--num_tasks', type=int, default=-1, help='Number of tasks to consider. Default: ' + '%(default)s.') parser.add_argument('--sup_dimred_criterion', type=int, default=-1, help='Accuracy to be obtained when projecting the ' + 'hidden activity into a lower-dimensional ' + 'subspace in a supervised fasion. For a value '+ 'of -1, no criterion is used and so all ' + 'possible number of dimensions are explored. '+ '%(default)s.') parser.add_argument('--sdr_orth_strength', type=float, default=1e3, help='The strength of the orthogonal regularizer ' + 'when doing supervised dimensionality reduction ' + '(option "do_supervised_dimred"). ' + 'Default: %(default)s.') parser.add_argument('--sdr_lr', type=float, default=1e-2, help='The learning rate when doing supervised ' + 'dimensionality reduction ' + '(option "do_supervised_dimred"). ' + 'Default: %(default)s.') parser.add_argument('--sdr_batch_size', type=int, default=64, help='The batch size when doing supervised ' + 'dimensionality reduction ' + '(option "do_supervised_dimred"). ' + 'Default: %(default)s.') parser.add_argument('--sdr_n_iter', type=int, default=100, help='The number of training iterations per ' + 'projection column when doing supervised ' + 'dimensionality reduction ' + '(option "do_supervised_dimred"). ' + 'Default: %(default)s.') cmd_args = parser.parse_args() if (dataset == 'audioset' and cmd_args.timesteps_for_analysis not in \ ['all', 'last']) or (dataset == 'seq_smnist' and \ cmd_args.timesteps_for_analysis != 'all') \ or (dataset == 'student_teacher' \ and cmd_args.timesteps_for_analysis != 'all'): warnings.warn('A subset of timesteps for the analysis can only be ' + 'selected for the Copy Task currently. Using all timesteps.') # Note that for Sequential SMNIST, the last timestep needs to be # correctly selected, which is not currently implemented. setattr(cmd_args, 'timesteps_for_analysis', 'all') copy_task = False if dataset == 'copy': copy_task = True # Define directory where to store the results of all current analyses. results_dir = os.path.join(cmd_args.out_dir, 'pca_analyses') if not os.path.exists(results_dir): os.makedirs(results_dir) # Overwrite the output directory. config.out_dir = os.path.join(results_dir, 'sim_files') ### Set up environment using some general command line arguments. device, writer, logger = sutils.setup_environment(config) ### Check if the current directory corresponds to a single run or not. # The name of folders has recently been changed. Allow old and new namings. out_dirs = glob.glob(os.path.join(cmd_args.out_dir, '20*')) out_dirs.extend(glob.glob(os.path.join(cmd_args.out_dir, 'sim*'))) if len(out_dirs)==0: out_dirs = [cmd_args.out_dir] ### Store the results of the different runs in a same dictionary. results = {} for i, out_dir in enumerate(out_dirs): sdr_args = { 'lambda_ortho': cmd_args.sdr_orth_strength, 'lr': cmd_args.sdr_lr, 'n_iter': cmd_args.sdr_n_iter, 'batch_size': cmd_args.sdr_batch_size } results[out_dir], settings = analyse_single_run(out_dir, device, writer, logger, *args, redo_analyses=cmd_args.redo_analyses, n_samples=cmd_args.n_samples, do_kernel_pca=cmd_args.do_kernel_pca, timesteps_for_analysis=cmd_args.timesteps_for_analysis, copy_task=copy_task, num_tasks=cmd_args.num_tasks, do_supervised_dimred=cmd_args.do_supervised_dimred, sup_dimred_criterion=cmd_args.sup_dimred_criterion, sup_dimred_args=sdr_args) # Ensure all runs have comparable properties if i == 0: common_settings = settings.copy() for key in settings.keys(): assert settings[key] == common_settings[key] num_tasks = common_settings['num_tasks'] if cmd_args.num_tasks != -1: num_tasks = cmd_args.num_tasks assert num_tasks <= common_settings['num_tasks'] ### Check if there are identical runs with different random seeds. seed_groups, num_seeds, num_runs = group_runs(results) print('\nThe analysis was done with %i seeds '%num_seeds + 'for each of the %i different runs.'%num_runs) ### Pickle the results. with open(os.path.join(results_dir, 'results.pickle'), 'wb') as handle: pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(results_dir, 'seed_groups.pickle'), 'wb') as handle: pickle.dump(seed_groups, handle, protocol=pickle.HIGHEST_PROTOCOL) ### Plot. for task_id in range(num_tasks): if num_runs > 1: if copy_task: ssp.plot_per_ts(results, seed_groups, task_id=task_id, path=results_dir, key_name='accuracy') ssp.plot_per_ts(results, seed_groups, task_id=task_id, path=results_dir, key_name='dimension') ssp.plot_per_ts(results, seed_groups, task_id=task_id, path=results_dir, key_name='dimension', internal=False) if cmd_args.do_kernel_pca: ssp.plot_per_ts(results, seed_groups, key_name='dimension', task_id=task_id, path=results_dir, kernel=True) ssp.plot_per_ts(results, seed_groups, key_name='dimension', task_id=task_id, path=results_dir, kernel=True, internal=False) if copy_task: if settings['permute_time'] and not \ settings['permute_width'] and not settings['permute_xor']: ssp.plot_accuracy_vs_bptt_steps(results, seed_groups, task_id=task_id, path=results_dir) # Make multi-task plots. if num_tasks > 1: ssp.plot_importance_vs_task(results, seed_groups, path=results_dir) ssp.plot_dimension_vs_task(results, seed_groups, path=results_dir) ssp.plot_dimension_vs_task(results, seed_groups, path=results_dir, internal=False) ssp.plot_dimension_vs_task(results, seed_groups, path=results_dir, onto_task_1=True) ssp.print_dimension_vs_task(results, seed_groups, p_var=cmd_args.p_var) if num_runs == 1: ssp.plot_per_ts(results, seed_groups, path=results_dir, key_name='dimension') ssp.plot_per_ts(results, seed_groups, path=results_dir, key_name='dimension', internal=False) if cmd_args.do_kernel_pca: ssp.plot_per_ts(results, seed_groups, key_name='dimension', path=results_dir, kernel=True) if copy_task: ssp.plot_per_ts(results, seed_groups, path=results_dir, key_name='accuracy') if cmd_args.do_supervised_dimred: ssp.plot_supervised_dimension_vs_task(results, seed_groups, path=results_dir) ssp.plot_supervised_dimension_vs_task(results, seed_groups, path=results_dir, key='accu') ssp.plot_supervised_dimension_vs_task(results, seed_groups, path=results_dir, key='accu', stop_bit=True) # Make multi-run plots. if num_runs > 1: ssp.plot_across_runs(results, seed_groups, path=results_dir, do_kernel_pca=cmd_args.do_kernel_pca, n_pcs=cmd_args.n_pcs, p_var=cmd_args.p_var)
def run(config, experiment='resnet'): """Run the training. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - ``resnet``: CIFAR-10/100 with Resnet-32. - ``zenke``: CIFAR-10/100 with Zenkenet. """ assert (experiment in ['resnet', 'zenke']) script_start = time() device, writer, logger = sutils.setup_environment( config, logger_name='det_cl_cifar_%s' % experiment) # TODO Adapt script to allow checkpointing of models using # `utils.torch_ckpts` (i.e., we should be able to continue training or just # test an existing checkpoint). #config.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Container for variables shared across function. shared = Namespace() shared.experiment = experiment ### Load datasets (i.e., create tasks). dhandlers = tutils.load_datasets(config, shared, logger, data_dir='../datasets') print('handlers traninig: ', dhandlers[0].num_train_samples) # data,taskcla,inputsize=mixceleba.get(seed=config.seed,args=config) data, taskcla, inputsize = mixemnist.get(seed=config.seed, args=config) print('Input size =', inputsize, '\nTask info =', taskcla) ### Create main network. # TODO Allow main net only training. mnet = tutils.get_main_model(config, shared, logger, device, no_weights=not config.mnet_only) ### Create the hypernetwork. if config.mnet_only: hnet = None else: hnet = tutils.get_hnet_model(config, mnet, logger, device) ### Initialize the performance measures, that should be tracked during ### training. tutils.setup_summary_dict(config, shared, mnet, hnet=hnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={ **vars(config), **{ 'num_weights_main': shared.summary['num_weights_main'], 'num_weights_hyper': shared.summary['num_weights_hyper'], 'num_weights_ratio': shared.summary['num_weights_ratio'], } }, metric_dict={}) # FIXME: Method "calc_fix_target_reg" expects a None value. # But `writer.add_hparams` can't deal with `None` values. if config.cl_reg_batch_size == -1: config.cl_reg_batch_size = None # We keep the hnet output right after training to measure forgetting. weights_after_training = [] ###################### ### Start Training ### ###################### for j in range(config.num_tasks): logger.info('Starting training of task %d ...' % (j + 1)) data = dhandlers[j] # It might be that tasks are very similar and we can transfer knowledge # form the previous solution. if hnet is not None and config.init_with_prev_emb and j > 0: last_emb = hnet.get_task_emb(j - 1).detach().clone() hnet.get_task_emb(j).data = last_emb # Training from scratch -- create new network instance! # -> No transfer possible. if j > 0 and config.train_from_scratch: # FIXME Since we simply override the current network, future testing # on this new network for old tasks doesn't make sense. So we # shouldn't report `final` accuracies. if config.mnet_only: logger.info( 'From scratch training: Creating new main network.') mnet = tutils.get_main_model(config, shared, logger, device, no_weights=not config.mnet_only) else: logger.info( 'From scratch training: Creating new hypernetwork.') hnet = tutils.get_hnet_model(config, mnet, logger, device) ################################ ### Train and test on task j ### ################################ train(j, data, mnet, hnet, device, config, shared, writer, logger) ### Final test run. if hnet is not None: weights = hnet.forward(j) # Push to CPU to avoid growing GPU memory when solving very long # task sequences. weights = [w.detach().clone().cpu() for w in weights] weights_after_training.append(weights) test_acc, _ = test(j, data, mnet, hnet, device, shared, config, writer, logger) logger.info('### Accuracy of task %d / %d: %.3f' % \ (j+1, config.num_tasks, test_acc)) logger.info('### Finished training task: %d' % (j + 1)) shared.summary['acc_during'][j] = test_acc # Backup results so far. tutils.save_summary_dict(config, shared, experiment) shared.summary['acc_avg_during'] = np.mean(shared.summary['acc_during']) logger.info('### Accuracy of individual tasks after training %s' % \ (str(shared.summary['acc_during']))) logger.info('### Average of these accuracies %.2f' % \ (shared.summary['acc_avg_during'])) writer.add_scalar('final/during_acc_avg', shared.summary['acc_avg_during']) ######################################### ### Test continual learning scenarios ### ######################################### test_multiple(dhandlers, mnet, hnet, device, config, shared, writer, logger) ######################### ### Run some analysis ### ######################### if not config.mnet_only: analysis(dhandlers, mnet, hnet, device, config, shared, writer, logger, weights_after_training) ### Write final summary. shared.summary['finished'] = 1 tutils.save_summary_dict(config, shared, experiment) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))
def run(config, experiment='regression_ewc'): """Run the EWC training for the given experiment. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - ``'regression_ewc'``: Regression tasks with EWC - ``'gmm_ewc'``: GMM Data with EWC - ``'split_mnist_ewc'``: SplitMNIST with EWC - ``'perm_mnist_ewc'``: PermutedMNIST with EWC - ``'cifar_resnet_ewc'``: CIFAR-10/100 with EWC """ assert experiment in ['regression_ewc', 'gmm_ewc', 'split_mnist_ewc', 'perm_mnist_ewc', 'cifar_resnet_ewc'] script_start = time() device, writer, logger = sutils.setup_environment(config, logger_name=experiment) rutils.backup_cli_command(config) is_classification = True if 'regression' in experiment: is_classification = False ### Create tasks. if is_classification: dhandlers = pmutils.load_datasets(config, logger, experiment, writer) else: dhandlers, num_tasks = rutils.generate_tasks(config, writer) config.num_tasks = num_tasks ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.num_trained = 0 ### Generate network. mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if not is_classification: shared.during_mse = np.ones(config.num_tasks) * -1. # MSE achieved after most recent call of test method. shared.current_mse = np.ones(config.num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). shared.during_weights = [-1] * config.num_tasks # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. if is_classification: shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] ### Initialize summary. if is_classification: pcutils.setup_summary_dict(config, shared, experiment, mnet) else: rutils.setup_summary_dict(config, shared, 'ewc', config.num_tasks, mnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={**vars(config), **{ 'num_weights_main': shared.summary['num_weights_main'] \ if is_classification else shared.summary['aa_num_weights_main'] }}, metric_dict={}) if is_classification: during_acc_criterion = pmutils.parse_performance_criterion(config, shared, logger) ### Train on tasks sequentially. for i in range(config.num_tasks): logger.info('### Training on task %d ###' % (i+1)) data = dhandlers[i] # Train the network. shared.num_trained += 1 train(i, data, mnet, device, config, shared, logger, writer) ### Test networks. test_ids = None if hasattr(config, 'full_test_interval') and \ config.full_test_interval != -1: if i == config.num_tasks-1 or \ (i > 0 and i % config.full_test_interval == 0): test_ids = None # Test on all tasks. else: test_ids = [i] # Only test on current task. test(dhandlers[:(i+1)], mnet, device, config, shared, logger, writer, test_ids=test_ids) ### Check if last task got "acceptable" accuracy ### if is_classification: curr_dur_acc = shared.summary['acc_task_given_during'][i] if is_classification and i < config.num_tasks-1 \ and during_acc_criterion[i] != -1 \ and during_acc_criterion[i] > curr_dur_acc: logger.error('During accuracy of task %d too small (%f < %f).' % \ (i+1, curr_dur_acc, during_acc_criterion[i])) logger.error('Training of future tasks will be skipped') writer.close() exit(1) if config.train_from_scratch and i < config.num_tasks-1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. pmutils.checkpoint_nets(config, shared, i, mnet, None) mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks-1, mnet, None) ### Plot final classification scores. if is_classification: logger.info('During accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given_during']), precision=2, separator=','), shared.summary['acc_avg_task_given_during'])) logger.info('Final accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) if is_classification and config.cl_scenario == 3 and config.split_head_cl3: logger.info('During accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array( \ shared.summary['acc_task_inferred_ent_during']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent_during'])) logger.info('Final accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_inferred_ent']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent'])) ### Plot final regression scores. if not is_classification: logger.info('During MSE values after training each task: %s' % \ np.array2string(np.array(shared.summary['aa_mse_during']), precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(np.array(shared.summary['aa_mse_final']), precision=5, separator=',')) logger.info('Final MSE mean %.4f.' % \ (shared.summary['aa_mse_during_mean'])) ### Write final summary. shared.summary['finished'] = 1 if is_classification: pmutils.save_summary_dict(config, shared, experiment) else: rutils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))
def run(): """Run the script. Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_bbb' config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Generate networks. use_hnet = not config.mnet_only mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers # Mean and variance of prior that is used for variational inference. if config.mean_only: # No prior-matching can be performed. shared.prior_mean = None shared.prior_logvar = None shared.prior_std = None else: plogvar = np.log(config.prior_variance) pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some main networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint mnets as well! shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, 'bbb', num_tasks, mnet, hnet=hnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={ **vars(config), **{ 'num_weights_main': shared.summary['aa_num_weights_main'], 'num_weights_hyper': shared.summary['aa_num_weights_hyper'], 'num_weights_ratio': shared.summary['aa_num_weights_ratio'], } }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. train(i, data, mnet, hnet, device, config, shared, logger, writer) ### Test networks. test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. pmutils.checkpoint_nets(config, shared, i, mnet, hnet) mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse
parser.add_argument('out_dir', type=str, help='The output directory of the simulation to be ' + 'analyzed.') args = parser.parse_args() out_dir = args.out_dir # Temporary simulation directory, required by method `setup_environment`. args.out_dir = os.path.join( tempfile.gettempdir(), 'tmp_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) args.loglevel_info = False args.random_seed = 42 # Note, this script doesn't perform random computation args.deterministic_run = False args.no_cuda = True device, writer, logger = sutils.setup_environment(args) # FIXME Code below copied from script `state_space_analysis`. # Load the config if not os.path.exists(out_dir): raise ValueError('The directory "%s" does not exist.' % out_dir) with open(os.path.join(out_dir, "config.pickle"), "rb") as f: config = pickle.load(f) # Overwrite the directory it it's not the same as the original. if config.out_dir != out_dir: config.out_dir = out_dir # Check for old command line arguments and make compatible with new version. config = sta.update_cli_args(config) # FIXME only for copy task! generate_tasks_func = copytu.generate_copy_tasks
import matplotlib.pyplot as plt import numpy as np import torch from data.timeseries.rnd_rec_teacher import RndRecTeacher from mnets.simple_rnn import SimpleRNN import sequential.ht_analyses.state_space_analysis as ssa import sequential.student_teacher.train_args_st as sta import sequential.student_teacher.train_utils_st as stu from sequential.ht_analyses import pca_utils import utils.ewc_regularizer as ewc import utils.sim_utils as sutils if __name__ == '__main__': config = sta.parse_cmd_arguments(mode='student_teacher') device, writer, logger = sutils.setup_environment(config) writer.close() num_samples = 1000 ### Construct datasets ### scenario = 2 if scenario == 0: # All tasks identical -> hidden dim should remain constant d1 = RndRecTeacher(num_test=num_samples, rseed=1) d2 = RndRecTeacher(num_test=num_samples, rseed=1) d3 = RndRecTeacher(num_test=num_samples, rseed=1) dhandlers = [d1, d2, d3] elif scenario == 1: # All tasks different -> hidden dim should increase d1 = RndRecTeacher(num_test=num_samples, rseed=1) d2 = RndRecTeacher(num_test=num_samples, rseed=2) d3 = RndRecTeacher(num_test=num_samples, rseed=3) dhandlers = [d1, d2, d3]
def run(method='avb'): """Run the script. Args: method (str, optional): The VI algorithm. Possible values are: - ``'avb'`` - ``'ssge'`` Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_' + method use_dis = False # whether a discriminator network is used if method == 'avb': use_dis = True config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) if config.prior_focused: logger.info('Running a prior-focused CL experiment ...') else: logger.info('Learning task-specific posteriors sequentially ...') ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers shared.prior_focused = config.prior_focused ### Generate networks and environment mnet, hnet, hhnet, dnet = pcu.generate_networks(config, shared, logger, dhandlers, device, create_dis=use_dis) # Mean and variance of prior that is used for variational inference. # For a prior-focused training, this prior will only be used for the # first task. pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint all networks, even if they # where constructed without weights. shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d') #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, method, num_tasks, mnet, hnet=hnet, hhnet=hhnet, dis=dnet) logger.info('Ratio num hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hm_ratio']) if hhnet is not None: logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hhm_ratio']) if mode == 'regression_avb' and dnet is not None: # A discriminator only exists for AVB. logger.info('Ratio num dis weights / num mnet weights: %f.' % shared.summary['aa_num_weights_dm_ratio']) # Add hparams to tensorboard, such that the identification of runs is # easier. hparams_extra_dict = { 'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'], 'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio'] } if mode == 'regression_avb': hparams_extra_dict['num_weights_dm_ratio'] = \ shared.summary['aa_num_weights_dm_ratio'] writer.add_hparams(hparam_dict={ **vars(config), **hparams_extra_dict }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. tvi.train(i, data, mnet, hnet, hhnet, dnet, device, config, shared, logger, writer, method=method) # Test networks. train_bbb.test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer, hhnet=hhnet) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. # Note, we only need the discriminator as helper for training, # so we don't checkpoint it. pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) mnet, hnet, hhnet, dnet = pcu.generate_networks( config, shared, logger, dhandlers, device) elif config.store_during_models: logger.info('Checkpointing current model ...') pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) if config.store_final_model: logger.info('Checkpointing final model ...') pmu.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet, hhnet=hhnet, dis=None) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse
def run(config, experiment='split_mnist_avb'): """Run the training. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - "gmm_avb": Synthetic GMM data with Posterior Replay via AVB - "gmm_avb_pf": Synthetic GMM data with Prior-Focused CL via AVB - "split_mnist_avb": Split MNIST with Posterior Replay via AVB - "perm_mnist_avb": Permuted MNIST with Posterior Replay via AVB - "split_mnist_avb_pf": Split MNIST with Prior-Focused CL via AVB - "perm_mnist_avb_pf": Permuted MNIST with Prior-Focused CL via AVB - "cifar_zenke_avb": CIFAR-10/100 with Posterior Replay using a ZenkeNet and AVB - "cifar_resnet_avb": CIFAR-10/100 with Posterior Replay using a Resnet and AVB - "cifar_zenke_avb_pf": CIFAR-10/100 with Prior-Focused CL using a ZenkeNet and AVB - "cifar_resnet_avb_pf": CIFAR-10/100 with Prior-Focused CL using a Resnet and AVB - "gmm_ssge": Synthetic GMM data with Posterior Replay via SSGE - "gmm_ssge_pf": Synthetic GMM data with Prior-Focused CL via SSGE - "split_mnist_ssge": Split MNIST with Posterior Replay via SSGE - "perm_mnist_ssge": Permuted MNIST with Posterior Replay via SSGE - "split_mnist_ssge_pf": Split MNIST with Prior-Focused CL via SSGE - "perm_mnist_ssge_pf": Permuted MNIST with Prior-Focused CL via SSGE - "cifar_resnet_ssge": CIFAR-10/100 with Posterior Replay using a Resnet and SSGE - "cifar_resnet_ssge_pf": CIFAR-10/100 with Prior-Focused CL using a Resnet and SSGE """ assert experiment in [ 'gmm_avb', 'gmm_avb_pf', 'split_mnist_avb', 'split_mnist_avb_pf', 'perm_mnist_avb', 'perm_mnist_avb_pf', 'cifar_zenke_avb', 'cifar_zenke_avb_pf', 'cifar_resnet_avb', 'cifar_resnet_avb_pf', 'gmm_ssge', 'gmm_ssge_pf', 'split_mnist_ssge', 'split_mnist_ssge_pf', 'perm_mnist_ssge', 'perm_mnist_ssge_pf', 'cifar_resnet_ssge', 'cifar_resnet_ssge_pf' ] script_start = time() if 'avb' in experiment: method = 'avb' use_dis = True # whether a discriminator network is used elif 'ssge' in experiment: method = 'ssge' use_dis = False device, writer, logger = sutils.setup_environment(config, logger_name=experiment + 'logger') rutils.backup_cli_command(config) if experiment.endswith('pf'): prior_focused_cl = True logger.info('Running a prior-focused CL experiment ...') else: prior_focused_cl = False logger.info('Learning task-specific posteriors sequentially ...') ### Create tasks. dhandlers = pmutils.load_datasets(config, logger, experiment, writer) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.prior_focused = prior_focused_cl ### Generate networks. mnet, hnet, hhnet, dis = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_dis=use_dis) if method == 'ssge': assert dis is None ### Add more information to shared. # Mean and variance of prior that is used for variational inference. # For a prior-focused training, this prior will only be used for the # first task. #plogvar = np.log(config.prior_variance) pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.param_shapes] #shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \ # for s in mnet.param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.param_shapes] # The output weights of the hyper-hyper network right after training on # a task (can be used to assess how close the final weights are to the # original ones). shared.during_weights = [-1] * config.num_tasks if hhnet is not None \ else None # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint all networks, even if they # where constructed without weights. shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d') #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] shared.num_trained = 0 # Setup coresets iff regularization on all tasks is allowed. if config.coreset_size != -1 and config.past_and_future_coresets: for i in range(config.num_tasks): pmutils.update_coreset(config, shared, i, dhandlers[i], None, None, device, logger, None, hhnet=None, method='avb') ### Initialize summary. pcutils.setup_summary_dict(config, shared, experiment, mnet, hnet=hnet, hhnet=hhnet, dis=dis) logger.info('Ratio num hnet weights / num mnet weights: %f.' % shared.summary['num_weights_hm_ratio']) if 'num_weights_hhm_ratio' in shared.summary.keys(): logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' % shared.summary['num_weights_hhm_ratio']) if method == 'avb': logger.info('Ratio num dis weights / num mnet weights: %f.' % shared.summary['num_weights_dm_ratio']) # Add hparams to tensorboard, such that the identification of runs is # easier. hparams_extra_dict = { 'num_weights_hm_ratio': shared.summary['num_weights_hm_ratio'], } if 'num_weights_dm_ratio' in shared.summary.keys(): hparams_extra_dict = {**hparams_extra_dict, **{'num_weights_dm_ratio': \ shared.summary['num_weights_dm_ratio']}} if 'num_weights_hhm_ratio' in shared.summary.keys(): hparams_extra_dict = {**hparams_extra_dict, **{'num_weights_hhm_ratio': \ shared.summary['num_weights_hhm_ratio']}} writer.add_hparams(hparam_dict={ **vars(config), **hparams_extra_dict }, metric_dict={}) during_acc_criterion = pmutils.parse_performance_criterion( config, shared, logger) ### Train on tasks sequentially. for i in range(config.num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. shared.num_trained += 1 if config.distill_iter == -1: tvi.train(i, data, mnet, hnet, hhnet, dis, device, config, shared, logger, writer, method=method) else: assert hhnet is not None # Train main network only. tvi.train(i, data, mnet, hnet, None, dis, device, config, shared, logger, writer, method=method) # Distill `hnet` into `hhnet`. train_bbb.distill_net(i, data, mnet, hnet, hhnet, device, config, shared, logger, writer) # Create a new main network before training the next task. mnet, hnet, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_dis=False, create_hhnet=False) ### Temperature Calibration. if config.calibrate_temp: pcutils.calibrate_temperature(i, data, mnet, hnet, hhnet, device, config, shared, logger, writer) ### Test networks. test_ids = None if config.full_test_interval != -1: if i == config.num_tasks-1 or \ (i > 0 and i % config.full_test_interval == 0): test_ids = None # Test on all tasks. else: test_ids = [i] # Only test on current task. tvi.test(dhandlers[:(i + 1)], mnet, hnet, hhnet, device, config, shared, logger, writer, test_ids=test_ids, method=method) ### Check if last task got "acceptable" accuracy ### curr_dur_acc = shared.summary['acc_task_given_during'][i] if i < config.num_tasks-1 and during_acc_criterion[i] != -1 \ and during_acc_criterion[i] > curr_dur_acc: logger.error('During accuracy of task %d too small (%f < %f).' % \ (i+1, curr_dur_acc, during_acc_criterion[i])) logger.error('Training of future tasks will be skipped') writer.close() exit(1) if config.train_from_scratch and i < config.num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. # Note, we only need the discriminator as helper for training, # so we don't checkpoint it. pmutils.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) mnet, hnet, hhnet, dis = pcutils.generate_networks( config, shared, logger, dhandlers, device, create_dis=use_dis) elif dis is not None and not config.no_dis_reinit and \ i < config.num_tasks-1: logger.debug('Reinitializing discriminator network ...') # FIXME Build a new network as this init doesn't effect batchnorm # weights atm. dis.custom_init(normal_init=config.normal_init, normal_std=config.std_normal_init, zero_bias=True) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet, hnet, hhnet=hhnet, dis=None) logger.info('During accuracies (task identity given): %s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given_during']), precision=2, separator=','), shared.summary['acc_avg_task_given_during'])) logger.info('Final accuracies (task identity given): %s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) logger.info('During accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array( \ shared.summary['acc_task_inferred_ent_during']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent_during'])) logger.info('Final accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_inferred_ent']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent'])) logger.info('### Avg. during accuracy (CL scenario %d): %.4f.' % (config.cl_scenario, shared.summary['acc_avg_during'])) logger.info('### Avg. final accuracy (CL scenario %d): %.4f.' % (config.cl_scenario, shared.summary['acc_avg_final'])) ### Write final summary. shared.summary['finished'] = 1 pmutils.save_summary_dict(config, shared, experiment) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))
def run(): """ Run the script""" ############# ### Setup ### ############# config = train_args_copy.parse_cmd_arguments() device, writer, logger = sutils.setup_environment(config) dhandlers = ctu.generate_copy_tasks(config, logger, writer=writer) plc.visualise_data(dhandlers, config, device) # We will use the namespace below to share miscellaneous information between # functions. shared = Namespace() shared.feature_size = dhandlers[0].in_shape[0] if (config.permute_time or config.permute_width) and not \ config.scatter_pattern and not config.permute_xor_separate and \ not config.permute_xor_iter > 1: chance = ctu.compute_chance_level(dhandlers, config) logger.info('Chance level for perfect during accuracies: %.2f' % chance) # A bit ugly, find a nicer way (problem is, if you overwrite this before # generating the tasks, always the task with shortest sequences is chosen). if config.last_task_only: config.num_tasks = 1 target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers, device) # generate masks if needed ctx_masks = None if config.use_masks: ctx_masks = stu.generate_binary_masks(config, device, target_net) # We store the target network weights (excluding potential context-mod # weights after every task). In this way, we can quantify changes and # observe the "stiffness" of EWC. shared.tnet_weights = [] # We store the context-mod weights (or all weights) coming from the hypernet # after every task, in order to quantify "forgetting". Note, the hnet # regularizer should keep them fix. shared.hnet_out = [] # Get the task-specific functions for loss and accuracy. task_loss_func = ctu.get_copy_loss_func(config, device, logger, ewc_loss=False) accuracy_func = ctu.get_accuracy ewc_loss_func = ctu.get_copy_loss_func(config, device, logger, \ ewc_loss=True) if config.use_ewc else None replay_fcts = None if config.use_replay: replay_fcts = dict() replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func() replay_fcts['distill_loss'] = ctu.get_distill_loss_func() replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func() if config.multitask: summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS summary_filename = hpsearch_mt._SUMMARY_FILENAME else: summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS summary_filename = hpsearch_cl._SUMMARY_FILENAME ######################## ### Train classifier ### ######################## # Train the network task by task. Testing on all tasks is run after # finishing training on each task. ret, train_loss, test_loss, test_acc = sts.train_tasks( dhandlers, target_net, hnet, dnet, device, config, shared, logger, writer, ctx_masks, summary_keywords, summary_filename, task_loss_func=task_loss_func, accuracy_func=accuracy_func, ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts) stu.log_results(test_acc, config, logger) writer.close() if ret == -1: logger.info('Program finished successfully.') if config.show_plots: plt.show() else: logger.error('Only %d tasks have completed training.' % (ret + 1))
def run(): """Run the script""" ############# ### Setup ### ############# config = train_args_pos.parse_cmd_arguments() device, writer, logger = sutils.setup_environment(config) dhandlers = ctu.generate_tasks(config, logger, writer=writer) # Load preprocessed word embeddings, see # :mod:`data.timeseries.preprocess_mud` for details. wembs_path = '../../datasets/sequential/mud/embeddings.pickle' wemb_lookups = eu.generate_emb_lookups(config, filename=wembs_path, device=device) assert len(wemb_lookups) == config.num_tasks # We will use the namespace below to share miscellaneous information between # functions. shared = Namespace() # The embedding size is fixed due to the use of pretrained polyglot # embeddings. # FIXME Could be made configurable in the future in case we don't initialize # embeddings via polyglot. shared.feature_size = 64 shared.word_emb_lookups = wemb_lookups target_net, hnet, dnet = stu.generate_networks(config, shared, dhandlers, device) # generate masks if needed ctx_masks = None if config.use_masks: ctx_masks = stu.generate_binary_masks(config, device, target_net) # We store the target network weights (excluding potential context-mod # weights after every task). In this way, we can quantify changes and # observe the "stiffness" of EWC. shared.tnet_weights = [] # We store the context-mod weights (or all weights) coming from the hypernet # after every task, in order to quantify "forgetting". Note, the hnet # regularizer should keep them fix. shared.hnet_out = [] # Get the task-specific functions for loss and accuracy. task_loss_func = ctu.get_loss_func(config, device, logger, ewc_loss=False) accuracy_func = ctu.get_accuracy_func(config) ewc_loss_func = ctu.get_loss_func(config, device, logger, \ ewc_loss=True) if config.use_ewc else None replay_fcts = None if config.use_replay: replay_fcts = dict() replay_fcts['rec_loss'] = ctu.get_vae_rec_loss_func() replay_fcts['distill_loss'] = ctu.get_distill_loss_func() replay_fcts['soft_trgt_acc'] = ctu.get_soft_trgt_acc_func() if config.multitask: summary_keywords = hpsearch_mt._SUMMARY_KEYWORDS summary_filename = hpsearch_mt._SUMMARY_FILENAME else: summary_keywords = hpsearch_cl._SUMMARY_KEYWORDS summary_filename = hpsearch_cl._SUMMARY_FILENAME ######################## ### Train classifier ### ######################## shared.f_scores = None # Train the network task by task. Testing on all tasks is run after # finishing training on each task. ret, train_loss, test_loss, test_acc = sts.train_tasks( dhandlers, target_net, hnet, dnet, device, config, shared, logger, writer, ctx_masks, summary_keywords, summary_filename, task_loss_func=task_loss_func, accuracy_func=accuracy_func, ewc_loss_func=ewc_loss_func, replay_fcts=replay_fcts) stu.log_results(test_acc, config, logger) writer.close() if ret == -1: logger.info('Program finished successfully.') if config.show_plots: plt.show() else: logger.error('Only %d tasks have completed training.' % (ret + 1))
def run(config, experiment='regression_mt'): """Run the Multitask training for the given experiment. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - ``'regression_mt'``: Regression tasks with multitask training - ``'gmm_mt'``: GMM Data with multitask training - ``'split_mnist_mt'``: SplitMNIST with multitask training - ``'perm_mnist_mt'``: PermutedMNIST with multitask training - ``'cifar_resnet_mt'``: CIFAR-10/100 with multitask training """ assert experiment in [ 'regression_mt', 'gmm_mt', 'split_mnist_mt', 'perm_mnist_mt', 'cifar_resnet_mt' ] script_start = time() device, writer, logger = sutils.setup_environment(config, logger_name=experiment) rutils.backup_cli_command(config) is_classification = True if 'regression' in experiment: is_classification = False ### Create tasks. if is_classification: dhandlers = pmutils.load_datasets(config, logger, experiment, writer) else: dhandlers, num_tasks = rutils.generate_tasks(config, writer) config.num_tasks = num_tasks ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.num_trained = 0 ### Generate network. mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if not is_classification: shared.during_mse = np.ones(config.num_tasks) * -1. # MSE achieved after most recent call of test method. shared.current_mse = np.ones(config.num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. if is_classification: shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] ### Initialize summary. if is_classification: pcutils.setup_summary_dict(config, shared, experiment, mnet) else: rutils.setup_summary_dict(config, shared, 'mt', config.num_tasks, mnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={**vars(config), **{ 'num_weights_main': shared.summary['num_weights_main'] \ if is_classification else shared.summary['aa_num_weights_main'] }}, metric_dict={}) ### Train on all tasks. logger.info('### Training ###') # Note, since we are training on all tasks; all output heads can at all # times be considered as trained! shared.num_trained = config.num_tasks train(dhandlers, mnet, device, config, shared, logger, writer) logger.info('### Testing ###') test(dhandlers, mnet, device, config, shared, logger, writer, test_ids=None) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet, None) ### Plot final classification scores. if is_classification: logger.info('Final accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) ### Plot final regression scores. if not is_classification: logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(np.array(shared.summary['aa_mse_final']), precision=5, separator=',')) logger.info('Final MSE mean %.4f.' % \ (shared.summary['aa_mse_during_mean'])) ### Write final summary. shared.summary['finished'] = 1 if is_classification: pmutils.save_summary_dict(config, shared, experiment) else: rutils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))