示例#1
0
def main(args):

    output = coreutils.match_option(args.output, supported_out)

    if isinstance(args.nips, list):
        if len(args.nips) > 1:
            print('WARNING Only one NIP will be used for this plot!')
        args.nips = args.nips[0]

    conf = confusion_data(args.nips,
                          args.cameras,
                          args.run,
                          args.reg,
                          root_dir=args.dir)

    if output == 'plot':
        import seaborn as sns
        images_x = np.ceil(np.sqrt(len(conf)))
        images_y = np.ceil(len(conf) / images_x)
        f_size = 3
        sns.set()
        fig = plt.figure(figsize=(images_x * f_size, images_y * f_size))

        for i, (k, c) in enumerate(conf.items()):
            data = (100 * c['data']).round(0)
            labels = c['labels']
            acc = np.mean(np.diag(data))
            ax = fig.add_subplot(images_y, images_x, i + 1)
            sns.heatmap(data,
                        annot=True,
                        fmt=".0f",
                        linewidths=.5,
                        xticklabels=[x[0] for x in labels],
                        yticklabels=labels)
            ax.set_title('{} : acc={}'.format(k, acc))

        plt.tight_layout()
        plt.show()
        sys.exit(0)

    if output == 'raw':
        for i, (k, c) in enumerate(conf.items()):
            data = (100 * c['data']).round(0)
            labels = c['labels']
            print(conf2txt(data, labels, k))

        sys.exit(0)

    if output == 'tex':
        for i, (k, c) in enumerate(conf.items()):
            data = (100 * c['data']).round(0)
            labels = c['labels']
            print(conf2tex(data, labels, k))

        sys.exit(0)

    print('No output mode matched!')
示例#2
0
def main():
    parser = argparse.ArgumentParser(
        description='Test a neural imaging pipeline')
    parser.add_argument('plot',
                        help='Plot type ({})'.format(
                            ', '.join(supported_plots)))
    parser.add_argument(
        '--data',
        dest='data',
        action='store',
        default='./data/rgb/clic256/',
        help='directory with training & validation images (png)')
    parser.add_argument('--images',
                        dest='images',
                        action='store',
                        default=10,
                        type=int,
                        help='number of images to test')
    parser.add_argument('--image',
                        dest='image_id',
                        action='store',
                        default=1,
                        type=int,
                        help='ID of the image to load')
    parser.add_argument('--patch',
                        dest='patch_size',
                        action='store',
                        default=128,
                        type=int,
                        help='training patch size')
    parser.add_argument('--dcn',
                        dest='dcn',
                        action='store',
                        help='directory with a trained DCN model')

    args = parser.parse_args()

    # Match the current
    args.plot = coreutils.match_option(args.plot, supported_plots)

    if args.plot == 'batch':
        model, stats = codec.restore_model(args.dcn,
                                           args.patch_size,
                                           fetch_stats=True)
        print('Training stats:', stats)

        data = dataset.IPDataset(args.data,
                                 load='y',
                                 n_images=0,
                                 v_images=args.images,
                                 val_rgb_patch_size=args.patch_size)
        batch_x = data.next_validation_batch(0, args.images)

        fig = show_example(model, batch_x)
        plt.show()
        plt.close()

    elif args.plot == 'jpeg-match-ssim':
        files, _ = loading.discover_files(args.data, n_images=-1, v_images=0)
        files = files[args.image_id:args.image_id + 1]
        batch_x = loading.load_images(files, args.data, load='y')
        batch_x = batch_x['y'].astype(np.float32) / (2**8 - 1)

        model = codec.restore_model(args.dcn, batch_x.shape[1])

        fig = match_jpeg(model, batch_x, match='ssim')
        plt.show()
        plt.close()

    elif args.plot == 'jpeg-match-bpp':
        files, _ = loading.discover_files(args.data, n_images=-1, v_images=0)
        files = files[args.image_id:args.image_id + 1]
        batch_x = loading.load_images(files, args.data, load='y')
        batch_x = batch_x['y'].astype(np.float32) / (2**8 - 1)

        model = codec.restore_model(args.dcn, batch_x.shape[1])

        fig = match_jpeg(model, batch_x, match='bpp')
        plt.show()
        plt.close()

    elif args.plot == 'jpg-trade-off':
        df = ratedistortion.get_jpeg_df(args.data, write_files=True)
        print(df.to_string())

    elif args.plot == 'jp2-trade-off':
        df = ratedistortion.get_jpeg2k_df(args.data, write_files=True)
        print(df.to_string())

    elif args.plot == 'dcn-trade-off':
        df = ratedistortion.get_dcn_df(args.data, args.dcn, write_files=False)
        print(df.to_string())

    elif args.plot == 'bpg-trade-off':
        df = ratedistortion.get_bpg_df(args.data, write_files=False)
        print(df.to_string())

    else:
        print('Error: Unknown plot!')
