def run():
    """Run the script.

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()

    mode = 'regression_bbb'
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Generate networks.
    use_hnet = not config.mnet_only
    mnet, hnet = train_utils.generate_gauss_networks(
        config,
        logger,
        dhandlers,
        device,
        create_hnet=use_hnet,
        non_gaussian=config.mean_only)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    # Mean and variance of prior that is used for variational inference.
    if config.mean_only:  # No prior-matching can be performed.
        shared.prior_mean = None
        shared.prior_logvar = None
        shared.prior_std = None
    else:
        plogvar = np.log(config.prior_variance)
        pstd = np.sqrt(config.prior_variance)
        shared.prior_mean = [torch.zeros(*s).to(device) \
                             for s in mnet.orig_param_shapes]
        shared.prior_logvar = [plogvar * torch.ones(*s).to(device) \
                               for s in mnet.orig_param_shapes]
        shared.prior_std = [pstd * torch.ones(*s).to(device) \
                            for s in mnet.orig_param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some main networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint mnets as well!
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   'bbb',
                                   num_tasks,
                                   mnet,
                                   hnet=hnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={
        **vars(config),
        **{
            'num_weights_main': shared.summary['aa_num_weights_main'],
            'num_weights_hyper': shared.summary['aa_num_weights_hyper'],
            'num_weights_ratio': shared.summary['aa_num_weights_ratio'],
        }
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]
        # Train the network.
        train(i, data, mnet, hnet, device, config, shared, logger, writer)

        ### Test networks.
        test(dhandlers[:(i + 1)], mnet, hnet, device, config, shared, logger,
             writer)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            pmutils.checkpoint_nets(config, shared, i, mnet, hnet)

            mnet, hnet = train_utils.generate_gauss_networks(
                config,
                logger,
                dhandlers,
                device,
                create_hnet=use_hnet,
                non_gaussian=config.mean_only)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, num_tasks - 1, mnet, hnet)

    logger.info('During MSE values after training each task: %s' % \
          np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
          np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse
def test(data_handlers,
         mnet,
         hnet,
         device,
         config,
         shared,
         logger,
         writer,
         hhnet=None,
         save_fig=True):
    """Test the performance of all tasks.

    Tasks are assumed to be regression tasks.

    Args:
        (....): See docstring of method
            :func:`probabilistic.train_vi.train`.
        data_handlers: A list of data handlers, each representing a task.
        save_fig: Whether the figures should be saved in the output folder.
    """
    logger.info('### Testing all trained tasks ... ###')

    if hasattr(config, 'mean_only') and config.mean_only:
        warn('Task inference calculated in test method doesn\'t make any ' +
             'sense, as the deterministic main network has no notion of ' +
             'uncertainty.')

    pcutils.set_train_mode(False, mnet, hnet, hhnet, None)

    n = len(data_handlers)

    disable_lrt_test = config.disable_lrt_test if \
        hasattr(config, 'disable_lrt_test') else None

    if not hasattr(shared, 'current_mse'):
        shared.current_mse = np.ones(n) * -1.
    elif shared.current_mse.size < n:
        tmp = shared.current_mse
        shared.current_mse = np.ones(n) * -1.
        shared.current_mse[:tmp.size] = tmp

    # Current MSE value on test set.
    test_mse = np.ones(n) * -1.
    # Current MSE value using the mean prediction of the inferred embedding.
    inferred_val_mse = np.ones(n) * -1.
    # Task inference accuracies.
    task_infer_val_accs = np.ones(n) * -1.

    with torch.no_grad():
        # We need to keep data for plotting results on all tasks later on.
        val_inputs = []
        val_targets = []  # Needed to compute MSE values.
        val_preds_mean = []
        val_preds_std = []

        # Which uncertainties have been measured per sample and task. The argmax
        # over all tasks gives the predicted task.
        val_task_preds = []

        test_inputs = []
        test_preds_mean = []
        test_preds_std = []

        normal_post = None
        if 'ewc' in shared.experiment_type:
            assert hnet is None
            normal_post = ewcutil.build_ewc_posterior(data_handlers,
                                                      mnet,
                                                      device,
                                                      config,
                                                      shared,
                                                      logger,
                                                      writer,
                                                      n,
                                                      task_id=n - 1)

        if config.train_from_scratch and n > 1:
            # We need to iterate over different networks when we want to
            # measure the uncertainty of dataset i on task j.
            # Note, we will always load the corresponding checkpoint of task j
            # before using these networks.
            if 'avb' in shared.experiment_type \
                    or 'ssge' in shared.experiment_type\
                    or 'ewc' in shared.experiment_type:
                mnet_other, hnet_other, hhnet_other, _ = \
                    pcutils.generate_networks(config, shared, logger,
                        shared.all_dhandlers, device, create_dis=False)
            else:
                assert hhnet is None
                hhnet_other = None
                non_gaussian = config.mean_only \
                    if hasattr(config, 'mean_only') else True
                mnet_other, hnet_other = train_utils.generate_gauss_networks( \
                    config, logger, shared.all_dhandlers, device,
                    create_hnet=hnet is not None, non_gaussian=non_gaussian)

            pcutils.set_train_mode(False, mnet_other, hnet_other, hhnet_other,
                                   None)

        task_n_mnet = mnet
        task_n_hnet = hnet
        task_n_hhnet = hhnet
        task_n_normal_post = normal_post

        # This renaming is just a protection against myself, that I don't use
        # any of those networks (`mnet`, `hnet`, `hhnet`) in the future
        # inside the loop when training from scratch.
        if config.train_from_scratch:
            mnet = None
            hnet = None
            hhnet = None
            normal_post = None

        ### For each data set (i.e., for each task).
        for i in range(n):
            data = data_handlers[i]

            ### We want to measure MSE values within the training range only!
            split_type = 'val'
            num_val_samples = data.num_val_samples
            if num_val_samples == 0:
                split_type = 'train'
                num_val_samples = data.num_train_samples

                logger.debug('Test: Task %d - Using training set as no ' % i +
                             'validation set is available.')

            ### Task inference.
            # We need to iterate over each task embedding and measure the
            # predictive uncertainty in order to decide which embedding to use.
            data_preds = np.empty((num_val_samples, config.val_sample_size, n))
            data_preds_mean = np.empty((num_val_samples, n))
            data_preds_std = np.empty((num_val_samples, n))

            for j in range(n):
                ckpt_score_j = None
                if config.train_from_scratch and j == (n - 1):
                    # Note, the networks trained on dataset (n-1) haven't been
                    # checkpointed yet.
                    mnet_j = task_n_mnet
                    hnet_j = task_n_hnet
                    hhnet_j = task_n_hhnet
                    normal_post_j = task_n_normal_post
                elif config.train_from_scratch:
                    ckpt_score_j = pmutils.load_networks(shared,
                                                         j,
                                                         device,
                                                         logger,
                                                         mnet_other,
                                                         hnet_other,
                                                         hhnet=hhnet_other,
                                                         dis=None)
                    mnet_j = mnet_other
                    hnet_j = hnet_other
                    hhnet_j = hhnet_other

                    if 'ewc' in shared.experiment_type:
                        normal_post_j = ewcutil.build_ewc_posterior( \
                            data_handlers, mnet_j, device, config, shared,
                            logger, writer, n, task_id=j)
                else:
                    mnet_j = mnet
                    hnet_j = hnet
                    hhnet_j = hhnet
                    normal_post_j = normal_post

                mse_val, val_struct = train_utils.compute_mse(
                    j,
                    data,
                    mnet_j,
                    hnet_j,
                    device,
                    config,
                    shared,
                    hhnet=hhnet_j,
                    split_type=split_type,
                    return_dataset=i == j,
                    return_predictions=True,
                    disable_lrt=disable_lrt_test,
                    normal_post=normal_post_j)

                if i == j:  # I.e., we used the correct embedding.
                    # This sanity check is likely to fail as we don't
                    # deterministically sample the models.
                    #if ckpt_score_j is not None:
                    #    assert np.allclose(-mse_val, ckpt_score_j)

                    val_inputs.append(val_struct.inputs)
                    val_targets.append(val_struct.targets)
                    val_preds_mean.append(val_struct.predictions.mean(axis=1))
                    val_preds_std.append(val_struct.predictions.std(axis=1))

                    shared.current_mse[i] = mse_val

                    logger.debug('Test: Task %d - Mean MSE on %s set: %f ' %
                                 (i, split_type, mse_val) + '(std: %g).' %
                                 (val_struct.mse_vals.std()))
                    writer.add_scalar('test/task_%d/val_mse' % i,
                                      shared.current_mse[i], n)

                    # The test set spans into the OOD range and can be used to
                    # visualize how uncertainty behaves outside the
                    # in-distribution range.
                    mse_test, test_struct = train_utils.compute_mse(
                        i,
                        data,
                        mnet_j,
                        hnet_j,
                        device,
                        config,
                        shared,
                        hhnet=hhnet_j,
                        split_type='test',
                        return_dataset=True,
                        return_predictions=True,
                        disable_lrt=disable_lrt_test,
                        normal_post=normal_post_j)

                data_preds[:, :, j] = val_struct.predictions
                data_preds_mean[:, j] = val_struct.predictions.mean(axis=1)
                ### We interpret this value as the certainty of the prediction.
                # I.e., how certain is our system that each of the samples
                # belong to task j?
                data_preds_std[:, j] = val_struct.predictions.std(axis=1)

            val_task_preds.append(data_preds_std)

            ### Compute task inference accuracy.
            inferred_task_ids = data_preds_std.argmin(axis=1)
            num_correct = np.sum(inferred_task_ids == i)
            accuracy = 100. * num_correct / num_val_samples
            task_infer_val_accs[i] = accuracy

            logger.debug('Test: Task %d - Accuracy of task inference ' % i +
                         'on %s set: %.2f%%.' % (split_type, accuracy))
            writer.add_scalar('test/task_%d/accuracy' % i, accuracy, n)

            ### Compute MSE based on inferred embedding.

            # Note, this (commented) way of computing the mean does not take
            # into account the variance of the predictive distribution, which is
            # why we don't use it (see docstring of `compute_mse`).
            #means_of_inferred_preds = data_preds_mean[np.arange( \
            #    data_preds_mean.shape[0]), inferred_task_ids]
            #inferred_val_mse[i] = np.power(means_of_inferred_preds -
            #                               val_targets[-1].squeeze(), 2).mean()

            inferred_preds = data_preds[np.arange(data_preds.shape[0]), :,
                                        inferred_task_ids]
            inferred_val_mse[i] = np.power(inferred_preds - \
                val_targets[-1].squeeze()[:, np.newaxis], 2).mean()

            logger.debug('Test: Task %d - Mean MSE on %s set using inferred '\
                         % (i, split_type) + 'embeddings: %f.'
                         % (inferred_val_mse[i]))
            writer.add_scalar('test/task_%d/inferred_val_mse' % i,
                              inferred_val_mse[i], n)

            ### We are interested in the predictive uncertainty across the
            ### whole test range!
            test_mse[i] = mse_test
            writer.add_scalar('test/task_%d/test_mse' % i, test_mse[i], n)

            test_inputs.append(test_struct.inputs.squeeze())
            test_preds_mean.append(test_struct.predictions.mean(axis=1). \
                                   squeeze())
            test_preds_std.append(
                test_struct.predictions.std(axis=1).squeeze())

            if hasattr(shared, 'during_mse') and \
                    shared.during_mse[i] == -1:
                shared.during_mse[i] = shared.current_mse[i]

            if test_struct.w_hnet is not None or test_struct.w_mean is not None:
                assert hasattr(shared, 'during_weights')
                if test_struct.w_hnet is not None:
                    # We have a hyper-hypernetwork. In this case, the CL
                    # regularizer is applied to its output and therefore, these
                    # are the during weights whose Euclidean distance we want to
                    # track.
                    assert task_n_hhnet is not None
                    w_all = test_struct.w_hnet
                else:
                    assert test_struct.w_mean is not None
                    # We will be here whenever the hnet is deterministic (i.e.,
                    # doesn't represent an implicit distribution).
                    w_all = list(test_struct.w_mean)
                    if test_struct.w_std is not None:
                        w_all += list(test_struct.w_std)

                W_curr = torch.cat([d.clone().view(-1) for d in w_all])
                if type(shared.during_weights[i]) == int:
                    assert (shared.during_weights[i] == -1)
                    shared.during_weights[i] = W_curr
                else:
                    W_during = shared.during_weights[i]
                    W_dis = torch.norm(W_curr - W_during, 2)
                    logger.info('Euclidean distance between hypernet output ' +
                                'for task %d: %g' % (i, W_dis))

    ### Compute overall task inference accuracy.
    num_correct = 0
    num_samples = 0
    for i, uncertainties in enumerate(val_task_preds):
        pred_task_ids = uncertainties.argmin(axis=1)
        num_correct += np.sum(pred_task_ids == i)
        num_samples += pred_task_ids.size

    accuracy = 100. * num_correct / num_samples
    logger.info('Task inference accuracy: %.2f%%.' % accuracy)

    # TODO Compute overall MSE on all tasks using inferred embeddings.

    ### Plot the mean predictions on all tasks.
    # (Using the validation set and the correct embedding per dataset)
    plot_x_ranges = []
    for i in range(n):
        plot_x_ranges.append(data_handlers[i].train_x_range)

    fig_fn = None
    if save_fig:
        fig_fn = os.path.join(config.out_dir, 'val_predictions_%d' % n)

    data_inputs = val_inputs
    mean_preds = val_preds_mean
    data_handlers[0].plot_datasets(data_handlers,
                                   data_inputs,
                                   mean_preds,
                                   fun_xranges=plot_x_ranges,
                                   filename=fig_fn,
                                   show=False,
                                   publication_style=config.publication_style)
    writer.add_figure('test/val_predictions',
                      plt.gcf(),
                      n,
                      close=not config.show_plots)
    if config.show_plots:
        utils.repair_canvas_and_show_fig(plt.gcf())

    ### Scatter plot showing MSE per task (original + current one).
    during_mse = None
    if hasattr(shared, 'during_mse'):
        during_mse = shared.during_mse[:n]
    train_utils.plot_mse(config,
                         writer,
                         n,
                         shared.current_mse[:n],
                         during_mse,
                         save_fig=save_fig)
    additional_plots = {
        'Current Inferred Val MSE': inferred_val_mse,
        #'Current Test MSE': test_mse
    }
    train_utils.plot_mse(config,
                         writer,
                         n,
                         shared.current_mse[:n],
                         during_mse,
                         baselines=additional_plots,
                         save_fig=False,
                         summary_label='test/mse_detailed')

    ### Plot predictive distributions over test range for all tasks.
    data_inputs = test_inputs
    mean_preds = test_preds_mean
    std_preds = test_preds_std

    train_utils.plot_predictive_distributions(
        config,
        writer,
        data_handlers,
        data_inputs,
        mean_preds,
        std_preds,
        save_fig=save_fig,
        publication_style=config.publication_style)

    logger.info('Mean task MSE: %f (std: %d)' %
                (shared.current_mse[:n].mean(), shared.current_mse[:n].std()))

    ### Update performance summary.
    s = shared.summary
    s['aa_mse_during'][:n] = shared.during_mse[:n].tolist()
    s['aa_mse_during_mean'] = shared.during_mse[:n].mean()
    s['aa_mse_final'][:n] = shared.current_mse[:n].tolist()
    s['aa_mse_final_mean'] = shared.current_mse[:n].mean()

    s['aa_task_inference'][:n] = task_infer_val_accs.tolist()
    s['aa_task_inference_mean'] = task_infer_val_accs.mean()

    s['aa_mse_during_inferred'][n - 1] = inferred_val_mse[n - 1]
    s['aa_mse_during_inferred_mean'] = np.mean(s['aa_mse_during_inferred'][:n])
    s['aa_mse_final_inferred'] = inferred_val_mse[:n].tolist()
    s['aa_mse_final_inferred_mean'] = inferred_val_mse[:n].mean()

    train_utils.save_summary_dict(config, shared)

    logger.info('### Testing all trained tasks ... Done ###')
Beispiel #3
0
def run(method='avb'):
    """Run the script.

    Args:
        method (str, optional): The VI algorithm. Possible values are:

            - ``'avb'``
            - ``'ssge'``

    Returns:
        (tuple): Tuple containing:

        - **final_mse**: Final MSE for each task.
        - **during_mse**: MSE achieved directly after training on each task.
    """
    script_start = time()
    mode = 'regression_' + method
    use_dis = False  # whether a discriminator network is used
    if method == 'avb':
        use_dis = True
    config = train_args.parse_cmd_arguments(mode=mode)

    device, writer, logger = sutils.setup_environment(config, logger_name=mode)

    train_utils.backup_cli_command(config)

    if config.prior_focused:
        logger.info('Running a prior-focused CL experiment ...')
    else:
        logger.info('Learning task-specific posteriors sequentially ...')

    ### Create tasks.
    dhandlers, num_tasks = train_utils.generate_tasks(config, writer)

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = mode
    shared.all_dhandlers = dhandlers
    shared.prior_focused = config.prior_focused

    ### Generate networks and environment
    mnet, hnet, hhnet, dnet = pcu.generate_networks(config,
                                                    shared,
                                                    logger,
                                                    dhandlers,
                                                    device,
                                                    create_dis=use_dis)

    # Mean and variance of prior that is used for variational inference.
    # For a prior-focused training, this prior will only be used for the
    # first task.
    pstd = np.sqrt(config.prior_variance)
    shared.prior_mean = [torch.zeros(*s).to(device) \
                         for s in mnet.param_shapes]
    shared.prior_std = [pstd * torch.ones(*s).to(device) \
                        for s in mnet.param_shapes]

    # Note, all MSE values are measured on a validation set if given, otherwise
    # on the training set. All samples in the validation set are expected to
    # lay inside the training range. Test samples may lay outside the training
    # range.
    # The MSE value achieved right after training on the corresponding task.
    shared.during_mse = np.ones(num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones). Note, weights refer to mean and variances (e.g., the output of the
    # hypernetwork).
    shared.during_weights = [-1] * num_tasks
    # MSE achieved after most recent call of test method.
    shared.current_mse = np.ones(num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    # Note, some networks have stuff to store such as batch statistics for
    # batch norm. So it is wise to always checkpoint all networks, even if they
    # where constructed without weights.
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')
    shared.ckpt_hnet_fn = os.path.join(shared.ckpt_dir, 'hnet_task_%d')
    shared.ckpt_hhnet_fn = os.path.join(shared.ckpt_dir, 'hhnet_task_%d')
    #shared.ckpt_dis_fn = os.path.join(shared.ckpt_dir, 'dis_task_%d')

    ### Initialize the performance measures, that should be tracked during
    ### training.
    train_utils.setup_summary_dict(config,
                                   shared,
                                   method,
                                   num_tasks,
                                   mnet,
                                   hnet=hnet,
                                   hhnet=hhnet,
                                   dis=dnet)
    logger.info('Ratio num hnet weights / num mnet weights: %f.' %
                shared.summary['aa_num_weights_hm_ratio'])
    if hhnet is not None:
        logger.info('Ratio num hyper-hnet weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_hhm_ratio'])
    if mode == 'regression_avb' and dnet is not None:
        # A discriminator only exists for AVB.
        logger.info('Ratio num dis weights / num mnet weights: %f.' %
                    shared.summary['aa_num_weights_dm_ratio'])

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    hparams_extra_dict = {
        'num_weights_hm_ratio': shared.summary['aa_num_weights_hm_ratio'],
        'num_weights_hhm_ratio': shared.summary['aa_num_weights_hhm_ratio']
    }
    if mode == 'regression_avb':
        hparams_extra_dict['num_weights_dm_ratio'] = \
            shared.summary['aa_num_weights_dm_ratio']
    writer.add_hparams(hparam_dict={
        **vars(config),
        **hparams_extra_dict
    },
                       metric_dict={})

    ### Train on tasks sequentially.
    for i in range(num_tasks):
        logger.info('### Training on task %d ###' % (i + 1))
        data = dhandlers[i]

        # Train the network.
        tvi.train(i,
                  data,
                  mnet,
                  hnet,
                  hhnet,
                  dnet,
                  device,
                  config,
                  shared,
                  logger,
                  writer,
                  method=method)

        # Test networks.
        train_bbb.test(dhandlers[:(i + 1)],
                       mnet,
                       hnet,
                       device,
                       config,
                       shared,
                       logger,
                       writer,
                       hhnet=hhnet)

        if config.train_from_scratch and i < num_tasks - 1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            # Note, we only need the discriminator as helper for training,
            # so we don't checkpoint it.
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

            mnet, hnet, hhnet, dnet = pcu.generate_networks(
                config, shared, logger, dhandlers, device)

        elif config.store_during_models:
            logger.info('Checkpointing current model ...')
            pmu.checkpoint_nets(config,
                                shared,
                                i,
                                mnet,
                                hnet,
                                hhnet=hhnet,
                                dis=None)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmu.checkpoint_nets(config,
                            shared,
                            num_tasks - 1,
                            mnet,
                            hnet,
                            hhnet=hhnet,
                            dis=None)

    logger.info('During MSE values after training each task: %s' % \
                np.array2string(shared.during_mse, precision=5, separator=','))
    logger.info('Final MSE values after training on all tasks: %s' % \
                np.array2string(shared.current_mse, precision=5, separator=','))
    logger.info('Final MSE mean %.4f (std %.4f).' %
                (shared.current_mse.mean(), shared.current_mse.std()))

    ### Write final summary.
    shared.summary['finished'] = 1
    train_utils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))

    return shared.current_mse, shared.during_mse
def run(config, experiment='regression_ewc'):
    """Run the EWC training for the given experiment.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - ``'regression_ewc'``: Regression tasks with EWC
            - ``'gmm_ewc'``: GMM Data with EWC
            - ``'split_mnist_ewc'``: SplitMNIST with EWC
            - ``'perm_mnist_ewc'``: PermutedMNIST with EWC
            - ``'cifar_resnet_ewc'``: CIFAR-10/100 with EWC
    """
    assert experiment in ['regression_ewc', 'gmm_ewc', 'split_mnist_ewc',
                          'perm_mnist_ewc', 'cifar_resnet_ewc']

    script_start = time()

    device, writer, logger = sutils.setup_environment(config,
        logger_name=experiment)

    rutils.backup_cli_command(config)

    is_classification = True
    if 'regression' in experiment:
        is_classification = False

    ### Create tasks.
    if is_classification:
        dhandlers = pmutils.load_datasets(config, logger, experiment, writer)
    else:
        dhandlers, num_tasks = rutils.generate_tasks(config, writer)
        config.num_tasks = num_tasks

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.num_trained = 0

    ### Generate network.
    mnet, _, _, _ = pcutils.generate_networks(config, shared, logger, dhandlers,
        device, create_mnet=True, create_hnet=False, create_hhnet=dhandlers,
        create_dis=False)

    if not is_classification:
        shared.during_mse = np.ones(config.num_tasks) * -1.
        # MSE achieved after most recent call of test method.
        shared.current_mse = np.ones(config.num_tasks) * -1.
    # The weights of the main network right after training on that task
    # (can be used to assess how close the final weights are to the original
    # ones).
    shared.during_weights = [-1] * config.num_tasks

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    if is_classification:
        shared.softmax_temp = [torch.ones(1).to(device) \
                               for _ in range(config.num_tasks)]

    ### Initialize summary.
    if is_classification:
        pcutils.setup_summary_dict(config, shared, experiment, mnet)
    else:
        rutils.setup_summary_dict(config, shared, 'ewc', config.num_tasks, mnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={**vars(config), **{
        'num_weights_main': shared.summary['num_weights_main'] \
            if is_classification else shared.summary['aa_num_weights_main']
    }}, metric_dict={})

    if is_classification:
        during_acc_criterion = pmutils.parse_performance_criterion(config,
            shared, logger)

    ### Train on tasks sequentially.
    for i in range(config.num_tasks):
        logger.info('### Training on task %d ###' % (i+1))
        data = dhandlers[i]

        # Train the network.
        shared.num_trained += 1
        train(i, data, mnet, device, config, shared, logger, writer)

        ### Test networks.
        test_ids = None
        if hasattr(config, 'full_test_interval') and \
                config.full_test_interval != -1:
            if i == config.num_tasks-1 or \
                    (i > 0 and i % config.full_test_interval == 0):
                test_ids = None # Test on all tasks.
            else:
                test_ids = [i] # Only test on current task.
        test(dhandlers[:(i+1)], mnet, device, config, shared, logger, writer,
             test_ids=test_ids)

        ### Check if last task got "acceptable" accuracy ###
        if is_classification:
            curr_dur_acc = shared.summary['acc_task_given_during'][i]
        if is_classification and i < config.num_tasks-1 \
                and during_acc_criterion[i] != -1 \
                and during_acc_criterion[i] > curr_dur_acc:
            logger.error('During accuracy of task %d too small (%f < %f).' % \
                         (i+1, curr_dur_acc, during_acc_criterion[i]))
            logger.error('Training of future tasks will be skipped')
            writer.close()
            exit(1)

        if config.train_from_scratch and i < config.num_tasks-1:
            # We have to checkpoint the networks, such that we can reload them
            # for task inference later during testing.
            pmutils.checkpoint_nets(config, shared, i, mnet, None)

            mnet, _, _, _ = pcutils.generate_networks(config, shared, logger,
                dhandlers, device, create_mnet=True, create_hnet=False,
                create_hhnet=dhandlers, create_dis=False)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, config.num_tasks-1, mnet, None)

    ### Plot final classification scores.
    if is_classification:
        logger.info('During accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given_during']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given_during']))
        logger.info('Final accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given']))

    if is_classification and config.cl_scenario == 3 and config.split_head_cl3:
        logger.info('During accuracies (task identity inferred): ' +
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array( \
                shared.summary['acc_task_inferred_ent_during']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_inferred_ent_during']))
        logger.info('Final accuracies (task identity inferred): ' +
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_inferred_ent']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_inferred_ent']))

    ### Plot final regression scores.
    if not is_classification:
        logger.info('During MSE values after training each task: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_during']),
                            precision=5, separator=','))
        logger.info('Final MSE values after training on all tasks: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_final']),
                            precision=5, separator=','))
        logger.info('Final MSE mean %.4f.' % \
                    (shared.summary['aa_mse_during_mean']))

    ### Write final summary.
    shared.summary['finished'] = 1
    if is_classification:
        pmutils.save_summary_dict(config, shared, experiment)
    else:
        rutils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.'
                % (time() - script_start))
Beispiel #5
0
def run(config, experiment='regression_mt'):
    """Run the Multitask training for the given experiment.

    Args:
        config (argparse.Namespace): Command-line arguments.
        experiment (str): Which kind of experiment should be performed?

            - ``'regression_mt'``: Regression tasks with multitask training
            - ``'gmm_mt'``: GMM Data with multitask training
            - ``'split_mnist_mt'``: SplitMNIST with multitask training
            - ``'perm_mnist_mt'``: PermutedMNIST with multitask training
            - ``'cifar_resnet_mt'``: CIFAR-10/100 with multitask training
    """
    assert experiment in [
        'regression_mt', 'gmm_mt', 'split_mnist_mt', 'perm_mnist_mt',
        'cifar_resnet_mt'
    ]

    script_start = time()

    device, writer, logger = sutils.setup_environment(config,
                                                      logger_name=experiment)

    rutils.backup_cli_command(config)

    is_classification = True
    if 'regression' in experiment:
        is_classification = False

    ### Create tasks.
    if is_classification:
        dhandlers = pmutils.load_datasets(config, logger, experiment, writer)
    else:
        dhandlers, num_tasks = rutils.generate_tasks(config, writer)
        config.num_tasks = num_tasks

    ### Simple struct, that is used to share data among functions.
    shared = Namespace()
    shared.experiment_type = experiment
    shared.all_dhandlers = dhandlers
    shared.num_trained = 0

    ### Generate network.
    mnet, _, _, _ = pcutils.generate_networks(config,
                                              shared,
                                              logger,
                                              dhandlers,
                                              device,
                                              create_mnet=True,
                                              create_hnet=False,
                                              create_hhnet=dhandlers,
                                              create_dis=False)

    if not is_classification:
        shared.during_mse = np.ones(config.num_tasks) * -1.
        # MSE achieved after most recent call of test method.
        shared.current_mse = np.ones(config.num_tasks) * -1.

    # Where to save network checkpoints?
    shared.ckpt_dir = os.path.join(config.out_dir, 'checkpoints')
    shared.ckpt_mnet_fn = os.path.join(shared.ckpt_dir, 'mnet_task_%d')

    # Initialize the softmax temperature per-task with one. Might be changed
    # later on to calibrate the temperature.
    if is_classification:
        shared.softmax_temp = [torch.ones(1).to(device) \
                               for _ in range(config.num_tasks)]

    ### Initialize summary.
    if is_classification:
        pcutils.setup_summary_dict(config, shared, experiment, mnet)
    else:
        rutils.setup_summary_dict(config, shared, 'mt', config.num_tasks, mnet)

    # Add hparams to tensorboard, such that the identification of runs is
    # easier.
    writer.add_hparams(hparam_dict={**vars(config), **{
        'num_weights_main': shared.summary['num_weights_main'] \
            if is_classification else shared.summary['aa_num_weights_main']
    }}, metric_dict={})

    ### Train on all tasks.
    logger.info('### Training ###')
    # Note, since we are training on all tasks; all output heads can at all
    # times be considered as trained!
    shared.num_trained = config.num_tasks
    train(dhandlers, mnet, device, config, shared, logger, writer)

    logger.info('### Testing ###')
    test(dhandlers,
         mnet,
         device,
         config,
         shared,
         logger,
         writer,
         test_ids=None)

    if config.store_final_model:
        logger.info('Checkpointing final model ...')
        pmutils.checkpoint_nets(config, shared, config.num_tasks - 1, mnet,
                                None)

    ### Plot final classification scores.
    if is_classification:
        logger.info('Final accuracies (task identity given): ' + \
                    '%s (avg: %.2f%%).' % \
            (np.array2string(np.array(shared.summary['acc_task_given']),
                             precision=2, separator=','),
             shared.summary['acc_avg_task_given']))

    ### Plot final regression scores.
    if not is_classification:
        logger.info('Final MSE values after training on all tasks: %s' % \
            np.array2string(np.array(shared.summary['aa_mse_final']),
                            precision=5, separator=','))
        logger.info('Final MSE mean %.4f.' % \
                    (shared.summary['aa_mse_during_mean']))

    ### Write final summary.
    shared.summary['finished'] = 1
    if is_classification:
        pmutils.save_summary_dict(config, shared, experiment)
    else:
        rutils.save_summary_dict(config, shared)

    writer.close()

    logger.info('Program finished successfully in %f sec.' %
                (time() - script_start))