# Should be set in your program when the execution finished successfully. 'finished' ] _OUT_ARG = hpbbb._OUT_ARG _SUMMARY_PARSER_HANDLE = hpbbb._SUMMARY_PARSER_HANDLE _PERFORMANCE_EVAL_HANDLE = hpbbb._PERFORMANCE_EVAL_HANDLE _PERFORMANCE_KEY = hpbbb._PERFORMANCE_KEY assert (_PERFORMANCE_KEY is None or _PERFORMANCE_KEY in _SUMMARY_KEYWORDS) # Whether the CSV should be sorted ascending or descending based on the # `_PERFORMANCE_KEY`. _PERFORMANCE_SORT_ASC = hpbbb._PERFORMANCE_SORT_ASC # FIXME: This attribute will vanish in future releases. # This attribute is only required by the `hpsearch_postprocessing` script. # A function handle to the argument parser function used by the simulation # script. The function handle should expect the list of command line options # as only parameter. # Example: # >>> from probabilistic.prob_mnist import train_args as targs # >>> f = lambda argv : targs.parse_cmd_arguments(mode='split_mnist_bbb', # ... argv=argv) # >>> _ARGPARSE_HANDLE = f import probabilistic.regression.train_args as targs _ARGPARSE_HANDLE = lambda argv : targs.parse_cmd_arguments( \ mode='regression_avb', argv=argv) if __name__ == '__main__': pass
def run(): """Run the script. Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_bbb' config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Generate networks. use_hnet = not config.mnet_only mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers # Mean and variance of prior that is used for variational inference. if config.mean_only: # No prior-matching can be performed. shared.prior_mean = None shared.prior_logvar = None shared.prior_std = None else: plogvar = np.log(config.prior_variance) pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.orig_param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some main networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint mnets as well! shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, 'bbb', num_tasks, mnet, hnet=hnet) # Add hparams to tensorboard, such that the identification of runs is # easier. writer.add_hparams(hparam_dict={ **vars(config), **{ 'num_weights_main': shared.summary['aa_num_weights_main'], 'num_weights_hyper': shared.summary['aa_num_weights_hyper'], 'num_weights_ratio': shared.summary['aa_num_weights_ratio'], } }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. train(i, data, mnet, hnet, device, config, shared, logger, writer) ### Test networks. test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. pmutils.checkpoint_nets(config, shared, i, mnet, hnet) mnet, hnet = train_utils.generate_gauss_networks( config, logger, dhandlers, device, create_hnet=use_hnet, non_gaussian=config.mean_only) if config.store_final_model: logger.info('Checkpointing final model ...') pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse
def run(method='avb'): """Run the script. Args: method (str, optional): The VI algorithm. Possible values are: - ``'avb'`` - ``'ssge'`` Returns: (tuple): Tuple containing: - **final_mse**: Final MSE for each task. - **during_mse**: MSE achieved directly after training on each task. """ script_start = time() mode = 'regression_' + method use_dis = False # whether a discriminator network is used if method == 'avb': use_dis = True config = train_args.parse_cmd_arguments(mode=mode) device, writer, logger = sutils.setup_environment(config, logger_name=mode) train_utils.backup_cli_command(config) if config.prior_focused: logger.info('Running a prior-focused CL experiment ...') else: logger.info('Learning task-specific posteriors sequentially ...') ### Create tasks. dhandlers, num_tasks = train_utils.generate_tasks(config, writer) ### Simple struct, that is used to share data among functions. shared = Namespace() shared.experiment_type = mode shared.all_dhandlers = dhandlers shared.prior_focused = config.prior_focused ### Generate networks and environment mnet, hnet, hhnet, dnet = pcu.generate_networks(config, shared, logger, dhandlers, device, create_dis=use_dis) # Mean and variance of prior that is used for variational inference. # For a prior-focused training, this prior will only be used for the # first task. pstd = np.sqrt(config.prior_variance) shared.prior_mean = [torch.zeros(*s).to(device) \ for s in mnet.param_shapes] shared.prior_std = [pstd * torch.ones(*s).to(device) \ for s in mnet.param_shapes] # Note, all MSE values are measured on a validation set if given, otherwise # on the training set. All samples in the validation set are expected to # lay inside the training range. Test samples may lay outside the training # range. # The MSE value achieved right after training on the corresponding task. shared.during_mse = np.ones(num_tasks) * -1. # The weights of the main network right after training on that task # (can be used to assess how close the final weights are to the original # ones). Note, weights refer to mean and variances (e.g., the output of the # hypernetwork). shared.during_weights = [-1] * num_tasks # MSE achieved after most recent call of test method. shared.current_mse = np.ones(num_tasks) * -1. # Where to save network checkpoints? shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints') # Note, some networks have stuff to store such as batch statistics for # batch norm. So it is wise to always checkpoint all networks, even if they # where constructed without weights. shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d') shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d') shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d') #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d') ### Initialize the performance measures, that should be tracked during ### training. train_utils.setup_summary_dict(config, shared, method, num_tasks, mnet, hnet=hnet, hhnet=hhnet, dis=dnet) logger.info('Ratio num hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hm_ratio']) if hhnet is not None: logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' % shared.summary['aa_num_weights_hhm_ratio']) if mode == 'regression_avb' and dnet is not None: # A discriminator only exists for AVB. logger.info('Ratio num dis weights / num mnet weights: %f.' % shared.summary['aa_num_weights_dm_ratio']) # Add hparams to tensorboard, such that the identification of runs is # easier. hparams_extra_dict = { 'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'], 'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio'] } if mode == 'regression_avb': hparams_extra_dict['num_weights_dm_ratio'] = \ shared.summary['aa_num_weights_dm_ratio'] writer.add_hparams(hparam_dict={ **vars(config), **hparams_extra_dict }, metric_dict={}) ### Train on tasks sequentially. for i in range(num_tasks): logger.info('### Training on task %d ###' % (i + 1)) data = dhandlers[i] # Train the network. tvi.train(i, data, mnet, hnet, hhnet, dnet, device, config, shared, logger, writer, method=method) # Test networks. train_bbb.test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger, writer, hhnet=hhnet) if config.train_from_scratch and i < num_tasks - 1: # We have to checkpoint the networks, such that we can reload them # for task inference later during testing. # Note, we only need the discriminator as helper for training, # so we don't checkpoint it. pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) mnet, hnet, hhnet, dnet = pcu.generate_networks( config, shared, logger, dhandlers, device) elif config.store_during_models: logger.info('Checkpointing current model ...') pmu.checkpoint_nets(config, shared, i, mnet, hnet, hhnet=hhnet, dis=None) if config.store_final_model: logger.info('Checkpointing final model ...') pmu.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet, hhnet=hhnet, dis=None) logger.info('During MSE values after training each task: %s' % \ np.array2string(shared.during_mse, precision=5, separator=',')) logger.info('Final MSE values after training on all tasks: %s' % \ np.array2string(shared.current_mse, precision=5, separator=',')) logger.info('Final MSE mean %.4f (std %.4f).' % (shared.current_mse.mean(), shared.current_mse.std())) ### Write final summary. shared.summary['finished'] = 1 train_utils.save_summary_dict(config, shared) writer.close() logger.info('Program finished successfully in %f sec.' % (time() - script_start)) return shared.current_mse, shared.during_mse