示例#3
0
def display_results(args):

    sns.set('paper', font_scale=1, style="ticks")
    plot = coreutils.match_option(args.plot, supported_plots)

    if not os.path.isdir(args.dir):
        raise FileNotFoundError('Directory {} not found!'.format(args.dir))

    print('Results from: {}'.format(args.dir))
    print('Matched plotting command: {}'.format(plot))

    postfix = [
        coreutils.splitall(args.dir)[-1],
        ','.join(args.nips) if args.nips is not None else None,
        ','.join(args.cameras) if args.cameras is not None else None,
    ]
    postfix = '-'.join(x for x in postfix if x is not None)

    if plot in ['ssim', 'psnr', 'accuracy']:

        df = results_data.manipulation_metrics(args.nips,
                                               args.cameras,
                                               root_dir=args.dir)
        sns.catplot(x='ln',
                    y=plot,
                    col='camera',
                    row='nip',
                    data=df,
                    kind='box')
        save_df(df, args.df, 'manipulation_metrics-{}.csv'.format(postfix))
        plt.show()
        return

    if plot == 'scatter-psnr' or plot == 'scatter-ssim':

        df = results_data.manipulation_metrics(args.nips,
                                               args.cameras,
                                               root_dir=args.dir)

        if len(df) == 0:
            print('ERROR No results found!')
            sys.exit(2)

        print(df)
        g = sns.relplot(x=plot.split('-')[-1],
                        y='accuracy',
                        hue='ln',
                        col='camera',
                        row='nip',
                        data=df,
                        palette=sns.color_palette("Set2",
                                                  len(df['ln'].unique())))
        save_df(df, args.df, 'manipulation_metrics-{}.csv'.format(postfix))
        plt.show()
        return

    if plot == 'progress':

        cases = []

        if args.cameras is None:
            args.cameras = coreutils.listdir(args.dir, '.', dirs_only=True)

        for cam in args.cameras:

            nip_models = args.nips or coreutils.listdir(
                os.path.join(args.dir, cam), '.', dirs_only=True)

            for nip in nip_models:

                reg_path = os.path.join(args.dir, cam, nip)

                if args.regularization:
                    # If given, use specified regularization strengths
                    reg_list = args.regularization
                else:
                    # Otherwise, auto-detect available scenarios
                    reg_list = coreutils.listdir(reg_path,
                                                 '.*',
                                                 dirs_only=True)

                    if len(reg_list) > 4:
                        indices = np.linspace(0,
                                              len(reg_list) - 1,
                                              4).astype(np.int32)
                        reg_list = [reg_list[i] for i in indices]
                        print(
                            '! warning - too many experiments to show - sampling: {}'
                            .format(reg_list))

                for reg in reg_list:
                    for r in coreutils.listdir(os.path.join(reg_path, reg),
                                               '[0-9]+',
                                               dirs_only=True):
                        print('* found scenario {}'.format(
                            (cam, nip, reg, int(r))))
                        cases.append((cam, nip, reg, int(r)))

        df, labels = results_data.manipulation_progress(cases,
                                                        root_dir=args.dir)
        save_df(df, args.df, 'progress-{}.csv'.format(postfix))

        for col in ['psnr', 'accuracy']:
            sns.relplot(x="step",
                        y=col,
                        hue='exp',
                        row='nip',
                        col='camera',
                        style='exp',
                        kind="line",
                        legend="full",
                        aspect=2,
                        height=3,
                        data=df)

        plt.show()
        return

    if plot == 'conf' or plot == 'conf-tex':

        if isinstance(args.nips, list):
            if len(args.nips) > 1:
                print('WARNING Only one NIP will be used for this plot!')
            args.nips = args.nips[0]

        conf = results_data.confusion_data(args.run, root_dir=args.dir)

        if len(conf) == 0:
            print('ERROR No results found!')
            return

        tex_output = plot == 'conf-tex'
        plot_data = not tex_output if len(conf.keys()) < 20 else False

        if plot_data:
            images_x = np.ceil(np.sqrt(len(conf)))
            images_y = np.ceil(len(conf) / images_x)
            f_size = 3
            fig = plt.figure(figsize=(images_x * f_size, images_y * f_size))

        for i, (k, c) in enumerate(conf.items()):
            data = (100 * c['data']).round(0)
            labels = c['labels']
            if tex_output:
                print(results_data.confusion_to_text(data, labels, k, 'tex'))
            else:
                print(results_data.confusion_to_text(data, labels, k, 'txt'))

            if plot_data:
                acc = np.mean(np.diag(data))
                ax = fig.add_subplot(images_y, images_x, i + 1)
                sns.heatmap(data,
                            annot=True,
                            fmt=".0f",
                            linewidths=.5,
                            xticklabels=[x[0] for x in labels],
                            yticklabels=labels)
                ax.set_title('{} : acc={:.1f}'.format(k, acc))

        if plot_data:
            plt.tight_layout()
            plt.show()

        return

    if plot == 'df':

        print('Searching for "training.json" in', args.dir)
        df = results_data.manipulation_summary(args.dir)

        if len(df) > 0:
            if False:
                print(df.groupby('scenario').mean().to_string())
            else:
                gb = df.groupby('scenario')
                counts = gb.size().to_frame(name='reps')
                print(counts.join(gb.agg('mean')).reset_index().to_string())

        save_df(df, args.df, 'summary-{}.csv'.format(postfix))

        return

    if plot == 'auto':

        print('Searching for "training.json" in', args.dir)
        df = results_data.manipulation_summary(args.dir)
        df = df.sort_values('scenario')

        guessed_names = {}

        # Guess scenario
        components = df['scenario'].str.split("/", expand=True)
        for i in components:
            # Try to guess the column name based on content
            template = 'scenario:{}'.format(i)
            if components.iloc[0, i].endswith('Net'):
                guessed_names[template] = 'nip'
            elif components.iloc[0, i].startswith('ln-'):
                guessed_names[template] = 'nip reg.'
            elif components.iloc[0, i].startswith('lc-'):
                guessed_names[template] = 'dcn reg.'
            elif set(components.iloc[:, i].unique()) == {'4k', '8k', '16k'}:
                guessed_names[template] = 'dcn'
            elif all([
                    re.match('^[0-9]{2,3}$', x)
                    for x in components.iloc[:, i].unique()
            ]):
                guessed_names[template] = 'jpeg'
            else:
                guessed_names[template] = template

            df[guessed_names[template]] = components[i]

        df['scenario'] = coreutils.remove_commons(df['scenario'])

        mapping = {}
        mapping_targets = ['col', 'col', 'hue', 'style', 'size']
        mapping_id = 0

        # Choose the feature with most unique values as x axis
        uniques = [
            len(df[guessed_names['scenario:{}'.format(i)]].unique())
            for i in components
        ]

        x_feature = np.argmax(uniques)

        for i in components:
            if i == x_feature:
                continue

            if len(df[guessed_names['scenario:{}'.format(i)]].unique()) > 1:
                mapping[mapping_targets[mapping_id]] = guessed_names[
                    'scenario:{}'.format(i)]
                mapping_id += 1

        sns.catplot(x=guessed_names['scenario:{}'.format(x_feature)],
                    y='accuracy',
                    data=df,
                    kind='box',
                    **mapping)
        # sns.catplot(x='scenario:0', y='dcn_ssim', data=df, kind='box', **mapping)
        # sns.scatterplot(x='dcn_ssim', y='accuracy', data=df)
        plt.show()

        if len(df) > 0:
            gb = df.groupby('scenario')
            counts = gb.size().to_frame(name='reps')
            print(counts.join(gb.agg('mean')).reset_index().to_string())

        return

    raise RuntimeError('No plot matched! Available plots {}'.format(
        ', '.join(supported_plots)))
