def generate_xy_plots(output_dir, output_file_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception(f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first') df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata(df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata(df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, ['val', 'imagenetv2-matched-frequency-format-val', 'avg_corruptions', 'avg_pgd']) df = plotter.add_plotting_data(df, ['val', 'imagenetv2-matched-frequency-format-val', 'avg_corruptions', 'avg_pgd']) df = format_eff_robust(df, 'val', 'imagenetv2-matched-frequency-format-val', 'logit') df = df.round(2) df = df[df.show_in_table] string = "" for i, (index, row) in enumerate(df.iterrows()): name = row.name name = name.replace('_', '\\_') string += f"\\foo{{{name}}} & {row.val} & {row['eff_robust']} & " string += f"{row.avg_corruptions if pd.notna(row.avg_corruptions) else ''} & {row.avg_pgd if pd.notna(row.avg_pgd) else ''} \\\\ \n" os.makedirs(output_file_dir, exist_ok=True) f = open(join(output_file_dir, f"model_table.tex"), "w+") f.write(string) f.close() print(f'written to {join(output_file_dir, f"model_table.tex")}')
def generate_xy_plot(x_axis, y_axis, transform, num_bootstrap_samples, output_dir, output_file_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception(f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first') df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata(df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata(df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis]) df = plotter.add_plotting_data(df, [x_axis, y_axis]) # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] xlim = [df_visible[x_axis].min() - 1, df_visible[x_axis].max() + 0.5] ylim = [df_visible[y_axis].min() - 1, df_visible[y_axis].values.max() + 1] fig, _ = plotter.model_scatter_plot(df, x_axis, y_axis, xlim, ylim, DeepAugmentModelTypes, transform=transform, tick_multiplier=5, num_bootstrap_samples=num_bootstrap_samples, title='Distribution Shift to ImageNet-R', x_label='ImageNet (class-subsampled)', y_label='ImageNet-R', figsize=(12, 8), include_legend=False, return_separate_legend=False) os.makedirs(output_file_dir, exist_ok=True) fig.savefig(join(output_file_dir, f'imagenet_r.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {join(output_file_dir, f'imagenet_r.pdf')}")
def generate_xy_plot(x_axis, y_axis, output_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis]) df = plotter.add_plotting_data(df, [x_axis, y_axis]) plotter.label_fontsize = 18 plotter.legend_fontsize = 15 plotter.tick_fontsize = 15 # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] xlim = [df_visible[x_axis].min() - 2, df_visible[x_axis].max() + 2] ylim = [df_visible[y_axis].min() - 2, df_visible[y_axis].max() + 2] fig, _, = plotter.model_scatter_plot(df, x_axis, y_axis, xlim, ylim, ModelTypes, transform='logit', tick_multiplier=5, num_bootstrap_samples=1000, title='ImageNet', x_label=x_axis, y_label=y_axis, figsize=(12, 8), include_legend=True, return_separate_legend=False) os.makedirs(output_dir, exist_ok=True) fig.savefig(join(output_dir, f'{x_axis}_vs_{y_axis}.pdf'), dpi='figure') print(f"Plot saved to {join(output_dir, f'{x_axis}_vs_{y_axis}.pdf')}")
def generate_xy_plot(x_axis, y_axis, transform, num_bootstrap_samples, output_dir, output_file_dir, skip_download, option): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception(f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first') df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata(df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata(df, df_metadata) if option == 'no-subsample': x_axis = 'val' df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis], option) df = plotter.add_plotting_data(df, [x_axis, y_axis]) # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] if option == 'no-subsample': xlim = [df_visible[x_axis].min() - 1, df_visible[x_axis].max() + 1] else: xlim = [df_visible[x_axis].min() - 1, df_visible[x_axis].max() + 0.1] if option == 'show-yx': ylim = [df_visible[y_axis].min() - 2, df_visible[[x_axis, y_axis]].values.max() + 0.1] else: ylim = [df_visible[y_axis].min() - 2, df_visible[y_axis].values.max() + 2] x_label = 'ImageNet' if option=='no-subsample' else 'ImageNet (class-subsampled)' fig, _ = plotter.model_scatter_plot(df, x_axis, y_axis, xlim, ylim, NatModelTypes, transform=transform, num_bootstrap_samples=num_bootstrap_samples, x_tick_multiplier=5 if option=='no-subsample' else 1, y_tick_multiplier=5, y_unit='pm-0, %', title='Distribution Shift to YTBB-Anchors', x_label=x_label, y_label='YTBB-Robust', figsize=(12, 8), include_legend=False, return_separate_legend=False) os.makedirs(output_file_dir, exist_ok=True) filename = 'ytbb_robust_benign' if option == 'only-standard': filename += '_standard' if option == 'no-subsample': filename += '_subsample' if option == 'show-yx': filename += '_yx' fig.savefig(join(output_file_dir, f'{filename}.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {join(output_file_dir, f'{filename}.pdf')}")
def make_plot(x_axis, y_axis, df, df_metadata): df = plotter.add_plotting_data(df, [x_axis, y_axis]) df_visible = df[df.show_in_plot == True] xlim = [df_visible[x_axis].min() - 2, df_visible[x_axis].max() + 2] xlim = [max(xlim[0], 0.05), min(xlim[1], 99.95)] ylim = [df_visible[y_axis].min() - 2, df_visible[y_axis].max() + 2] ylim = [max(ylim[0], 0.05), min(ylim[1], 99.95)] if plot_style == 'Pretty': fig, _, slope, intercept = plotter.model_scatter_plot(df, x_axis, y_axis, xlim, ylim, model_types, transform=transform.lower(), tick_multiplier=5, num_bootstrap_samples=100, title=f'Distribution Shift Plot ({transform} Scaling)', x_label=x_axis, y_label=y_axis, figsize=(9, 8), include_legend=True, return_separate_legend=False) elif plot_style == 'Interactive': fig, slope, intercept = plotter.model_scatter_plot_interactive(df, x_axis, y_axis, xlim, ylim, model_types, transform=transform.lower(), tick_multiplier=5, num_bootstrap_samples=100, title=f'Distribution Shift Plot ({transform} Scaling)', x_label=x_axis, y_label=y_axis, height=650, width=750, include_legend=True, return_separate_legend=False) return fig, slope, intercept, df
def generate_xy_plot(x_axis, y_axis, y_axis_fit, transform, num_bootstrap_samples, output_dir, output_file_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis, y_axis_fit]) df = plotter.add_plotting_data(df, [x_axis, y_axis, y_axis_fit]) df = format_eff_robust(df, x_axis, y_axis, y_axis_fit, transform) # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] xlim = [df_visible[x_axis].min() - 1, df_visible[x_axis].max() + 0.5] ylim = [df_visible[y_axis].min() - 1, df_visible[y_axis].values.max() + 1] os.makedirs(output_file_dir, exist_ok=True) fig, _, legend = plotter.model_scatter_plot( df, x_axis, y_axis, xlim, ylim, ModelTypes, transform=transform, tick_multiplier=5, num_bootstrap_samples=num_bootstrap_samples, title='Distribution Shift to Corruptions Averaged', x_label='ImageNet', y_label='Corruptions Averaged', figsize=(12, 8), include_legend=False, return_separate_legend=True) legend.savefig(join(output_file_dir, f'syn_shift_legend.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Legend saved to {join(output_file_dir, f'syn_shift_legend.pdf')}") fig.savefig(join(output_file_dir, f'syn_shift_corruptions.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print( f"Plot saved to {join(output_file_dir, f'syn_shift_corruptions.pdf')}") df.show_in_plot = df.apply(show_in_plot2, axis=1) # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] xlim = [ df_visible['eff_robust_x'].min() - 1, df_visible['eff_robust_x'].max() + 1 ] ylim = [ df_visible['eff_robust_y'].min() - 0.5, df_visible['eff_robust_y'].values.max() + 0.5 ] fig, _ = plotter.simple_scatter_plot( df, 'eff_robust_x', 'eff_robust_y', xlim, ylim, ModelTypes, title='Effective Robustness Scatterplot', x_tick_multiplier=5, y_tick_multiplier=1, x_label='Corruptions Averaged Effective Robustness', y_label='ImageNetV2 Effective Robustness', figsize=(12, 8), include_legend=False, return_separate_legend=False) fig.savefig(join(output_file_dir, f'eff_robust_corruptions.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print( f"Plot saved to {join(output_file_dir, f'eff_robust_corruptions.pdf')}" )
def generate_xy_plot(x_axis, y_axis, transform, num_bootstrap_samples, output_dir, output_file_dir, skip_download, x_label, y_label, x_unit='top-1, %', y_unit='top-1, %', imagenet_a=False): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis]) df = plotter.add_plotting_data(df, [x_axis, y_axis]) df = df.dropna() df_visible = df[df.show_in_plot == True] xlim = [ df_visible[x_axis].min() - 1, min(df_visible[x_axis].max() + 1, 99.5) ] ylim = [df_visible[y_axis].min() - 1, df_visible[y_axis].values.max() + 1] os.makedirs(output_file_dir, exist_ok=True) if not imagenet_a: fig, _, legend = plotter.model_scatter_plot_quadrants( df, x_axis, y_axis, xlim, ylim, ModelTypes, transform=transform, tick_multiplier=5, num_bootstrap_samples=num_bootstrap_samples, title='Relative and Effective Robustness - ResNet50 Family', alpha=0.8, x_label=x_label, y_label=y_label, x_unit=x_unit, y_unit=y_unit, figsize=(12, 8), include_legend=False, return_separate_legend=True) legend.savefig(join(output_file_dir, f'resnet50_legend.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print( f"Legend saved to {join(output_file_dir, f'resnet50_legend.pdf')}") filename_fig = join(output_file_dir, f'resnet50_{y_axis.replace("1.0", "1")}.pdf') fig.savefig(filename_fig, dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {filename_fig}") else: fig, _, legend = plotter.model_scatter_plot_quadrants_imagenet_a( df, x_axis, y_axis, xlim, ylim, ModelTypes, transform=transform, tick_multiplier=5, num_bootstrap_samples=num_bootstrap_samples, title='Relative and Effective Robustness - ResNet50 Family', alpha=0.8, x_label=x_label, y_label=y_label, x_unit=x_unit, y_unit=y_unit, pivot=91.86, figsize=(12, 8), include_legend=False, return_separate_legend=True) legend.savefig(join(output_file_dir, f'resnet50_legend2.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print( f"Legend saved to {join(output_file_dir, f'resnet50_legend2.pdf')}" ) filename_fig = join(output_file_dir, f'resnet50_{y_axis.replace("1.0", "1")}.pdf') fig.savefig(join(output_file_dir, f'resnet50_{y_axis}.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print( f"Plot saved to {join(output_file_dir, f'resnet50_{y_axis}.pdf')}")
def generate_xy_plot(x_axis, y_axis, transform, num_bootstrap_samples, output_dir, output_file_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis]) df = plotter.add_plotting_data(df, [x_axis, y_axis]) df = df.dropna() df_visible = df[df.show_in_plot == True] xlim = [df_visible[x_axis].min() - 1, df_visible[x_axis].max() + 1] ylim = [df_visible[y_axis].min() - 2, df_visible[y_axis].values.max() + 5] fig, ax = plotter.model_scatter_plot( df, x_axis, y_axis, xlim, ylim, ModelTypes, transform=transform, tick_multiplier=5, num_bootstrap_samples=num_bootstrap_samples, title='Robustness for Subsampling ImageNet', x_label='ImageNet (iid-subsampled)', y_label='ImageNetV2 (iid-\nsubsampled)', figsize=(12, 8), include_legend=True, return_separate_legend=False) l = ax.legend(loc='upper left', ncol=2, bbox_to_anchor=(0, 1), fontsize=plotter.legend_fontsize, scatterpoints=1, columnspacing=0, handlelength=1.5, borderpad=0.2) for x in l.legendHandles: x._sizes = [100] x.set_alpha(0.8) os.makedirs(output_file_dir, exist_ok=True) fig.savefig(join(output_file_dir, f'subsample_iid.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {join(output_file_dir, f'subsample_iid.pdf')}")
def generate_xy_plot(x_axis, y_axis, transform, output_dir, output_file_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis]) df = plotter.add_plotting_data(df, [x_axis, y_axis]) df = df.dropna() hyp_robust_model = df.loc['vgg19'].copy() arrow_params = ( hyp_robust_model['val'], hyp_robust_model['imagenetv2-matched-frequency-format-val'] + 0.3, 0, 0.285) hyp_robust_model.model_type = HypModelTypes.HYP_ROBUST hyp_robust_model['imagenetv2-matched-frequency-format-val'] += 8 hyp_robust_model.name = 'vgg19_hyp_robust' hyp_robust_model.use_for_line_fit = False df = df.append(hyp_robust_model) # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] xlim = [df_visible[x_axis].min() - 1, df_visible[x_axis].max() + 1] ylim = [df_visible[y_axis].min() - 2, df_visible[y_axis].values.max() + 2] fig, ax = plotter.model_scatter_plot_hyp( df, x_axis, y_axis, xlim, ylim, HypModelTypes, transform=transform, tick_multiplier=5, title='Hypothetical Robustness Intervention', x_label='ImageNet', y_label='ImageNetV2', figsize=(12, 9), include_legend=True, return_separate_legend=False, alpha=0.7, arrow_params=arrow_params) l = ax.legend(loc='lower right', ncol=1, bbox_to_anchor=(1, 0), fontsize=plotter.legend_fontsize, scatterpoints=1, columnspacing=0, handlelength=1.5, borderpad=0.2) for i, x in enumerate(l.legendHandles): x._sizes = [100] if i == 2: x._sizes = [400] os.makedirs(output_file_dir, exist_ok=True) fig.savefig(join(output_file_dir, f'hyp_robust_imagenetv2.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print( f"Plot saved to {join(output_file_dir, f'hyp_robust_imagenetv2.pdf')}")
def generate_xy_plot(x_axis, y_axis_fit, transform, output_dir, output_file_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis_fit]) df = plotter.add_plotting_data(df, [x_axis, y_axis_fit]) df = format_eff_robust(df, x_axis, y_axis_fit, transform) eff_robust = df.eff_robust fig, ax = plt.subplots(1, figsize=(8, 6)) ax.set_xticks([18, 34, 50, 101, 152]) ax.set_xlabel('Number of Layers') ax.set_ylabel('Effective Robustness (ImageNetV2)') resnet = eff_robust[['resnet' + str(x) for x in [18, 34, 50, 101, 152]]] se_resnet = eff_robust[['se_resnet' + str(x) for x in [50, 101, 152]]] ssl_resnet = eff_robust[['resnet' + str(x) + '_ssl' for x in [18, 50]]] swsl_resnet = eff_robust[['resnet' + str(x) + '_swsl' for x in [18, 50]]] resnext = eff_robust[['resnext' + str(x) + '_32x4d' for x in [50, 101]]] ssl_resnext = eff_robust[[ 'resnext' + str(x) + '_32x4d_ssl' for x in [50, 101] ]] swsl_resnext = eff_robust[[ 'resnext' + str(x) + '_32x4d_swsl' for x in [50, 101] ]] ax.plot([18, 34, 50, 101, 152], resnet.values, label='resnet', c='blue') ax.plot([50, 101, 152], se_resnet.values, label='se_resnet', c='green') ax.plot([18, 50], ssl_resnet.values, '--', label='resnet_ssl', c='blue') ax.plot([18, 50], swsl_resnet.values, '-.', label='resnet_swsl', c='blue') ax.plot([50, 101], resnext.values, label='resnext (32x4d)', c='red') ax.plot([50, 101], ssl_resnext.values, '--', label='resnext_ssl (32x4d)', c='red') ax.plot([50, 101], swsl_resnext.values, '-.', label='resnext_swsl (32x4d)', c='red') # ax.scatter(152, eff_robust['resnet152-imagenet11k']) ax.legend() os.makedirs(output_file_dir, exist_ok=True) fig.savefig(join(output_dir, f'eff_robust_inspect_plot.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {join(output_dir, f'eff_robust_inspect_plot.pdf')}")
def generate_xy_plot(x_axis, y_axis, transform, output_dir, output_file_dir, skip_download): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis]) df = plotter.add_plotting_data(df, [x_axis, y_axis]) df = df.dropna() hyp_robust_model = df.loc['resnet50_aws_baseline'].copy() arrow_params = [ hyp_robust_model['val'], hyp_robust_model['imagenetv2-matched-frequency-format-val'] ] hyp_robust_model.model_type = HypModelTypes.ROBUST hyp_robust_model['val'] -= 8 hyp_robust_model['imagenetv2-matched-frequency-format-val'] -= 3 arrow_params += [ hyp_robust_model['val'], hyp_robust_model['imagenetv2-matched-frequency-format-val'] ] hyp_robust_model.name = 'resnet50_aws_baseline_robust' hyp_robust_model.use_for_line_fit = False df = df.append(hyp_robust_model) df_visible = df[df.show_in_plot == True] xlim = [55.522000013427736, 85.72600555419922] ylim = [51.43000002441406, 76.83999633789062] fig, ax, = plotter.hyp_model_scatter_plot_quadrants( df, x_axis, y_axis, xlim, ylim, HypModelTypes, transform=transform, tick_multiplier=5, title='Hypothetical Robustness Intervention', alpha=0.8, x_label='ImageNet', y_label='ImageNetV2', arrow_params=arrow_params, figsize=(12, 8), include_legend=True, return_separate_legend=False) l = ax.legend(loc='upper left', ncol=1, bbox_to_anchor=(0, 1), fontsize=plotter.legend_fontsize, scatterpoints=1, columnspacing=0, handlelength=1.5, borderpad=0.2) for x in l.legendHandles: x._sizes = [100] os.makedirs(output_file_dir, exist_ok=True) fig.savefig(join(output_file_dir, f'small_resnet50.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {join(output_file_dir, f'small_resnet50.pdf')}")
def generate_xy_plot(x_axis, y_axis, transform, num_bootstrap_samples, output_dir, output_file_dir, skip_download, option): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) if option == 'no-subsample': x_axis = 'val' df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis], option) df = plotter.add_plotting_data(df, [x_axis, y_axis]) # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] xlim = [df_visible[x_axis].min() - 1, df_visible[x_axis].max() + 0.8] if option == 'show-yx': ylim = [ df_visible[y_axis].min() - 0.2, df_visible[[x_axis, y_axis]].values.max() + 1 ] else: ylim = [ df_visible[y_axis].min() - 0.2, df_visible[y_axis].values.max() + 2 ] x_label = 'ImageNet' if option == 'no-subsample' else 'ImageNet (class-subsampled)' fig, ax = plotter.model_scatter_plot_imagenet_a( df, x_axis, y_axis, xlim, ylim, NatModelTypes, transform=transform, tick_multiplier=10, extra_y_ticks=[5], num_bootstrap_samples=num_bootstrap_samples, title='Distribution Shift to Imagenet-A', x_label=x_label, y_label='ImageNet-A', alpha=0.6, figsize=(12, 8), include_legend=True, return_separate_legend=False, pivot=76.13 if option == 'no-subsample' else 91.86, extra_x_ticks=[95, 96, 97, 98, 99]) l = ax.legend(loc='lower right', ncol=1, bbox_to_anchor=(1.01, -0.01), fontsize=plotter.legend_fontsize, scatterpoints=1, columnspacing=0, handlelength=1.5, borderpad=0.2) for x in l.legendHandles: x._sizes = [100] if option: ax.get_legend().remove() os.makedirs(output_file_dir, exist_ok=True) filename = 'imagenet_a' if option == 'only-standard': filename += '_standard' if option == 'no-subsample': filename += '_subsample' if option == 'show-yx': filename += '_yx' fig.savefig(join(output_file_dir, f'{filename}.pdf'), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {join(output_file_dir, f'{filename}.pdf')}")
def generate_xy_plot(x_axis, y_axis, x_axis_fit, y_axis_fit, transform, num_bootstrap_samples, output_dir, output_file_dir, skip_download, x_label, y_label): if skip_download: filename = join(output_dir, 'grid_df.pkl') if not exists(filename): raise Exception( f'Downloaded data not found at {filename}. Please run python src/plotting/download_data.py first' ) df = pd.read_pickle(filename) else: df = download_data.download_plotting_data(output_dir, store_data=True, verbose=True) df, df_metadata = dataframe.extract_metadata(df) df, df_metadata = dataframe.replace_10percent_with_metadata( df, df_metadata) df, df_metadata = dataframe.aggregate_corruptions_with_metadata( df, df_metadata) df = prepare_df_for_plotting(df, df_metadata, [x_axis, y_axis, x_axis_fit, y_axis_fit]) df = plotter.add_plotting_data(df, [x_axis, y_axis, x_axis_fit, y_axis_fit]) df = format_eff_robust(df, x_axis, y_axis, x_axis_fit, y_axis_fit, transform) # dfp = df[df.show_in_plot][['eff_robust_x', 'eff_robust_y']].dropna() # print("PEARSONR:", scipy.stats.pearsonr(dfp['eff_robust_x'], dfp['eff_robust_y'])[0]) # auto set xlim and ylim based on visible points df_visible = df[df.show_in_plot == True] xlim = [ df_visible['eff_robust_x'].min() - 1, df_visible['eff_robust_x'].max() + 1 ] ylim = [ df_visible['eff_robust_y'].min() - 0.5, df_visible['eff_robust_y'].values.max() + 0.5 ] fig, _, legend = plotter.simple_scatter_plot( df, 'eff_robust_x', 'eff_robust_y', xlim, ylim, cur_model_types, title='Effective Robustness Scatterplot', x_tick_multiplier=5, y_tick_multiplier=1, x_label=f'{x_label} Effective Robustness', y_label=f'{y_label}\nEffective Robustness', figsize=(12, 8), include_legend=False, return_separate_legend=True) os.makedirs(output_file_dir, exist_ok=True) name = f'eff_robust_legend.pdf' if len( cur_model_types) == 3 else f'eff_robust_legend2.pdf' legend.savefig(join(output_file_dir, name), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Legend saved to {join(output_file_dir, name)}") fig_name = f'eff_robust_{y_axis.split("_")[1]}_{y_axis_fit.replace("1.0", "1")}.pdf' fig.savefig(join(output_file_dir, fig_name), dpi='figure', bbox_inches='tight', pad_inches=0.1) print(f"Plot saved to {join(output_file_dir, fig_name)}")