Ejemplo n.º 1
0
    def create_plots(self, train_recorder, save_dir):

        fig, axes = create_fig((3, 3))
        plot_curves(
            axes[0, 0],
            xs=[remove_nones(train_recorder.tape['update_i'])],
            ys=[remove_nones(train_recorder.tape['reconstruction_loss'])],
            xlabel='update_i',
            ylabel='reconstruction_loss')
        plot_curves(axes[0, 1],
                    xs=[remove_nones(train_recorder.tape['update_i'])],
                    ys=[remove_nones(train_recorder.tape['prior_loss'])],
                    xlabel='update_i',
                    ylabel='prior_loss')
        plot_curves(axes[0, 2],
                    xs=[remove_nones(train_recorder.tape['update_i'])],
                    ys=[remove_nones(train_recorder.tape['total_loss'])],
                    xlabel='update_i',
                    ylabel='total_loss')
        plot_curves(axes[1, 0],
                    xs=[remove_nones(train_recorder.tape['update_i'])],
                    ys=[remove_nones(train_recorder.tape['lr'])],
                    xlabel='update_i',
                    ylabel='lr')

        plt.tight_layout()

        fig.savefig(str(save_dir / 'graphs.png'))
        plt.close(fig)
Ejemplo n.º 2
0
    def save_training_graphs(self, train_recorder, save_dir):
        from alfred.utils.plots import plot_curves, create_fig
        import matplotlib.pyplot as plt

        # Losses

        fig, axes = create_fig((1, 1))
        plot_curves(axes,
                    ys=[train_recorder.tape['d_loss']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode",
                    ylabel="d_loss")

        fig.savefig(str(save_dir / 'losses.png'))
        plt.close(fig)

        # True Returns
        fig, axes = create_fig((1, 2))
        fig.suptitle('True returns')
        plot_curves(axes[0],
                    ys=[train_recorder.tape['return']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode", ylabel="Mean Return")
        plot_curves(axes[1],
                    ys=[train_recorder.tape['eval_return']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode", ylabel="Mean Eval Return")
        fig.savefig(str(save_dir / 'true_returns.png'))
        plt.close(fig)
Ejemplo n.º 3
0
    def save_training_graphs(self, train_recorder, save_dir):
        from alfred.utils.plots import create_fig, plot_curves
        import matplotlib.pyplot as plt

        # Loss and return

        fig, axes = create_fig((3, 1))
        plot_curves(axes[0],
                    ys=[train_recorder.tape['loss']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel='Transitions',
                    ylabel="loss")
        plot_curves(axes[1],
                    ys=[train_recorder.tape['return']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel="Transitions",
                    ylabel="return")
        plot_curves(axes[2],
                    ys=[train_recorder.tape['eval_return']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel="Transitions",
                    ylabel="Eval return")

        fig.savefig(str(save_dir / 'figures.png'))
        plt.close(fig)
Ejemplo n.º 4
0
    def save_training_graphs(self, train_recorder, save_dir):
        from alfred.utils.plots import create_fig, plot_curves
        import matplotlib.pyplot as plt

        # Losses

        fig, axes = create_fig((1, 1))
        plot_curves(axes,
                    ys=[train_recorder.tape['d_loss']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode",
                    ylabel="d_loss")

        fig.savefig(str(save_dir / 'losses.png'))
        plt.close(fig)

        # True Returns
        fig, axes = create_fig((1, 2))
        fig.suptitle('True returns')
        plot_curves(axes[0],
                    ys=[train_recorder.tape['return']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode", ylabel="Mean Return")
        plot_curves(axes[1],
                    ys=[train_recorder.tape['eval_return']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode", ylabel="Mean Eval Return")
        fig.savefig(str(save_dir / 'true_returns.png'))
        plt.close(fig)

        # Accuracies
        to_plot = ('recall', 'specificity', 'precision', 'accuracy', 'F1')
        if any([k in train_recorder.tape for k in to_plot]):
            fig, axes = create_fig((1, 1))
            ys = [train_recorder.tape[key] for key in to_plot]
            plot_curves(axes, ys=ys,
                        xs=[train_recorder.tape['episode']] * len(ys),
                        xlabel='Episode',
                        ylabel='-',
                        labels=to_plot)
            fig.savefig(str(save_dir / 'Accuracy.png'))
            plt.close(fig)
Ejemplo n.º 5
0
    def save_training_graphs(self, train_recorder, save_dir):
        from alfred.utils.plots import create_fig, plot_curves
        import matplotlib.pyplot as plt

        # Losses

        fig, axes = create_fig((1, 1))
        plot_curves(axes,
                    ys=[train_recorder.tape['d_loss']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode",
                    ylabel="d_loss")

        fig.savefig(str(save_dir / 'losses.png'))
        plt.close(fig)

        # True Returns
        fig, axes = create_fig((1, 2))
        fig.suptitle('True returns')
        plot_curves(axes[0],
                    ys=[train_recorder.tape['return']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode",
                    ylabel="Mean Return")
        plot_curves(axes[1],
                    ys=[train_recorder.tape['eval_return']],
                    xs=[train_recorder.tape['episode']],
                    xlabel="Episode",
                    ylabel="Mean Eval Return")
        fig.savefig(str(save_dir / 'true_returns.png'))
        plt.close(fig)

        # Estimated Returns
        to_plot = ('IRLAverageEntReward', 'IRLAverageF', 'IRLAverageLogPi',
                   'IRLMedianLogPi', 'ExpertIRLAverageEntReward',
                   'ExpertIRLAverageF', 'ExpertIRLAverageLogPi',
                   'ExpertIRLMedianLogPi')
        if any([k in train_recorder.tape for k in to_plot]):
            fig, axes = create_fig((2, 5))
            for i, key in enumerate(to_plot):
                if key in train_recorder.tape:
                    ax = axes[i // 5, i % 5]
                    plot_curves(ax,
                                ys=[train_recorder.tape[key]],
                                xs=[train_recorder.tape['episode']],
                                xlabel='Episode',
                                ylabel=key)
            fig.savefig(str(save_dir / 'estimated_rewards.png'))
            plt.close(fig)

        # Accuracies
        to_plot = ('recall', 'specificity', 'precision', 'accuracy', 'F1')
        if any([k in train_recorder.tape for k in to_plot]):
            fig, axes = create_fig((1, 1))
            ys = [train_recorder.tape[key] for key in to_plot]
            plot_curves(axes,
                        ys=ys,
                        xs=[train_recorder.tape['episode']] * len(ys),
                        xlabel='Episode',
                        ylabel='-',
                        labels=to_plot)
            fig.savefig(str(save_dir / 'Accuracy.png'))
            plt.close(fig)
Ejemplo n.º 6
0
def _make_benchmark_learning_figure(x_data,
                                    y_data,
                                    x_metric,
                                    y_metric,
                                    y_error_bars,
                                    storage_dirs,
                                    save_dir,
                                    logger,
                                    n_labels=np.inf,
                                    visuals_file=None,
                                    additional_curves_file=None):
    # Initialize containers

    y_data_means = OrderedDict()
    y_data_err_up = OrderedDict()
    y_data_err_down = OrderedDict()
    long_labels = OrderedDict()
    titles = OrderedDict()
    x_axis_titles = OrderedDict()
    y_axis_titles = OrderedDict()
    labels = OrderedDict()
    colors = OrderedDict()
    markers = OrderedDict()

    for outer_key in y_data:
        y_data_means[outer_key] = OrderedDict()
        y_data_err_up[outer_key] = OrderedDict()
        y_data_err_down[outer_key] = OrderedDict()

    # Initialize figure

    n_graphs = len(y_data.keys())

    if n_graphs == 3:
        axes_shape = (1, 3)

    elif n_graphs > 1:
        i_max = int(np.ceil(np.sqrt(len(y_data.keys()))))
        axes_shape = (int(np.ceil(len(y_data.keys()) / i_max)), i_max)
    else:
        axes_shape = (1, 1)

    # Creates figure

    gs = gridspec.GridSpec(*axes_shape)
    fig = plt.figure(figsize=(8 * axes_shape[1], 4 * axes_shape[0]))

    # Compute means and stds for all inner_key curve from raw data

    for i, outer_key in enumerate(y_data.keys()):
        for inner_key in y_data[outer_key].keys():
            x_data[outer_key][inner_key] = x_data[outer_key][inner_key][
                0]  # assumes all x_data are the same

            if y_error_bars == "stderr":
                y_data_means[outer_key][inner_key] = np.stack(
                    y_data[outer_key][inner_key], axis=-1).mean(-1)
                y_data_err_up[outer_key][inner_key] = np.stack(y_data[outer_key][inner_key], axis=-1).std(-1) \
                                                      / len(y_data_means[outer_key][inner_key]) ** 0.5
                y_data_err_down = y_data_err_up

            elif y_error_bars == "bootstrapped_CI":
                y_data_samples = np.stack(
                    y_data[outer_key][inner_key],
                    axis=-1)  # dim=0 is accross time (n_time_steps, n_samples)
                mean, err_up, err_down = get_95_confidence_interval_of_sequence(
                    list_of_samples=y_data_samples, method=y_error_bars)
                y_data_means[outer_key][inner_key] = mean
                y_data_err_up[outer_key][inner_key] = err_up
                y_data_err_down[outer_key][inner_key] = err_down

            else:
                raise NotImplementedError

        long_labels[outer_key] = list(y_data_means[outer_key].keys())

        # Limits the number of labels to be displayed (only displays labels of n_labels best experiments)

        if n_labels < np.inf:
            mean_over_entire_curves = np.array(
                [array.mean() for array in y_data_means[outer_key].values()])
            n_max_idxs = (-mean_over_entire_curves).argsort()[:n_labels]

            for k in range(len(long_labels[outer_key])):
                if k in n_max_idxs:
                    continue
                else:
                    long_labels[outer_key][k] = None

        # Selects right ax object

        if axes_shape == (1, 1):
            current_ax = fig.add_subplot(gs[0, 0])
        elif any(np.array(axes_shape) == 1):
            current_ax = fig.add_subplot(gs[0, i])
        else:
            current_ax = fig.add_subplot(gs[i // axes_shape[1],
                                            i % axes_shape[1]])

        # Collect algorithm names

        if all([
                type(long_label) is pathlib.PosixPath
                for long_label in long_labels[outer_key]
        ]):
            algs = []
            for path in long_labels[outer_key]:
                _, _, alg, _, _ = DirectoryTree.extract_info_from_storage_name(
                    path.name)
                algs.append(alg)

        # Loads visuals dictionaries

        if visuals_file is not None:
            visuals = load_dict_from_json(visuals_file)
        else:
            visuals = None

        # Loads additional curves file

        if additional_curves_file is not None:
            additional_curves = load_dict_from_json(additional_curves_file)
        else:
            additional_curves = None

        # Sets visuals

        if type(visuals) is dict and 'titles_dict' in visuals.keys():
            titles[outer_key] = visuals['titles_dict'][outer_key]
        else:
            titles[outer_key] = outer_key

        if type(visuals) is dict and 'axis_titles_dict' in visuals.keys():
            x_axis_titles[outer_key] = visuals['axis_titles_dict'][x_metric]
            y_axis_titles[outer_key] = visuals['axis_titles_dict'][y_metric]
        else:
            x_axis_titles[outer_key] = x_metric
            y_axis_titles[outer_key] = y_metric

        if type(visuals) is dict and 'labels_dict' in visuals.keys():
            labels[outer_key] = [
                visuals['labels_dict'][inner_key]
                for inner_key in y_data_means[outer_key].keys()
            ]
        else:
            labels[outer_key] = long_labels[outer_key]

        if type(visuals) is dict and 'colors_dict' in visuals.keys():
            colors[outer_key] = [
                visuals['colors_dict'][inner_key]
                for inner_key in y_data_means[outer_key].keys()
            ]
        else:
            colors[outer_key] = [None for _ in long_labels[outer_key]]

        if type(visuals) is dict and 'markers_dict' in visuals.keys():
            markers[outer_key] = [
                visuals['markers_dict'][inner_key]
                for inner_key in y_data_means[outer_key].keys()
            ]
        else:
            markers[outer_key] = [None for _ in long_labels[outer_key]]

        logger.info(
            f"Graph for {outer_key}:\n\tlabels={labels}\n\tcolors={colors}\n\tmarkers={markers}"
        )

        if additional_curves_file is not None:
            hlines = additional_curves['hlines'][outer_key]
            n_lines = len(hlines)
        else:
            hlines = None
            n_lines = 0

        # Plots the curves

        plot_curves(
            current_ax,
            xs=list(x_data[outer_key].values()),
            ys=list(y_data_means[outer_key].values()),
            fill_up=list(y_data_err_up[outer_key].values()),
            fill_down=list(y_data_err_down[outer_key].values()),
            labels=labels[outer_key],
            colors=colors[outer_key],
            markers=markers[outer_key],
            xlabel=x_axis_titles[outer_key],
            ylabel=y_axis_titles[outer_key] if i == 0 else "",
            title=titles[outer_key].upper(),
            add_legend=True if i == (len(list(y_data.keys())) - 1) else False,
            legend_outside=True,
            legend_loc="upper right",
            legend_pos=(0.95, -0.2),
            legend_n_columns=len(list(y_data_means[outer_key].values())) +
            n_lines,
            hlines=hlines,
            tick_font_size=22,
            axis_font_size=26,
            legend_font_size=26,
            title_font_size=28)

    plt.tight_layout()

    for storage_dir in storage_dirs:
        os.makedirs(storage_dir / save_dir, exist_ok=True)
        fig.savefig(storage_dir / save_dir / f'{save_dir}_learning.pdf',
                    bbox_inches='tight')

    plt.close(fig)
Ejemplo n.º 7
0
def create_plot_arrays(
        from_file,
        storage_name,
        root_dir,
        remove_none,
        logger,
        plots_to_make=alfred.defaults.DEFAULT_PLOTS_ARRAYS_TO_MAKE):
    """
    Creates and and saves comparative figure containing a plot of total reward for each different experiment
    :param storage_dir: pathlib.Path object of the model directory containing the experiments to compare
    :param plots_to_make: list of strings indicating which comparative plots to make
    :return: None
    """
    # Select storage_dirs to run over

    storage_dirs = select_storage_dirs(from_file, storage_name, root_dir)

    for storage_dir in storage_dirs:

        # Get all experiment directories and sorts them numerically

        sorted_experiments = DirectoryTree.get_all_experiments(storage_dir)

        all_seeds_dir = []
        for experiment in sorted_experiments:
            all_seeds_dir = all_seeds_dir + DirectoryTree.get_all_seeds(
                experiment)

        # Determines what type of search was done

        if (storage_dir / 'GRID_SEARCH').exists():
            search_type = 'grid'
        elif (storage_dir / 'RANDOM_SEARCH').exists():
            search_type = 'random'
        else:
            search_type = 'unknown'

        # Determines row and columns of subplots

        if search_type == 'grid':
            variations = load_dict_from_json(filename=str(storage_dir /
                                                          'variations.json'))

            # experiment_groups account for the fact that all the experiment_dir in a storage_dir may have been created
            # though several runs of prepare_schedule.py, and therefore, many "groups" of experiments have been created
            experiment_groups = {key: {} for key in variations.keys()}
            for group_key, properties in experiment_groups.items():
                properties['variations'] = variations[group_key]

                properties['variations_lengths'] = {
                    k: len(properties['variations'][k])
                    for k in properties['variations'].keys()
                }

                # Deleting alg_name and task_name from variations (because they will not be contained in same storage_dir)

                hyperparam_variations_lengths = deepcopy(
                    properties['variations_lengths'])
                del hyperparam_variations_lengths['alg_name']
                del hyperparam_variations_lengths['task_name']

                i_max = sorted(hyperparam_variations_lengths.values())[-1]
                j_max = int(
                    np.prod(
                        sorted(hyperparam_variations_lengths.values())[:-1]))

                if i_max < 4 and j_max == 1:
                    # If only one hyperparameter was varied over, we order plots on a line
                    j_max = i_max
                    i_max = 1
                    ax_array_dim = 1

                elif i_max >= 4 and j_max == 1:
                    # ... unless there are 4 or more variations, then we put them in a square-ish fashion
                    j_max = int(np.sqrt(i_max))
                    i_max = int(np.ceil(float(i_max) / float(j_max)))
                    ax_array_dim = 2

                else:
                    ax_array_dim = 2

                properties['ax_array_shape'] = (i_max, j_max)
                properties['ax_array_dim'] = ax_array_dim

        else:
            experiment_groups = {"all": {}}
            for group_key, properties in experiment_groups.items():
                i_max = len(sorted_experiments
                            )  # each experiment is on a different row
                j_max = len(all_seeds_dir
                            ) // i_max  # each seed is on a different column

                if i_max == 1:
                    ax_array_dim = 1
                else:
                    ax_array_dim = 2

                properties['ax_array_shape'] = (i_max, j_max)
                properties['ax_array_dim'] = ax_array_dim

        for group_key, properties in experiment_groups.items():
            logger.debug(
                f"\n===========================\nPLOTS FOR EXPERIMENT GROUP: {group_key}"
            )
            i_max, j_max = properties['ax_array_shape']
            ax_array_dim = properties['ax_array_dim']

            first_exp = group_key.split('-')[0] if group_key != "all" else 0
            if first_exp != 0:
                for seed_idx, seed_dir in enumerate(all_seeds_dir):
                    if seed_dir.parent.stem.strip('experiment') == first_exp:
                        first_seed_idx = seed_idx
                        break
            else:
                first_seed_idx = 0

            for plot_to_make in plots_to_make:
                x_metric, y_metric, x_lim, y_lim = plot_to_make
                logger.debug(f'\n{y_metric} as a function of {x_metric}:')

                # Creates the subplots

                fig, ax_array = plt.subplots(i_max,
                                             j_max,
                                             figsize=(10 * j_max, 6 * i_max))

                for i in range(i_max):
                    for j in range(j_max):

                        if ax_array_dim == 1 and i_max == 1 and j_max == 1:
                            current_ax = ax_array
                        elif ax_array_dim == 1 and (i_max > 1 or j_max > 1):
                            current_ax = ax_array[j]
                        elif ax_array_dim == 2:
                            current_ax = ax_array[i, j]
                        else:
                            raise Exception(
                                'ax_array should not have more than two dimensions'
                            )

                        try:
                            seed_dir = all_seeds_dir[first_seed_idx +
                                                     (i * j_max + j)]
                            if group_key != 'all' \
                                    and (int(str(seed_dir.parent).split('experiment')[1]) < int(group_key.split('-')[0]) \
                                         or int(str(seed_dir.parent).split('experiment')[1]) > int(
                                        group_key.split('-')[1])):
                                raise IndexError
                            logger.debug(str(seed_dir))
                        except IndexError as e:
                            logger.debug(
                                f'experiment{i * j_max + j} does not exist')
                            current_ax.text(0.2,
                                            0.2,
                                            "no experiment\n found",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='red')
                            continue

                        logger.debug(seed_dir)

                        # Writes unique hyperparameters on plot

                        config = load_config_from_json(
                            filename=str(seed_dir / 'config.json'))
                        config_unique_dict = load_dict_from_json(
                            filename=str(seed_dir / 'config_unique.json'))
                        validate_config_unique(config, config_unique_dict)

                        if search_type == 'grid':
                            sorted_keys = sorted(
                                config_unique_dict.keys(),
                                key=lambda item:
                                (properties['variations_lengths'][item], item),
                                reverse=True)

                        else:
                            sorted_keys = config_unique_dict

                        info_str = f'{seed_dir.parent.stem}\n' + '\n'.join([
                            f'{k} = {config_unique_dict[k]}'
                            for k in sorted_keys
                        ])
                        bbox_props = dict(facecolor='gray', alpha=0.1)
                        current_ax.text(0.05,
                                        0.95,
                                        info_str,
                                        transform=current_ax.transAxes,
                                        fontsize=12,
                                        verticalalignment='top',
                                        bbox=bbox_props)

                        # Skip cases of UNHATCHED or CRASHED experiments

                        if (seed_dir / 'UNHATCHED').exists():
                            logger.debug('UNHATCHED')
                            current_ax.text(0.2,
                                            0.2,
                                            "UNHATCHED",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='blue')
                            continue

                        if (seed_dir / 'CRASH.txt').exists():
                            logger.debug('CRASHED')
                            current_ax.text(0.2,
                                            0.2,
                                            "CRASHED",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='red')
                            continue

                        try:

                            # Loading the recorder

                            loaded_recorder = Recorder.init_from_pickle_file(
                                filename=str(seed_dir / 'recorders' /
                                             'train_recorder.pkl'))

                            # Checking if provided metrics are present in the recorder

                            if y_metric not in loaded_recorder.tape.keys():
                                logger.debug(
                                    f"'{y_metric}' was not recorded in train_recorder."
                                )
                                current_ax.text(0.2,
                                                0.2,
                                                "ABSENT METRIC",
                                                transform=current_ax.transAxes,
                                                fontsize=24,
                                                fontweight='bold',
                                                color='red')
                                continue

                            if x_metric not in loaded_recorder.tape.keys(
                            ) and x_metric is not None:
                                if x_metric is None:
                                    pass
                                else:
                                    logger.debug(
                                        f"'{x_metric}' was not recorded in train_recorder."
                                    )
                                    current_ax.text(
                                        0.2,
                                        0.2,
                                        "ABSENT METRIC",
                                        transform=current_ax.transAxes,
                                        fontsize=24,
                                        fontweight='bold',
                                        color='red')
                                    continue

                            # Removing None entries

                            if remove_none:
                                loaded_recorder.tape[x_metric] = remove_nones(
                                    loaded_recorder.tape[x_metric])
                                loaded_recorder.tape[y_metric] = remove_nones(
                                    loaded_recorder.tape[y_metric])

                            # Plotting

                            try:

                                if x_metric is not None:
                                    plot_curves(
                                        current_ax,
                                        ys=[loaded_recorder.tape[y_metric]],
                                        xs=[loaded_recorder.tape[x_metric]],
                                        xlim=x_lim,
                                        ylim=y_lim,
                                        xlabel=x_metric,
                                        title=y_metric)
                                else:
                                    plot_curves(
                                        current_ax,
                                        ys=[loaded_recorder.tape[y_metric]],
                                        xlim=x_lim,
                                        ylim=y_lim,
                                        title=y_metric)

                            except Exception as e:
                                logger.debug(f'Polotting error: {e}')

                        except FileNotFoundError:
                            logger.debug('Training recorder not found')
                            current_ax.text(0.2,
                                            0.2,
                                            "'train_recorder'\nnot found",
                                            transform=current_ax.transAxes,
                                            fontsize=24,
                                            fontweight='bold',
                                            color='red')
                            continue

                plt.tight_layout()
                fig.savefig(
                    str(storage_dir /
                        f'{group_key}_comparative_{y_metric}_over_{x_metric}.png'
                        ))
                plt.close(fig)
Ejemplo n.º 8
0
    def save_training_graphs(self, train_recorder, save_dir):
        from alfred.utils.plots import create_fig, plot_curves
        import matplotlib.pyplot as plt

        # Losses

        fig, axes = create_fig((3, 1))
        plot_curves(axes[0],
                    ys=[train_recorder.tape['q1_loss']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel='Transitions',
                    ylabel='q1_loss')
        plot_curves(axes[1],
                    ys=[train_recorder.tape['q2_loss']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel='Transitions',
                    ylabel='q2_loss')
        plot_curves(axes[2],
                    ys=[train_recorder.tape['pi_loss']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel='Transitions',
                    ylabel='pi_loss')

        fig.savefig(str(save_dir / 'losses.png'))
        plt.close(fig)

        # True Returns
        fig, axes = create_fig((3, 1))
        fig.suptitle('Returns')
        plot_curves(axes[0],
                    ys=[train_recorder.tape['return']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel='Transitions',
                    ylabel="Mean Return")
        plot_curves(axes[1],
                    ys=[train_recorder.tape['eval_return']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel='Transitions',
                    ylabel="Mean Eval Return")
        plot_curves(axes[2],
                    ys=[train_recorder.tape['pi_entropy']],
                    xs=[train_recorder.tape['total_transitions']],
                    xlabel='Transitions',
                    ylabel="pi_entropy")
        fig.savefig(str(save_dir / 'figures.png'))
        plt.close(fig)