def run(): """Run the script. Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_bbb' config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Generate networks. use_hnet = not config.mnet_only mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers # Mean and variance of prior that is used for variational inference. if config.mean_only: # No prior-matching can be performed. shared.prior_mean = None shared.prior_logvar = None shared.prior_std = None else: plogvar = np.log(config.prior_variance) pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some main networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint mnets as well! shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, 'bbb', num_tasks, mnet, hnet=hnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={ **vars(config), **{ 'num_weights_main': shared.summary['aa_num_weights_main'], 'num_weights_hyper': shared.summary['aa_num_weights_hyper'], 'num_weights_ratio': shared.summary['aa_num_weights_ratio'], } }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. train(i, data, mnet, hnet, device, config, shared, logger, writer) ### Test networks. test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. pmutils.checkpoint_nets(config, shared, i, mnet, hnet) mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse
def run(config, experiment='regression_ewc'): """Run the EWC training for the given experiment. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - ``'regression_ewc'``: Regression tasks with EWC - ``'gmm_ewc'``: GMM Data with EWC - ``'split_mnist_ewc'``: SplitMNIST with EWC - ``'perm_mnist_ewc'``: PermutedMNIST with EWC - ``'cifar_resnet_ewc'``: CIFAR-10/100 with EWC """ assert experiment in ['regression_ewc', 'gmm_ewc', 'split_mnist_ewc', 'perm_mnist_ewc', 'cifar_resnet_ewc'] script_start = time() device, writer, logger = sutils.setup_environment(config, logger_name=experiment) rutils.backup_cli_command(config) is_classification = True if 'regression' in experiment: is_classification = False ### Create tasks. if is_classification: dhandlers = pmutils.load_datasets(config, logger, experiment, writer) else: dhandlers, num_tasks = rutils.generate_tasks(config, writer) config.num_tasks = num_tasks ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.num_trained = 0 ### Generate network. mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if not is_classification: shared.during_mse = np.ones(config.num_tasks) * -1. # MSE achieved after most recent call of test method. shared.current_mse = np.ones(config.num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). shared.during_weights = [-1] * config.num_tasks # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. if is_classification: shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] ### Initialize summary. if is_classification: pcutils.setup_summary_dict(config, shared, experiment, mnet) else: rutils.setup_summary_dict(config, shared, 'ewc', config.num_tasks, mnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={**vars(config), **{ 'num_weights_main': shared.summary['num_weights_main'] \ if is_classification else shared.summary['aa_num_weights_main'] }}, metric_dict={}) if is_classification: during_acc_criterion = pmutils.parse_performance_criterion(config, shared, logger) ### Train on tasks sequentially. for i in range(config.num_tasks): logger.info('### Training on task %d ###' % (i+1)) data = dhandlers[i] # Train the network. shared.num_trained += 1 train(i, data, mnet, device, config, shared, logger, writer) ### Test networks. test_ids = None if hasattr(config, 'full_test_interval') and \ config.full_test_interval != -1: if i == config.num_tasks-1 or \ (i > 0 and i % config.full_test_interval == 0): test_ids = None # Test on all tasks. else: test_ids = [i] # Only test on current task. test(dhandlers[:(i+1)], mnet, device, config, shared, logger, writer, test_ids=test_ids) ### Check if last task got "acceptable" accuracy ### if is_classification: curr_dur_acc = shared.summary['acc_task_given_during'][i] if is_classification and i < config.num_tasks-1 \ and during_acc_criterion[i] != -1 \ and during_acc_criterion[i] > curr_dur_acc: logger.error('During accuracy of task %d too small (%f < %f).' % \ (i+1, curr_dur_acc, during_acc_criterion[i])) logger.error('Training of future tasks will be skipped') writer.close() exit(1) if config.train_from_scratch and i < config.num_tasks-1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. pmutils.checkpoint_nets(config, shared, i, mnet, None) mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks-1, mnet, None) ### Plot final classification scores. if is_classification: logger.info('During accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given_during']), precision=2, separator=','), shared.summary['acc_avg_task_given_during'])) logger.info('Final accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) if is_classification and config.cl_scenario == 3 and config.split_head_cl3: logger.info('During accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array( \ shared.summary['acc_task_inferred_ent_during']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent_during'])) logger.info('Final accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_inferred_ent']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent'])) ### Plot final regression scores. if not is_classification: logger.info('During MSE values after training each task: %s' % \ np.array2string(np.array(shared.summary['aa_mse_during']), precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(np.array(shared.summary['aa_mse_final']), precision=5, separator=',')) logger.info('Final MSE mean %.4f.' % \ (shared.summary['aa_mse_during_mean'])) ### Write final summary. shared.summary['finished'] = 1 if is_classification: pmutils.save_summary_dict(config, shared, experiment) else: rutils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))
def run(config, experiment='split_mnist_avb'): """Run the training. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - "gmm_avb": Synthetic GMM data with Posterior Replay via AVB - "gmm_avb_pf": Synthetic GMM data with Prior-Focused CL via AVB - "split_mnist_avb": Split MNIST with Posterior Replay via AVB - "perm_mnist_avb": Permuted MNIST with Posterior Replay via AVB - "split_mnist_avb_pf": Split MNIST with Prior-Focused CL via AVB - "perm_mnist_avb_pf": Permuted MNIST with Prior-Focused CL via AVB - "cifar_zenke_avb": CIFAR-10/100 with Posterior Replay using a ZenkeNet and AVB - "cifar_resnet_avb": CIFAR-10/100 with Posterior Replay using a Resnet and AVB - "cifar_zenke_avb_pf": CIFAR-10/100 with Prior-Focused CL using a ZenkeNet and AVB - "cifar_resnet_avb_pf": CIFAR-10/100 with Prior-Focused CL using a Resnet and AVB - "gmm_ssge": Synthetic GMM data with Posterior Replay via SSGE - "gmm_ssge_pf": Synthetic GMM data with Prior-Focused CL via SSGE - "split_mnist_ssge": Split MNIST with Posterior Replay via SSGE - "perm_mnist_ssge": Permuted MNIST with Posterior Replay via SSGE - "split_mnist_ssge_pf": Split MNIST with Prior-Focused CL via SSGE - "perm_mnist_ssge_pf": Permuted MNIST with Prior-Focused CL via SSGE - "cifar_resnet_ssge": CIFAR-10/100 with Posterior Replay using a Resnet and SSGE - "cifar_resnet_ssge_pf": CIFAR-10/100 with Prior-Focused CL using a Resnet and SSGE """ assert experiment in [ 'gmm_avb', 'gmm_avb_pf', 'split_mnist_avb', 'split_mnist_avb_pf', 'perm_mnist_avb', 'perm_mnist_avb_pf', 'cifar_zenke_avb', 'cifar_zenke_avb_pf', 'cifar_resnet_avb', 'cifar_resnet_avb_pf', 'gmm_ssge', 'gmm_ssge_pf', 'split_mnist_ssge', 'split_mnist_ssge_pf', 'perm_mnist_ssge', 'perm_mnist_ssge_pf', 'cifar_resnet_ssge', 'cifar_resnet_ssge_pf' ] script_start = time() if 'avb' in experiment: method = 'avb' use_dis = True # whether a discriminator network is used elif 'ssge' in experiment: method = 'ssge' use_dis = False device, writer, logger = sutils.setup_environment(config, logger_name=experiment + 'logger') rutils.backup_cli_command(config) if experiment.endswith('pf'): prior_focused_cl = True logger.info('Running a prior-focused CL experiment ...') else: prior_focused_cl = False logger.info('Learning task-specific posteriors sequentially ...') ### Create tasks. dhandlers = pmutils.load_datasets(config, logger, experiment, writer) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.prior_focused = prior_focused_cl ### Generate networks. mnet, hnet, hhnet, dis = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_dis=use_dis) if method == 'ssge': assert dis is None ### Add more information to shared. # Mean and variance of prior that is used for variational inference. # For a prior-focused training, this prior will only be used for the # first task. #plogvar = np.log(config.prior_variance) pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.param_shapes] #shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \ # for s in mnet.param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.param_shapes] # The output weights of the hyper-hyper network right after training on # a task (can be used to assess how close the final weights are to the # original ones). shared.during_weights = [-1] * config.num_tasks if hhnet is not None \ else None # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint all networks, even if they # where constructed without weights. shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d') #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] shared.num_trained = 0 # Setup coresets iff regularization on all tasks is allowed. if config.coreset_size != -1 and config.past_and_future_coresets: for i in range(config.num_tasks): pmutils.update_coreset(config, shared, i, dhandlers[i], None, None, device, logger, None, hhnet=None, method='avb') ### Initialize summary. pcutils.setup_summary_dict(config, shared, experiment, mnet, hnet=hnet, hhnet=hhnet, dis=dis) logger.info('Ratio num hnet weights / num mnet weights: %f.' % shared.summary['num_weights_hm_ratio']) if 'num_weights_hhm_ratio' in shared.summary.keys(): logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' % shared.summary['num_weights_hhm_ratio']) if method == 'avb': logger.info('Ratio num dis weights / num mnet weights: %f.' % shared.summary['num_weights_dm_ratio']) # Add hparams to tensorboard, such that the identification of runs is # easier. hparams_extra_dict = { 'num_weights_hm_ratio': shared.summary['num_weights_hm_ratio'], } if 'num_weights_dm_ratio' in shared.summary.keys(): hparams_extra_dict = {**hparams_extra_dict, **{'num_weights_dm_ratio': \ shared.summary['num_weights_dm_ratio']}} if 'num_weights_hhm_ratio' in shared.summary.keys(): hparams_extra_dict = {**hparams_extra_dict, **{'num_weights_hhm_ratio': \ shared.summary['num_weights_hhm_ratio']}} writer.add_hparams(hparam_dict={ **vars(config), **hparams_extra_dict }, metric_dict={}) during_acc_criterion = pmutils.parse_performance_criterion( config, shared, logger) ### Train on tasks sequentially. for i in range(config.num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. shared.num_trained += 1 if config.distill_iter == -1: tvi.train(i, data, mnet, hnet, hhnet, dis, device, config, shared, logger, writer, method=method) else: assert hhnet is not None # Train main network only. tvi.train(i, data, mnet, hnet, None, dis, device, config, shared, logger, writer, method=method) # Distill `hnet` into `hhnet`. train_bbb.distill_net(i, data, mnet, hnet, hhnet, device, config, shared, logger, writer) # Create a new main network before training the next task. mnet, hnet, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_dis=False, create_hhnet=False) ### Temperature Calibration. if config.calibrate_temp: pcutils.calibrate_temperature(i, data, mnet, hnet, hhnet, device, config, shared, logger, writer) ### Test networks. test_ids = None if config.full_test_interval != -1: if i == config.num_tasks-1 or \ (i > 0 and i % config.full_test_interval == 0): test_ids = None # Test on all tasks. else: test_ids = [i] # Only test on current task. tvi.test(dhandlers[:(i + 1)], mnet, hnet, hhnet, device, config, shared, logger, writer, test_ids=test_ids, method=method) ### Check if last task got "acceptable" accuracy ### curr_dur_acc = shared.summary['acc_task_given_during'][i] if i < config.num_tasks-1 and during_acc_criterion[i] != -1 \ and during_acc_criterion[i] > curr_dur_acc: logger.error('During accuracy of task %d too small (%f < %f).' % \ (i+1, curr_dur_acc, during_acc_criterion[i])) logger.error('Training of future tasks will be skipped') writer.close() exit(1) if config.train_from_scratch and i < config.num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. # Note, we only need the discriminator as helper for training, # so we don't checkpoint it. pmutils.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) mnet, hnet, hhnet, dis = pcutils.generate_networks( config, shared, logger, dhandlers, device, create_dis=use_dis) elif dis is not None and not config.no_dis_reinit and \ i < config.num_tasks-1: logger.debug('Reinitializing discriminator network ...') # FIXME Build a new network as this init doesn't effect batchnorm # weights atm. dis.custom_init(normal_init=config.normal_init, normal_std=config.std_normal_init, zero_bias=True) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet, hnet, hhnet=hhnet, dis=None) logger.info('During accuracies (task identity given): %s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given_during']), precision=2, separator=','), shared.summary['acc_avg_task_given_during'])) logger.info('Final accuracies (task identity given): %s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) logger.info('During accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array( \ shared.summary['acc_task_inferred_ent_during']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent_during'])) logger.info('Final accuracies (task identity inferred): ' + '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_inferred_ent']), precision=2, separator=','), shared.summary['acc_avg_task_inferred_ent'])) logger.info('### Avg. during accuracy (CL scenario %d): %.4f.' % (config.cl_scenario, shared.summary['acc_avg_during'])) logger.info('### Avg. final accuracy (CL scenario %d): %.4f.' % (config.cl_scenario, shared.summary['acc_avg_final'])) ### Write final summary. shared.summary['finished'] = 1 pmutils.save_summary_dict(config, shared, experiment) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))
def run(method='avb'): """Run the script. Args: method (str, optional): The VI algorithm. Possible values are: - ``'avb'`` - ``'ssge'`` Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_' + method use_dis = False # whether a discriminator network is used if method == 'avb': use_dis = True config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) if config.prior_focused: logger.info('Running a prior-focused CL experiment ...') else: logger.info('Learning task-specific posteriors sequentially ...') ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers shared.prior_focused = config.prior_focused ### Generate networks and environment mnet, hnet, hhnet, dnet = pcu.generate_networks(config, shared, logger, dhandlers, device, create_dis=use_dis) # Mean and variance of prior that is used for variational inference. # For a prior-focused training, this prior will only be used for the # first task. pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint all networks, even if they # where constructed without weights. shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d') #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, method, num_tasks, mnet, hnet=hnet, hhnet=hhnet, dis=dnet) logger.info('Ratio num hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hm_ratio']) if hhnet is not None: logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hhm_ratio']) if mode == 'regression_avb' and dnet is not None: # A discriminator only exists for AVB. logger.info('Ratio num dis weights / num mnet weights: %f.' % shared.summary['aa_num_weights_dm_ratio']) # Add hparams to tensorboard, such that the identification of runs is # easier. hparams_extra_dict = { 'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'], 'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio'] } if mode == 'regression_avb': hparams_extra_dict['num_weights_dm_ratio'] = \ shared.summary['aa_num_weights_dm_ratio'] writer.add_hparams(hparam_dict={ **vars(config), **hparams_extra_dict }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. tvi.train(i, data, mnet, hnet, hhnet, dnet, device, config, shared, logger, writer, method=method) # Test networks. train_bbb.test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer, hhnet=hhnet) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. # Note, we only need the discriminator as helper for training, # so we don't checkpoint it. pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) mnet, hnet, hhnet, dnet = pcu.generate_networks( config, shared, logger, dhandlers, device) elif config.store_during_models: logger.info('Checkpointing current model ...') pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) if config.store_final_model: logger.info('Checkpointing final model ...') pmu.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet, hhnet=hhnet, dis=None) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse
def run(config, experiment='regression_mt'): """Run the Multitask training for the given experiment. Args: config (argparse.Namespace): Command-line arguments. experiment (str): Which kind of experiment should be performed? - ``'regression_mt'``: Regression tasks with multitask training - ``'gmm_mt'``: GMM Data with multitask training - ``'split_mnist_mt'``: SplitMNIST with multitask training - ``'perm_mnist_mt'``: PermutedMNIST with multitask training - ``'cifar_resnet_mt'``: CIFAR-10/100 with multitask training """ assert experiment in [ 'regression_mt', 'gmm_mt', 'split_mnist_mt', 'perm_mnist_mt', 'cifar_resnet_mt' ] script_start = time() device, writer, logger = sutils.setup_environment(config, logger_name=experiment) rutils.backup_cli_command(config) is_classification = True if 'regression' in experiment: is_classification = False ### Create tasks. if is_classification: dhandlers = pmutils.load_datasets(config, logger, experiment, writer) else: dhandlers, num_tasks = rutils.generate_tasks(config, writer) config.num_tasks = num_tasks ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = experiment shared.all_dhandlers = dhandlers shared.num_trained = 0 ### Generate network. mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers, device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers, create_dis=False) if not is_classification: shared.during_mse = np.ones(config.num_tasks) * -1. # MSE achieved after most recent call of test method. shared.current_mse = np.ones(config.num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') # Initialize the softmax temperature per-task with one. Might be changed # later on to calibrate the temperature. if is_classification: shared.softmax_temp = [torch.ones(1).to(device) \ for _ in range(config.num_tasks)] ### Initialize summary. if is_classification: pcutils.setup_summary_dict(config, shared, experiment, mnet) else: rutils.setup_summary_dict(config, shared, 'mt', config.num_tasks, mnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={**vars(config), **{ 'num_weights_main': shared.summary['num_weights_main'] \ if is_classification else shared.summary['aa_num_weights_main'] }}, metric_dict={}) ### Train on all tasks. logger.info('### Training ###') # Note, since we are training on all tasks; all output heads can at all # times be considered as trained! shared.num_trained = config.num_tasks train(dhandlers, mnet, device, config, shared, logger, writer) logger.info('### Testing ###') test(dhandlers, mnet, device, config, shared, logger, writer, test_ids=None) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet, None) ### Plot final classification scores. if is_classification: logger.info('Final accuracies (task identity given): ' + \ '%s (avg: %.2f%%).' % \ (np.array2string(np.array(shared.summary['acc_task_given']), precision=2, separator=','), shared.summary['acc_avg_task_given'])) ### Plot final regression scores. if not is_classification: logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(np.array(shared.summary['aa_mse_final']), precision=5, separator=',')) logger.info('Final MSE mean %.4f.' % \ (shared.summary['aa_mse_during_mean'])) ### Write final summary. shared.summary['finished'] = 1 if is_classification: pmutils.save_summary_dict(config, shared, experiment) else: rutils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start))