def interrater_iccs(ratings, rater_col_name='rater', index_label='onset_ms', column_labels=None): """ This function computes the interrater ICCs using the Pingouin library. By default it computes the absolute agreement between raters assuming a random sample of raters at each target (each rating at each instance). Read more on ICC2 at https://pingouin-stats.org/generated/pingouin.intraclass_corr.html#pingouin.intraclass_corr Parameters ---------- index_label: str The label denoting each measurement. This must be consistent across all raters. Default is "onset_ms". ratings: DataFrame DataFrame with the ratings information stored in a long format. rater_col_name: str The name of the column containing rater information. Default is "rater" column_labels: list The list of variables to computer inter-rater ICCs for. Default is None, which means it will compute ICCs for every column in the DataFrame not equal to the rater_col_name or the index_label. Returns ------- icc_df: DataFrame The dataframe object containing instance-level and overall intraclass correlation values. """ if not column_labels: column_labels = ratings.columns.to_list() column_labels.remove(rater_col_name) if index_label in column_labels: column_labels.remove(index_label) else: ratings[index_label] = ratings.index ratings.index.name = 'index' icc_df = pd.DataFrame( columns=['instance_level_ICC', 'instance_level_consistency']) for i, x in enumerate(column_labels): icc = pg.intraclass_corr(data=ratings, targets=index_label, raters=rater_col_name, ratings=x, nan_policy='omit') icc_df.loc[x, 'instance_level_ICC'] = icc.loc[1, 'ICC'] # evaluate item-level ICCs if icc.loc[1, 'ICC'] < 0.50: icc_df.loc[x, 'instance_level_consistency'] = 'poor' elif (icc.loc[1, 'ICC'] >= 0.50) & (icc.loc[1, 'ICC'] < 0.75): icc_df.loc[x, 'instance_level_consistency'] = 'moderate' elif (icc.loc[1, 'ICC'] >= 0.75) & (icc.loc[1, 'ICC'] < 0.90): icc_df.loc[x, 'instance_level_consistency'] = 'good' elif icc.loc[1, 'ICC'] >= 0.90: icc_df.loc[x, 'instance_level_consistency'] = 'excellent' return icc_df
def icc(ds, i, attr): df = regression(ds, i, attr) icc_index = intraclass_corr(df, targets='subject', raters='task', ratings='y') return icc_index['ICC'][0]
def remove_metabolites_icc(self, cutoff: float = 0.65): ''' Compute the intra-class correlation among duplicates or triplicates for each metabolite and removes metabolites with ICC lower than cutoff Parameters ---------- cutoff: float ICC metabolite removal cutoff. Returns ---------- data: pd.Dataframe Dataframe with metabolites removed due to low ICC. pool: pd.Dataframe Dataframe with metabolites removed due to low ICC. ''' print('-----Removing metabolites with ICC values lower than ' + str(cutoff) + '-----') for i in range(len(self.data)): duplicates_ID = self.data[i].index[\ self.data[i].index.duplicated()].unique() duplicates_dat = self.data[i].loc[duplicates_ID] raters = [] for j in duplicates_ID: n_duplicates = len(duplicates_dat.loc[j]) for k in range(n_duplicates): raters.append(k+1) iccs = [] for met in duplicates_dat.columns: icc_dat = pd.DataFrame() icc_dat['raters'] = raters icc_dat['value'] = list(duplicates_dat[met]) icc_dat['targets'] = list(duplicates_dat.index) icc_results = pg.intraclass_corr(icc_dat, targets='targets', raters='raters', ratings='value', nan_policy='omit') iccs.append(icc_results.iloc[2,2]) icc_values = pd.DataFrame(index=duplicates_dat.columns) icc_values['ICC'] = iccs remove_met_table = icc_values[icc_values['ICC'] < cutoff] # Print and remove metabolites self._print_metabolites_removed(remove_met_table, i) self._remove_metabolites(remove_met_table, i)
def calculate_icc(df): icc_data = [] for i, row in df.iterrows(): for judge in JUDGES: icc_data.append((row['ID'], judge, row[f'BUILDING_CLASS_{judge}'], row[f'DAMAGE_{judge}'])) icc_df = pd.DataFrame(icc_data, columns=['ID', 'judge', 'class', 'damage']) assert (icc_df['damage'] <= 1).all() assert (icc_df['damage'] >= 0).all() assert (icc_df['class'].isin(['G', 'M', 'B'])).all() icc = pg.intraclass_corr(data=icc_df, targets='ID', raters='judge', ratings='damage', nan_policy='raise').round(3) return icc
def icc_by_metric(df, select_metric): df_formatted = df[['athlete', 'id', select_metric]].pivot_table(index=['athlete'], values=select_metric, columns='id') df_formatted = df_formatted.dropna(axis=1) df_formatted = df_formatted.unstack().reset_index(name='value') df_formatted.rename(columns={'value': select_metric}, inplace=True) reliability = pg.intraclass_corr(data=df_formatted, targets='athlete', raters='id', ratings=select_metric).iloc[2] reliability_dict = { 'Metric': select_metric, 'Type': reliability['Type'], 'ICC': reliability['ICC'], 'CI95': reliability['CI95%'] } return reliability_dict
# sns.lineplot(x='EVLP ID', y=df[y].mean(), data=df, ax=ax) fig_biochem.tight_layout() fig_biochem.savefig('Biochem_%CV.png', dpi=200) ### Correlations of dPO2 with Other Important EVLP Parameters ### df = pd.read_excel(r'C:\Users\chaob\Documents\Perfusate Heterogeneity %CV Plotting.xlsx', 'dPO2 Correlations') # markers = ['o', '^', 's', 'P', '*', 'D', 'v', 'X', '>'] fig = plt.figure(figsize=(12, 6)) sns.swarmplot(x='EVLP ID', y='Correlation with dPO2', data=df, hue='Biomarker', size=10) plt.legend(bbox_to_anchor=(1.01, 1),borderaxespad=0) fig.tight_layout() fig.savefig('dPO2_Correlation_Scatterplot.png', dpi=200) ### Intra-class Correlations ### import pingouin as pg sheet_names = ['#389', '#393', '#398', '#406', '#422', '#426', '#472', '#474'] for s in sheet_names: df = pd.read_excel(r'C:\Users\chaob\Documents\Perfusate Protein Data by Donor.xlsx', s) icc = pg.intraclass_corr(data=df, targets='Target Cytokine', raters='Location', ratings='Expression').round(3) print(s, '\n', icc)
def benchmark_reproducibility(comb, modality, alg, sub_dict_clean, disc, int_consist, final_missingness_summary): df_summary = pd.DataFrame( columns=['grid', 'modality', 'embedding', 'discriminability']) print(comb) df_summary.at[0, "modality"] = modality df_summary.at[0, "embedding"] = alg if modality == 'func': try: extract, hpass, model, res, atlas, smooth = comb except: print(f"Missing {comb}...") extract, hpass, model, res, atlas = comb smooth = '0' comb_tuple = (atlas, extract, hpass, model, res, smooth) else: directget, minlength, model, res, atlas, tol = comb comb_tuple = (atlas, directget, minlength, model, res, tol) df_summary.at[0, "grid"] = comb_tuple missing_sub_seshes = \ final_missingness_summary.loc[(final_missingness_summary['alg']==alg) & (final_missingness_summary[ 'modality']==modality) & (final_missingness_summary[ 'grid']==comb_tuple) ].drop_duplicates(subset='id') # int_consist if int_consist is True and alg == 'topology': try: import pingouin as pg except ImportError: print("Cannot evaluate test-retest int_consist. pingouin" " must be installed!") for met in mets: id_dict = {} for ID in ids: id_dict[ID] = {} for ses in sub_dict_clean[ID].keys(): if comb_tuple in sub_dict_clean[ID][ses][modality][ alg].keys(): id_dict[ID][ses] = \ sub_dict_clean[ID][ses][modality][alg][comb_tuple][ mets.index(met)][0] df_wide = pd.DataFrame(id_dict).T if df_wide.empty: del df_wide return pd.Series() df_wide = df_wide.add_prefix(f"{met}_visit_") df_wide.replace(0, np.nan, inplace=True) try: c_alpha = pg.cronbach_alpha(data=df_wide) except: print('FAILED...') print(df_wide) del df_wide return pd.Series() df_summary.at[0, f"cronbach_alpha_{met}"] = c_alpha[0] del df_wide # icc if icc is True and alg == 'topology': try: import pingouin as pg except ImportError: print("Cannot evaluate ICC. pingouin" " must be installed!") for met in mets: id_dict = {} dfs = [] for ses in [str(i) for i in range(1, 11)]: for ID in ids: id_dict[ID] = {} if comb_tuple in sub_dict_clean[ID][ses][modality][ alg].keys(): id_dict[ID][ses] = \ sub_dict_clean[ID][ses][modality][alg][comb_tuple][ mets.index(met)][0] df = pd.DataFrame(id_dict).T if df.empty: del df_long return pd.Series() df.columns.values[0] = f"{met}" df.replace(0, np.nan, inplace=True) df['id'] = df.index df['ses'] = ses df.reset_index(drop=True, inplace=True) dfs.append(df) df_long = pd.concat(dfs, names=[ 'id', 'ses', f"{met}" ]).drop(columns=[str(i) for i in range(1, 10)]) try: c_icc = pg.intraclass_corr(data=df_long, targets='id', raters='ses', ratings=f"{met}", nan_policy='omit').round(3) c_icc = c_icc.set_index("Type") df_summary.at[0, f"icc_{met}"] = pd.DataFrame( c_icc.drop( index=['ICC1', 'ICC2', 'ICC3'])['ICC']).mean()[0] except: print('FAILED...') print(df_long) del df_long return pd.Series() del df_long if disc is True: vect_all = [] for ID in ids: try: out = gen_sub_vec(sub_dict_clean, ID, modality, alg, comb_tuple) except: print(f"{ID} {modality} {alg} {comb_tuple} failed...") continue # print(out) vect_all.append(out) vect_all = [ i for i in vect_all if i is not None and not np.isnan(i).all() ] if len(vect_all) > 0: if alg == 'topology': X_top = np.swapaxes(np.hstack(vect_all), 0, 1) bad_ixs = [i[1] for i in np.argwhere(np.isnan(X_top))] for m in set(bad_ixs): if (X_top.shape[0] - bad_ixs.count(m)) / \ X_top.shape[0] < 0.50: X_top = np.delete(X_top, m, axis=1) else: if len(vect_all) > 0: X_top = np.array(pd.concat(vect_all, axis=0)) else: return pd.Series() shapes = [] for ix, i in enumerate(vect_all): shapes.append(i.shape[0] * [list(ids)[ix]]) Y = np.array(list(flatten(shapes))) if alg == 'topology': imp = IterativeImputer(max_iter=50, random_state=42) else: imp = SimpleImputer() X_top = imp.fit_transform(X_top) scaler = StandardScaler() X_top = scaler.fit_transform(X_top) try: discr_stat_val, rdf = discr_stat(X_top, Y) except: return pd.Series() df_summary.at[0, "discriminability"] = discr_stat_val print(discr_stat_val) print("\n") # print(rdf) del discr_stat_val del vect_all return df_summary
pred_lenke_prob_data_col = pred_lenke_prob_data.reshape(-1) icc_ratings = np.concatenate( (gt_lenke_prob_data_col, pred_lenke_prob_data_col), axis=0) icc_targets = [] icc_raters = [] for k in range(1536): if k < 768: icc_targets.append(str(k)) icc_raters.append('gt') else: icc_targets.append(str(k - 768)) icc_raters.append('pred') icc_df = pd.DataFrame({ 'Targets': icc_targets, 'Raters': icc_raters, 'Ratings': icc_ratings }) icc = pg.intraclass_corr(data=icc_df, targets='Targets', raters='Raters', ratings='Ratings') print(icc.to_string()) ###### curve type classification from probabilities gt_prob_class = np.argmax(gt_lenke_prob_data, axis=1) + 1 pred_prob_class = np.argmax(pred_lenke_prob_data, axis=1) + 1 kap_prob_class = cohen_kappa_score(gt_prob_class, pred_prob_class)
ICCRadar = pd.DataFrame() ICCRadar['Athlete'] = NewData['Athlete'] ICCRadar['TimingGate'] = NewData['TimingGate_Max'] ICCRadar['Radar'] = NewData['Radar_Max'] ICCRadar = ICCRadar.melt(id_vars=['Athlete']) ICCRadar.sort_values('Athlete', inplace=True, ascending=True) ICCOpto = pd.DataFrame() ICCOpto['Athlete'] = NewData['Athlete'] ICCOpto['TimingGate'] = NewData['TimingGate_Max'] ICCOpto['Optojump'] = NewData['Optojump_Max'] ICCOpto = ICCOpto.melt(id_vars=['Athlete']) ICCOpto.sort_values('Athlete', inplace=True, ascending=True) iccRadar = pg.intraclass_corr(data=ICCRadar, targets='variable', raters='Athlete', ratings='value') iccOpto = pg.intraclass_corr(data=ICCOpto, targets='variable', raters='Athlete', ratings='value') iccRadar = iccRadar.round(decimals=2) iccOpto = iccOpto.round(decimals=2) iccRadar = iccRadar.drop("Description", axis=1) iccOpto = iccOpto.drop("Description", axis=1) table1 = ff.create_table(iccRadar) table2 = ff.create_table(iccOpto) fig = make_subplots(rows=2, cols=1, print_grid=True,
def benchmark_reproducibility(base_dir, comb, modality, alg, par_dict, disc, final_missingness_summary, icc_tmps_dir, icc, mets, ids, template): import gc import json import glob from pathlib import Path import ast import matplotlib from pynets.stats.utils import gen_sub_vec matplotlib.use('Agg') df_summary = pd.DataFrame( columns=['grid', 'modality', 'embedding', 'discriminability']) print(comb) df_summary.at[0, "modality"] = modality df_summary.at[0, "embedding"] = alg if modality == 'func': try: extract, hpass, model, res, atlas, smooth = comb except BaseException: print(f"Missing {comb}...") extract, hpass, model, res, atlas = comb smooth = '0' # comb_tuple = (atlas, extract, hpass, model, res, smooth) comb_tuple = comb else: directget, minlength, model, res, atlas, tol = comb # comb_tuple = (atlas, directget, minlength, model, res, tol) comb_tuple = comb df_summary.at[0, "grid"] = comb_tuple # missing_sub_seshes = \ # final_missingness_summary.loc[(final_missingness_summary['alg']==alg) # & (final_missingness_summary[ # 'modality']==modality) & # (final_missingness_summary[ # 'grid']==comb_tuple) # ].drop_duplicates(subset='id') # icc if icc is True and alg == 'topology': try: import pingouin as pg except ImportError: print("Cannot evaluate ICC. pingouin" " must be installed!") for met in mets: id_dict = {} dfs = [] for ses in [str(i) for i in range(1, 11)]: for ID in ids: id_dict[ID] = {} if comb_tuple in par_dict[ID][str( ses)][modality][alg].keys(): id_dict[ID][str(ses)] = \ par_dict[ID][str(ses)][modality][alg][comb_tuple][ mets.index(met)][0] df = pd.DataFrame(id_dict).T if df.empty: del df return df_summary df.columns.values[0] = f"{met}" df.replace(0, np.nan, inplace=True) df['id'] = df.index df['ses'] = ses df.reset_index(drop=True, inplace=True) dfs.append(df) df_long = pd.concat(dfs, names=[ 'id', 'ses', f"{met}" ]).drop(columns=[str(i) for i in range(1, 10)]) if '10' in df_long.columns: df_long[f"{met}"] = df_long[f"{met}"].fillna(df_long['10']) df_long = df_long.drop(columns='10') try: c_icc = pg.intraclass_corr(data=df_long, targets='id', raters='ses', ratings=f"{met}", nan_policy='omit').round(3) c_icc = c_icc.set_index("Type") c_icc3 = c_icc.drop( index=['ICC1', 'ICC2', 'ICC1k', 'ICC2k', 'ICC3']) df_summary.at[0, f"icc_{met}"] = c_icc3['ICC'].values[0] df_summary.at[0, f"icc_{met}_CI95%_L"] = \ c_icc3['CI95%'].values[0][0] df_summary.at[0, f"icc_{met}_CI95%_U"] = \ c_icc3['CI95%'].values[0][1] except BaseException: print('FAILED...') print(df_long) del df_long return df_summary del df_long elif icc is True and alg != 'topology': import re from pynets.stats.utils import parse_closest_ixs try: import pingouin as pg except ImportError: print("Cannot evaluate ICC. pingouin" " must be installed!") dfs = [] coords_frames = [] labels_frames = [] for ses in [str(i) for i in range(1, 11)]: for ID in ids: if ses in par_dict[ID].keys(): if comb_tuple in par_dict[ID][str( ses)][modality][alg].keys(): if 'data' in par_dict[ID][str( ses)][modality][alg][comb_tuple].keys(): if par_dict[ID][str(ses)][modality][alg][ comb_tuple]['data'] is not None: if isinstance( par_dict[ID][str(ses)][modality][alg] [comb_tuple]['data'], str): data_path = par_dict[ID][str(ses)][ modality][alg][comb_tuple]['data'] parent_dir = Path( os.path.dirname( par_dict[ID][str(ses)][modality] [alg][comb_tuple]['data'])).parent if os.path.isfile(data_path): try: if data_path.endswith('.npy'): emb_data = np.load(data_path) elif data_path.endswith('.csv'): emb_data = np.array( pd.read_csv( data_path)).reshape( -1, 1) else: emb_data = np.nan node_files = glob.glob( f"{parent_dir}/nodes/*.json") except: print(f"Failed to load data from " f"{data_path}..") continue else: continue else: node_files = glob.glob( f"{base_dir}/pynets/sub-{ID}/ses-" f"{ses}/{modality}/rsn-" f"{atlas}_res-{res}/nodes/*.json") emb_data = par_dict[ID][str(ses)][ modality][alg][comb_tuple]['data'] emb_shape = emb_data.shape[0] if len(node_files) > 0: ixs, node_dict = parse_closest_ixs( node_files, emb_shape, template=template) if len(ixs) != emb_shape: ixs, node_dict = parse_closest_ixs( node_files, emb_shape) if isinstance(node_dict, dict): coords = [ node_dict[i]['coord'] for i in node_dict.keys() ] labels = [ node_dict[i]['label'][ 'BrainnetomeAtlas' 'Fan2016'] for i in node_dict.keys() ] else: print(f"Failed to parse coords/" f"labels from {node_files}. " f"Skipping...") continue df_coords = pd.DataFrame( [str(tuple(x)) for x in coords]).T df_coords.columns = [ f"rsn-{atlas}_res-" f"{res}_{i}" for i in ixs ] # labels = [ # list(i['label'])[7] for i # in # node_dict] df_labels = pd.DataFrame(labels).T df_labels.columns = [ f"rsn-{atlas}_res-" f"{res}_{i}" for i in ixs ] coords_frames.append(df_coords) labels_frames.append(df_labels) else: print(f"No node files detected for " f"{comb_tuple} and {ID}-{ses}...") ixs = [ i for i in par_dict[ID][str(ses)] [modality][alg][comb_tuple]['index'] if i is not None ] coords_frames.append(pd.Series()) labels_frames.append(pd.Series()) if len(ixs) == emb_shape: df_pref = pd.DataFrame(emb_data.T, columns=[ f"{alg}_{i}_rsn" f"-{atlas}_res-" f"{res}" for i in ixs ]) df_pref['id'] = ID df_pref['ses'] = ses df_pref.replace(0, np.nan, inplace=True) df_pref.reset_index(drop=True, inplace=True) dfs.append(df_pref) else: print( f"Embedding shape {emb_shape} for " f"{comb_tuple} does not correspond to " f"{len(ixs)} indices found for " f"{ID}-{ses}. Skipping...") continue else: print( f"data not found in {comb_tuple}. Skipping...") continue else: continue if len(dfs) == 0: return df_summary if len(coords_frames) > 0 and len(labels_frames) > 0: coords_frames_icc = pd.concat(coords_frames) labels_frames_icc = pd.concat(labels_frames) nodes = True else: nodes = False df_long = pd.concat(dfs, axis=0) df_long = df_long.dropna(axis='columns', thresh=0.75 * len(df_long)) df_long = df_long.dropna(axis='rows', how='all') dict_sum = df_summary.drop( columns=['grid', 'modality', 'embedding', 'discriminability' ]).to_dict() for lp in [ i for i in df_long.columns if 'ses' not in i and 'id' not in i ]: ix = int(lp.split(f"{alg}_")[1].split('_')[0]) rsn = lp.split(f"{alg}_{ix}_")[1] df_long_clean = df_long[['id', 'ses', lp]] # df_long_clean = df_long[['id', 'ses', lp]].loc[(df_long[['id', # 'ses', lp]]['id'].duplicated() == True) & (df_long[['id', 'ses', # lp]]['ses'].duplicated() == True) & (df_long[['id', 'ses', # lp]][lp].isnull()==False)] # df_long_clean[lp] = np.abs(df_long_clean[lp].round(6)) # df_long_clean['ses'] = df_long_clean['ses'].astype('int') # g = df_long_clean.groupby(['ses']) # df_long_clean = pd.DataFrame(g.apply( # lambda x: x.sample(g.size().min()).reset_index(drop=True)) # ).reset_index(drop=True) try: c_icc = pg.intraclass_corr(data=df_long_clean, targets='id', raters='ses', ratings=lp, nan_policy='omit').round(3) c_icc = c_icc.set_index("Type") c_icc3 = c_icc.drop( index=['ICC1', 'ICC2', 'ICC1k', 'ICC2k', 'ICC3']) icc_val = c_icc3['ICC'].values[0] if nodes is True: coord_in = np.array(ast.literal_eval( coords_frames_icc[f"{rsn}_{ix}"].mode().values[0]), dtype=np.dtype("O")) label_in = np.array( labels_frames_icc[f"{rsn}_{ix}"].mode().values[0], dtype=np.dtype("O")) else: coord_in = np.nan label_in = np.nan dict_sum[f"{lp}_icc"] = icc_val del c_icc, c_icc3, icc_val except BaseException: print(f"FAILED for {lp}...") # print(df_long) #df_summary.at[0, f"{lp}_icc"] = np.nan coord_in = np.nan label_in = np.nan dict_sum[f"{lp}_coord"] = coord_in dict_sum[f"{lp}_label"] = label_in df_summary = pd.concat( [df_summary, pd.DataFrame(pd.Series(dict_sum).T).T], axis=1) print(df_summary) tup_name = str(comb_tuple).replace('\', \'', '_').replace('(', '').replace( ')', '').replace('\'', '') df_summary.to_csv(f"{icc_tmps_dir}/{alg}_{tup_name}.csv", index=False, header=True) del df_long # discriminability if disc is True: vect_all = [] for ID in ids: try: out = gen_sub_vec(base_dir, par_dict, ID, modality, alg, comb_tuple) except BaseException: print(f"{ID} {modality} {alg} {comb_tuple} failed...") continue # print(out) vect_all.append(out) # ## TODO: Remove the .iloc below to include global efficiency. # vect_all = [pd.DataFrame(i).iloc[1:] for i in vect_all if i is not # None and not np.isnan(np.array(i)).all()] vect_all = [ pd.DataFrame(i) for i in vect_all if i is not None and not np.isnan(np.array(i)).all() ] if len(vect_all) > 0: if len(vect_all) > 0: X_top = pd.concat(vect_all, axis=0, join="outer") X_top = np.array( X_top.dropna(axis='columns', thresh=0.50 * len(X_top))) else: print('Empty dataframe!') return df_summary shapes = [] for ix, i in enumerate(vect_all): shapes.append(i.shape[0] * [list(ids)[ix]]) Y = np.array(list(flatten(shapes))) if alg == 'topology': imp = IterativeImputer(max_iter=50, random_state=42) else: imp = SimpleImputer() X_top = imp.fit_transform(X_top) scaler = StandardScaler() X_top = scaler.fit_transform(X_top) try: discr_stat_val, rdf = discr_stat(X_top, Y) df_summary.at[0, "discriminability"] = discr_stat_val print(discr_stat_val) print("\n") del discr_stat_val except BaseException: print('Discriminability calculation failed...') return df_summary # print(rdf) del vect_all gc.collect() return df_summary
def main(): # parse the args args = parse_args() # Set the data loaders (train, val, test) ### BreastPathQ ################## if args.mode == 'fine-tuning': # Train set transform_train = transforms.Compose([]) # None train_labeled_dataset = DatasetBreastPathQ_Supervised_train(args.train_image_pth, args.image_size, transform=transform_train) train_unlabeled_dataset = DatasetBreastPathQ_SSLtrain(args.train_image_pth, transform=TransformFix(args.image_size, args.NAug)) # Validation set transform_val = transforms.Compose([transforms.Resize(size=args.image_size)]) val_dataset = DatasetBreastPathQ_SSLtrain(args.train_image_pth, transform=transform_val) # train and validation split num_train = len(train_labeled_dataset.datalist) indices = list(range(num_train)) split = int(np.floor(args.validation_split * num_train)) np.random.shuffle(indices) train_idx, val_idx = indices[split:], indices[:split] #### Semi-Supervised Split (10, 25, 50, 100) labeled_train_idx = np.random.choice(train_idx, int(args.labeled_train * len(train_idx))) unlabeled_train_sampler = SubsetRandomSampler(train_idx) labeled_train_sampler = SubsetRandomSampler(labeled_train_idx) val_sampler = SubsetRandomSampler(val_idx) # Data loaders labeled_train_loader = torch.utils.data.DataLoader(train_labeled_dataset, batch_size=args.batch_size, sampler=labeled_train_sampler, shuffle=True if labeled_train_sampler is None else False, num_workers=args.num_workers, pin_memory=True, drop_last=True) unlabeled_train_loader = torch.utils.data.DataLoader(train_unlabeled_dataset, batch_size=args.batch_size*args.mu, sampler=unlabeled_train_sampler, shuffle=True if unlabeled_train_sampler is None else False, num_workers=args.num_workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) # number of samples num_label_data = len(labeled_train_sampler) print('number of labeled training samples: {}'.format(num_label_data)) num_unlabel_data = len(unlabeled_train_sampler) print('number of unlabeled training samples: {}'.format(num_unlabel_data)) num_val_data = len(val_sampler) print('number of validation samples: {}'.format(num_val_data)) elif args.mode == 'evaluation': # Test set test_transforms = transforms.Compose([transforms.Resize(size=args.image_size)]) test_dataset = DatasetBreastPathQ_eval(args.test_image_pth, args.image_size, test_transforms) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # number of samples n_data = len(test_dataset) print('number of testing samples: {}'.format(n_data)) else: raise NotImplementedError('invalid mode {}'.format(args.mode)) ######################################## # set the model if args.model == 'resnet18': model_teacher = net.TripletNet_Finetune(args.model) model_student = net.TripletNet_Finetune(args.model) classifier_teacher = net.FinetuneResNet(args.num_classes) classifier_student = net.FinetuneResNet(args.num_classes) if args.mode == 'fine-tuning': ###### Intialize both teacher and student network with fine-tuned SSL model ############### # Load model state_dict = torch.load(args.model_path_finetune) # Load fine-tuned model model_teacher.load_state_dict(state_dict['model']) model_student.load_state_dict(state_dict['model']) # Load fine-tuned classifier classifier_teacher.load_state_dict(state_dict['classifier']) classifier_student.load_state_dict(state_dict['classifier']) ################# Freeze Teacher model (Entire network) #################### # look at the contents of the teacher model and freeze it idx = 0 for layer_name, param in model_teacher.named_parameters(): print(layer_name, '-->', idx) idx += 1 # Freeze the teacher model for name, param in enumerate(model_teacher.named_parameters()): if name < args.modules_teacher: # No of layers(modules) to be freezed print("module", name, "was frozen") param = param[1] param.requires_grad = False else: print("module", name, "was not frozen") param = param[1] param.requires_grad = True ############## Freeze Student model (Except last FC layer) ######################### # look at the contents of the student model and freeze it idx = 0 for layer_name, param in model_student.named_parameters(): print(layer_name, '-->', idx) idx += 1 # Freeze the teacher model for name, param in enumerate(model_student.named_parameters()): if name < args.modules_student: # No of layers(modules) to be freezed print("module", name, "was frozen") param = param[1] param.requires_grad = False else: print("module", name, "was not frozen") param = param[1] param.requires_grad = True elif args.mode == 'evaluation': # Load fine-tuned model state = torch.load(args.model_path_eval) # create new OrderedDict that does not contain `module.` new_state_dict = OrderedDict() for k, v in state['model_student'].items(): name = k[7:] # remove `module.` new_state_dict[name] = v model_student.load_state_dict(new_state_dict) # create new OrderedDict that does not contain `module.` new_state_dict_cls = OrderedDict() for k, v in state['classifier_student'].items(): name = k[7:] # remove `module.` new_state_dict_cls[name] = v classifier_student.load_state_dict(new_state_dict_cls) else: raise NotImplementedError('invalid training {}'.format(args.mode)) else: raise NotImplementedError('model not supported {}'.format(args.model)) # Load model to CUDA if torch.cuda.is_available(): model_teacher = torch.nn.DataParallel(model_teacher) model_student = torch.nn.DataParallel(model_student) classifier_teacher = torch.nn.DataParallel(classifier_teacher) classifier_student = torch.nn.DataParallel(classifier_student) cudnn.benchmark = True # Optimiser & scheduler optimizer = optim.Adam(filter(lambda p: p.requires_grad, list(model_student.parameters()) + list(classifier_student.parameters())), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60], gamma=0.1) # Training Model start_epoch = 1 prev_best_val_loss = float('inf') # Start log (writing into XL sheet) with open(os.path.join(args.save_loss, 'fine_tuned_results.csv'), 'w') as f: f.write('epoch, train_loss, train_losses_x, train_losses_u, val_loss\n') # Routine for epoch in range(start_epoch, args.num_epoch + 1): if args.mode == 'fine-tuning': print("==> fine-tuning the pretrained SSL model...") time_start = time.time() train_losses, train_losses_x, train_losses_u, final_feats, final_targets = train(args, model_teacher, model_student, classifier_teacher, classifier_student, labeled_train_loader, unlabeled_train_loader, optimizer, epoch) print('Epoch time: {:.2f} s.'.format(time.time() - time_start)) print("==> validating the fine-tuned model...") val_losses = validate(args, model_student, classifier_student, val_loader, epoch) # Log results with open(os.path.join(args.save_loss, 'fine_tuned_results.csv'), 'a') as f: f.write('%03d,%0.6f,%0.6f,%0.6f,%0.6f,\n' % ((epoch + 1), train_losses, train_losses_x, train_losses_u, val_losses)) 'adjust learning rate --- Note that step should be called after validate()' scheduler.step() # Iterative training: Use the student as a teacher after every epoch model_teacher = copy.deepcopy(model_student) classifier_teacher = copy.deepcopy(classifier_student) # Save model every 10 epochs if epoch % args.save_freq == 0: print('==> Saving...') state = { 'args': args, 'model_student': model_student.state_dict(), 'model_teacher': model_teacher.state_dict(), 'classifier_teacher': classifier_teacher.state_dict(), 'classifier_student': classifier_student.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'train_loss': train_losses, 'train_losses_x': train_losses_x, 'train_losses_u': train_losses_u, } torch.save(state, '{}/fine_CR_trained_model_{}.pt'.format(args.model_save_pth, epoch)) # help release GPU memory del state torch.cuda.empty_cache() # Save model for the best val if (val_losses < prev_best_val_loss) & (epoch>1): print('==> Saving...') state = { 'args': args, 'model_student': model_student.state_dict(), 'model_teacher': model_teacher.state_dict(), 'classifier_teacher': classifier_teacher.state_dict(), 'classifier_student': classifier_student.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'train_loss': train_losses, 'train_losses_x': train_losses_x, 'train_losses_u': train_losses_u, } torch.save(state, '{}/best_CR_trained_model_{}.pt'.format(args.model_save_pth, epoch)) prev_best_val_loss = val_losses # help release GPU memory del state torch.cuda.empty_cache() elif args.mode == 'evaluation': print("==> testing final test data...") final_predicitions, final_feats, final_targetsA, final_targetsB = test(args, model_student, classifier_student, test_loader) final_predicitions = final_predicitions.numpy() final_targetsA = final_targetsA.numpy() final_targetsB = final_targetsB.numpy() # BreastPathQ dataset ####### d = {'targets': np.hstack( [np.arange(1, len(final_predicitions) + 1, 1), np.arange(1, len(final_predicitions) + 1, 1)]), 'raters': np.hstack([np.tile(np.array(['M']), len(final_predicitions)), np.tile(np.array(['A']), len(final_predicitions))]), 'scores': np.hstack([final_predicitions, final_targetsA])} df = pd.DataFrame(data=d) iccA = pg.intraclass_corr(data=df, targets='targets', raters='raters', ratings='scores') iccA.to_csv(os.path.join(args.save_loss, 'BreastPathQ_ICC_Eval_2way_MA.csv')) print(iccA) d = {'targets': np.hstack( [np.arange(1, len(final_predicitions) + 1, 1), np.arange(1, len(final_predicitions) + 1, 1)]), 'raters': np.hstack([np.tile(np.array(['M']), len(final_predicitions)), np.tile(np.array(['B']), len(final_predicitions))]), 'scores': np.hstack([final_predicitions, final_targetsB])} df = pd.DataFrame(data=d) iccB = pg.intraclass_corr(data=df, targets='targets', raters='raters', ratings='scores') iccB.to_csv(os.path.join(args.save_loss, 'BreastPathQ_ICC_Eval_2way_MB.csv')) print(iccB) d = {'targets': np.hstack( [np.arange(1, len(final_targetsA) + 1, 1), np.arange(1, len(final_targetsB) + 1, 1)]), 'raters': np.hstack( [np.tile(np.array(['A']), len(final_targetsA)), np.tile(np.array(['B']), len(final_targetsB))]), 'scores': np.hstack([final_targetsA, final_targetsB])} df = pd.DataFrame(data=d) iccC = pg.intraclass_corr(data=df, targets='targets', raters='raters', ratings='scores') iccC.to_csv(os.path.join(args.save_loss, 'BreastPathQ_ICC_Eval_2way_AB.csv')) print(iccC) # Plots fig, ax = plt.subplots() # P1 vs automated ax.scatter(final_targetsA, final_predicitions, edgecolors=(0, 0, 0)) ax.plot([final_targetsA.min(), final_targetsA.max()], [final_targetsA.min(), final_targetsA.max()], 'k--', lw=2) ax.set_xlabel('Pathologist1') ax.set_ylabel('Automated Method') plt.savefig(os.path.join(args.save_loss, 'BreastPathQ_Eval_2way_MA_plot.png'), dpi=300) plt.show() fig, ax = plt.subplots() # P2 vs automated ax.scatter(final_targetsB, final_predicitions, edgecolors=(0, 0, 0)) ax.plot([final_targetsB.min(), final_targetsB.max()], [final_targetsB.min(), final_targetsB.max()], 'k--', lw=2) ax.set_xlabel('Pathologist2') ax.set_ylabel('Automated Method') plt.savefig(os.path.join(args.save_loss, 'BreastPathQ_Eval_2way_MB_plot.png'), dpi=300) plt.show() fig, ax = plt.subplots() # P1 vs P2 ax.scatter(final_targetsA, final_targetsB, edgecolors=(0, 0, 0)) ax.plot([final_targetsA.min(), final_targetsA.max()], [final_targetsA.min(), final_targetsA.max()], 'k--', lw=2) ax.set_xlabel('Pathologist1') ax.set_ylabel('Pathologist2') plt.savefig(os.path.join(args.save_loss, 'BreastPathQ_Eval_2way_AB_plot.png'), dpi=300) plt.show() # Bland altman plot fig, ax = plt.subplots(1, figsize=(8, 8)) sm.graphics.mean_diff_plot(final_targetsA, final_predicitions, ax=ax) plt.savefig(os.path.join(args.save_loss, 'BDPlot_Eval_2way_MA_plot.png'), dpi=300) plt.show() fig, ax = plt.subplots(1, figsize=(8, 8)) sm.graphics.mean_diff_plot(final_targetsB, final_predicitions, ax=ax) plt.savefig(os.path.join(args.save_loss, 'BDPlot_Eval_2way_MB_plot.png'), dpi=300) plt.show() fig, ax = plt.subplots(1, figsize=(8, 8)) sm.graphics.mean_diff_plot(final_targetsA, final_targetsB, ax=ax) plt.savefig(os.path.join(args.save_loss, 'BDPlot_Eval_2way_AB_plot.png'), dpi=300) plt.show() else: raise NotImplementedError('mode not supported {}'.format(args.mode))