def plot_tv_loss_over_time(all_exps):
    def normalize_errors(all_exps):
        for exp in all_exps:
            exp.progress['distributional_shift_abs_diff_loss'] = exp.progress[
                'distributional_shift_abs_diff_loss'] / exp.progress[
                    'returns_expert']

    normalize_errors(all_exps)

    # plot first 50
    frame = log_processor.timewise_data_frame(all_exps,
                                              time_min=1,
                                              time_max=50 + 2)
    frame['iteration'] -= 1  # shift iteration back by 1
    frame = log_processor.reduce_mean_keys(
        frame, col_keys=('weighting_scheme', 'iteration',
                         'env_name'))  # this step is slow

    frame = log_processor.rename_partitions(frame,
                                            plot_utils.WEIGHT_NAME_MAP,
                                            col_key="weighting_scheme")
    frame = log_processor.rename_values(frame, plot_utils.VALUE_NAME_MAP)

    sns.set(style="whitegrid")
    sns.lineplot(x="Iteration",
                 y="Distributional Shift (TV)",
                 hue="Weighting Scheme",
                 data=frame)
    plt.show()
    sns.lineplot(x="Iteration",
                 y="Normalized Loss Shift",
                 hue="Weighting Scheme",
                 data=frame)
    plt.show()
def plot_errors_over_time(all_exps):

    # do my processing manually (not using reduce_mean_keys), then arrange into a table
    split_exps = log_processor.partition_params(all_exps, ('validation_stop', 'env_name'))
    exps = list(reduce_mean_partitions_over_time(split_exps, ('validation_stop', 'env_name'), expected_len=300))
    exps = dict(exps)
    split_exps = collections.defaultdict(list)
    for partition_key in exps:
        split_exps[partition_key[0]].append(exps[partition_key])
    exps = list(reduce_mean_partitions_over_time(split_exps, 'validation_stop', expected_len=300))
    exps = [exp[1] for exp in exps]
    frame = log_processor.timewise_data_frame(exps, time_min=0, time_max=600)

    frame = log_processor.rename_partitions(frame, {False: 'None', 'returns': 'Oracle Returns', 'bellman': 'Bellman Error'}, col_key="validation_stop")
    frame = log_processor.rename_values(frame, plot_utils.VALUE_NAME_MAP)
    frame = log_processor.rename_values(frame, {'validation_stop':'Stop Method'})

    sns.set(style="whitegrid")
    #import pdb; pdb.set_trace()
    #palette = dict(zip(sorted([float(x) for x in frame['Stop Method'].unique()]), sns.color_palette("rocket_r", 6)))
    

    #frame256 = frame[frame['Architecture'] == '(256, 256)']
    g = sns.relplot(x="Iteration", y="Normalized Returns",
                hue="Stop Method", 
                #palette=palette,
                height=5, aspect=1.5, facet_kws=dict(sharex=False),
                kind="line", legend='brief', data=frame)
    plt.legend(loc='best')
    plt.savefig('fig.png')
    plt.show()
def plot_errors_over_time(all_exps):
    def normalize_errors(all_exps):
        for exp in all_exps:
            exp.progress['q*_diff_abs_max'] = exp.progress[
                'q*_diff_abs_max'] / exp.progress['returns_expert']
            exp.progress['ground_truth_error_max'] = exp.progress[
                'ground_truth_error_max'] / exp.progress['returns_expert']

    normalize_errors(all_exps)

    # do my processing manually (not using reduce_mean_keys), then arrange into a table
    split_exps = log_processor.partition_params(
        all_exps, ('layers', 'smooth_target_tau'))
    exps = list(
        reduce_mean_partitions_over_time(split_exps,
                                         ('layers', 'smooth_target_tau')))

    frame = log_processor.timewise_data_frame(exps, time_min=0, time_max=600)

    frame = log_processor.rename_partitions(frame,
                                            plot_utils.ARCH_NAME_MAP,
                                            col_key="layers")
    frame = log_processor.rename_values(frame, plot_utils.VALUE_NAME_MAP)

    sns.set(style="whitegrid")
    palette = dict(
        zip(sorted([float(x) for x in frame['Alpha'].unique()]),
            sns.color_palette("rocket_r", 6)))
    sns.relplot(x="Iteration",
                y="Normalized Q* Error",
                hue="Alpha",
                col="Architecture",
                palette=palette,
                col_order=plot_utils.ARCH_ORDER,
                height=5,
                aspect=.75,
                facet_kws=dict(sharex=False),
                kind="line",
                legend="full",
                data=frame)
    plt.show()
    sns.relplot(x="Iteration",
                y="Normalized Returns",
                hue="Alpha",
                col="Architecture",
                palette=palette,
                col_order=plot_utils.ARCH_ORDER,
                height=5,
                aspect=.75,
                facet_kws=dict(sharex=False),
                kind="line",
                legend="full",
                data=frame)
    plt.show()