示例#4
0
def display_results(args):
    
    sns.set()
    plot = coreutils.match_option(args.plot, supported_plots)
    
    if plot == 'boxplot':

        for nip in args.nips:
            df = boxplot_data(nip, args.cameras, root_dir=args.dir)
            print(df)
            print('Averages')
            print(df.mean().round(2))
            plt.figure()
            sns.boxplot(data=df)
            plt.xticks(rotation=90)
            plt.gca().set_title(nip)

            if args.df is not None:
                if not os.path.isdir(args.df):
                    os.makedirs(args.df)
                df_filename = '{}/box-{}-{}-{}.csv'.format(args.df, 'accuracy', nip, plot)
                df.to_csv(df_filename, index=False)
                print('> saving dataframe to {}'.format(df_filename))
        plt.show()
        return
        
    if plot == 'psnr' or plot == 'ssim':

        for nip in args.nips:
            df = boxplot_data(nip, args.cameras, field=plot, root_dir=args.dir)
            print(df)
            print('Averages')
            print(df.mean().round(1 if plot == 'psnr' else 3))
            plt.figure()
            sns.boxplot(data=df)
            plt.xticks(rotation=90)
            plt.gca().set_title(nip)

            if args.df is not None:
                if not os.path.isdir(args.df):
                    os.makedirs(args.df)
                df_filename = '{}/box-{}-{}-{}.csv'.format(args.df, plot, nip, plot)
                df.to_csv(df_filename, index=False)
                print('> saving dataframe to {}'.format(df_filename))

        plt.show()
        return

    if plot == 'scatter-psnr' or plot == 'scatter-ssim':

        if args.cameras is None:
            args.cameras = coreutils.listdir(args.dir, '.', dirs_only=True)

        for cam in args.cameras:
            df = scatterplot_data(args.nips, cam, root_dir=args.dir)
            print(df)
            sns.relplot(x=plot.split('-')[-1], y='accuracy', hue='lr', col='camera', data=df)

            if args.df is not None:
                if not os.path.isdir(args.df):
                    os.makedirs(args.df)
                df_filename = '{}/scatter-{}-{}.csv'.format(args.df, cam, ','.join(args.nips))
                df.to_csv(df_filename, index=False)
                print('> saving dataframe to {}'.format(df_filename))
        plt.show()
        return

    if plot == 'progressplot':

        cases = []

        if args.cameras is None:
            args.cameras = coreutils.listdir(args.dir, '.', dirs_only=True)
        
        for cam in args.cameras:
            for nip in args.nips:

                reg_path = os.path.join(args.dir, cam, nip)

                if args.regularization:
                    # If given, use specified regularization strengths
                    reg_list = args.regularization
                else:
                    # Otherwise, auto-detect available scenarios
                    reg_list = coreutils.listdir(reg_path, 'lr-[0-9\.]+', dirs_only=True)

                    if len(reg_list) > 4:
                        indices = np.linspace(0, len(reg_list)-1, 4).astype(np.int32)
                        reg_list = [reg_list[i] for i in indices]
                        print('! warning - too many experiments to show - sampling: {}'.format(reg_list))

                for reg in reg_list:
                    for r in coreutils.listdir(os.path.join(reg_path, reg), '[0-9]+', dirs_only=True):
                        print('* found scenario {}'.format((cam, nip, reg, int(r))))
                        cases.append((cam, nip, reg, int(r)))
            
        df, labels = progressplot_data(cases, root_dir=args.dir)

        if args.df is not None:
            if not os.path.isdir(args.df):
                os.makedirs(args.df)
            df_filename = '{}/progress-{}-{}.csv'.format(args.df, ','.join(args.cameras), ','.join(args.nips))
            df.to_csv(df_filename, index=False)
            print('> saving dataframe to {}'.format(df_filename))

        for col in ['psnr', 'accuracy']:
            sns.relplot(x="step", y=col, hue='exp', col='nip', row='camera', style='exp', kind="line", legend="full", aspect=2, height=3, data=df)
            
        plt.show()
        return
    
    if plot == 'confusion':
        
        if isinstance(args.nips, list):
            if len(args.nips) > 1:
                print('WARNING Only one NIP will be used for this plot!')
            args.nips = args.nips[0]
        
        conf = confusion_data(args.nips, args.cameras, root_dir=args.dir)

        images_x = np.ceil(np.sqrt(len(conf)))
        images_y = np.ceil(len(conf) / images_x)
        f_size = 3
        fig = plt.figure(figsize=(images_x*f_size, images_y*f_size))
                
        for i, (k, c) in enumerate(conf.items()):
            data = (100*c['data']).round(0)
            labels = c['labels']
            print('\n', k, '=')
            print(data)
            print(labels)
            acc = np.mean(np.diag(data))
            print('Accuracy = {}'.format(acc))
            ax = fig.add_subplot(images_y, images_x, i+1)
            sns.heatmap(data, annot=True, fmt=".0f", linewidths=.5, xticklabels=[x[0] for x in labels], yticklabels=labels)
            ax.set_title('{} : acc={}'.format(k, acc))

        plt.tight_layout()
        plt.show()
        return
    
    raise RuntimeError('No plot matched! Available plots {}'.format(', '.join(supported_plots)))
