Exemple #1
0
def interrater_iccs(ratings,
                    rater_col_name='rater',
                    index_label='onset_ms',
                    column_labels=None):
    """
    This function computes the interrater ICCs using the Pingouin library. By default it computes the absolute agreement
    between raters assuming a random sample of raters at each target (each rating at each instance).
    Read more on ICC2 at https://pingouin-stats.org/generated/pingouin.intraclass_corr.html#pingouin.intraclass_corr

    Parameters
    ----------
    index_label: str
        The label denoting each measurement. This must be consistent across all raters. Default is "onset_ms".
    ratings: DataFrame
        DataFrame with the ratings information stored in a long format.
    rater_col_name: str
        The name of the column containing rater information. Default is "rater"
    column_labels: list
        The list of variables to computer inter-rater ICCs for. Default is None, which means it will compute ICCs for
        every column in the DataFrame not equal to the rater_col_name or the index_label.

    Returns
    -------
    icc_df: DataFrame
        The dataframe object containing instance-level and overall intraclass correlation values.

    """
    if not column_labels:
        column_labels = ratings.columns.to_list()
        column_labels.remove(rater_col_name)

    if index_label in column_labels:
        column_labels.remove(index_label)
    else:
        ratings[index_label] = ratings.index
        ratings.index.name = 'index'

    icc_df = pd.DataFrame(
        columns=['instance_level_ICC', 'instance_level_consistency'])

    for i, x in enumerate(column_labels):
        icc = pg.intraclass_corr(data=ratings,
                                 targets=index_label,
                                 raters=rater_col_name,
                                 ratings=x,
                                 nan_policy='omit')
        icc_df.loc[x, 'instance_level_ICC'] = icc.loc[1, 'ICC']

        # evaluate item-level ICCs
        if icc.loc[1, 'ICC'] < 0.50:
            icc_df.loc[x, 'instance_level_consistency'] = 'poor'
        elif (icc.loc[1, 'ICC'] >= 0.50) & (icc.loc[1, 'ICC'] < 0.75):
            icc_df.loc[x, 'instance_level_consistency'] = 'moderate'
        elif (icc.loc[1, 'ICC'] >= 0.75) & (icc.loc[1, 'ICC'] < 0.90):
            icc_df.loc[x, 'instance_level_consistency'] = 'good'
        elif icc.loc[1, 'ICC'] >= 0.90:
            icc_df.loc[x, 'instance_level_consistency'] = 'excellent'

    return icc_df
Exemple #2
0
def icc(ds, i, attr):

    df = regression(ds, i, attr)
    icc_index = intraclass_corr(df,
                                targets='subject',
                                raters='task',
                                ratings='y')

    return icc_index['ICC'][0]
Exemple #3
0
    def remove_metabolites_icc(self,
                               cutoff: float = 0.65):
        '''
        Compute the intra-class correlation among duplicates or triplicates
        for each metabolite and removes metabolites with ICC lower than cutoff

        Parameters
        ----------
        cutoff: float
            ICC metabolite removal cutoff. 

        Returns
        ----------
        data: pd.Dataframe
            Dataframe with metabolites removed due to low ICC.
        pool: pd.Dataframe
            Dataframe with metabolites removed due to low ICC.
        '''
        print('-----Removing metabolites with ICC values lower than ' +
              str(cutoff) + '-----')
        for i in range(len(self.data)):
            duplicates_ID  = self.data[i].index[\
                             self.data[i].index.duplicated()].unique()
            duplicates_dat = self.data[i].loc[duplicates_ID]
            
            raters = []
            for j in duplicates_ID:
                n_duplicates = len(duplicates_dat.loc[j])
                for k in range(n_duplicates):
                    raters.append(k+1)
            
            iccs = []
            for met in duplicates_dat.columns:
                icc_dat = pd.DataFrame()
                icc_dat['raters']  = raters
                icc_dat['value']   = list(duplicates_dat[met])
                icc_dat['targets'] = list(duplicates_dat.index)
                icc_results = pg.intraclass_corr(icc_dat, 
                                                 targets='targets',
                                                 raters='raters',
                                                 ratings='value',
                                                 nan_policy='omit')
                iccs.append(icc_results.iloc[2,2])
            
            icc_values = pd.DataFrame(index=duplicates_dat.columns)
            icc_values['ICC'] = iccs
            remove_met_table  = icc_values[icc_values['ICC'] < cutoff]

            # Print and remove metabolites
            self._print_metabolites_removed(remove_met_table, i)
            self._remove_metabolites(remove_met_table, i)
Exemple #4
0
def calculate_icc(df):
    icc_data = []
    for i, row in df.iterrows():
        for judge in JUDGES:
            icc_data.append((row['ID'], judge, row[f'BUILDING_CLASS_{judge}'],
                             row[f'DAMAGE_{judge}']))

    icc_df = pd.DataFrame(icc_data, columns=['ID', 'judge', 'class', 'damage'])

    assert (icc_df['damage'] <= 1).all()
    assert (icc_df['damage'] >= 0).all()
    assert (icc_df['class'].isin(['G', 'M', 'B'])).all()

    icc = pg.intraclass_corr(data=icc_df,
                             targets='ID',
                             raters='judge',
                             ratings='damage',
                             nan_policy='raise').round(3)
    return icc