def returns_qstar_vs_arch(all_exps):
    sns.set(style="whitegrid")

    all_exps = log_processor.filter_params(all_exps, 'weighting_scheme',
                                           'uniform')
    fqi_exps = log_processor.filter_params(all_exps, 'target_mode', 'tq')

    num_diverge = 0

    # normalize errors by expert returns so they are on the same scale across environments
    def normalize_q_errors(all_exps):
        for exp in all_exps:
            if exp.progress['q*_diff_abs_max'][
                    -1] >= exp.progress['returns_expert'][0] * 10:
                nonlocal num_diverge
                num_diverge += 1
            exp.progress['q*_diff_abs_max'] = exp.progress[
                'q*_diff_abs_max'] / exp.progress['returns_expert']
            exp.progress['ground_truth_error_max'] = exp.progress[
                'ground_truth_error_max'] / exp.progress['returns_expert']

    normalize_q_errors(fqi_exps)
    print('Num diverge:', num_diverge)
    print('Total exps:', len(fqi_exps))
    print('Fraction diverge:', float(num_diverge) / len(fqi_exps))

    split_wts = log_processor.partition_params(fqi_exps, 'layers')
    frame = log_processor.aggregate_partitions(
        split_wts, aggregate_fn=log_processor.reduce_trimmed_mean)

    frame = double_plot(frame,
                        plot_keys={
                            'returns_normalized': "Normalized Returns",
                            'q*_diff_abs_max': "FQI Q* Error",
                            'ground_truth_error_max': "Project Q* Error"
                        },
                        plot_name='')
    frame = log_processor.rename_partitions(frame, {"tabular": "Tabular"})

    g = sns.catplot(x='split_key',
                    y='plot_val',
                    hue='',
                    data=frame,
                    order=plot_utils.ARCH_ORDER,
                    height=4,
                    aspect=1.4,
                    kind='bar',
                    palette='muted',
                    legend_out=False)
    g.despine(left=True)
    g.set_ylabels("Normalized Returns/Q-function Error")
    g.set_xlabels("Architecture")
    g.set(ylim=(-0.1, 1))
    plt.show()
示例#5
0
def plot_returns_over_time(all_exps):

    # do my processing manually (not using reduce_mean_keys), then arrange into a table
    split_exps = log_processor.partition_params(all_exps,
                                                ('layers', 'num_samples'))
    exps = list(
        reduce_mean_partitions_over_time(split_exps,
                                         ('layers', 'num_samples')))
    frame = log_processor.timewise_data_frame(exps, time_min=0, time_max=300)

    frame = log_processor.rename_partitions(frame,
                                            plot_utils.ARCH_NAME_MAP,
                                            col_key="layers")
    frame = log_processor.rename_values(frame, plot_utils.VALUE_NAME_MAP)

    sns.set(style="whitegrid")
    palette = dict(
        zip(sorted([float(x) for x in frame['Samples'].unique()]),
            sns.color_palette("rocket_r", 6)))

    frame256 = frame[frame['Architecture'] == '(256, 256)']
    g = sns.relplot(x="Iteration",
                    y="Normalized Returns",
                    hue="Samples",
                    palette=palette,
                    height=5,
                    aspect=1.5,
                    facet_kws=dict(sharex=False),
                    kind="line",
                    legend='brief',
                    data=frame256)
    plt.legend(loc='best')
    plt.show()
    sns.relplot(
        x="Iteration",
        y="Normalized Returns",
        hue="Samples",
        col="Architecture",
        palette=palette,
        #col_order=['Tabular', '(256, 256)', '(16, 16)'],
        col_order=plot_utils.ARCH_ORDER,
        height=5,
        aspect=1.25,
        facet_kws=dict(sharex=False),
        kind="line",
        legend="full",
        data=frame)
    plt.show()