示例#5
0
def plot_bulk(plots,
              dirname,
              plot_images,
              metric,
              plot,
              baseline_count=3,
              add_legend=True,
              max_bpp=5,
              draw_markers=1):
    plot = coreutils.match_option(plot, ['fit', 'aggregate'])
    if dirname.endswith('/') or dirname.endswith('\\'):
        dirname = dirname[:-1]

    # Load data and select images for plotting
    df_all, labels = load_data(plots, dirname)
    plot_images = plot_images if len(
        plot_images) > 0 else [-1] + df_all[0].image_id.unique().tolist()
    print(plot_images)

    images_x = int(np.ceil(np.sqrt(len(plot_images))))
    images_y = int(np.ceil(len(plot_images) / images_x))

    update_ylim = False
    marker_legend = False

    # Plot setup
    func, fit_bounds = setup_fit(metric)
    y_min, y_max, metric_label = setup_plot(metric)

    # Setup drawing styles
    styles = [['r-', 'rx'], ['b--', 'b+'], ['k:', 'k2'], ['g-', 'gx'],
              ['m-', 'gx'], ['m--', 'gx'], ['m-.', 'gx'], ['m:', 'gx']]
    avg_markers = ['', '', '', 'o', 'o', '2', '+', 'X', '^', '.']

    # To retain consistent styles across plots, adjust the lists based on the number of baseline methods
    if baseline_count < 3:
        styles = styles[(3 - baseline_count):]
        avg_markers = avg_markers[(3 - baseline_count):]

    mse_labels = {}

    fig, ax = plt.subplots(images_y, images_x, sharex=True, sharey=True)
    fig.set_size_inches((images_x * 6, images_y * 4))

    for image_id in plot_images:

        if images_y > 1:
            axes = ax[image_id // images_x, image_id % images_x]
        elif images_x > 1:
            axes = ax[image_id % images_x]
        else:
            axes = ax

        # Select measurements for a specific image, if specified
        for dfc in df_all:
            if image_id >= 0:
                dfc['selected'] = dfc['image_id'].apply(
                    lambda x: x == image_id)
            else:
                dfc['selected'] = True

        for index, dfc in enumerate(df_all):

            x = dfc.loc[dfc['selected'], 'bpp'].values
            y = dfc.loc[dfc['selected'], metric].values

            X = np.linspace(max([0, x.min() * 0.9]), min([5, x.max() * 1.1]),
                            256)

            if plot == 'fit':
                # Fit individual images to a curve, then average the curves

                if image_id >= 0:
                    images = [image_id]
                else:
                    images = dfc.image_id.unique()

                Y = np.zeros((len(images), len(X)))
                mse_l = []

                for image_no, imid in enumerate(images):

                    x = dfc.loc[dfc['selected'] & (dfc['image_id'] == imid),
                                'bpp'].values
                    y = dfc.loc[dfc['selected'] & (dfc['image_id'] == imid),
                                metric].values

                    # Allow for larger errors for lower SSIM values
                    if metric in ['ssim', 'msssim']:
                        sigma = np.abs(1 - y).reshape((-1, ))
                    else:
                        sigma = np.ones_like(y).reshape((-1, ))

                    try:
                        popt, pcov = curve_fit(func,
                                               x,
                                               y,
                                               bounds=fit_bounds,
                                               sigma=sigma,
                                               maxfev=100000)
                        y_est = func(x, *popt)
                        mse = np.mean(np.power(y - y_est, 2))
                        mse_l.append(mse)
                        if mse > 0.1:
                            print('WARNING Large MSE for {} img=#{} = {:.2f}'.
                                  format(labels[index], image_no, mse))

                    except RuntimeError as err:
                        print('ERROR', labels[index], 'image =', imid, 'bpp =',
                              x, 'y =', y, 'err =', err)

                    Y[image_no] = func(X, *popt)

                if image_id < 0:
                    print(
                        'Fit summary - MSE for {} av={:.2f} max={:.2f}'.format(
                            labels[index], np.mean(mse_l), np.max(mse_l)))
                mse_labels[labels[index]] = np.mean(mse_l)

                yy = np.nanmean(Y, axis=0)
                axes.plot(X,
                          yy,
                          styles[index][0],
                          label='{} ({:.3f})'.format(labels[index],
                                                     mse_labels[labels[index]])
                          if add_legend else None)
                y_min = min([y_min, min(yy)]) if update_ylim else y_min

            elif plot == 'aggregate':
                # For each quality level (QF, #channels) find the average quality level
                dfa = dfc.loc[dfc['selected']]

                if 'n_features' in dfa:
                    dfg = dfa.groupby('n_features')
                else:
                    dfg = dfa.groupby('quality')

                x = dfg.mean()['bpp'].values
                y = dfg.mean()[metric].values

                axes.plot(x,
                          y,
                          styles[index][0],
                          label=labels[index] if add_legend else None,
                          marker=avg_markers[index],
                          alpha=0.65)
                y_min = min([y_min, min(y)]) if update_ylim else y_min

            elif plot == 'none':
                pass

            else:
                raise ValueError('Unsupported plot type!')

            if draw_markers > 0:

                if 'entropy_reg' in dfc:

                    if image_id >= 0 or draw_markers >= 2:

                        # No need to draw legend if multiple DCNs are plotted
                        detailed_legend = 'full' if marker_legend and index == baseline_count else False

                        style_mapping = {}

                        if 'n_features' in dfc and len(
                                dfc['n_features'].unique()) > 1:
                            style_mapping['hue'] = 'n_features'

                        if 'entropy_reg' in dfc and len(
                                dfc['entropy_reg'].unique()) > 1:
                            style_mapping['size'] = 'entropy_reg'

                        if 'quantization' in dfc and len(
                                dfc['quantization'].unique()) > 1:
                            style_mapping['style'] = 'quantization'

                        sns.scatterplot(data=dfc[dfc['selected']],
                                        x='bpp',
                                        y=metric,
                                        palette="Set2",
                                        ax=axes,
                                        legend=detailed_legend,
                                        **style_mapping)

                else:

                    if image_id >= 0:
                        axes.plot(x, y, styles[index][1], alpha=0.65)

        # Setup title
        n_images = len(dfc.loc[dfc['selected'], 'image_id'].unique())
        if n_images > 1:
            title = '{} for {} images ({})'.format(plot, n_images,
                                                   os.path.split(dirname)[-1])
        else:
            title = '\#{} : {}'.format(
                image_id, dfc.loc[dfc['selected'],
                                  'filename'].unique()[0].replace('.png', ''))

        # Fixes problems with rendering using the LaTeX backend
        if add_legend:
            for t in axes.legend().texts:
                t.set_text(t.get_text().replace('_', '-'))

        axes.set_xlim([-0.1, max_bpp + 0.1])
        axes.set_ylim([y_min * 0.95, y_max])
        axes.legend(loc='lower right')
        axes.set_title(title)
        if image_id // images_x == images_y - 1:
            axes.set_xlabel('Effective bpp')
        if image_id % images_x == 0:
            axes.set_ylabel(metric_label)

    return fig
示例#6
0
def plot_curve(plots,
               axes,
               dirname='./data/rgb/clic256',
               images=[],
               plot='fit',
               draw_markers=None,
               metric='ssim',
               title=None,
               add_legend=True,
               marker_legend=True,
               baseline_count=3,
               update_ylim=False):

    # Parse input parameters
    draw_markers = draw_markers if draw_markers is not None else len(
        images) == 1
    plot = coreutils.match_option(plot, ['fit', 'aggregate'])

    df_all, labels = load_data(plots, dirname)

    if len(images) == 0:
        images = df_all[0]['image_id'].unique().tolist()

    # Plot setup
    func, fit_bounds = setup_fit(metric)
    y_min, y_max, metric_label = setup_plot(metric)

    # Select measurements for specific images, if specified
    for dfc in df_all:
        if len(images) > 0:
            dfc['selected'] = dfc['image_id'].apply(lambda x: x in images)
        else:
            dfc['selected'] = True

    # Setup drawing styles
    styles = [['r-', 'rx'], ['b--', 'b+'], ['k:', 'k2'], ['g-', 'gx'],
              ['m-', 'gx'], ['m--', 'gx'], ['m-.', 'gx'], ['m:', 'gx']]
    avg_markers = ['', '', '', 'o', 'o', '2', '+', 'x', '^', '.']

    # To retain consistent styles across plots, adjust the lists based on the number of baseline methods
    if baseline_count < 3:
        styles = styles[(3 - baseline_count):]
        avg_markers = avg_markers[(3 - baseline_count):]

    # Iterate over defined plots and draw data accordingly
    for index, dfc in enumerate(df_all):

        x = dfc.loc[dfc['selected'], 'bpp'].values
        y = dfc.loc[dfc['selected'], metric].values

        X = np.linspace(max([0, x.min() * 0.9]), min([5, x.max() * 1.1]), 256)

        if plot == 'fit':
            # Fit individual images to a curve, then average the curves

            Y = np.zeros((len(images), len(X)))
            mse_l = []

            for image_no, image_id in enumerate(images):

                x = dfc.loc[dfc['selected'] & (dfc['image_id'] == image_id),
                            'bpp'].values
                y = dfc.loc[dfc['selected'] & (dfc['image_id'] == image_id),
                            metric].values

                # Allow for larger errors for lower SSIM values
                if metric in ['ssim', 'msssim']:
                    sigma = np.abs(1 - y).reshape((-1, ))
                else:
                    sigma = np.ones_like(y).reshape((-1, ))

                try:
                    popt, pcov = curve_fit(func,
                                           x,
                                           y,
                                           bounds=fit_bounds,
                                           maxfev=10000,
                                           sigma=sigma)
                    y_est = func(x, *popt)
                    mse = np.mean(np.power(y - y_est, 2))
                    mse_l.append(mse)
                    if mse > 0.5:
                        print('WARNING Large MSE for {}:{} = {:.2f}'.format(
                            labels[index], image_no, mse))

                except RuntimeError:
                    print('ERROR', labels[index], 'image =', image_id, 'bpp =',
                          x, 'y =', y)

                Y[image_no] = func(X, *popt)

            if len(images) > 1:
                print('Fit summary - MSE for {} av={:.2f} max={:.2f}'.format(
                    labels[index], np.mean(mse_l), np.max(mse_l)))

            yy = np.nanmean(Y, axis=0)
            axes.plot(X,
                      yy,
                      styles[index][0],
                      label=labels[index] if add_legend else None)
            y_min = min([y_min, min(yy)]) if update_ylim else y_min

        elif plot == 'aggregate':
            # For each quality level (QF, #channels) find the average quality level
            dfa = dfc.loc[dfc['selected']]

            if 'n_features' in dfa:
                dfg = dfa.groupby('n_features')
            else:
                dfg = dfa.groupby('quality')

            x = dfg.mean()['bpp'].values
            y = dfg.mean()[metric].values

            axes.plot(x,
                      y,
                      styles[index][0],
                      label=labels[index] if add_legend else None,
                      marker=avg_markers[index],
                      alpha=0.65)
            y_min = min([y_min, min(y)]) if update_ylim else y_min

        elif plot == 'none':
            pass

        else:
            raise ValueError('Unsupported plot type!')

        if draw_markers:

            if 'entropy_reg' in dfc:

                # No need to draw legend if multiple DCNs are plotted
                detailed_legend = 'full' if marker_legend and index == baseline_count else False

                style_mapping = {}

                if 'n_features' in dfc and len(dfc['n_features'].unique()) > 1:
                    style_mapping['hue'] = 'n_features'

                if 'entropy_reg' in dfc and len(
                        dfc['entropy_reg'].unique()) > 1:
                    style_mapping['size'] = 'entropy_reg'

                if 'quantization' in dfc and len(
                        dfc['quantization'].unique()) > 1:
                    style_mapping['style'] = 'quantization'

                sns.scatterplot(data=dfc[dfc['selected']],
                                x='bpp',
                                y=metric,
                                palette="Set2",
                                ax=axes,
                                legend=detailed_legend,
                                **style_mapping)

            else:
                axes.plot(x,
                          y,
                          styles[index][1],
                          alpha=10 / (sum(dfc['selected'])))

    n_images = len(dfc.loc[dfc['selected'], 'image_id'].unique())

    title = '{} : {}'.format(
        title if title is not None else os.path.split(dirname)[-1],
        '{} images'.format(n_images) if n_images > 1 else
        dfc.loc[dfc['selected'], 'filename'].unique()[0].replace('.png', ''))

    # Fixes problems with rendering using the LaTeX backend
    if add_legend:
        for t in axes.legend().texts:
            t.set_text(t.get_text().replace('_', '-'))

    axes.set_xlim([-0.1, 3.1])
    axes.set_ylim([y_min * 0.99, y_max])
    axes.set_title(title)
    axes.set_xlabel('Effective bpp')
    axes.set_ylabel(metric_label)