def main(args):
    df = get_run_logs(pattern=args.run_prefix)
    # transform to dict
    _df = pd.DataFrame(df.all_scores.tolist(), index=df.index)
    score_cols = []
    for col in _df:
        score_col = int(col)
        df[score_col] = _df[col].apply(lambda s: s[args.metric]
                                       if args.metric in s else np.nan)
        score_cols.append(score_col)
    df = df[score_cols + ['finetune_data', 'init_checkpoint_index']]
    df = df.groupby(['finetune_data',
                     'init_checkpoint_index']).mean().reset_index()
    df = df.melt(value_vars=score_cols,
                 id_vars=['finetune_data', 'init_checkpoint_index'],
                 var_name='epoch',
                 value_name=args.metric)
    fig = sns.relplot(x='epoch',
                      y=args.metric,
                      hue='init_checkpoint_index',
                      row='finetune_data',
                      kind='line',
                      data=df,
                      height=2,
                      aspect=1.3)
    for ax in fig.axes:
        ax[0].grid()
        ax[0].set_ylim((None, 1))

    # plotting
    save_fig(fig,
             'metrics_vs_time',
             version=args.version,
             plot_formats=['png'])
Beispiel #2
0
def main(args):
    df = get_run_logs(pattern=args.run_prefix,
                      project_name=args.project_name,
                      bucket_name=args.bucket_name)
    # due to multiple runs that have been run with the same prefix
    if len(df) == 0:
        logger.info('No run logs found')
        sys.exit()
    # make run names nicer
    df.finetune_data = df.finetune_data.apply(lambda s: s.split('/')[-1])
    df = df[['init_checkpoint_index', 'finetune_data', args.metric]]
    df = df.dropna(subset=['init_checkpoint_index'])

    for finetune_data, grp in df.groupby('finetune_data'):
        mean_base = grp[grp.init_checkpoint_index == 0].mean()[args.metric]
        df.loc[df.finetune_data == finetune_data, 'metric_potential'] = (
            df.loc[df.finetune_data == finetune_data, args.metric] -
            mean_base) / (1 - mean_base)

    # convert checkpoint_index to steps
    # df['init_checkpoint_index'] = df.init_checkpoint_index.apply(lambda s: s*25000)
    df = df[df[args.metric] > .5]

    # convert dataset names
    # convert_dataset_names = {'maternal_vaccine_stance_lshtm': 'Maternal Vaccine Stance (MVS)', 'covid_category': 'COVID-19 Category (CC)', 'twitter_sentiment_semeval': 'SemEval 2016 (SE)', 'vaccine_sentiment_epfl': 'Vaccine Sentiment (VS)', 'SST-2': 'Stanford Sentiment Treebank (SST-2)'}
    # df['finetune_data'] = df.finetune_data.apply(lambda s: convert_dataset_names[s])

    # plotting
    height = 2.6
    width = 1.61803398875 * height
    fig, ax = plt.subplots(1, 1, figsize=(width, height))
    sns.lineplot(x='init_checkpoint_index',
                 y='metric_potential',
                 hue='finetune_data',
                 data=df,
                 ax=ax,
                 markers=True)
    df.groupby('init_checkpoint_index').mean()['metric_potential'].plot(
        ax=ax, color='k', ls='dashed', label='Average')
    ax.grid()
    ax.set_ylabel(r'Marginal performance increase $\Delta$MP')
    legend = plt.legend(bbox_to_anchor=(1.02, 1),
                        loc=2,
                        borderaxespad=0.,
                        frameon=False,
                        title=False)
    legend.texts[0].set_text("Evaluation dataset")
    ax.set_xlabel('Pretraining step')

    # plotting
    save_fig(plt.gcf(),
             'fig2',
             version=args.version,
             plot_formats=['png', 'pdf'])
Beispiel #3
0
def main(args):
    raise NotImplementedError
    df = get_run_logs(pattern=args.run_name)
    # label_mapping = get_label_mapping(f_path)
    # labels = list(label_mapping.keys())
    cnf_matrix = sklearn.metrics.confusion_matrix(df.label, df.prediction)
    df = pd.DataFrame(cnf_matrix, columns=labels, index=labels)
    # plotting
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    sns.heatmap(df, ax=ax, annot=True, fmt='d', annot_kws={"fontsize": 8})
    ax.set(xlabel='predicted label', ylabel='true label')
    save_fig(fig, f_path, 'confusion_matrix')
def main(args):
    df = get_run_logs(pattern=args.run_prefix)
    if len(df) == 0:
        logger.info('No run logs found')
        sys.exit()
    df_pivot = df.pivot(args.y, args.x, 'f1_macro')
    # plotting
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    ax.set_ylim(len(df_pivot) - 0.5, -0.5)
    sns.heatmap(df_pivot,
                ax=ax,
                annot=True,
                fmt='.2f',
                annot_kws={"fontsize": 8})
    save_fig(fig, 'heatmap', version=args.version, plot_formats=['png'])