def plot_shift_vs_returns(all_exps):
    """ Make a scatterplot of weighting schemes vs returns """

    # Reduce the experiment logs values across **time**
    def reducer(values, key=None):
        if key == 'distributional_shift_tv':
            return np.mean(values[1:])
        else:
            return values[-1]  # return last

    frame = log_processor.to_data_frame(all_exps, reduce_fn=reducer)
    # Take the mean of weighting schemes across **experiments**
    frame = log_processor.reduce_mean_key(frame, col_key='weighting_scheme')

    # Rename axes to more readable names on the plot
    frame = log_processor.rename_partitions(frame,
                                            plot_utils.WEIGHT_NAME_MAP,
                                            col_key="weighting_scheme")
    frame = log_processor.rename_values(frame, plot_utils.VALUE_NAME_MAP)
    frame = log_processor.rename_values(
        frame, {'Distributional Shift (TV)': 'Average Distributional Shift'})
    #print(frame)
    print(frame['Weighting Scheme'])
    print(frame['Normalized Returns'])
    print(frame["Average Distributional Shift"])

    sns.set(style="whitegrid")
    ax = sns.scatterplot(y="Normalized Returns",
                         x="Average Distributional Shift",
                         hue="Weighting Scheme",
                         legend=False,
                         data=frame)
    log_processor.label_scatter_points(frame['Average Distributional Shift'],
                                       frame['Normalized Returns'],
                                       frame["Weighting Scheme"],
                                       ax,
                                       global_x_offset=0.005,
                                       offsets={
                                           'Prioritized': (-0.060, -0.03),
                                           'Replay(10)': (0, -0.00),
                                           'Random': (0, 0.000)
                                       })
    plt.savefig('fig.png')
    plt.show()
示例#7
0
def plot_errors_over_time(all_exps):

    # do my processing manually (not using reduce_mean_keys), then arrange into a table
    split_exps = log_processor.partition_params(all_exps, 'max_project_steps')
    exps = list(
        reduce_mean_partitions_over_time(split_exps,
                                         'max_project_steps',
                                         expected_len=600))
    frame = log_processor.timewise_data_frame(exps, time_min=0, time_max=600)

    frame = log_processor.rename_partitions(
        frame, {n: n / float(32)
                for n in split_exps.keys()},
        col_key="max_project_steps")
    frame = log_processor.rename_values(frame, plot_utils.VALUE_NAME_MAP)
    frame = log_processor.rename_values(
        frame, {'max_project_steps': 'Gradient Steps per Sample'})

    sns.set(style="whitegrid")
    palette = dict(
        zip(
            sorted([
                float(x) for x in frame['Gradient Steps per Sample'].unique()
            ]), sns.color_palette("rocket_r", 6)))

    #frame256 = frame[frame['Architecture'] == '(256, 256)']
    g = sns.relplot(x="Iteration",
                    y="Normalized Returns",
                    hue="Gradient Steps per Sample",
                    palette=palette,
                    height=5,
                    aspect=1.5,
                    facet_kws=dict(sharex=False),
                    kind="line",
                    legend='brief',
                    data=frame)
    plt.legend(loc='best')
    plt.savefig('fig.png')
    plt.show()