def run():
    """Run the script.

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()

    mode = 'regression_bbb'
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Generate networks.
    use_hnet = not config.mnet_only
    mnet, hnet = train_utils.generate_gauss_networks(
        config,
        logger,
        dhandlers,
        device,
        create_hnet=use_hnet,
        non_gaussian=config.mean_only)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    # Mean and variance of prior that is used for variational inference.
    if config.mean_only:  # No prior-matching can be performed.
        shared.prior_mean = None
        shared.prior_logvar = None
        shared.prior_std = None
    else:
        plogvar = np.log(config.prior_variance)
        pstd = np.sqrt(config.prior_variance)
        shared.prior_mean = [torch.zeros(*s).to(device) \
                             for s in mnet.orig_param_shapes]
        shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \
                               for s in mnet.orig_param_shapes]
        shared.prior_std = [pstd * torch.ones(*s).to(device) \
                            for s in mnet.orig_param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some main networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint mnets as well!
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   'bbb',
                                   num_tasks,
                                   mnet,
                                   hnet=hnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={
        **vars(config),
        **{
            'num_weights_main': shared.summary['aa_num_weights_main'],
            'num_weights_hyper': shared.summary['aa_num_weights_hyper'],
            'num_weights_ratio': shared.summary['aa_num_weights_ratio'],
        }
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]
        # Train the network.
        train(i, data, mnet, hnet, device, config, shared, logger, writer)

        ### Test networks.
        test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger,
             writer)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            pmutils.checkpoint_nets(config, shared, i, mnet, hnet)

            mnet, hnet = train_utils.generate_gauss_networks(
                config,
                logger,
                dhandlers,
                device,
                create_hnet=use_hnet,
                non_gaussian=config.mean_only)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet)

    logger.info('During MSE values after training each task: %s' % \
          np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
          np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse
def run(config, experiment='regression_ewc'):
    """Run the EWC training for the given experiment.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - ``'regression_ewc'``: Regression tasks with EWC
            - ``'gmm_ewc'``: GMM Data with EWC
            - ``'split_mnist_ewc'``: SplitMNIST with EWC
            - ``'perm_mnist_ewc'``: PermutedMNIST with EWC
            - ``'cifar_resnet_ewc'``: CIFAR-10/100 with EWC
    """
    assert experiment in ['regression_ewc', 'gmm_ewc', 'split_mnist_ewc',
                          'perm_mnist_ewc', 'cifar_resnet_ewc']

    script_start = time()

    device, writer, logger = sutils.setup_environment(config,
        logger_name=experiment)

    rutils.backup_cli_command(config)

    is_classification = True
    if 'regression' in experiment:
        is_classification = False

    ### Create tasks.
    if is_classification:
        dhandlers = pmutils.load_datasets(config, logger, experiment, writer)
    else:
        dhandlers, num_tasks = rutils.generate_tasks(config, writer)
        config.num_tasks = num_tasks

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.num_trained = 0

    ### Generate network.
    mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers,
        device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers,
        create_dis=False)

    if not is_classification:
        shared.during_mse = np.ones(config.num_tasks) * -1.
        # MSE achieved after most recent call of test method.
        shared.current_mse = np.ones(config.num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones).
    shared.during_weights = [-1] * config.num_tasks

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    if is_classification:
        shared.softmax_temp = [torch.ones(1).to(device) \
                               for _ in range(config.num_tasks)]

    ### Initialize summary.
    if is_classification:
        pcutils.setup_summary_dict(config, shared, experiment, mnet)
    else:
        rutils.setup_summary_dict(config, shared, 'ewc', config.num_tasks, mnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={**vars(config), **{
        'num_weights_main': shared.summary['num_weights_main'] \
            if is_classification else shared.summary['aa_num_weights_main']
    }}, metric_dict={})

    if is_classification:
        during_acc_criterion = pmutils.parse_performance_criterion(config,
            shared, logger)

    ### Train on tasks sequentially.
    for i in range(config.num_tasks):
        logger.info('### Training on task %d ###' % (i+1))
        data = dhandlers[i]

        # Train the network.
        shared.num_trained += 1
        train(i, data, mnet, device, config, shared, logger, writer)

        ### Test networks.
        test_ids = None
        if hasattr(config, 'full_test_interval') and \
                config.full_test_interval != -1:
            if i == config.num_tasks-1 or \
                    (i > 0 and i % config.full_test_interval == 0):
                test_ids = None # Test on all tasks.
            else:
                test_ids = [i] # Only test on current task.
        test(dhandlers[:(i+1)], mnet, device, config, shared, logger, writer,
             test_ids=test_ids)

        ### Check if last task got "acceptable" accuracy ###
        if is_classification:
            curr_dur_acc = shared.summary['acc_task_given_during'][i]
        if is_classification and i < config.num_tasks-1 \
                and during_acc_criterion[i] != -1 \
                and during_acc_criterion[i] > curr_dur_acc:
            logger.error('During accuracy of task %d too small (%f < %f).' % \
                         (i+1, curr_dur_acc, during_acc_criterion[i]))
            logger.error('Training of future tasks will be skipped')
            writer.close()
            exit(1)

        if config.train_from_scratch and i < config.num_tasks-1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            pmutils.checkpoint_nets(config, shared, i, mnet, None)

            mnet, _, _, _ = pcutils.generate_networks(config, shared, logger,
                dhandlers, device, create_mnet=True, create_hnet=False,
                create_hhnet=dhandlers, create_dis=False)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, config.num_tasks-1, mnet, None)

    ### Plot final classification scores.
    if is_classification:
        logger.info('During accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given_during']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given_during']))
        logger.info('Final accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given']))

    if is_classification and config.cl_scenario == 3 and config.split_head_cl3:
        logger.info('During accuracies (task identity inferred): ' +
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array( \
                shared.summary['acc_task_inferred_ent_during']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_inferred_ent_during']))
        logger.info('Final accuracies (task identity inferred): ' +
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_inferred_ent']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_inferred_ent']))

    ### Plot final regression scores.
    if not is_classification:
        logger.info('During MSE values after training each task: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_during']),
                            precision=5, separator=','))
        logger.info('Final MSE values after training on all tasks: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_final']),
                            precision=5, separator=','))
        logger.info('Final MSE mean %.4f.' % \
                    (shared.summary['aa_mse_during_mean']))

    ### Write final summary.
    shared.summary['finished'] = 1
    if is_classification:
        pmutils.save_summary_dict(config, shared, experiment)
    else:
        rutils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.'
                % (time() - script_start))
Beispiel #3
0
def run(config, experiment='split_mnist_avb'):
    """Run the training.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - "gmm_avb": Synthetic GMM data with Posterior Replay via AVB
            - "gmm_avb_pf": Synthetic GMM data with Prior-Focused CL via AVB
            - "split_mnist_avb": Split MNIST with Posterior Replay via AVB
            - "perm_mnist_avb": Permuted MNIST with Posterior Replay via AVB
            - "split_mnist_avb_pf": Split MNIST with Prior-Focused CL via AVB
            - "perm_mnist_avb_pf": Permuted MNIST with Prior-Focused CL via AVB
            - "cifar_zenke_avb": CIFAR-10/100 with Posterior Replay using a
              ZenkeNet and AVB
            - "cifar_resnet_avb": CIFAR-10/100 with Posterior Replay using a
              Resnet and AVB
            - "cifar_zenke_avb_pf": CIFAR-10/100 with Prior-Focused CL using a
              ZenkeNet and AVB
            - "cifar_resnet_avb_pf": CIFAR-10/100 with Prior-Focused CL using a
              Resnet and AVB
            - "gmm_ssge": Synthetic GMM data with Posterior Replay via SSGE
            - "gmm_ssge_pf": Synthetic GMM data with Prior-Focused CL via SSGE
            - "split_mnist_ssge": Split MNIST with Posterior Replay via SSGE
            - "perm_mnist_ssge": Permuted MNIST with Posterior Replay via SSGE
            - "split_mnist_ssge_pf": Split MNIST with Prior-Focused CL via SSGE
            - "perm_mnist_ssge_pf": Permuted MNIST with Prior-Focused CL via
              SSGE
            - "cifar_resnet_ssge": CIFAR-10/100 with Posterior Replay using a
              Resnet and SSGE
            - "cifar_resnet_ssge_pf": CIFAR-10/100 with Prior-Focused CL using a
              Resnet and SSGE
    """
    assert experiment in [
        'gmm_avb', 'gmm_avb_pf', 'split_mnist_avb', 'split_mnist_avb_pf',
        'perm_mnist_avb', 'perm_mnist_avb_pf', 'cifar_zenke_avb',
        'cifar_zenke_avb_pf', 'cifar_resnet_avb', 'cifar_resnet_avb_pf',
        'gmm_ssge', 'gmm_ssge_pf', 'split_mnist_ssge', 'split_mnist_ssge_pf',
        'perm_mnist_ssge', 'perm_mnist_ssge_pf', 'cifar_resnet_ssge',
        'cifar_resnet_ssge_pf'
    ]

    script_start = time()

    if 'avb' in experiment:
        method = 'avb'
        use_dis = True  # whether a discriminator network is used
    elif 'ssge' in experiment:
        method = 'ssge'
        use_dis = False

    device, writer, logger = sutils.setup_environment(config,
                                                      logger_name=experiment +
                                                      'logger')

    rutils.backup_cli_command(config)

    if experiment.endswith('pf'):
        prior_focused_cl = True
        logger.info('Running a prior-focused CL experiment ...')
    else:
        prior_focused_cl = False
        logger.info('Learning task-specific posteriors sequentially ...')

    ### Create tasks.
    dhandlers = pmutils.load_datasets(config, logger, experiment, writer)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.prior_focused = prior_focused_cl

    ### Generate networks.
    mnet, hnet, hhnet, dis = pcutils.generate_networks(config,
                                                       shared,
                                                       logger,
                                                       dhandlers,
                                                       device,
                                                       create_dis=use_dis)
    if method == 'ssge':
        assert dis is None

    ### Add more information to shared.
    # Mean and variance of prior that is used for variational inference.
    # For a prior-focused training, this prior will only be used for the
    # first task.
    #plogvar = np.log(config.prior_variance)
    pstd = np.sqrt(config.prior_variance)
    shared.prior_mean = [torch.zeros(*s).to(device) \
                         for s in mnet.param_shapes]
    #shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \
    #                       for s in mnet.param_shapes]
    shared.prior_std = [pstd * torch.ones(*s).to(device) \
                        for s in mnet.param_shapes]

    # The output weights of the hyper-hyper network right after training on
    # a task (can be used to assess how close the final weights are to the
    # original ones).
    shared.during_weights = [-1] * config.num_tasks if hhnet is not None \
        else None

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint all networks, even if they
    # where constructed without weights.
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')
    shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d')
    #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    shared.softmax_temp = [torch.ones(1).to(device) \
                           for _ in range(config.num_tasks)]
    shared.num_trained = 0

    # Setup coresets iff regularization on all tasks is allowed.
    if config.coreset_size != -1 and config.past_and_future_coresets:
        for i in range(config.num_tasks):
            pmutils.update_coreset(config,
                                   shared,
                                   i,
                                   dhandlers[i],
                                   None,
                                   None,
                                   device,
                                   logger,
                                   None,
                                   hhnet=None,
                                   method='avb')

    ### Initialize summary.
    pcutils.setup_summary_dict(config,
                               shared,
                               experiment,
                               mnet,
                               hnet=hnet,
                               hhnet=hhnet,
                               dis=dis)
    logger.info('Ratio num hnet weights / num mnet weights: %f.' %
                shared.summary['num_weights_hm_ratio'])
    if 'num_weights_hhm_ratio' in shared.summary.keys():
        logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' %
                    shared.summary['num_weights_hhm_ratio'])
    if method == 'avb':
        logger.info('Ratio num dis weights / num mnet weights: %f.' %
                    shared.summary['num_weights_dm_ratio'])

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    hparams_extra_dict = {
        'num_weights_hm_ratio': shared.summary['num_weights_hm_ratio'],
    }
    if 'num_weights_dm_ratio' in shared.summary.keys():
        hparams_extra_dict = {**hparams_extra_dict,
            **{'num_weights_dm_ratio': \
               shared.summary['num_weights_dm_ratio']}}
    if 'num_weights_hhm_ratio' in shared.summary.keys():
        hparams_extra_dict = {**hparams_extra_dict,
            **{'num_weights_hhm_ratio': \
               shared.summary['num_weights_hhm_ratio']}}
    writer.add_hparams(hparam_dict={
        **vars(config),
        **hparams_extra_dict
    },
                       metric_dict={})

    during_acc_criterion = pmutils.parse_performance_criterion(
        config, shared, logger)

    ### Train on tasks sequentially.
    for i in range(config.num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]

        # Train the network.
        shared.num_trained += 1
        if config.distill_iter == -1:
            tvi.train(i,
                      data,
                      mnet,
                      hnet,
                      hhnet,
                      dis,
                      device,
                      config,
                      shared,
                      logger,
                      writer,
                      method=method)
        else:
            assert hhnet is not None
            # Train main network only.
            tvi.train(i,
                      data,
                      mnet,
                      hnet,
                      None,
                      dis,
                      device,
                      config,
                      shared,
                      logger,
                      writer,
                      method=method)

            # Distill `hnet` into `hhnet`.
            train_bbb.distill_net(i, data, mnet, hnet, hhnet, device, config,
                                  shared, logger, writer)

            # Create a new main network before training the next task.
            mnet, hnet, _, _ = pcutils.generate_networks(config,
                                                         shared,
                                                         logger,
                                                         dhandlers,
                                                         device,
                                                         create_dis=False,
                                                         create_hhnet=False)

        ### Temperature Calibration.
        if config.calibrate_temp:
            pcutils.calibrate_temperature(i, data, mnet, hnet, hhnet, device,
                                          config, shared, logger, writer)

        ### Test networks.
        test_ids = None
        if config.full_test_interval != -1:
            if i == config.num_tasks-1 or \
                    (i > 0 and i % config.full_test_interval == 0):
                test_ids = None  # Test on all tasks.
            else:
                test_ids = [i]  # Only test on current task.
        tvi.test(dhandlers[:(i + 1)],
                 mnet,
                 hnet,
                 hhnet,
                 device,
                 config,
                 shared,
                 logger,
                 writer,
                 test_ids=test_ids,
                 method=method)

        ### Check if last task got "acceptable" accuracy ###
        curr_dur_acc = shared.summary['acc_task_given_during'][i]
        if i < config.num_tasks-1 and during_acc_criterion[i] != -1 \
                and during_acc_criterion[i] > curr_dur_acc:
            logger.error('During accuracy of task %d too small (%f < %f).' % \
                         (i+1, curr_dur_acc, during_acc_criterion[i]))
            logger.error('Training of future tasks will be skipped')
            writer.close()
            exit(1)

        if config.train_from_scratch and i < config.num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            # Note, we only need the discriminator as helper for training,
            # so we don't checkpoint it.
            pmutils.checkpoint_nets(config,
                                    shared,
                                    i,
                                    mnet,
                                    hnet,
                                    hhnet=hhnet,
                                    dis=None)

            mnet, hnet, hhnet, dis = pcutils.generate_networks(
                config, shared, logger, dhandlers, device, create_dis=use_dis)

        elif dis is not None and not config.no_dis_reinit and \
                i < config.num_tasks-1:
            logger.debug('Reinitializing discriminator network ...')
            # FIXME Build a new network as this init doesn't effect batchnorm
            # weights atm.
            dis.custom_init(normal_init=config.normal_init,
                            normal_std=config.std_normal_init,
                            zero_bias=True)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config,
                                shared,
                                config.num_tasks - 1,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

    logger.info('During accuracies (task identity given): %s (avg: %.2f%%).' % \
        (np.array2string(np.array(shared.summary['acc_task_given_during']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_given_during']))
    logger.info('Final accuracies (task identity given): %s (avg: %.2f%%).' % \
        (np.array2string(np.array(shared.summary['acc_task_given']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_given']))

    logger.info('During accuracies (task identity inferred): ' +
                '%s (avg: %.2f%%).' % \
        (np.array2string(np.array( \
            shared.summary['acc_task_inferred_ent_during']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_inferred_ent_during']))
    logger.info('Final accuracies (task identity inferred): ' +
                '%s (avg: %.2f%%).' % \
        (np.array2string(np.array(shared.summary['acc_task_inferred_ent']),
                         precision=2, separator=','),
         shared.summary['acc_avg_task_inferred_ent']))

    logger.info('### Avg. during accuracy (CL scenario %d): %.4f.' %
                (config.cl_scenario, shared.summary['acc_avg_during']))
    logger.info('### Avg. final accuracy (CL scenario %d): %.4f.' %
                (config.cl_scenario, shared.summary['acc_avg_final']))

    ### Write final summary.
    shared.summary['finished'] = 1
    pmutils.save_summary_dict(config, shared, experiment)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))
Beispiel #4
0
def run(method='avb'):
    """Run the script.

    Args:
        method (str, optional): The VI algorithm. Possible values are:

            - ``'avb'``
            - ``'ssge'``

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()
    mode = 'regression_' + method
    use_dis = False  # whether a discriminator network is used
    if method == 'avb':
        use_dis = True
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    if config.prior_focused:
        logger.info('Running a prior-focused CL experiment ...')
    else:
        logger.info('Learning task-specific posteriors sequentially ...')

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    shared.prior_focused = config.prior_focused

    ### Generate networks and environment
    mnet, hnet, hhnet, dnet = pcu.generate_networks(config,
                                                    shared,
                                                    logger,
                                                    dhandlers,
                                                    device,
                                                    create_dis=use_dis)

    # Mean and variance of prior that is used for variational inference.
    # For a prior-focused training, this prior will only be used for the
    # first task.
    pstd = np.sqrt(config.prior_variance)
    shared.prior_mean = [torch.zeros(*s).to(device) \
                         for s in mnet.param_shapes]
    shared.prior_std = [pstd * torch.ones(*s).to(device) \
                        for s in mnet.param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint all networks, even if they
    # where constructed without weights.
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')
    shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d')
    #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   method,
                                   num_tasks,
                                   mnet,
                                   hnet=hnet,
                                   hhnet=hhnet,
                                   dis=dnet)
    logger.info('Ratio num hnet weights / num mnet weights: %f.' %
                shared.summary['aa_num_weights_hm_ratio'])
    if hhnet is not None:
        logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_hhm_ratio'])
    if mode == 'regression_avb' and dnet is not None:
        # A discriminator only exists for AVB.
        logger.info('Ratio num dis weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_dm_ratio'])

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    hparams_extra_dict = {
        'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'],
        'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio']
    }
    if mode == 'regression_avb':
        hparams_extra_dict['num_weights_dm_ratio'] = \
            shared.summary['aa_num_weights_dm_ratio']
    writer.add_hparams(hparam_dict={
        **vars(config),
        **hparams_extra_dict
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]

        # Train the network.
        tvi.train(i,
                  data,
                  mnet,
                  hnet,
                  hhnet,
                  dnet,
                  device,
                  config,
                  shared,
                  logger,
                  writer,
                  method=method)

        # Test networks.
        train_bbb.test(dhandlers[:(i + 1)],
                       mnet,
                       hnet,
                       device,
                       config,
                       shared,
                       logger,
                       writer,
                       hhnet=hhnet)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            # Note, we only need the discriminator as helper for training,
            # so we don't checkpoint it.
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

            mnet, hnet, hhnet, dnet = pcu.generate_networks(
                config, shared, logger, dhandlers, device)

        elif config.store_during_models:
            logger.info('Checkpointing current model ...')
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmu.checkpoint_nets(config,
                            shared,
                            num_tasks - 1,
                            mnet,
                            hnet,
                            hhnet=hhnet,
                            dis=None)

    logger.info('During MSE values after training each task: %s' % \
                np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
                np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse
Beispiel #5
0
def run(config, experiment='regression_mt'):
    """Run the Multitask training for the given experiment.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - ``'regression_mt'``: Regression tasks with multitask training
            - ``'gmm_mt'``: GMM Data with multitask training
            - ``'split_mnist_mt'``: SplitMNIST with multitask training
            - ``'perm_mnist_mt'``: PermutedMNIST with multitask training
            - ``'cifar_resnet_mt'``: CIFAR-10/100 with multitask training
    """
    assert experiment in [
        'regression_mt', 'gmm_mt', 'split_mnist_mt', 'perm_mnist_mt',
        'cifar_resnet_mt'
    ]

    script_start = time()

    device, writer, logger = sutils.setup_environment(config,
                                                      logger_name=experiment)

    rutils.backup_cli_command(config)

    is_classification = True
    if 'regression' in experiment:
        is_classification = False

    ### Create tasks.
    if is_classification:
        dhandlers = pmutils.load_datasets(config, logger, experiment, writer)
    else:
        dhandlers, num_tasks = rutils.generate_tasks(config, writer)
        config.num_tasks = num_tasks

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.num_trained = 0

    ### Generate network.
    mnet, _, _, _ = pcutils.generate_networks(config,
                                              shared,
                                              logger,
                                              dhandlers,
                                              device,
                                              create_mnet=True,
                                              create_hnet=False,
                                              create_hhnet=dhandlers,
                                              create_dis=False)

    if not is_classification:
        shared.during_mse = np.ones(config.num_tasks) * -1.
        # MSE achieved after most recent call of test method.
        shared.current_mse = np.ones(config.num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    if is_classification:
        shared.softmax_temp = [torch.ones(1).to(device) \
                               for _ in range(config.num_tasks)]

    ### Initialize summary.
    if is_classification:
        pcutils.setup_summary_dict(config, shared, experiment, mnet)
    else:
        rutils.setup_summary_dict(config, shared, 'mt', config.num_tasks, mnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={**vars(config), **{
        'num_weights_main': shared.summary['num_weights_main'] \
            if is_classification else shared.summary['aa_num_weights_main']
    }}, metric_dict={})

    ### Train on all tasks.
    logger.info('### Training ###')
    # Note, since we are training on all tasks; all output heads can at all
    # times be considered as trained!
    shared.num_trained = config.num_tasks
    train(dhandlers, mnet, device, config, shared, logger, writer)

    logger.info('### Testing ###')
    test(dhandlers,
         mnet,
         device,
         config,
         shared,
         logger,
         writer,
         test_ids=None)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet,
                                None)

    ### Plot final classification scores.
    if is_classification:
        logger.info('Final accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given']))

    ### Plot final regression scores.
    if not is_classification:
        logger.info('Final MSE values after training on all tasks: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_final']),
                            precision=5, separator=','))
        logger.info('Final MSE mean %.4f.' % \
                    (shared.summary['aa_mse_during_mean']))

    ### Write final summary.
    shared.summary['finished'] = 1
    if is_classification:
        pmutils.save_summary_dict(config, shared, experiment)
    else:
        rutils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))