Beispiel #5
0
def main(args):
    df = get_run_logs(pattern=args.run_prefix,
                      bucket_name=args.bucket_name,
                      project_name=args.project_name)
    df.loc[(df.model_class == 'covid-twitter-bert'),
           'init_checkpoint'] = 'ct-bert-v1/bla'
    df.loc[(df.model_class != 'covid-twitter-bert') &
           (df.init_checkpoint_index == 0),
           'init_checkpoint'] = 'bert-large-uncased-wwm/bla'
    df = df.reset_index(drop=True)
    df = df[df.init_checkpoint_index.isin([None, 0, 9])]
    n = 3
    to_delete = []
    for (finetune_data, init_checkpoint), grp in df.groupby(
        ['finetune_data', 'init_checkpoint']):
        sorted_vals = grp[args.metric].sort_values(ascending=False)
        to_delete.extend(sorted_vals[n:].index.tolist())

    df = df.drop(index=to_delete)
    df.finetune_data = df.finetune_data.apply(lambda s: s.split('/')[-1])
    df = df[~df.finetune_data.isin([
        'SemEval2016_6_climate_change_is_a_real_concern',
        'SemEval2016_6_feminist_movement', 'SemEval2016_6_hillary_clinton',
        'SemEval2016_6_legalization_of_abortion'
    ])]
    df['exp'] = df.init_checkpoint.apply(lambda s: s.split('/')[0])
    if len(df) == 0:
        logger.info('No run logs found')
        sys.exit()
    fig = sns.catplot(y=args.metric,
                      x='finetune_data',
                      hue='exp',
                      data=df,
                      kind='bar',
                      height=4,
                      aspect=1.5,
                      ci='sd')
    # plt.legend(loc='bottom')
    # fig.set_xticklabels(rotation=90)
    fig.set(ylim=(0.5, 1))
    plt.gca().set_title(args.run_prefix)

    # plotting
    save_fig(fig,
             'metrics_by_checkpoint',
             version=args.version,
             plot_formats=['png'])
def main(args):
    df = get_run_logs(pattern=args.run_prefix, project_name=args.project_name, bucket_name=args.bucket_name)
    # due to multiple runs that have been run with the same prefix
    if len(df) == 0:
        logger.info('No run logs found')
        sys.exit()
    # df.loc[(df.model_class != 'covid-twitter-bert') & (df.init_checkpoint_index == 0), 'init_checkpoint'] = 'bert-large-uncased-wwm'
    # df = df[df['init_checkpoint'] != 'bert-large-uncased-wwm']
    df = df.reset_index(drop=True)
    # make run names nicer
    df.finetune_data = df.finetune_data.apply(lambda s: s.split('/')[-1])
    # ['SemEval2016_6_atheism', 'SemEval2016_4a_sentiment', 'SemEval2016_6_climate_change_is_a_real_concern', 'SemEval2016_6_feminist_movement', 'SemEval2016_6_hillary_clinton', 'SemEval2016_6_legalization_of_abortion', 'SemEval2016_6_sentiment', 'SemEval2017_4a_sentiment', 'SemEval2018_3a_irony', 'covid_category', 'SST-2']
    df = df[~df.finetune_data.isin(
        ['SemEval2016_6_climate_change_is_a_real_concern', 'SemEval2016_6_feminist_movement', 'SemEval2016_6_hillary_clinton', 'SemEval2016_6_legalization_of_abortion']
        )]
    df = df[['init_checkpoint_index', args.metric, 'init_checkpoint', 'model_class', 'finetune_data']]
    mean_ct_bert_v1 = df.loc[df.model_class == 'covid-twitter-bert', args.metric].mean()
    mean_bert_large = df.loc[(df.model_class == 'bert_large_uncased_wwm') & (df.init_checkpoint_index == 0), args.metric].mean()
    df = df.dropna(subset=['init_checkpoint'])
    df['exp'] = df.init_checkpoint.apply(lambda s: s.split('/')[0])
    # df['exp'] = df.init_checkpoint.apply(get_exp)

    # for finetune_data, grp in df.groupby('finetune_data'):
    #     mean_base = grp[grp.init_checkpoint_index == 0].mean()[args.metric]
    #     df.loc[df.finetune_data == finetune_data, 'metric_potential'] = (df.loc[df.finetune_data == finetune_data, args.metric] - mean_base)/(1 - mean_base)

    # plotting
    height = 2.6
    width = 1.61803398875 * height
    fig, ax = plt.subplots(1, 1, figsize=(width, height))

    for (finetune_data, init_checkpoint_index, exp), grp in df.groupby(['finetune_data', 'init_checkpoint_index', 'exp']):
        df.loc[(df.init_checkpoint_index == init_checkpoint_index) & (df.exp == exp) & (df.finetune_data == finetune_data), args.metric] = grp[args.metric].median()

    sns.lineplot(x='init_checkpoint_index', y=args.metric, hue='exp', data=df, ax=ax)
    ax.axhline(y=mean_ct_bert_v1, ls ='--', c='k')
    ax.axhline(y=mean_bert_large, ls ='--', c='red')
    # df.groupby('init_checkpoint_index').mean()[args.metric].plot(ax=ax, color='k', ls='dashed', label='Average')
    ax.grid()
    ax.set_ylabel(f'{args.metric}')
    legend = plt.legend(bbox_to_anchor=(1.02, 1), loc=2, borderaxespad=0., frameon=False, title=False)
    # legend.texts[0].set_text("Evaluation dataset")
    ax.set_xlabel('Pretraining step')

    # plotting
    save_fig(plt.gcf(), 'fig2', version=args.version, plot_formats=['png', 'pdf'])
Beispiel #7
0
def main(args):
    df = get_run_logs(pattern=args.run_prefix)
    if len(df) == 0:
        logger.info('No run logs found')
        sys.exit()
    fig = sns.catplot(x='finetune_data',
                      y=args.metric,
                      hue='init_checkpoint_index',
                      data=df,
                      kind='bar',
                      height=4,
                      aspect=1.5)
    fig.set_xticklabels(rotation=45)
    fig.set(ylim=(0.5, 1))
    plt.gca().set_title(args.run_prefix)

    # plotting
    save_fig(fig,
             'metrics_by_checkpoint',
             version=args.version,
             plot_formats=['png'])