def icc_by_metric(df, select_metric):
    df_formatted = df[['athlete', 'id',
                       select_metric]].pivot_table(index=['athlete'],
                                                   values=select_metric,
                                                   columns='id')
    df_formatted = df_formatted.dropna(axis=1)
    df_formatted = df_formatted.unstack().reset_index(name='value')
    df_formatted.rename(columns={'value': select_metric}, inplace=True)
    reliability = pg.intraclass_corr(data=df_formatted,
                                     targets='athlete',
                                     raters='id',
                                     ratings=select_metric).iloc[2]
    reliability_dict = {
        'Metric': select_metric,
        'Type': reliability['Type'],
        'ICC': reliability['ICC'],
        'CI95': reliability['CI95%']
    }
    return reliability_dict
Exemple #6
0
    # sns.lineplot(x='EVLP ID', y=df[y].mean(), data=df, ax=ax)

fig_biochem.tight_layout()
fig_biochem.savefig('Biochem_%CV.png', dpi=200)


### Correlations of dPO2 with Other Important EVLP Parameters ###


df = pd.read_excel(r'C:\Users\chaob\Documents\Perfusate Heterogeneity %CV Plotting.xlsx', 'dPO2 Correlations')
# markers = ['o', '^', 's', 'P', '*', 'D', 'v', 'X', '>']

fig = plt.figure(figsize=(12, 6))
sns.swarmplot(x='EVLP ID', y='Correlation with dPO2', data=df, hue='Biomarker', size=10)
plt.legend(bbox_to_anchor=(1.01, 1),borderaxespad=0)
fig.tight_layout()
fig.savefig('dPO2_Correlation_Scatterplot.png', dpi=200)


### Intra-class Correlations ###


import pingouin as pg

sheet_names = ['#389', '#393', '#398', '#406', '#422', '#426', '#472', '#474']

for s in sheet_names:
    df = pd.read_excel(r'C:\Users\chaob\Documents\Perfusate Protein Data by Donor.xlsx', s)
    icc = pg.intraclass_corr(data=df, targets='Target Cytokine', raters='Location',
                             ratings='Expression').round(3)
    print(s, '\n', icc)
Exemple #7
0
def benchmark_reproducibility(comb, modality, alg, sub_dict_clean, disc,
                              int_consist, final_missingness_summary):
    df_summary = pd.DataFrame(
        columns=['grid', 'modality', 'embedding', 'discriminability'])
    print(comb)
    df_summary.at[0, "modality"] = modality
    df_summary.at[0, "embedding"] = alg

    if modality == 'func':
        try:
            extract, hpass, model, res, atlas, smooth = comb
        except:
            print(f"Missing {comb}...")
            extract, hpass, model, res, atlas = comb
            smooth = '0'
        comb_tuple = (atlas, extract, hpass, model, res, smooth)
    else:
        directget, minlength, model, res, atlas, tol = comb
        comb_tuple = (atlas, directget, minlength, model, res, tol)

    df_summary.at[0, "grid"] = comb_tuple

    missing_sub_seshes = \
        final_missingness_summary.loc[(final_missingness_summary['alg']==alg)
                                      & (final_missingness_summary[
                                             'modality']==modality) &
                                      (final_missingness_summary[
                                           'grid']==comb_tuple)
                                      ].drop_duplicates(subset='id')

    # int_consist
    if int_consist is True and alg == 'topology':
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate test-retest int_consist. pingouin"
                  " must be installed!")
        for met in mets:
            id_dict = {}
            for ID in ids:
                id_dict[ID] = {}
                for ses in sub_dict_clean[ID].keys():
                    if comb_tuple in sub_dict_clean[ID][ses][modality][
                            alg].keys():
                        id_dict[ID][ses] = \
                        sub_dict_clean[ID][ses][modality][alg][comb_tuple][
                            mets.index(met)][0]
            df_wide = pd.DataFrame(id_dict).T
            if df_wide.empty:
                del df_wide
                return pd.Series()
            df_wide = df_wide.add_prefix(f"{met}_visit_")
            df_wide.replace(0, np.nan, inplace=True)
            try:
                c_alpha = pg.cronbach_alpha(data=df_wide)
            except:
                print('FAILED...')
                print(df_wide)
                del df_wide
                return pd.Series()
            df_summary.at[0, f"cronbach_alpha_{met}"] = c_alpha[0]
            del df_wide

    # icc
    if icc is True and alg == 'topology':
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate ICC. pingouin" " must be installed!")
        for met in mets:
            id_dict = {}
            dfs = []
            for ses in [str(i) for i in range(1, 11)]:
                for ID in ids:
                    id_dict[ID] = {}
                    if comb_tuple in sub_dict_clean[ID][ses][modality][
                            alg].keys():
                        id_dict[ID][ses] = \
                        sub_dict_clean[ID][ses][modality][alg][comb_tuple][
                            mets.index(met)][0]
                    df = pd.DataFrame(id_dict).T
                    if df.empty:
                        del df_long
                        return pd.Series()
                    df.columns.values[0] = f"{met}"
                    df.replace(0, np.nan, inplace=True)
                    df['id'] = df.index
                    df['ses'] = ses
                    df.reset_index(drop=True, inplace=True)
                    dfs.append(df)
            df_long = pd.concat(dfs, names=[
                'id', 'ses', f"{met}"
            ]).drop(columns=[str(i) for i in range(1, 10)])
            try:
                c_icc = pg.intraclass_corr(data=df_long,
                                           targets='id',
                                           raters='ses',
                                           ratings=f"{met}",
                                           nan_policy='omit').round(3)
                c_icc = c_icc.set_index("Type")
                df_summary.at[0, f"icc_{met}"] = pd.DataFrame(
                    c_icc.drop(
                        index=['ICC1', 'ICC2', 'ICC3'])['ICC']).mean()[0]
            except:
                print('FAILED...')
                print(df_long)
                del df_long
                return pd.Series()
            del df_long

    if disc is True:
        vect_all = []
        for ID in ids:
            try:
                out = gen_sub_vec(sub_dict_clean, ID, modality, alg,
                                  comb_tuple)
            except:
                print(f"{ID} {modality} {alg} {comb_tuple} failed...")
                continue
            # print(out)
            vect_all.append(out)
        vect_all = [
            i for i in vect_all if i is not None and not np.isnan(i).all()
        ]
        if len(vect_all) > 0:
            if alg == 'topology':
                X_top = np.swapaxes(np.hstack(vect_all), 0, 1)
                bad_ixs = [i[1] for i in np.argwhere(np.isnan(X_top))]
                for m in set(bad_ixs):
                    if (X_top.shape[0] - bad_ixs.count(m)) / \
                        X_top.shape[0] < 0.50:
                        X_top = np.delete(X_top, m, axis=1)
            else:
                if len(vect_all) > 0:
                    X_top = np.array(pd.concat(vect_all, axis=0))
                else:
                    return pd.Series()
            shapes = []
            for ix, i in enumerate(vect_all):
                shapes.append(i.shape[0] * [list(ids)[ix]])
            Y = np.array(list(flatten(shapes)))
            if alg == 'topology':
                imp = IterativeImputer(max_iter=50, random_state=42)
            else:
                imp = SimpleImputer()
            X_top = imp.fit_transform(X_top)
            scaler = StandardScaler()
            X_top = scaler.fit_transform(X_top)
            try:
                discr_stat_val, rdf = discr_stat(X_top, Y)
            except:
                return pd.Series()
            df_summary.at[0, "discriminability"] = discr_stat_val
            print(discr_stat_val)
            print("\n")
            # print(rdf)
            del discr_stat_val
        del vect_all
    return df_summary
