def plot_from_exps( exp_data, filters={}, split_figures_by=None, split_plots_by=None, x_key='n_timesteps', y_key=None, sup_y_key=None, plot_name='./bad-models.png', subfigure_titles=None, plot_labels=None, x_label=None, y_label=None, num_rows=1, y_limits=None, report_max_performance=False, log_scale=False, round_x=None, ): exp_data = filter(exp_data, filters=filters) exps_per_plot = group_by(exp_data, group_by_key=split_figures_by) num_columns = len(exps_per_plot.keys()) assert num_columns % num_rows == 0 num_columns = num_columns // num_rows fig, axarr = plt.subplots(num_rows, num_columns, figsize=(14, 8)) axarr = np.reshape(axarr, (num_rows, num_columns)) fig.tight_layout(pad=5.0, w_pad=1, h_pad=2, rect=[0, 0, 1, 1]) # iterate over subfigures for i, (default_plot_title, plot_exps) in enumerate(sorted(exps_per_plot.items())): plots_in_figure_exps = group_by(plot_exps, split_plots_by) subfigure_title = subfigure_titles[ i] if subfigure_titles else default_plot_title r, c = i // num_columns, i % num_columns axarr[r, c].set_title(subfigure_title) axarr[r, c].xaxis.set_major_locator(plt.MaxNLocator(5)) axarr[r, c].ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) # iterate over plots in figure y_max_mean = -1e10 y_axis_min = 1e10 y_axis_max = -1e10 for j, default_label in enumerate( sorted(plots_in_figure_exps, key=sorting_legend)): exps = plots_in_figure_exps[default_label] x, y_mean, y_std, x_limits = prepare_data_for_plot( exps, y_key=y_key, sup_y_key=sup_y_key, round_x=round_x) label = plot_labels[j] if plot_labels else default_label _label = label if i == 0 else "__nolabel__" if log_scale: axarr[r, c].semilogx(x, y_mean, label=_label, linewidth=LINEWIDTH, color=get_color(label)) else: axarr[r, c].plot(x, y_mean, label=_label, linewidth=LINEWIDTH, color=get_color(label)) axarr[r, c].fill_between(x, y_mean + y_std, y_mean - y_std, alpha=0.2, color=get_color(label)) # axis labels axarr[r, c].set_xlabel(x_label if x_label else x_key) if c == 0: axarr[r, c].set_ylabel(y_label if y_label else y_key) if x_limits is not None: axarr[r, c].set_xlim(*x_limits) if y_limits is not None: axarr[r, c].set_ylim(*y_limits) else: _y_axis_min, _y_axis_max = correct_limit( axarr[r, c], x, y_mean - y_std, y_mean + y_std) y_axis_max = max(_y_axis_max, y_axis_max) y_axis_min = min(_y_axis_min, y_axis_min) if max(y_mean) > y_max_mean: y_max_mean = max(y_mean) if report_max_performance: label = 'max' if i == 0 else "__nolabel__" axarr[r, c].plot(axarr[r, c].get_xlim(), [y_max_mean] * 2, 'k--', label=label) if y_limits is None: axarr[r, c].set_ylim([y_axis_min, y_axis_max]) fig.legend(loc='lower center', ncol=4, bbox_transform=plt.gcf().transFigure) fig.savefig(plot_name)
axarr[r, c].plot(axarr[r, c].get_xlim(), [y_max_mean] * 2, 'k--', label=label) if y_limits is None: axarr[r, c].set_ylim([y_axis_min, y_axis_max]) fig.legend(loc='lower center', ncol=3, bbox_transform=plt.gcf().transFigure) fig.savefig(plot_name) # filter_dict = {'fast_lr': 0.001, 'path_length':200} filter_dict = {} exps_data_filtered = filter(exps_data, filter_dict) plot_from_exps( exps_data, split_figures_by='env.$class', split_plots_by='algo', y_key='train-AverageReturn', x_key='Time', filters=filter_dict, sup_y_key=[ 'EnvTrajs-AverageReturn', 'Data-EnvTrajs-AverageReturn', 'EnvTrajs-AverageReturn', 'train-AverageReturn' ], # subfigure_titles=['HalfCheetah - output_bias_range [0.0, 0.1]', # 'HalfCheetah - output_bias_range [0.0, 0.5]', # 'HalfCheetah - output_bias_range [0.0, 1.0]'],