예제 #1
0
def _generate_1d_tasks(show_plots=True, data_random_seed=42, writer=None):
    """Generate a set of tasks for 1D regression.

    Args:
        show_plots: Visualize the generated datasets.
        data_random_seed: Random seed that should be applied to the
            synthetic data generation.
        writer: Tensorboard writer, in case plots should be logged.

    Returns:
        data_handlers: A data handler for each task (instance of class
            "ToyRegression").
        num_tasks: Number of generated tasks.
    """
    # FIXME task generation currently not controlled by user via command-line.
    if False:  # Set of random polynomials.
        num_tasks = 20
        x_domains = [[-10, 10]] * num_tasks

        # Disjoint x domains.
        # tmp = np.linspace(-10, 10, num_tasks+1)
        # x_domains = list(zip(tmp[:-1], tmp[1:]))

        max_degree = 6
        pcoeffs = np.random.uniform(-1, 1, size=(num_tasks, max_degree + 1))

        map_funcs = []
        for i in range(num_tasks):
            d = np.random.randint(0, max_degree)
            # Ignore highest degrees.
            c = pcoeffs[i, d:]

            # Decrease the magnitute of higher order coefficients.
            f = .5
            for j in range(c.size - 1, -1, -1):
                c[j] *= f
                f *= f

            # We want the border points of all polynomials to not exceed a
            # certain absolute magnitude.
            bp = np.polyval(c, x_domains[i])
            s = np.max(np.abs(bp)) + 1e-5
            c = c / s * 10.

            map_funcs.append(lambda x, c=c: np.polyval(c, x))

        std = .1

    else:  # Manually selected tasks.
        """
        tmp = np.linspace(-15, 15, num_tasks + 1)
        x_domains = list(zip(tmp[:-1], tmp[1:]))
        map_funcs = [lambda x: 2. * (x+10.),
                     lambda x: np.power(x, 2) * 2./2.5 - 10,
                     lambda x: np.power(x-10., 3) * 1./12.5]
        std = 1.
        """
        """
        map_funcs = [lambda x : 0.1 * x, lambda x : np.power(x, 2) * 1e-2,
                     lambda x : np.power(x, 3) * 1e-3]
        num_tasks = len(map_funcs)
        x_domains = [[-10, 10]] * num_tasks
        std = .1
        """

        map_funcs = [
            lambda x: (x + 3.), lambda x: 2. * np.power(x, 2) - 1,
            lambda x: np.power(x - 3., 3)
        ]
        num_tasks = len(map_funcs)
        x_domains = [[-4, -2], [-1, 1], [2, 4]]
        std = .05
        """
        map_funcs = [lambda x : (x+30.),
                     lambda x : .2 * np.power(x, 2) - 10,
                     lambda x : 1e-2 * np.power(x-30., 3)]
        num_tasks = len(map_funcs)
        x_domains = [[-40,-20], [-10,10], [20,40]]
        std = .5
        """

    dhandlers = []
    for i in range(num_tasks):
        print('Generating %d-th task.' % (i))
        dhandlers.append(
            ToyRegression(train_inter=x_domains[i],
                          num_train=100,
                          test_inter=x_domains[i],
                          num_test=50,
                          map_function=map_funcs[i],
                          std=std,
                          rseed=data_random_seed))

        if writer is not None:
            dhandlers[-1].plot_dataset(show=False)
            writer.add_figure('task_%d/dataset' % i,
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        elif show_plots:
            dhandlers[-1].plot_dataset()

    return dhandlers, num_tasks
def evaluate(task_id,
             data,
             mnet,
             hnet,
             device,
             config,
             shared,
             logger,
             writer,
             train_iter=None):
    """Evaluate the training progress.

    Evaluate the performance of the network on a single task (that is currently
    being trained) on the validation set.

    Note, if no validation set is available, the test set will be used instead.

    Args:
        (....): See docstring of method :func:`train`. Note, `hnet` can be
            passed as :code:`None`. In this case, no weights are passed to the
            `forward` method of the main network.
        train_iter: The current training iteration. If not given, the `writer`
            will not be used.
    """
    if train_iter is None:
        logger.info('# Evaluating training ...')
    else:
        logger.info('# Evaluating network on task %d ' % (task_id + 1) +
                    'before running training step %d ...' % (train_iter))

    # TODO: write histograms of weight samples to tensorboard.

    mnet.eval()
    if hnet is not None:
        hnet.eval()

    with torch.no_grad():
        # Note, if no validation set exists, we use the training data to compute
        # the MSE (note, test data may contain out-of-distribution data in our
        # setup).
        split_type = 'train' if data.num_val_samples == 0 else 'val'
        if split_type == 'train':
            logger.debug('Eval - Using training set as no validation set is ' +
                         'available.')

        mse_val, val_struct = train_utils.compute_mse(task_id,
                                                      data,
                                                      mnet,
                                                      hnet,
                                                      device,
                                                      config,
                                                      shared,
                                                      split_type=split_type)
        ident = 'training' if split_type == 'train' else 'validation'

        logger.info('Eval - Mean MSE on %s set: %f (std: %g).' %
                    (ident, mse_val, val_struct.mse_vals.std()))

        # In contrast, we visualize uncertainty using the test set.
        mse_test, test_struct = train_utils.compute_mse(
            task_id,
            data,
            mnet,
            hnet,
            device,
            config,
            shared,
            split_type='test',
            return_dataset=True,
            return_predictions=True)
        logger.debug('Eval - Mean MSE on test set: %f (std: %g).' %
                     (mse_test, test_struct.mse_vals.std()))

        if config.show_plots or train_iter is not None:
            train_utils.plot_predictive_distribution(data,
                                                     test_struct.inputs,
                                                     test_struct.predictions,
                                                     show_raw_pred=True,
                                                     figsize=(10, 4),
                                                     show=train_iter is None)
            if train_iter is not None:
                writer.add_figure('task_%d/predictions' % task_id,
                                  plt.gcf(),
                                  train_iter,
                                  close=not config.show_plots)
                if config.show_plots:
                    utils.repair_canvas_and_show_fig(plt.gcf())

                writer.add_scalar('eval/task_%d/val_mse' % task_id, mse_val,
                                  train_iter)
                writer.add_scalar('eval/task_%d/test_mse' % task_id, mse_test,
                                  train_iter)

        logger.info('# Evaluating training ... Done')
def test(data_handlers,
         mnet,
         hnet,
         device,
         config,
         shared,
         logger,
         writer,
         hhnet=None,
         save_fig=True):
    """Test the performance of all tasks.

    Tasks are assumed to be regression tasks.

    Args:
        (....): See docstring of method
            :func:`probabilistic.train_vi.train`.
        data_handlers: A list of data handlers, each representing a task.
        save_fig: Whether the figures should be saved in the output folder.
    """
    logger.info('### Testing all trained tasks ... ###')

    if hasattr(config, 'mean_only') and config.mean_only:
        warn('Task inference calculated in test method doesn\'t make any ' +
             'sense, as the deterministic main network has no notion of ' +
             'uncertainty.')

    pcutils.set_train_mode(False, mnet, hnet, hhnet, None)

    n = len(data_handlers)

    disable_lrt_test = config.disable_lrt_test if \
        hasattr(config, 'disable_lrt_test') else None

    if not hasattr(shared, 'current_mse'):
        shared.current_mse = np.ones(n) * -1.
    elif shared.current_mse.size < n:
        tmp = shared.current_mse
        shared.current_mse = np.ones(n) * -1.
        shared.current_mse[:tmp.size] = tmp

    # Current MSE value on test set.
    test_mse = np.ones(n) * -1.
    # Current MSE value using the mean prediction of the inferred embedding.
    inferred_val_mse = np.ones(n) * -1.
    # Task inference accuracies.
    task_infer_val_accs = np.ones(n) * -1.

    with torch.no_grad():
        # We need to keep data for plotting results on all tasks later on.
        val_inputs = []
        val_targets = []  # Needed to compute MSE values.
        val_preds_mean = []
        val_preds_std = []

        # Which uncertainties have been measured per sample and task. The argmax
        # over all tasks gives the predicted task.
        val_task_preds = []

        test_inputs = []
        test_preds_mean = []
        test_preds_std = []

        normal_post = None
        if 'ewc' in shared.experiment_type:
            assert hnet is None
            normal_post = ewcutil.build_ewc_posterior(data_handlers,
                                                      mnet,
                                                      device,
                                                      config,
                                                      shared,
                                                      logger,
                                                      writer,
                                                      n,
                                                      task_id=n - 1)

        if config.train_from_scratch and n > 1:
            # We need to iterate over different networks when we want to
            # measure the uncertainty of dataset i on task j.
            # Note, we will always load the corresponding checkpoint of task j
            # before using these networks.
            if 'avb' in shared.experiment_type \
                    or 'ssge' in shared.experiment_type\
                    or 'ewc' in shared.experiment_type:
                mnet_other, hnet_other, hhnet_other, _ = \
                    pcutils.generate_networks(config, shared, logger,
                        shared.all_dhandlers, device, create_dis=False)
            else:
                assert hhnet is None
                hhnet_other = None
                non_gaussian = config.mean_only \
                    if hasattr(config, 'mean_only') else True
                mnet_other, hnet_other = train_utils.generate_gauss_networks( \
                    config, logger, shared.all_dhandlers, device,
                    create_hnet=hnet is not None, non_gaussian=non_gaussian)

            pcutils.set_train_mode(False, mnet_other, hnet_other, hhnet_other,
                                   None)

        task_n_mnet = mnet
        task_n_hnet = hnet
        task_n_hhnet = hhnet
        task_n_normal_post = normal_post

        # This renaming is just a protection against myself, that I don't use
        # any of those networks (`mnet`, `hnet`, `hhnet`) in the future
        # inside the loop when training from scratch.
        if config.train_from_scratch:
            mnet = None
            hnet = None
            hhnet = None
            normal_post = None

        ### For each data set (i.e., for each task).
        for i in range(n):
            data = data_handlers[i]

            ### We want to measure MSE values within the training range only!
            split_type = 'val'
            num_val_samples = data.num_val_samples
            if num_val_samples == 0:
                split_type = 'train'
                num_val_samples = data.num_train_samples

                logger.debug('Test: Task %d - Using training set as no ' % i +
                             'validation set is available.')

            ### Task inference.
            # We need to iterate over each task embedding and measure the
            # predictive uncertainty in order to decide which embedding to use.
            data_preds = np.empty((num_val_samples, config.val_sample_size, n))
            data_preds_mean = np.empty((num_val_samples, n))
            data_preds_std = np.empty((num_val_samples, n))

            for j in range(n):
                ckpt_score_j = None
                if config.train_from_scratch and j == (n - 1):
                    # Note, the networks trained on dataset (n-1) haven't been
                    # checkpointed yet.
                    mnet_j = task_n_mnet
                    hnet_j = task_n_hnet
                    hhnet_j = task_n_hhnet
                    normal_post_j = task_n_normal_post
                elif config.train_from_scratch:
                    ckpt_score_j = pmutils.load_networks(shared,
                                                         j,
                                                         device,
                                                         logger,
                                                         mnet_other,
                                                         hnet_other,
                                                         hhnet=hhnet_other,
                                                         dis=None)
                    mnet_j = mnet_other
                    hnet_j = hnet_other
                    hhnet_j = hhnet_other

                    if 'ewc' in shared.experiment_type:
                        normal_post_j = ewcutil.build_ewc_posterior( \
                            data_handlers, mnet_j, device, config, shared,
                            logger, writer, n, task_id=j)
                else:
                    mnet_j = mnet
                    hnet_j = hnet
                    hhnet_j = hhnet
                    normal_post_j = normal_post

                mse_val, val_struct = train_utils.compute_mse(
                    j,
                    data,
                    mnet_j,
                    hnet_j,
                    device,
                    config,
                    shared,
                    hhnet=hhnet_j,
                    split_type=split_type,
                    return_dataset=i == j,
                    return_predictions=True,
                    disable_lrt=disable_lrt_test,
                    normal_post=normal_post_j)

                if i == j:  # I.e., we used the correct embedding.
                    # This sanity check is likely to fail as we don't
                    # deterministically sample the models.
                    #if ckpt_score_j is not None:
                    #    assert np.allclose(-mse_val, ckpt_score_j)

                    val_inputs.append(val_struct.inputs)
                    val_targets.append(val_struct.targets)
                    val_preds_mean.append(val_struct.predictions.mean(axis=1))
                    val_preds_std.append(val_struct.predictions.std(axis=1))

                    shared.current_mse[i] = mse_val

                    logger.debug('Test: Task %d - Mean MSE on %s set: %f ' %
                                 (i, split_type, mse_val) + '(std: %g).' %
                                 (val_struct.mse_vals.std()))
                    writer.add_scalar('test/task_%d/val_mse' % i,
                                      shared.current_mse[i], n)

                    # The test set spans into the OOD range and can be used to
                    # visualize how uncertainty behaves outside the
                    # in-distribution range.
                    mse_test, test_struct = train_utils.compute_mse(
                        i,
                        data,
                        mnet_j,
                        hnet_j,
                        device,
                        config,
                        shared,
                        hhnet=hhnet_j,
                        split_type='test',
                        return_dataset=True,
                        return_predictions=True,
                        disable_lrt=disable_lrt_test,
                        normal_post=normal_post_j)

                data_preds[:, :, j] = val_struct.predictions
                data_preds_mean[:, j] = val_struct.predictions.mean(axis=1)
                ### We interpret this value as the certainty of the prediction.
                # I.e., how certain is our system that each of the samples
                # belong to task j?
                data_preds_std[:, j] = val_struct.predictions.std(axis=1)

            val_task_preds.append(data_preds_std)

            ### Compute task inference accuracy.
            inferred_task_ids = data_preds_std.argmin(axis=1)
            num_correct = np.sum(inferred_task_ids == i)
            accuracy = 100. * num_correct / num_val_samples
            task_infer_val_accs[i] = accuracy

            logger.debug('Test: Task %d - Accuracy of task inference ' % i +
                         'on %s set: %.2f%%.' % (split_type, accuracy))
            writer.add_scalar('test/task_%d/accuracy' % i, accuracy, n)

            ### Compute MSE based on inferred embedding.

            # Note, this (commented) way of computing the mean does not take
            # into account the variance of the predictive distribution, which is
            # why we don't use it (see docstring of `compute_mse`).
            #means_of_inferred_preds = data_preds_mean[np.arange( \
            #    data_preds_mean.shape[0]), inferred_task_ids]
            #inferred_val_mse[i] = np.power(means_of_inferred_preds -
            #                               val_targets[-1].squeeze(), 2).mean()

            inferred_preds = data_preds[np.arange(data_preds.shape[0]), :,
                                        inferred_task_ids]
            inferred_val_mse[i] = np.power(inferred_preds - \
                val_targets[-1].squeeze()[:, np.newaxis], 2).mean()

            logger.debug('Test: Task %d - Mean MSE on %s set using inferred '\
                         % (i, split_type) + 'embeddings: %f.'
                         % (inferred_val_mse[i]))
            writer.add_scalar('test/task_%d/inferred_val_mse' % i,
                              inferred_val_mse[i], n)

            ### We are interested in the predictive uncertainty across the
            ### whole test range!
            test_mse[i] = mse_test
            writer.add_scalar('test/task_%d/test_mse' % i, test_mse[i], n)

            test_inputs.append(test_struct.inputs.squeeze())
            test_preds_mean.append(test_struct.predictions.mean(axis=1). \
                                   squeeze())
            test_preds_std.append(
                test_struct.predictions.std(axis=1).squeeze())

            if hasattr(shared, 'during_mse') and \
                    shared.during_mse[i] == -1:
                shared.during_mse[i] = shared.current_mse[i]

            if test_struct.w_hnet is not None or test_struct.w_mean is not None:
                assert hasattr(shared, 'during_weights')
                if test_struct.w_hnet is not None:
                    # We have a hyper-hypernetwork. In this case, the CL
                    # regularizer is applied to its output and therefore, these
                    # are the during weights whose Euclidean distance we want to
                    # track.
                    assert task_n_hhnet is not None
                    w_all = test_struct.w_hnet
                else:
                    assert test_struct.w_mean is not None
                    # We will be here whenever the hnet is deterministic (i.e.,
                    # doesn't represent an implicit distribution).
                    w_all = list(test_struct.w_mean)
                    if test_struct.w_std is not None:
                        w_all += list(test_struct.w_std)

                W_curr = torch.cat([d.clone().view(-1) for d in w_all])
                if type(shared.during_weights[i]) == int:
                    assert (shared.during_weights[i] == -1)
                    shared.during_weights[i] = W_curr
                else:
                    W_during = shared.during_weights[i]
                    W_dis = torch.norm(W_curr - W_during, 2)
                    logger.info('Euclidean distance between hypernet output ' +
                                'for task %d: %g' % (i, W_dis))

    ### Compute overall task inference accuracy.
    num_correct = 0
    num_samples = 0
    for i, uncertainties in enumerate(val_task_preds):
        pred_task_ids = uncertainties.argmin(axis=1)
        num_correct += np.sum(pred_task_ids == i)
        num_samples += pred_task_ids.size

    accuracy = 100. * num_correct / num_samples
    logger.info('Task inference accuracy: %.2f%%.' % accuracy)

    # TODO Compute overall MSE on all tasks using inferred embeddings.

    ### Plot the mean predictions on all tasks.
    # (Using the validation set and the correct embedding per dataset)
    plot_x_ranges = []
    for i in range(n):
        plot_x_ranges.append(data_handlers[i].train_x_range)

    fig_fn = None
    if save_fig:
        fig_fn = os.path.join(config.out_dir, 'val_predictions_%d' % n)

    data_inputs = val_inputs
    mean_preds = val_preds_mean
    data_handlers[0].plot_datasets(data_handlers,
                                   data_inputs,
                                   mean_preds,
                                   fun_xranges=plot_x_ranges,
                                   filename=fig_fn,
                                   show=False,
                                   publication_style=config.publication_style)
    writer.add_figure('test/val_predictions',
                      plt.gcf(),
                      n,
                      close=not config.show_plots)
    if config.show_plots:
        utils.repair_canvas_and_show_fig(plt.gcf())

    ### Scatter plot showing MSE per task (original + current one).
    during_mse = None
    if hasattr(shared, 'during_mse'):
        during_mse = shared.during_mse[:n]
    train_utils.plot_mse(config,
                         writer,
                         n,
                         shared.current_mse[:n],
                         during_mse,
                         save_fig=save_fig)
    additional_plots = {
        'Current Inferred Val MSE': inferred_val_mse,
        #'Current Test MSE': test_mse
    }
    train_utils.plot_mse(config,
                         writer,
                         n,
                         shared.current_mse[:n],
                         during_mse,
                         baselines=additional_plots,
                         save_fig=False,
                         summary_label='test/mse_detailed')

    ### Plot predictive distributions over test range for all tasks.
    data_inputs = test_inputs
    mean_preds = test_preds_mean
    std_preds = test_preds_std

    train_utils.plot_predictive_distributions(
        config,
        writer,
        data_handlers,
        data_inputs,
        mean_preds,
        std_preds,
        save_fig=save_fig,
        publication_style=config.publication_style)

    logger.info('Mean task MSE: %f (std: %d)' %
                (shared.current_mse[:n].mean(), shared.current_mse[:n].std()))

    ### Update performance summary.
    s = shared.summary
    s['aa_mse_during'][:n] = shared.during_mse[:n].tolist()
    s['aa_mse_during_mean'] = shared.during_mse[:n].mean()
    s['aa_mse_final'][:n] = shared.current_mse[:n].tolist()
    s['aa_mse_final_mean'] = shared.current_mse[:n].mean()

    s['aa_task_inference'][:n] = task_infer_val_accs.tolist()
    s['aa_task_inference_mean'] = task_infer_val_accs.mean()

    s['aa_mse_during_inferred'][n - 1] = inferred_val_mse[n - 1]
    s['aa_mse_during_inferred_mean'] = np.mean(s['aa_mse_during_inferred'][:n])
    s['aa_mse_final_inferred'] = inferred_val_mse[:n].tolist()
    s['aa_mse_final_inferred_mean'] = inferred_val_mse[:n].mean()

    train_utils.save_summary_dict(config, shared)

    logger.info('### Testing all trained tasks ... Done ###')
def generate_datasets(config, logger, writer):
    """Create a data handler per task.

    Note:
        The datasets are hard-coded in this function.

    Args:
        config (argparse.Namespace): Command-line arguments. This function will
            add the key ``num_tasks`` to this namespace if not existing.

            Note, this function will also add the keys ``gmm_grid_size``,
            ``gmm_grid_range_1``, ``gmm_grid_range_2``, which are used for
            plotting.

        logger: Logger object.
        writer (tensorboardX.SummaryWriter): Tensorboard logger.

    Returns:
        (list): A list of data handlers.
    """
    NUM_TRAIN = 10
    NUM_TEST = 100

    config.gmm_grid_size = 250

    dhandlers = []

    TASK_SET = 3

    if TASK_SET == 0:
        config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-1, 1]
        means = [[np.array([0, 1]), np.array([0, -1])]]
        variances = [[0.05**2 * np.eye(len(mean)) for mean in means[0]]]
    elif TASK_SET == 1:
        config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-6, 6]

        means = [gauss_mod.CHE_MEANS[i:i + 2] for i in range(0, 6, 2)]
        variances = [gauss_mod.CHE_VARIANCES[i:i + 2] for i in range(0, 6, 2)]
    elif TASK_SET == 2:
        config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-9, 9]

        means = [gauss_mod.CHE_MEANS[i:i + 2] for i in range(0, 6, 2)]
        variances = [[1.**2 * np.eye(len(m)) for m in mm] for mm in means]
    elif TASK_SET == 3:
        config.gmm_grid_range_1 = config.gmm_grid_range_2 = [-9, 9]

        means = [gauss_mod.CHE_MEANS[i:i + 2] for i in range(0, 6, 2)]
        variances = [[.2**2 * np.eye(len(m)) for m in mm] for mm in means]
    else:
        raise NotImplementedError()

    # Note, this is a synthetic dataset where the number of tasks and the
    # number of classes per tasks is hard-coded inside this function.
    if hasattr(config, 'num_tasks') and config.num_tasks > len(means):
        raise ValueError('Command-line argument "num_tasks" has impossible ' +
                         'value %d (maximum value would be %d).' %
                         (config.num_tasks, len(means)))
    elif not hasattr(config, 'num_tasks'):
        config.num_tasks = len(means)
    else:
        means = means[:config.num_tasks]
        variances = variances[:config.num_tasks]

    if hasattr(config, 'num_classes_per_task'):
        raise ValueError('Command-line argument "num_classes_per_task" ' +
                         'cannot be considered by this function.')

    if hasattr(config, 'val_set_size') and config.val_set_size > 0:
        raise ValueError('GMM Dataset does not support a validation set!')

    show_plots = False
    if hasattr(config, 'show_plots'):
        show_plots = config.show_plots

    # For multiple tasks, generate a combined dataset just to create some plots.
    gauss_bumps_all = get_gmm_tasks(means=list(itertools.chain(*means)),
                                    covs=list(itertools.chain(*variances)),
                                    num_train=NUM_TRAIN,
                                    num_test=NUM_TEST,
                                    map_functions=None,
                                    rseed=config.data_random_seed)
    if config.num_tasks > 1:
        full_data = GMMData(gauss_bumps_all,
                            classification=True,
                            use_one_hot=True,
                            mixing_coefficients=None)

        input_mesh = full_data.get_input_mesh(x1_range=config.gmm_grid_range_1,
                                              x2_range=config.gmm_grid_range_2,
                                              grid_size=config.gmm_grid_size)

        # Plot data distribution.
        if writer is not None:
            full_data.plot_uncertainty_map(title='All Data',
                                           input_mesh=input_mesh,
                                           use_generative_uncertainty=True,
                                           sketch_components=False,
                                           show=False)
            writer.add_figure('all_tasks/data_dist',
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot ground-truth conditional uncertainty.
        if writer is not None:
            full_data.plot_uncertainty_map(title='Conditional Uncertainty',
                                           input_mesh=input_mesh,
                                           sketch_components=True,
                                           show=False)
            writer.add_figure('all_tasks/cond_entropy',
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot ground-truth class boundaries.
        if writer is not None:
            full_data.plot_optimal_classification(title='Class-Boundaries',
                                                  input_mesh=input_mesh,
                                                  sketch_components=True,
                                                  show=False)
            writer.add_figure('all_tasks/class_boundaries',
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot ground-truth class boundaries together with all training data.
        # Note, that might visualize training points the would even be
        # misclassified by the true underlying model (due to the stochastic
        # drawing of samples).
        if writer is not None:
            full_data.plot_optimal_classification(
                title='Class-Boundaries - Training Data',
                input_mesh=input_mesh,
                sketch_components=True,
                show=False,
                sample_inputs=full_data.get_train_inputs(),
                sample_modes=np.argmax(full_data.get_train_outputs(), axis=1))
            writer.add_figure('all_tasks/class_boundaries_train',
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot ground-truth class boundaries together with all test data.
        if writer is not None:
            full_data.plot_optimal_classification(
                title='Class-Boundaries - Test Data',
                input_mesh=input_mesh,
                sketch_components=True,
                show=False,
                sample_inputs=full_data.get_test_inputs(),
                sample_modes=np.argmax(full_data.get_test_outputs(), axis=1))
            writer.add_figure('all_tasks/class_boundaries_test',
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

    # Create individual task datasets.
    ii = 0
    for i in range(len(means)):
        gauss_bumps = gauss_bumps_all[ii:ii + len(means[i])]
        ii += len(means[i])

        dhandlers.append(
            GMMData(gauss_bumps,
                    classification=True,
                    use_one_hot=True,
                    mixing_coefficients=None))

        input_mesh = dhandlers[-1].get_input_mesh( \
            x1_range=config.gmm_grid_range_1, x2_range=config.gmm_grid_range_2,
            grid_size=config.gmm_grid_size)

        # Plot training data.
        if writer is not None:
            dhandlers[-1].plot_uncertainty_map(title='Training Data',
                     input_mesh=input_mesh, use_generative_uncertainty=True,
                     sample_inputs=dhandlers[-1].get_train_inputs(),
                     sample_modes=np.argmax(dhandlers[-1].get_train_outputs(), \
                                            axis=1),
                     #sample_label='Training data',
                     sketch_components=True, show=False)
            writer.add_figure('task_%d/train_data' % i,
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot test data.
        if writer is not None:
            dhandlers[-1].plot_uncertainty_map(title='Test Data',
                     input_mesh=input_mesh, use_generative_uncertainty=True,
                     sample_inputs=dhandlers[-1].get_test_inputs(),
                     sample_modes=np.argmax(dhandlers[-1].get_test_outputs(), \
                                            axis=1),
                     #sample_label='Training data',
                     sketch_components=True, show=False)
            writer.add_figure('task_%d/test_data' % i,
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot ground-truth conditional uncertainty.
        if writer is not None:
            dhandlers[-1].plot_uncertainty_map(title='Conditional Uncertainty',
                                               input_mesh=input_mesh,
                                               sketch_components=True,
                                               show=False)
            writer.add_figure('task_%d/cond_entropy' % i,
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot ground-truth class boundaries.
        if writer is not None:
            dhandlers[-1].plot_optimal_classification(title='Class-Boundaries',
                                                      input_mesh=input_mesh,
                                                      sketch_components=True,
                                                      show=False)
            writer.add_figure('task_%d/class_boundaries' % i,
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

    return dhandlers
def plot_gmm_prior_preds(task_id,
                         data,
                         mnet,
                         hnet,
                         hhnet,
                         device,
                         config,
                         shared,
                         logger,
                         writer,
                         prior_mean,
                         prior_std,
                         prior_theta=None):
    """Visualize the prior predictive entropy over the whole input space.

    Similar to function :func:`plot_gmm_preds`, but rather than sampling from
    the approximate posterior, samples are drawn from a given prior
    distribution.

    Args:
        (....): See docstring of function :func:`plot_gmm_preds`.
        prior_mean (list): A list of tensors that represent the mean of an
            explicit prior. Is expected to be ``None`` if ``prior_theta`` is
            specified.
        prior_std (list): A list of tensors that represent the std of an
            explicit prior. See``prior_mean`` for more details.
        prior_theta (list): The weights passed to ``hnet`` when drawing samples
            from the current implicit distribution, which represents the prior.
    """
    # FIXME Code in this function is almost identical to the one in function
    # `plot_gmm_preds`.
    assert prior_mean is None and prior_std is None or \
           prior_mean is not None and prior_std is not None
    assert (prior_theta is None or prior_mean is None) and \
           (prior_theta is not None or prior_std is not None)

    # Gather prior samples.
    prior_samples = []
    for i in range(config.val_sample_size):
        if prior_theta is not None:
            z = torch.normal(torch.zeros(1, shared.noise_dim),
                             config.latent_std).to(device)
            prior_samples.append(
                hnet.forward(uncond_input=z, weights=prior_theta))
        else:
            prior_samples.append(
                putils.sample_diag_gauss(prior_mean, prior_std))

    input_mesh = data.get_input_mesh(x1_range=config.gmm_grid_range_1,
                                     x2_range=config.gmm_grid_range_2,
                                     grid_size=config.gmm_grid_size)

    if 'bbb' in shared.experiment_type:
        _, ret_fig = pmutils.compute_acc(task_id,
                                         data,
                                         mnet,
                                         hnet,
                                         device,
                                         config,
                                         shared,
                                         split_type=None,
                                         return_entropies=True,
                                         return_pred_labels=True,
                                         deterministic_sampling=True,
                                         disable_lrt=config.disable_lrt_test,
                                         in_samples=input_mesh[2],
                                         w_samples=prior_samples)
    else:
        assert 'avb' in shared.experiment_type or 'ssge' in \
            shared.experiment_type
        _, ret_fig = pcutils.compute_acc(task_id,
                                         data,
                                         mnet,
                                         hnet,
                                         hhnet,
                                         device,
                                         config,
                                         shared,
                                         split_type=None,
                                         return_entropies=True,
                                         return_pred_labels=True,
                                         deterministic_sampling=True,
                                         in_samples=input_mesh[2],
                                         w_samples=prior_samples)

    # The means of other tasks.
    other_means = np.concatenate([dh.means for dh in shared.all_dhandlers],
                                 axis=0)

    # Plot entropies over whole input space (according to `input_mesh`).
    data.plot_uncertainty_map( \
        title='Entropy of prior predictive distribution',
        input_mesh=input_mesh, uncertainties=ret_fig.entropies.reshape(-1, 1),
        sample_inputs=other_means,
        sketch_components=True, show=False)
    writer.add_figure('prior/task_%d/pred_entropies' % task_id,
                      plt.gcf(),
                      close=not config.show_plots)
    if config.show_plots:
        misc.repair_canvas_and_show_fig(plt.gcf())

    # Plot entropies over whole input space (according to `input_mesh`).
    data.plot_optimal_classification(title='Prior Predicted Class-Boundaries',
                                     input_mesh=input_mesh,
                                     mesh_modes=ret_fig.pred_labels.reshape(
                                         -1, 1),
                                     sample_inputs=other_means,
                                     sketch_components=True,
                                     show=False)
    writer.add_figure('prior/task_%d/pred_class_boundaries' % task_id,
                      plt.gcf(),
                      close=not config.show_plots)
    if config.show_plots:
        misc.repair_canvas_and_show_fig(plt.gcf())

    # Plots for single samples from the prior.
    assert 'bbb' in shared.experiment_type and not config.mean_only or \
           'avb' in shared.experiment_type and hnet is not None or \
           'ssge' in shared.experiment_type and hnet is not None
    for ii in range(min(10, len(prior_samples))):
        if 'bbb' in shared.experiment_type:
            _, ret_fig = pmutils.compute_acc(
                task_id,
                data,
                mnet,
                hnet,
                device,
                config,
                shared,
                split_type=None,
                return_entropies=True,
                return_pred_labels=True,
                deterministic_sampling=False,
                disable_lrt=config.disable_lrt_test,
                in_samples=input_mesh[2],
                w_samples=[prior_samples[ii]])
        else:
            assert 'avb' in shared.experiment_type or 'ssge' in \
                shared.experiment_type
            _, ret_fig = pcutils.compute_acc(task_id,
                                             data,
                                             mnet,
                                             hnet,
                                             hhnet,
                                             device,
                                             config,
                                             shared,
                                             split_type=None,
                                             return_entropies=True,
                                             return_pred_labels=True,
                                             deterministic_sampling=False,
                                             in_samples=input_mesh[2],
                                             w_samples=[prior_samples[ii]])

        # Plot entropies over whole input space (according to `input_mesh`).
        data.plot_uncertainty_map( \
            title='Entropy of prior predictive distribution',
            input_mesh=input_mesh,
            uncertainties=ret_fig.entropies.reshape(-1, 1),
            sample_inputs=other_means,
            sketch_components=True, show=False)
        writer.add_figure('prior/task_%d/single_samples_pred_entropies' \
            % task_id, plt.gcf(), ii, close=not config.show_plots)
        if config.show_plots:
            misc.repair_canvas_and_show_fig(plt.gcf())

        # Plot entropies over whole input space (according to `input_mesh`).
        data.plot_optimal_classification( \
            title='Prior Predicted Class-Boundaries',
            input_mesh=input_mesh,
            mesh_modes=ret_fig.pred_labels.reshape(-1, 1),
            sample_inputs=other_means,
            sketch_components=True, show=False)
        writer.add_figure(\
            'prior/task_%d/single_samples_pred_class_boundaries' % task_id,
            plt.gcf(), ii, close=not config.show_plots)
        if config.show_plots:
            misc.repair_canvas_and_show_fig(plt.gcf())
def plot_gmm_preds(task_id,
                   data,
                   mnet,
                   hnet,
                   hhnet,
                   device,
                   config,
                   shared,
                   logger,
                   writer,
                   tb_step,
                   draw_samples=False,
                   normal_post=None):
    """Visualize the predictive entropy over the whole input space.

    The advantage of the GMM toy example is, that we can visualize quantities
    such as predictions and predictive entropies over an arbitrary large part of
    the 2D input space.

    Here, we use the current model associated with ``task_id``. All plots are
    logged to tensorboard.

    Args:
        (....): See docstring of function
            :func:`probabilistic.prob_cifar.train_avb.test`.
        task_id (int): ID of current task.
        tb_step (int): Tensorboard step for plots to be logged.
        draw_samples (bool): If ``True``, the method will also draw plots for
            single samples (if model is non-deterministic).
        normal_post (tuple, optional): See docstring of function
            :func:`probabilistic.regression.train_utils.compute_mse`
    """
    input_mesh = data.get_input_mesh(x1_range=config.gmm_grid_range_1,
                                     x2_range=config.gmm_grid_range_2,
                                     grid_size=config.gmm_grid_size)

    if 'bbb' in shared.experiment_type or 'ewc' in shared.experiment_type:
        disable_lrt = config.disable_lrt_test if \
            hasattr(config, 'disable_lrt_test') else False
        _, ret_fig = pmutils.compute_acc(task_id,
                                         data,
                                         mnet,
                                         hnet,
                                         device,
                                         config,
                                         shared,
                                         split_type=None,
                                         return_entropies=True,
                                         return_pred_labels=True,
                                         deterministic_sampling=True,
                                         disable_lrt=disable_lrt,
                                         in_samples=input_mesh[2],
                                         normal_post=normal_post)
    else:
        assert 'avb' in shared.experiment_type or 'ssge' in \
            shared.experiment_type
        _, ret_fig = pcutils.compute_acc(task_id,
                                         data,
                                         mnet,
                                         hnet,
                                         hhnet,
                                         device,
                                         config,
                                         shared,
                                         split_type=None,
                                         return_entropies=True,
                                         return_pred_labels=True,
                                         deterministic_sampling=True,
                                         in_samples=input_mesh[2])

    # The means of other tasks.
    other_means = np.concatenate([dh.means for dh in shared.all_dhandlers],
                                 axis=0)

    # Plot entropies over whole input space (according to `input_mesh`).
    data.plot_uncertainty_map( \
        title='Entropy of predictive distribution',
        input_mesh=input_mesh, uncertainties=ret_fig.entropies.reshape(-1, 1),
        sample_inputs=other_means,
        sketch_components=True, show=False)
    writer.add_figure('task_%d/pred_entropies' % task_id,
                      plt.gcf(),
                      tb_step,
                      close=not config.show_plots)
    if config.show_plots:
        misc.repair_canvas_and_show_fig(plt.gcf())

    # Plot entropies over whole input space (according to `input_mesh`).
    data.plot_optimal_classification(title='Predicted Class-Boundaries',
                                     input_mesh=input_mesh,
                                     mesh_modes=ret_fig.pred_labels.reshape(
                                         -1, 1),
                                     sample_inputs=other_means,
                                     sketch_components=True,
                                     show=False)
    writer.add_figure('task_%d/pred_class_boundaries' % task_id,
                      plt.gcf(),
                      tb_step,
                      close=not config.show_plots)
    if config.show_plots:
        misc.repair_canvas_and_show_fig(plt.gcf())

    # If not deterministic, plot single weight samples.
    # TODO We could also plot them for EWC
    if 'bbb' in shared.experiment_type and not config.mean_only or \
            ('avb' in shared.experiment_type or \
             'ssge' in shared.experiment_type) and hnet is not None and \
            draw_samples:
        for ii in range(10):
            if 'bbb' in shared.experiment_type:
                _, ret_fig = pmutils.compute_acc(
                    task_id,
                    data,
                    mnet,
                    hnet,
                    device,
                    config,
                    shared,
                    split_type=None,
                    return_entropies=True,
                    return_pred_labels=True,
                    deterministic_sampling=False,
                    disable_lrt=config.disable_lrt_test,
                    in_samples=input_mesh[2],
                    num_w_samples=1)
            else:
                assert 'avb' in shared.experiment_type or \
                    'ssge' in shared.experiment_type
                _, ret_fig = pcutils.compute_acc(task_id,
                                                 data,
                                                 mnet,
                                                 hnet,
                                                 hhnet,
                                                 device,
                                                 config,
                                                 shared,
                                                 split_type=None,
                                                 return_entropies=True,
                                                 return_pred_labels=True,
                                                 deterministic_sampling=False,
                                                 in_samples=input_mesh[2],
                                                 num_w_samples=1)

            # Plot entropies over whole input space (according to `input_mesh`).
            data.plot_uncertainty_map( \
                title='Entropy of predictive distribution',
                input_mesh=input_mesh,
                uncertainties=ret_fig.entropies.reshape(-1, 1),
                sample_inputs=other_means,
                sketch_components=True, show=False)
            writer.add_figure('task_%d/single_samples_pred_entropies' \
                % task_id, plt.gcf(), ii, close=not config.show_plots)
            if config.show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())

            # Plot entropies over whole input space (according to `input_mesh`).
            data.plot_optimal_classification(
                title='Predicted Class-Boundaries',
                input_mesh=input_mesh,
                mesh_modes=ret_fig.pred_labels.reshape(-1, 1),
                sample_inputs=other_means,
                sketch_components=True,
                show=False)
            writer.add_figure(\
                'task_%d/single_samples_pred_class_boundaries' % task_id,
                plt.gcf(), ii, close=not config.show_plots)
            if config.show_plots:
                misc.repair_canvas_and_show_fig(plt.gcf())
예제 #7
0
def evaluate(task_id,
             data,
             mnet,
             hnet,
             hhnet,
             dis,
             device,
             config,
             shared,
             logger,
             writer,
             train_iter=None):
    """Evaluate the training progress.

    Evaluate the performance of the network on a single task (that is currently
    being trained) on the validation set.

    Note, if no validation set is available, the test set will be used instead.

    Args:
        (....): See docstring of method
            :func:`probabilistic.prob_cifar.train_avb.evaluate`.
    """
    # FIXME Code below almost identical to
    # `probabilistic.regression.train_bbb.evaluate`.
    if train_iter is None:
        logger.info('# Evaluating training ...')
    else:
        logger.info('# Evaluating network on task %d ' % (task_id + 1) +
                    'before running training step %d ...' % train_iter)

    pcu.set_train_mode(False, mnet, hnet, hhnet, dis)

    with torch.no_grad():
        # Note, if no validation set exists, we use the training data to compute
        # the MSE (note, test data may contain out-of-distribution data in our
        # setup).
        split_type = 'train' if data.num_val_samples == 0 else 'val'
        if split_type == 'train':
            logger.debug('Eval - Using training set as no validation set is ' +
                         'available.')

        mse_val, val_struct = train_utils.compute_mse(task_id,
                                                      data,
                                                      mnet,
                                                      hnet,
                                                      device,
                                                      config,
                                                      shared,
                                                      hhnet=hhnet,
                                                      split_type=split_type)
        ident = 'training' if split_type == 'train' else 'validation'

        logger.info('Eval - Mean MSE on %s set: %f (std: %g).' %
                    (ident, mse_val, val_struct.mse_vals.std()))

        # In contrast, we visualize uncertainty using the test set.
        mse_test, test_struct = train_utils.compute_mse(
            task_id,
            data,
            mnet,
            hnet,
            device,
            config,
            shared,
            hhnet=hhnet,
            split_type='test',
            return_dataset=True,
            return_predictions=True)
        logger.debug('Eval - Mean MSE on test set: %f (std: %g).' %
                     (mse_test.mean(), mse_test.std()))

        if config.show_plots or train_iter is not None:
            train_utils.plot_predictive_distribution(data,
                                                     test_struct.inputs,
                                                     test_struct.predictions,
                                                     show_raw_pred=True,
                                                     figsize=(10, 4),
                                                     show=train_iter is None)
            if train_iter is not None:
                writer.add_figure('task_%d/predictions' % task_id,
                                  plt.gcf(),
                                  train_iter,
                                  close=not config.show_plots)
                if config.show_plots:
                    utils.repair_canvas_and_show_fig(plt.gcf())

                writer.add_scalar('eval/task_%d/val_mse' % task_id,
                                  mse_val.mean(), train_iter)
                writer.add_scalar('eval/task_%d/test_mse' % task_id,
                                  mse_test.mean(), train_iter)

        # FIXME Code below copied from
        # `probabilistic.prob_cifar.train_avb.evaluate`.
        ### Compute discriminator accuracy.
        if dis is not None and hnet is not None:
            hnet_theta = None
            if hhnet is not None:
                hnet_theta = hhnet.forward(cond_id=task_id)

            # FIXME Is it ok if I only look at how samples from the current
            # implicit distribution are classified?
            dis_out, dis_inputs = pcu.process_dis_batch(config,
                                                        shared,
                                                        config.val_sample_size,
                                                        device,
                                                        dis,
                                                        hnet,
                                                        hnet_theta,
                                                        dist=None)
            dis_acc = (dis_out > 0).sum().detach().cpu().numpy() / \
                config.val_sample_size * 100.

            logger.debug('Eval - Discriminator accuracy: %.2f%%.' % (dis_acc))
            writer.add_scalar('eval/task_%d/dis_acc' % task_id, dis_acc,
                              train_iter)

            # FIXME Summary results should be written in the test method after
            # training on a task has finished (note, eval is no guaranteed to be
            # called after or even during training). But I just want to get an
            # overview.
            s = shared.summary
            s['aa_acc_dis'][task_id] = dis_acc
            s['aa_acc_avg_dis'] = np.mean(s['aa_acc_dis'][:(task_id + 1)])

            # Visualize weight samples.
            # FIXME A bit hacky.
            w_samples = dis_inputs
            if config.use_batchstats:
                w_samples = dis_inputs[:, (dis_inputs.shape[1] // 2):]
            pcu.visualize_implicit_dist(config,
                                        task_id,
                                        writer,
                                        train_iter,
                                        w_samples,
                                        figsize=(10, 6))

        logger.info('# Evaluating training ... Done')
예제 #8
0
def plot_predictive_distributions(config,
                                  writer,
                                  data_handlers,
                                  inputs,
                                  preds_mean,
                                  preds_std,
                                  save_fig=True,
                                  publication_style=False):
    """Plot the predictive distribution of several tasks into one plot.

    Args:
        config: Command-line arguments.
        writer: Tensorboard summary writer.
        data_handlers: A set of data loaders.
        inputs: A list of arrays containing the x values per task.
        preds_mean: The mean predictions corresponding to each task in `inputs`.
        preds_std: The std of all predictions in `preds_mean`.
        save_fig: Whether the figure should be saved in the output folder.
        publication_style: whether plots should be made in publication style.
    """
    num_tasks = len(data_handlers)
    assert (len(inputs) == num_tasks)
    assert (len(preds_mean) == num_tasks)
    assert (len(preds_std) == num_tasks)

    colors = utils.get_colorbrewer2_colors(family='Dark2')
    if num_tasks > len(colors):
        warn('Changing to automatic color scheme as we don\'t have ' +
             'as many manual colors as tasks.')
        colors = cm.rainbow(np.linspace(0, 1, num_tasks))

    fig, axes = plt.subplots(figsize=(12, 6))

    if publication_style:
        ts, lw, ms = 60, 15, 10  # text fontsize, line width, marker size
    else:
        ts, lw, ms = 12, 5, 8

    # The default matplotlib setting is usually too high for most plots.
    plt.locator_params(axis='y', nbins=2)
    plt.locator_params(axis='x', nbins=6)

    for i, data in enumerate(data_handlers):
        assert (isinstance(data, ToyRegression))

        # In what range to plot the real function values?
        train_range = data.train_x_range
        range_offset = (train_range[1] - train_range[0]) * 0.05
        sample_x, sample_y = data._get_function_vals( \
            x_range=[train_range[0]-range_offset, train_range[1]+range_offset])
        #sample_x, sample_y = data._get_function_vals()

        #plt.plot(sample_x, sample_y, color='k', label='f(x)',
        #         linestyle='dashed', linewidth=.5)
        plt.plot(sample_x,
                 sample_y,
                 color='k',
                 linestyle='dashed',
                 linewidth=lw / 7.)

        train_x = data.get_train_inputs().squeeze()
        train_y = data.get_train_outputs().squeeze()

        if i == 0:
            plt.plot(train_x,
                     train_y,
                     'o',
                     color='k',
                     label='Training Data',
                     markersize=ms)
        else:
            plt.plot(train_x, train_y, 'o', color='k', markersize=ms)

        plt.plot(inputs[i],
                 preds_mean[i],
                 color=colors[i],
                 label='Task %d' % (i + 1),
                 lw=lw / 3.)

        plt.fill_between(inputs[i],
                         preds_mean[i] + preds_std[i],
                         preds_mean[i] - preds_std[i],
                         color=colors[i],
                         alpha=0.3)
        plt.fill_between(inputs[i],
                         preds_mean[i] + 2. * preds_std[i],
                         preds_mean[i] - 2. * preds_std[i],
                         color=colors[i],
                         alpha=0.2)
        plt.fill_between(inputs[i],
                         preds_mean[i] + 3. * preds_std[i],
                         preds_mean[i] - 3. * preds_std[i],
                         color=colors[i],
                         alpha=0.1)

    if publication_style:
        axes.grid(False)
        axes.set_facecolor('w')
        axes.set_ylim([-2.5, 3])
        axes.axhline(y=axes.get_ylim()[0], color='k', lw=lw)
        axes.axvline(x=axes.get_xlim()[0], color='k', lw=lw)
        if len(data_handlers) == 3:
            plt.yticks([-2, 0, 2], fontsize=ts)
            plt.xticks([-3, 0, 3], fontsize=ts)
        else:
            for tick in axes.yaxis.get_major_ticks():
                tick.label.set_fontsize(ts)
            for tick in axes.xaxis.get_major_ticks():
                tick.label.set_fontsize(ts)
        axes.tick_params(axis='both',
                         length=lw,
                         direction='out',
                         width=lw / 2.)
        if config.train_from_scratch:
            plt.title('training from scratch', fontsize=ts, pad=ts)
        elif config.beta == 0:
            plt.title('fine-tuning', fontsize=ts, pad=ts)
        else:
            plt.title('CL with prob. hnet reg.', fontsize=ts, pad=ts)
    else:
        plt.legend()
        plt.title('Predictive distributions', fontsize=ts, pad=ts)

    plt.xlabel('$x$', fontsize=ts)
    plt.ylabel('$y$', fontsize=ts)

    if save_fig:
        plt.savefig(os.path.join(config.out_dir, 'pred_dists_%d' % num_tasks),
                    bbox_inches='tight')

    writer.add_figure('test/pred_dists',
                      plt.gcf(),
                      num_tasks,
                      close=not config.show_plots)
    if config.show_plots:
        utils.repair_canvas_and_show_fig(plt.gcf())
예제 #9
0
def generate_1d_tasks(show_plots=True,
                      data_random_seed=42,
                      writer=None,
                      task_set=1):
    """Generate a set of tasks for 1D regression.

    Args:
        show_plots: Visualize the generated datasets.
        data_random_seed: Random seed that should be applied to the
            synthetic data generation.
        writer: Tensorboard writer, in case plots should be logged.
        task_set (int): The set of tasks to be used. All sets are hard-coded
            inside this function.

    Returns:
        (tuple): Tuple containing:

        - **data_handlers**: A data handler for each task (instance of class
            :class:`data.special.regression1d_data.ToyRegression`).
        - **num_tasks**: Number of generated tasks.
    """
    if task_set == 0:
        # Here, we define as benchmark a typical dataset used in the
        # uncertainty literature:
        # Regression task y = x**3 + eps, where eps ~ N(0, 9*I).
        # For instance, see here:
        #   https://arxiv.org/pdf/1703.01961.pdf

        # How far outside the regime of the training data do we wanna predict
        # samples?
        test_offset = 1.5

        map_funcs = [lambda x: (x**3.)]
        num_tasks = len(map_funcs)
        x_domains = [[-4, 4]]
        std = 3  # 3**2 == 9
        num_train = 20

        # Range of pred. dist. plots.
        test_domains = [[
            x_domains[0][0] - test_offset, x_domains[0][1] + test_offset
        ]]

    elif task_set == 1:
        #test_offset = 1
        map_funcs = [
            lambda x: (x + 3.), lambda x: 2. * np.power(x, 2) - 1,
            lambda x: np.power(x - 3., 3)
        ]
        num_tasks = len(map_funcs)
        x_domains = [[-4, -2], [-1, 1], [2, 4]]
        std = .05

        #test_domains = [[-4.1, -0.5], [-2.5,2.5], [.5, 4.1]]
        test_domains = [[-4.1, 4.1], [-4.1, 4.1], [-4.1, 4.1]]

        #m = 32 # magnitude
        #s = -.25 * np.pi
        #map_funcs = [lambda x : m*np.pi * np.sin(x + s),
        #             lambda x : m*np.pi * np.sin(x + s),
        #             lambda x : m*np.pi * np.sin(x + s)]
        #x_domains = [[-2*np.pi, -1*np.pi], [-0.5*np.pi, 0.5*np.pi],
        #             [1*np.pi, 2*np.pi]]
        #x_domains_test = [[-2*np.pi, 2*np.pi], [-2*np.pi, 2*np.pi],
        #                  [-2*np.pi, 2*np.pi]]
        #std = 3

        num_tasks = len(map_funcs)
        num_train = 20

    elif task_set == 2:
        map_funcs = [
            lambda x: np.power(x + 3., 3), lambda x: 2. * np.power(x, 2) - 1,
            lambda x: -np.power(x - 3., 3)
        ]
        num_tasks = len(map_funcs)
        x_domains = [[-4, -2], [-1, 1], [2, 4]]
        std = .1

        test_domains = [[-4.1, 4.1], [-4.1, 4.1], [-4.1, 4.1]]

        num_tasks = len(map_funcs)
        num_train = 20

    elif task_set == 3:
        # Same as task set 2, but less aleatoric uncertainty.

        map_funcs = [
            lambda x: np.power(x + 3., 3), lambda x: 2. * np.power(x, 2) - 1,
            lambda x: -np.power(x - 3., 3)
        ]
        num_tasks = len(map_funcs)
        x_domains = [[-4, -2], [-1, 1], [2, 4]]
        std = .05

        test_domains = [[-4.1, 4.1], [-4.1, 4.1], [-4.1, 4.1]]

        num_tasks = len(map_funcs)
        num_train = 20

    else:
        raise NotImplementedError('Set of tasks "%d" unknown!' % task_set)

    dhandlers = []
    for i in range(num_tasks):
        print('Generating %d-th task.' % (i))
        #test_inter = [x_domains[i][0] - test_offset,
        #              x_domains[i][1] + test_offset]
        dhandlers.append(
            ToyRegression(train_inter=x_domains[i],
                          num_train=num_train,
                          test_inter=test_domains[i],
                          num_test=50,
                          val_inter=x_domains[i],
                          num_val=50,
                          map_function=map_funcs[i],
                          std=std,
                          rseed=data_random_seed))

        if writer is not None:
            dhandlers[-1].plot_dataset(show=False)
            writer.add_figure('task_%d/dataset' % i,
                              plt.gcf(),
                              close=not show_plots)
            if show_plots:
                utils.repair_canvas_and_show_fig(plt.gcf())

        elif show_plots:
            dhandlers[-1].plot_dataset()

    return dhandlers, num_tasks
예제 #10
0
def plot_mse(config,
             writer,
             num_tasks,
             current_mse,
             during_mse=None,
             baselines=None,
             save_fig=True,
             summary_label='test/mse'):
    """Produce a scatter plot that shows the current and immediate mse values
    (and maybe a set of baselines) of each task. This visualization helps to
    understand the impact of forgetting.

    Args:
        config: Command-line arguments.
        writer: Tensorboard summary writer.
        num_tasks: Number of tasks.
        current_mse: Array of MSE values currently achieved.
        during_mse (optional): Array of MSE values achieved right after
            training on the corresponding task.
        baselines (optional): A dictionary of label names mapping onto arrays.
            Can be used to plot additional baselines.
        save_fig: Whether the figure should be saved in the output folder.
        summary_label: Label used for the figure when writing it to tensorboard.
    """
    x_vals = np.arange(1, num_tasks + 1)

    num_plots = 1
    if during_mse is not None:
        num_plots += 1
    if baselines is not None:
        num_plots += len(baselines.keys())

    colors = utils.get_colorbrewer2_colors(family='Dark2')
    if num_plots > len(colors):
        warn('Changing to automatic color scheme as we don\'t have ' +
             'as many manual colors as tasks.')
        colors = cm.rainbow(np.linspace(0, 1, num_plots))

    fig, axes = plt.subplots(figsize=(10, 6))
    plt.title('Current MSE on each task')

    plt.scatter(x_vals, current_mse, color=colors[0], label='Current Val MSE')

    if during_mse is not None:
        plt.scatter(x_vals,
                    during_mse,
                    label='During MSE',
                    color=colors[1],
                    marker='*')

    if baselines is not None:
        for i, (label, vals) in enumerate(baselines.items()):
            plt.scatter(x_vals,
                        vals,
                        label=label,
                        color=colors[2 + i],
                        marker='x')

    plt.ylabel('MSE')
    plt.xlabel('Task')
    plt.xticks(x_vals)
    plt.legend()

    if save_fig:
        plt.savefig(os.path.join(config.out_dir, 'mse_%d' % num_tasks),
                    bbox_inches='tight')

    writer.add_figure(summary_label,
                      plt.gcf(),
                      num_tasks,
                      close=not config.show_plots)
    if config.show_plots:
        utils.repair_canvas_and_show_fig(plt.gcf())