Exemple #8
0
pred_lenke_prob_data_col = pred_lenke_prob_data.reshape(-1)
icc_ratings = np.concatenate(
    (gt_lenke_prob_data_col, pred_lenke_prob_data_col), axis=0)

icc_targets = []
icc_raters = []
for k in range(1536):
    if k < 768:
        icc_targets.append(str(k))
        icc_raters.append('gt')
    else:
        icc_targets.append(str(k - 768))
        icc_raters.append('pred')

icc_df = pd.DataFrame({
    'Targets': icc_targets,
    'Raters': icc_raters,
    'Ratings': icc_ratings
})

icc = pg.intraclass_corr(data=icc_df,
                         targets='Targets',
                         raters='Raters',
                         ratings='Ratings')
print(icc.to_string())

###### curve type classification from probabilities
gt_prob_class = np.argmax(gt_lenke_prob_data, axis=1) + 1
pred_prob_class = np.argmax(pred_lenke_prob_data, axis=1) + 1
kap_prob_class = cohen_kappa_score(gt_prob_class, pred_prob_class)
Exemple #9
0
ICCRadar = pd.DataFrame()
ICCRadar['Athlete'] = NewData['Athlete']
ICCRadar['TimingGate'] = NewData['TimingGate_Max']
ICCRadar['Radar'] = NewData['Radar_Max']
ICCRadar = ICCRadar.melt(id_vars=['Athlete'])
ICCRadar.sort_values('Athlete', inplace=True, ascending=True)

ICCOpto = pd.DataFrame()
ICCOpto['Athlete'] = NewData['Athlete']
ICCOpto['TimingGate'] = NewData['TimingGate_Max']
ICCOpto['Optojump'] = NewData['Optojump_Max']
ICCOpto = ICCOpto.melt(id_vars=['Athlete'])
ICCOpto.sort_values('Athlete', inplace=True, ascending=True)

iccRadar = pg.intraclass_corr(data=ICCRadar,
                              targets='variable',
                              raters='Athlete',
                              ratings='value')
iccOpto = pg.intraclass_corr(data=ICCOpto,
                             targets='variable',
                             raters='Athlete',
                             ratings='value')
iccRadar = iccRadar.round(decimals=2)
iccOpto = iccOpto.round(decimals=2)
iccRadar = iccRadar.drop("Description", axis=1)
iccOpto = iccOpto.drop("Description", axis=1)

table1 = ff.create_table(iccRadar)
table2 = ff.create_table(iccOpto)
fig = make_subplots(rows=2,
                    cols=1,
                    print_grid=True,
Exemple #10
0
def benchmark_reproducibility(base_dir, comb, modality, alg, par_dict, disc,
                              final_missingness_summary, icc_tmps_dir, icc,
                              mets, ids, template):
    import gc
    import json
    import glob
    from pathlib import Path
    import ast
    import matplotlib
    from pynets.stats.utils import gen_sub_vec
    matplotlib.use('Agg')

    df_summary = pd.DataFrame(
        columns=['grid', 'modality', 'embedding', 'discriminability'])
    print(comb)
    df_summary.at[0, "modality"] = modality
    df_summary.at[0, "embedding"] = alg

    if modality == 'func':
        try:
            extract, hpass, model, res, atlas, smooth = comb
        except BaseException:
            print(f"Missing {comb}...")
            extract, hpass, model, res, atlas = comb
            smooth = '0'
        # comb_tuple = (atlas, extract, hpass, model, res, smooth)
        comb_tuple = comb
    else:
        directget, minlength, model, res, atlas, tol = comb
        # comb_tuple = (atlas, directget, minlength, model, res, tol)
        comb_tuple = comb

    df_summary.at[0, "grid"] = comb_tuple

    # missing_sub_seshes = \
    #     final_missingness_summary.loc[(final_missingness_summary['alg']==alg)
    #                                   & (final_missingness_summary[
    #                                          'modality']==modality) &
    #                                   (final_missingness_summary[
    #                                        'grid']==comb_tuple)
    #                                   ].drop_duplicates(subset='id')

    # icc
    if icc is True and alg == 'topology':
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate ICC. pingouin" " must be installed!")
        for met in mets:
            id_dict = {}
            dfs = []
            for ses in [str(i) for i in range(1, 11)]:
                for ID in ids:
                    id_dict[ID] = {}
                    if comb_tuple in par_dict[ID][str(
                            ses)][modality][alg].keys():
                        id_dict[ID][str(ses)] = \
                            par_dict[ID][str(ses)][modality][alg][comb_tuple][
                                mets.index(met)][0]
                    df = pd.DataFrame(id_dict).T
                    if df.empty:
                        del df
                        return df_summary
                    df.columns.values[0] = f"{met}"
                    df.replace(0, np.nan, inplace=True)
                    df['id'] = df.index
                    df['ses'] = ses
                    df.reset_index(drop=True, inplace=True)
                    dfs.append(df)
            df_long = pd.concat(dfs, names=[
                'id', 'ses', f"{met}"
            ]).drop(columns=[str(i) for i in range(1, 10)])
            if '10' in df_long.columns:
                df_long[f"{met}"] = df_long[f"{met}"].fillna(df_long['10'])
                df_long = df_long.drop(columns='10')
            try:
                c_icc = pg.intraclass_corr(data=df_long,
                                           targets='id',
                                           raters='ses',
                                           ratings=f"{met}",
                                           nan_policy='omit').round(3)
                c_icc = c_icc.set_index("Type")
                c_icc3 = c_icc.drop(
                    index=['ICC1', 'ICC2', 'ICC1k', 'ICC2k', 'ICC3'])
                df_summary.at[0, f"icc_{met}"] = c_icc3['ICC'].values[0]
                df_summary.at[0, f"icc_{met}_CI95%_L"] = \
                    c_icc3['CI95%'].values[0][0]
                df_summary.at[0, f"icc_{met}_CI95%_U"] = \
                    c_icc3['CI95%'].values[0][1]
            except BaseException:
                print('FAILED...')
                print(df_long)
                del df_long
                return df_summary
            del df_long
    elif icc is True and alg != 'topology':
        import re
        from pynets.stats.utils import parse_closest_ixs
        try:
            import pingouin as pg
        except ImportError:
            print("Cannot evaluate ICC. pingouin" " must be installed!")
        dfs = []
        coords_frames = []
        labels_frames = []
        for ses in [str(i) for i in range(1, 11)]:
            for ID in ids:
                if ses in par_dict[ID].keys():
                    if comb_tuple in par_dict[ID][str(
                            ses)][modality][alg].keys():
                        if 'data' in par_dict[ID][str(
                                ses)][modality][alg][comb_tuple].keys():
                            if par_dict[ID][str(ses)][modality][alg][
                                    comb_tuple]['data'] is not None:
                                if isinstance(
                                        par_dict[ID][str(ses)][modality][alg]
                                    [comb_tuple]['data'], str):
                                    data_path = par_dict[ID][str(ses)][
                                        modality][alg][comb_tuple]['data']
                                    parent_dir = Path(
                                        os.path.dirname(
                                            par_dict[ID][str(ses)][modality]
                                            [alg][comb_tuple]['data'])).parent
                                    if os.path.isfile(data_path):
                                        try:
                                            if data_path.endswith('.npy'):
                                                emb_data = np.load(data_path)
                                            elif data_path.endswith('.csv'):
                                                emb_data = np.array(
                                                    pd.read_csv(
                                                        data_path)).reshape(
                                                            -1, 1)
                                            else:
                                                emb_data = np.nan
                                            node_files = glob.glob(
                                                f"{parent_dir}/nodes/*.json")
                                        except:
                                            print(f"Failed to load data from "
                                                  f"{data_path}..")
                                            continue
                                    else:
                                        continue
                                else:
                                    node_files = glob.glob(
                                        f"{base_dir}/pynets/sub-{ID}/ses-"
                                        f"{ses}/{modality}/rsn-"
                                        f"{atlas}_res-{res}/nodes/*.json")
                                    emb_data = par_dict[ID][str(ses)][
                                        modality][alg][comb_tuple]['data']

                                emb_shape = emb_data.shape[0]

                                if len(node_files) > 0:
                                    ixs, node_dict = parse_closest_ixs(
                                        node_files,
                                        emb_shape,
                                        template=template)
                                    if len(ixs) != emb_shape:
                                        ixs, node_dict = parse_closest_ixs(
                                            node_files, emb_shape)
                                    if isinstance(node_dict, dict):
                                        coords = [
                                            node_dict[i]['coord']
                                            for i in node_dict.keys()
                                        ]
                                        labels = [
                                            node_dict[i]['label'][
                                                'BrainnetomeAtlas'
                                                'Fan2016']
                                            for i in node_dict.keys()
                                        ]
                                    else:
                                        print(f"Failed to parse coords/"
                                              f"labels from {node_files}. "
                                              f"Skipping...")
                                        continue
                                    df_coords = pd.DataFrame(
                                        [str(tuple(x)) for x in coords]).T
                                    df_coords.columns = [
                                        f"rsn-{atlas}_res-"
                                        f"{res}_{i}" for i in ixs
                                    ]
                                    # labels = [
                                    #     list(i['label'])[7] for i
                                    #     in
                                    #     node_dict]
                                    df_labels = pd.DataFrame(labels).T
                                    df_labels.columns = [
                                        f"rsn-{atlas}_res-"
                                        f"{res}_{i}" for i in ixs
                                    ]
                                    coords_frames.append(df_coords)
                                    labels_frames.append(df_labels)
                                else:
                                    print(f"No node files detected for "
                                          f"{comb_tuple} and {ID}-{ses}...")
                                    ixs = [
                                        i for i in par_dict[ID][str(ses)]
                                        [modality][alg][comb_tuple]['index']
                                        if i is not None
                                    ]
                                    coords_frames.append(pd.Series())
                                    labels_frames.append(pd.Series())

                                if len(ixs) == emb_shape:
                                    df_pref = pd.DataFrame(emb_data.T,
                                                           columns=[
                                                               f"{alg}_{i}_rsn"
                                                               f"-{atlas}_res-"
                                                               f"{res}"
                                                               for i in ixs
                                                           ])
                                    df_pref['id'] = ID
                                    df_pref['ses'] = ses
                                    df_pref.replace(0, np.nan, inplace=True)
                                    df_pref.reset_index(drop=True,
                                                        inplace=True)
                                    dfs.append(df_pref)
                                else:
                                    print(
                                        f"Embedding shape {emb_shape} for "
                                        f"{comb_tuple} does not correspond to "
                                        f"{len(ixs)} indices found for "
                                        f"{ID}-{ses}. Skipping...")
                                    continue
                        else:
                            print(
                                f"data not found in {comb_tuple}. Skipping...")
                            continue
                else:
                    continue

        if len(dfs) == 0:
            return df_summary

        if len(coords_frames) > 0 and len(labels_frames) > 0:
            coords_frames_icc = pd.concat(coords_frames)
            labels_frames_icc = pd.concat(labels_frames)
            nodes = True
        else:
            nodes = False

        df_long = pd.concat(dfs, axis=0)
        df_long = df_long.dropna(axis='columns', thresh=0.75 * len(df_long))
        df_long = df_long.dropna(axis='rows', how='all')

        dict_sum = df_summary.drop(
            columns=['grid', 'modality', 'embedding', 'discriminability'
                     ]).to_dict()

        for lp in [
                i for i in df_long.columns if 'ses' not in i and 'id' not in i
        ]:
            ix = int(lp.split(f"{alg}_")[1].split('_')[0])
            rsn = lp.split(f"{alg}_{ix}_")[1]
            df_long_clean = df_long[['id', 'ses', lp]]
            # df_long_clean = df_long[['id', 'ses', lp]].loc[(df_long[['id',
            # 'ses', lp]]['id'].duplicated() == True) & (df_long[['id', 'ses',
            # lp]]['ses'].duplicated() == True) & (df_long[['id', 'ses',
            # lp]][lp].isnull()==False)]
            # df_long_clean[lp] = np.abs(df_long_clean[lp].round(6))
            # df_long_clean['ses'] = df_long_clean['ses'].astype('int')
            # g = df_long_clean.groupby(['ses'])
            # df_long_clean = pd.DataFrame(g.apply(
            #     lambda x: x.sample(g.size().min()).reset_index(drop=True))
            #     ).reset_index(drop=True)
            try:
                c_icc = pg.intraclass_corr(data=df_long_clean,
                                           targets='id',
                                           raters='ses',
                                           ratings=lp,
                                           nan_policy='omit').round(3)
                c_icc = c_icc.set_index("Type")
                c_icc3 = c_icc.drop(
                    index=['ICC1', 'ICC2', 'ICC1k', 'ICC2k', 'ICC3'])
                icc_val = c_icc3['ICC'].values[0]
                if nodes is True:
                    coord_in = np.array(ast.literal_eval(
                        coords_frames_icc[f"{rsn}_{ix}"].mode().values[0]),
                                        dtype=np.dtype("O"))
                    label_in = np.array(
                        labels_frames_icc[f"{rsn}_{ix}"].mode().values[0],
                        dtype=np.dtype("O"))
                else:
                    coord_in = np.nan
                    label_in = np.nan
                dict_sum[f"{lp}_icc"] = icc_val
                del c_icc, c_icc3, icc_val
            except BaseException:
                print(f"FAILED for {lp}...")
                # print(df_long)
                #df_summary.at[0, f"{lp}_icc"] = np.nan
                coord_in = np.nan
                label_in = np.nan

            dict_sum[f"{lp}_coord"] = coord_in
            dict_sum[f"{lp}_label"] = label_in

        df_summary = pd.concat(
            [df_summary, pd.DataFrame(pd.Series(dict_sum).T).T], axis=1)

        print(df_summary)

        tup_name = str(comb_tuple).replace('\', \'',
                                           '_').replace('(', '').replace(
                                               ')', '').replace('\'', '')
        df_summary.to_csv(f"{icc_tmps_dir}/{alg}_{tup_name}.csv",
                          index=False,
                          header=True)
        del df_long

    # discriminability
    if disc is True:
        vect_all = []
        for ID in ids:
            try:
                out = gen_sub_vec(base_dir, par_dict, ID, modality, alg,
                                  comb_tuple)
            except BaseException:
                print(f"{ID} {modality} {alg} {comb_tuple} failed...")
                continue
            # print(out)
            vect_all.append(out)
        # ## TODO: Remove the .iloc below to include global efficiency.
        # vect_all = [pd.DataFrame(i).iloc[1:] for i in vect_all if i is not
        #             None and not np.isnan(np.array(i)).all()]
        vect_all = [
            pd.DataFrame(i) for i in vect_all
            if i is not None and not np.isnan(np.array(i)).all()
        ]

        if len(vect_all) > 0:
            if len(vect_all) > 0:
                X_top = pd.concat(vect_all, axis=0, join="outer")
                X_top = np.array(
                    X_top.dropna(axis='columns', thresh=0.50 * len(X_top)))
            else:
                print('Empty dataframe!')
                return df_summary

            shapes = []
            for ix, i in enumerate(vect_all):
                shapes.append(i.shape[0] * [list(ids)[ix]])
            Y = np.array(list(flatten(shapes)))
            if alg == 'topology':
                imp = IterativeImputer(max_iter=50, random_state=42)
            else:
                imp = SimpleImputer()
            X_top = imp.fit_transform(X_top)
            scaler = StandardScaler()
            X_top = scaler.fit_transform(X_top)
            try:
                discr_stat_val, rdf = discr_stat(X_top, Y)
                df_summary.at[0, "discriminability"] = discr_stat_val
                print(discr_stat_val)
                print("\n")
                del discr_stat_val
            except BaseException:
                print('Discriminability calculation failed...')
                return df_summary
            # print(rdf)
        del vect_all

    gc.collect()
    return df_summary
def main():

    # parse the args
    args = parse_args()

    # Set the data loaders (train, val, test)

    ### BreastPathQ ##################

    if args.mode == 'fine-tuning':

        # Train set
        transform_train = transforms.Compose([])  # None
        train_labeled_dataset = DatasetBreastPathQ_Supervised_train(args.train_image_pth, args.image_size, transform=transform_train)
        train_unlabeled_dataset = DatasetBreastPathQ_SSLtrain(args.train_image_pth, transform=TransformFix(args.image_size, args.NAug))

        # Validation set
        transform_val = transforms.Compose([transforms.Resize(size=args.image_size)])
        val_dataset = DatasetBreastPathQ_SSLtrain(args.train_image_pth, transform=transform_val)

        # train and validation split
        num_train = len(train_labeled_dataset.datalist)
        indices = list(range(num_train))
        split = int(np.floor(args.validation_split * num_train))
        np.random.shuffle(indices)
        train_idx, val_idx = indices[split:], indices[:split]

        #### Semi-Supervised Split (10, 25, 50, 100)
        labeled_train_idx = np.random.choice(train_idx, int(args.labeled_train * len(train_idx)))

        unlabeled_train_sampler = SubsetRandomSampler(train_idx)
        labeled_train_sampler = SubsetRandomSampler(labeled_train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        # Data loaders
        labeled_train_loader = torch.utils.data.DataLoader(train_labeled_dataset, batch_size=args.batch_size, sampler=labeled_train_sampler,
                                                           shuffle=True if labeled_train_sampler is None else False, num_workers=args.num_workers, pin_memory=True, drop_last=True)

        unlabeled_train_loader = torch.utils.data.DataLoader(train_unlabeled_dataset, batch_size=args.batch_size*args.mu, sampler=unlabeled_train_sampler,
                                                             shuffle=True if unlabeled_train_sampler is None else False, num_workers=args.num_workers, pin_memory=True, drop_last=True)

        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler,
                                                 shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False)

        # number of samples
        num_label_data = len(labeled_train_sampler)
        print('number of labeled training samples: {}'.format(num_label_data))

        num_unlabel_data = len(unlabeled_train_sampler)
        print('number of unlabeled training samples: {}'.format(num_unlabel_data))

        num_val_data = len(val_sampler)
        print('number of validation samples: {}'.format(num_val_data))

    elif args.mode == 'evaluation':

        # Test set
        test_transforms = transforms.Compose([transforms.Resize(size=args.image_size)])
        test_dataset = DatasetBreastPathQ_eval(args.test_image_pth, args.image_size, test_transforms)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                                                  num_workers=args.num_workers, pin_memory=True)

        # number of samples
        n_data = len(test_dataset)
        print('number of testing samples: {}'.format(n_data))

    else:

        raise NotImplementedError('invalid mode {}'.format(args.mode))

    ########################################

    # set the model
    if args.model == 'resnet18':

        model_teacher = net.TripletNet_Finetune(args.model)
        model_student = net.TripletNet_Finetune(args.model)

        classifier_teacher = net.FinetuneResNet(args.num_classes)
        classifier_student = net.FinetuneResNet(args.num_classes)

        if args.mode == 'fine-tuning':

            ###### Intialize both teacher and student network with fine-tuned SSL model ###############

            # Load model
            state_dict = torch.load(args.model_path_finetune)

            # Load fine-tuned model
            model_teacher.load_state_dict(state_dict['model'])
            model_student.load_state_dict(state_dict['model'])

            # Load fine-tuned classifier
            classifier_teacher.load_state_dict(state_dict['classifier'])
            classifier_student.load_state_dict(state_dict['classifier'])


            ################# Freeze Teacher model (Entire network)  ####################

            # look at the contents of the teacher model and freeze it
            idx = 0
            for layer_name, param in model_teacher.named_parameters():
                print(layer_name, '-->', idx)
                idx += 1

            # Freeze the teacher model
            for name, param in enumerate(model_teacher.named_parameters()):
                if name < args.modules_teacher:  # No of layers(modules) to be freezed
                    print("module", name, "was frozen")
                    param = param[1]
                    param.requires_grad = False
                else:
                    print("module", name, "was not frozen")
                    param = param[1]
                    param.requires_grad = True

            ############## Freeze Student model (Except last FC layer)  #########################

            # look at the contents of the student model and freeze it
            idx = 0
            for layer_name, param in model_student.named_parameters():
                print(layer_name, '-->', idx)
                idx += 1

            # Freeze the teacher model
            for name, param in enumerate(model_student.named_parameters()):
                if name < args.modules_student:  # No of layers(modules) to be freezed
                    print("module", name, "was frozen")
                    param = param[1]
                    param.requires_grad = False
                else:
                    print("module", name, "was not frozen")
                    param = param[1]
                    param.requires_grad = True

        elif args.mode == 'evaluation':

            # Load fine-tuned model
            state = torch.load(args.model_path_eval)

            # create new OrderedDict that does not contain `module.`
            new_state_dict = OrderedDict()

            for k, v in state['model_student'].items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v

            model_student.load_state_dict(new_state_dict)

            # create new OrderedDict that does not contain `module.`
            new_state_dict_cls = OrderedDict()

            for k, v in state['classifier_student'].items():
                name = k[7:]  # remove `module.`
                new_state_dict_cls[name] = v

            classifier_student.load_state_dict(new_state_dict_cls)

        else:
            raise NotImplementedError('invalid training {}'.format(args.mode))

    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # Load model to CUDA
    if torch.cuda.is_available():
        model_teacher = torch.nn.DataParallel(model_teacher)
        model_student = torch.nn.DataParallel(model_student)
        classifier_teacher = torch.nn.DataParallel(classifier_teacher)
        classifier_student = torch.nn.DataParallel(classifier_student)
        cudnn.benchmark = True

    # Optimiser & scheduler
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, list(model_student.parameters()) + list(classifier_student.parameters())), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60], gamma=0.1)

    # Training Model
    start_epoch = 1
    prev_best_val_loss = float('inf')

    # Start log (writing into XL sheet)
    with open(os.path.join(args.save_loss, 'fine_tuned_results.csv'), 'w') as f:
        f.write('epoch, train_loss, train_losses_x, train_losses_u, val_loss\n')

    # Routine
    for epoch in range(start_epoch, args.num_epoch + 1):

        if args.mode == 'fine-tuning':

            print("==> fine-tuning the pretrained SSL model...")

            time_start = time.time()

            train_losses, train_losses_x, train_losses_u, final_feats, final_targets = train(args, model_teacher, model_student, classifier_teacher, classifier_student, labeled_train_loader, unlabeled_train_loader, optimizer, epoch)
            print('Epoch time: {:.2f} s.'.format(time.time() - time_start))

            print("==> validating the fine-tuned model...")
            val_losses = validate(args, model_student, classifier_student, val_loader, epoch)

            # Log results
            with open(os.path.join(args.save_loss, 'fine_tuned_results.csv'), 'a') as f:
                f.write('%03d,%0.6f,%0.6f,%0.6f,%0.6f,\n' % ((epoch + 1), train_losses, train_losses_x, train_losses_u, val_losses))

            'adjust learning rate --- Note that step should be called after validate()'
            scheduler.step()

            # Iterative training: Use the student as a teacher after every epoch
            model_teacher = copy.deepcopy(model_student)
            classifier_teacher = copy.deepcopy(classifier_student)

            # Save model every 10 epochs
            if epoch % args.save_freq == 0:
                print('==> Saving...')
                state = {
                    'args': args,
                    'model_student': model_student.state_dict(),
                    'model_teacher': model_teacher.state_dict(),
                    'classifier_teacher': classifier_teacher.state_dict(),
                    'classifier_student': classifier_student.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'train_loss': train_losses,
                    'train_losses_x': train_losses_x,
                    'train_losses_u': train_losses_u,
                }
                torch.save(state, '{}/fine_CR_trained_model_{}.pt'.format(args.model_save_pth, epoch))

                # help release GPU memory
                del state
            torch.cuda.empty_cache()

            # Save model for the best val
            if (val_losses < prev_best_val_loss) & (epoch>1):
                print('==> Saving...')
                state = {
                    'args': args,
                    'model_student': model_student.state_dict(),
                    'model_teacher': model_teacher.state_dict(),
                    'classifier_teacher': classifier_teacher.state_dict(),
                    'classifier_student': classifier_student.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'train_loss': train_losses,
                    'train_losses_x': train_losses_x,
                    'train_losses_u': train_losses_u,
                }
                torch.save(state, '{}/best_CR_trained_model_{}.pt'.format(args.model_save_pth, epoch))
                prev_best_val_loss = val_losses

                # help release GPU memory
                del state
            torch.cuda.empty_cache()

        elif args.mode == 'evaluation':

            print("==> testing final test data...")
            final_predicitions, final_feats, final_targetsA, final_targetsB = test(args, model_student, classifier_student, test_loader)

            final_predicitions = final_predicitions.numpy()
            final_targetsA = final_targetsA.numpy()
            final_targetsB = final_targetsB.numpy()

            # BreastPathQ dataset #######
            d = {'targets': np.hstack(
                [np.arange(1, len(final_predicitions) + 1, 1), np.arange(1, len(final_predicitions) + 1, 1)]),
                'raters': np.hstack([np.tile(np.array(['M']), len(final_predicitions)),
                                     np.tile(np.array(['A']), len(final_predicitions))]),
                'scores': np.hstack([final_predicitions, final_targetsA])}
            df = pd.DataFrame(data=d)
            iccA = pg.intraclass_corr(data=df, targets='targets', raters='raters', ratings='scores')
            iccA.to_csv(os.path.join(args.save_loss, 'BreastPathQ_ICC_Eval_2way_MA.csv'))
            print(iccA)

            d = {'targets': np.hstack(
                [np.arange(1, len(final_predicitions) + 1, 1), np.arange(1, len(final_predicitions) + 1, 1)]),
                'raters': np.hstack([np.tile(np.array(['M']), len(final_predicitions)),
                                     np.tile(np.array(['B']), len(final_predicitions))]),
                'scores': np.hstack([final_predicitions, final_targetsB])}
            df = pd.DataFrame(data=d)
            iccB = pg.intraclass_corr(data=df, targets='targets', raters='raters', ratings='scores')
            iccB.to_csv(os.path.join(args.save_loss, 'BreastPathQ_ICC_Eval_2way_MB.csv'))
            print(iccB)

            d = {'targets': np.hstack(
                [np.arange(1, len(final_targetsA) + 1, 1), np.arange(1, len(final_targetsB) + 1, 1)]),
                'raters': np.hstack(
                    [np.tile(np.array(['A']), len(final_targetsA)), np.tile(np.array(['B']), len(final_targetsB))]),
                'scores': np.hstack([final_targetsA, final_targetsB])}
            df = pd.DataFrame(data=d)
            iccC = pg.intraclass_corr(data=df, targets='targets', raters='raters', ratings='scores')
            iccC.to_csv(os.path.join(args.save_loss, 'BreastPathQ_ICC_Eval_2way_AB.csv'))
            print(iccC)

            # Plots
            fig, ax = plt.subplots()  # P1 vs automated
            ax.scatter(final_targetsA, final_predicitions, edgecolors=(0, 0, 0))
            ax.plot([final_targetsA.min(), final_targetsA.max()], [final_targetsA.min(), final_targetsA.max()], 'k--',
                    lw=2)
            ax.set_xlabel('Pathologist1')
            ax.set_ylabel('Automated Method')
            plt.savefig(os.path.join(args.save_loss, 'BreastPathQ_Eval_2way_MA_plot.png'), dpi=300)
            plt.show()

            fig, ax = plt.subplots()  # P2 vs automated
            ax.scatter(final_targetsB, final_predicitions, edgecolors=(0, 0, 0))
            ax.plot([final_targetsB.min(), final_targetsB.max()], [final_targetsB.min(), final_targetsB.max()], 'k--',
                    lw=2)
            ax.set_xlabel('Pathologist2')
            ax.set_ylabel('Automated Method')
            plt.savefig(os.path.join(args.save_loss, 'BreastPathQ_Eval_2way_MB_plot.png'), dpi=300)
            plt.show()

            fig, ax = plt.subplots()  # P1 vs P2
            ax.scatter(final_targetsA, final_targetsB, edgecolors=(0, 0, 0))
            ax.plot([final_targetsA.min(), final_targetsA.max()], [final_targetsA.min(), final_targetsA.max()], 'k--',
                    lw=2)
            ax.set_xlabel('Pathologist1')
            ax.set_ylabel('Pathologist2')
            plt.savefig(os.path.join(args.save_loss, 'BreastPathQ_Eval_2way_AB_plot.png'), dpi=300)
            plt.show()

            # Bland altman plot
            fig, ax = plt.subplots(1, figsize=(8, 8))
            sm.graphics.mean_diff_plot(final_targetsA, final_predicitions, ax=ax)
            plt.savefig(os.path.join(args.save_loss, 'BDPlot_Eval_2way_MA_plot.png'), dpi=300)
            plt.show()

            fig, ax = plt.subplots(1, figsize=(8, 8))
            sm.graphics.mean_diff_plot(final_targetsB, final_predicitions, ax=ax)
            plt.savefig(os.path.join(args.save_loss, 'BDPlot_Eval_2way_MB_plot.png'), dpi=300)
            plt.show()

            fig, ax = plt.subplots(1, figsize=(8, 8))
            sm.graphics.mean_diff_plot(final_targetsA, final_targetsB, ax=ax)
            plt.savefig(os.path.join(args.save_loss, 'BDPlot_Eval_2way_AB_plot.png'), dpi=300)
            plt.show()

        else:

            raise NotImplementedError('mode not supported {}'.format(args.mode))