# Should be set in your program when the execution finished successfully.
    'finished'
]
_OUT_ARG = hpbbb._OUT_ARG
_SUMMARY_PARSER_HANDLE = hpbbb._SUMMARY_PARSER_HANDLE
_PERFORMANCE_EVAL_HANDLE = hpbbb._PERFORMANCE_EVAL_HANDLE

_PERFORMANCE_KEY = hpbbb._PERFORMANCE_KEY
assert (_PERFORMANCE_KEY is None or _PERFORMANCE_KEY in _SUMMARY_KEYWORDS)
# Whether the CSV should be sorted ascending or descending based on the
# `_PERFORMANCE_KEY`.
_PERFORMANCE_SORT_ASC = hpbbb._PERFORMANCE_SORT_ASC

# FIXME: This attribute will vanish in future releases.
# This attribute is only required by the `hpsearch_postprocessing` script.
# A function handle to the argument parser function used by the simulation
# script. The function handle should expect the list of command line options
# as only parameter.
# Example:
# >>> from probabilistic.prob_mnist import train_args as targs
# >>> f = lambda argv : targs.parse_cmd_arguments(mode='split_mnist_bbb',
# ...                                             argv=argv)
# >>> _ARGPARSE_HANDLE = f
import probabilistic.regression.train_args as targs
_ARGPARSE_HANDLE = lambda argv : targs.parse_cmd_arguments( \
    mode='regression_avb', argv=argv)

if __name__ == '__main__':
    pass
def run():
    """Run the script.

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()

    mode = 'regression_bbb'
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Generate networks.
    use_hnet = not config.mnet_only
    mnet, hnet = train_utils.generate_gauss_networks(
        config,
        logger,
        dhandlers,
        device,
        create_hnet=use_hnet,
        non_gaussian=config.mean_only)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    # Mean and variance of prior that is used for variational inference.
    if config.mean_only:  # No prior-matching can be performed.
        shared.prior_mean = None
        shared.prior_logvar = None
        shared.prior_std = None
    else:
        plogvar = np.log(config.prior_variance)
        pstd = np.sqrt(config.prior_variance)
        shared.prior_mean = [torch.zeros(*s).to(device) \
                             for s in mnet.orig_param_shapes]
        shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \
                               for s in mnet.orig_param_shapes]
        shared.prior_std = [pstd * torch.ones(*s).to(device) \
                            for s in mnet.orig_param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some main networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint mnets as well!
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   'bbb',
                                   num_tasks,
                                   mnet,
                                   hnet=hnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={
        **vars(config),
        **{
            'num_weights_main': shared.summary['aa_num_weights_main'],
            'num_weights_hyper': shared.summary['aa_num_weights_hyper'],
            'num_weights_ratio': shared.summary['aa_num_weights_ratio'],
        }
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]
        # Train the network.
        train(i, data, mnet, hnet, device, config, shared, logger, writer)

        ### Test networks.
        test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger,
             writer)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            pmutils.checkpoint_nets(config, shared, i, mnet, hnet)

            mnet, hnet = train_utils.generate_gauss_networks(
                config,
                logger,
                dhandlers,
                device,
                create_hnet=use_hnet,
                non_gaussian=config.mean_only)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet)

    logger.info('During MSE values after training each task: %s' % \
          np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
          np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse
Exemple #3
0
def run(method='avb'):
    """Run the script.

    Args:
        method (str, optional): The VI algorithm. Possible values are:

            - ``'avb'``
            - ``'ssge'``

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()
    mode = 'regression_' + method
    use_dis = False  # whether a discriminator network is used
    if method == 'avb':
        use_dis = True
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    if config.prior_focused:
        logger.info('Running a prior-focused CL experiment ...')
    else:
        logger.info('Learning task-specific posteriors sequentially ...')

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    shared.prior_focused = config.prior_focused

    ### Generate networks and environment
    mnet, hnet, hhnet, dnet = pcu.generate_networks(config,
                                                    shared,
                                                    logger,
                                                    dhandlers,
                                                    device,
                                                    create_dis=use_dis)

    # Mean and variance of prior that is used for variational inference.
    # For a prior-focused training, this prior will only be used for the
    # first task.
    pstd = np.sqrt(config.prior_variance)
    shared.prior_mean = [torch.zeros(*s).to(device) \
                         for s in mnet.param_shapes]
    shared.prior_std = [pstd * torch.ones(*s).to(device) \
                        for s in mnet.param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint all networks, even if they
    # where constructed without weights.
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')
    shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d')
    #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   method,
                                   num_tasks,
                                   mnet,
                                   hnet=hnet,
                                   hhnet=hhnet,
                                   dis=dnet)
    logger.info('Ratio num hnet weights / num mnet weights: %f.' %
                shared.summary['aa_num_weights_hm_ratio'])
    if hhnet is not None:
        logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_hhm_ratio'])
    if mode == 'regression_avb' and dnet is not None:
        # A discriminator only exists for AVB.
        logger.info('Ratio num dis weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_dm_ratio'])

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    hparams_extra_dict = {
        'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'],
        'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio']
    }
    if mode == 'regression_avb':
        hparams_extra_dict['num_weights_dm_ratio'] = \
            shared.summary['aa_num_weights_dm_ratio']
    writer.add_hparams(hparam_dict={
        **vars(config),
        **hparams_extra_dict
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]

        # Train the network.
        tvi.train(i,
                  data,
                  mnet,
                  hnet,
                  hhnet,
                  dnet,
                  device,
                  config,
                  shared,
                  logger,
                  writer,
                  method=method)

        # Test networks.
        train_bbb.test(dhandlers[:(i + 1)],
                       mnet,
                       hnet,
                       device,
                       config,
                       shared,
                       logger,
                       writer,
                       hhnet=hhnet)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            # Note, we only need the discriminator as helper for training,
            # so we don't checkpoint it.
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

            mnet, hnet, hhnet, dnet = pcu.generate_networks(
                config, shared, logger, dhandlers, device)

        elif config.store_during_models:
            logger.info('Checkpointing current model ...')
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmu.checkpoint_nets(config,
                            shared,
                            num_tasks - 1,
                            mnet,
                            hnet,
                            hhnet=hhnet,
                            dis=None)

    logger.info('During MSE values after training each task: %s' % \
                np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
                np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse