def init_data_processing(): kvs = GlobalKVS() train_augs = init_train_augs() dataset = OAProgressionDataset(dataset=kvs['args'].dataset_root, split=kvs['metadata'], trf=train_augs) mean_vector, std_vector = init_mean_std(snapshots_dir=kvs['args'].snapshots, dataset=dataset, batch_size=kvs['args'].bs, n_threads=kvs['args'].n_threads) print(colored('====> ', 'red') + 'Mean:', mean_vector) print(colored('====> ', 'red') + 'Std:', std_vector) norm_trf = tv_transforms.Normalize(mean_vector.tolist(), std_vector.tolist()) train_trf = tv_transforms.Compose([ train_augs, partial(apply_by_index, transform=norm_trf, idx=0) ]) val_trf = tv_transforms.Compose([ img_labels2solt, slc.Stream([ slt.ResizeTransform((310, 310)), slt.CropTransform(crop_size=(300, 300), crop_mode='c'), slt.ImageColorTransform(mode='gs2rgb'), ], interpolation='bicubic'), unpack_solt_data, partial(apply_by_index, transform=tv_transforms.ToTensor(), idx=0), partial(apply_by_index, transform=norm_trf, idx=0) ]) kvs.update('train_trf', train_trf) kvs.update('val_trf', val_trf) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def save_checkpoint(model, val_metric_name, comparator='lt'): if isinstance(model, torch.nn.DataParallel): model = model.module kvs = GlobalKVS() fold_id = kvs['cur_fold'] epoch = kvs['cur_epoch'] val_metric = kvs[f'val_metrics_fold_[{fold_id}]'][-1][0][val_metric_name] comparator = getattr(operator, comparator) cur_snapshot_name = os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], f'fold_{fold_id}_epoch_{epoch + 1}.pth') if kvs['prev_model'] is None: print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) torch.save(model.state_dict(), cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) else: if comparator(val_metric, kvs['best_val_metric']): print(colored('====> ', 'red') + 'Snapshot was saved to', cur_snapshot_name) os.remove(kvs['prev_model']) torch.save(model.state_dict(), cur_snapshot_name) kvs.update('prev_model', cur_snapshot_name) kvs.update('best_val_metric', val_metric) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_session(snapshot_name_prefix=None): if not torch.cuda.is_available(): raise EnvironmentError('The code must be run on GPU.') kvs = GlobalKVS() # Getting the arguments args = parse_args() kvs.update('args', args) # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) # Creating the snapshot init_snapshot_dir(snapshot_name_prefix) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.save_pkl(os.path.join(args.snapshots, kvs['snapshot_name'], 'session.pkl')) return args, kvs['snapshot_name']
def log_metrics_prog(boardlogger, train_loss, val_loss, gt_progression, preds_progression, gt_kl, preds_kl): kvs = GlobalKVS() res = testtools.calc_metrics(gt_progression, gt_kl, preds_progression, preds_kl) res['val_loss'] = val_loss, res['epoch'] = kvs['cur_epoch'] print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}') print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}') print(colored('====> ', 'green') + f'Validation AUC [prog]: {res["auc_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation F1 @ 0.3 [prog]: {res["f1_score_03_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation F1 @ 0.4 [prog]: {res["f1_score_04_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation F1 @ 0.5 [prog]: {res["f1_score_05_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation AP [prog]: {res["ap_prog"]:.5f}') print(colored('====> ', 'green') + f'Validation AUC [oa]: {res["auc_oa"]:.5f}') print(colored('====> ', 'green') + f'Kappa [oa]: {res["kappa_kl"]:.5f}') boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch']) boardlogger.add_scalars('AUC progression', {'val': res['auc_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('F1-score @ 0.3 progression', {'val': res['f1_score_03_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('F1-score @ 0.4 progression', {'val': res['f1_score_04_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('F1-score @ 0.5 progression', {'val': res['f1_score_05_prog']}, kvs['cur_epoch']) boardlogger.add_scalars('Average Precision progression', {'val': res['ap_prog']}, kvs['cur_epoch']) kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'], 'train_loss': train_loss, 'val_loss': val_loss}) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_age_sex_bmi_metadata(): kvs = GlobalKVS() oai_meta = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv')) clinical_data_oai = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'OAI_participants.csv')) oai_meta = pd.merge(oai_meta, clinical_data_oai, on=('ID', 'Side')) oai_meta = oai_meta[~oai_meta.BMI.isna() & ~oai_meta.AGE.isna() & ~oai_meta.SEX.isna()] clinical_data_most = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'MOST_participants.csv')) metadata_test = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv')) metadata_test = pd.merge(metadata_test, clinical_data_most, on=('ID', 'Side')) kvs.update('metadata', oai_meta) kvs.update('metadata_test', metadata_test) gkf = GroupKFold(n_splits=5) cv_split = [ x for x in gkf.split(kvs['metadata'], kvs['metadata'][ kvs['args'].target_var], kvs['metadata']['ID'].astype(str)) ] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))
def init_folds(project='OA_progression'): kvs = GlobalKVS() writers = {} cv_split_train = {} for fold_id, split in enumerate(kvs['cv_split_all_folds']): if kvs['args'].fold != -1 and fold_id != kvs['args'].fold: continue kvs.update(f'losses_fold_[{fold_id}]', None, list) kvs.update(f'val_metrics_fold_[{fold_id}]', None, list) cv_split_train[fold_id] = split writers[fold_id] = SummaryWriter(os.path.join(kvs['args'].logs, project, 'fold_{}'.format(fold_id), kvs['snapshot_name'])) kvs.update('cv_split_train', cv_split_train) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) return writers
def init_progression_metadata(): # We get should rid of non-progressors from MOST because we can check # non-progressors only up to 84 months kvs = GlobalKVS() most_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv')) oai_meta = pd.read_csv(os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv')) if kvs['args'].subsample_train != -1: n_train = oai_meta.shape[0] prevalence = (oai_meta.Progressor > 0).sum() / n_train sample_pos = int(kvs['args'].subsample_train * prevalence) sample_neg = kvs['args'].subsample_train - sample_pos train_pos = oai_meta[oai_meta.Progressor > 0] train_neg = oai_meta[oai_meta.Progressor == 0] pos_sampled = train_pos.iloc[np.random.choice(train_pos.shape[0], sample_pos)] neg_sampled = train_neg.iloc[np.random.choice(train_neg.shape[0], sample_neg)] new_meta = pd.concat((pos_sampled, neg_sampled)) oai_meta = new_meta.iloc[np.random.choice(new_meta.shape[0], new_meta.shape[0])] print(colored("==> ", 'red') + f"Train set has been sub-sampled. New # pos/neg {sample_pos}/{sample_neg}") kvs.update('metadata', oai_meta) kvs.update('metadata_test', most_meta) gkf = GroupKFold(n_splits=5) cv_split = [x for x in gkf.split(kvs['metadata'], kvs['metadata']['Progressor'], kvs['metadata']['ID'].astype(str))] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) print(colored("==> ", 'green') + f"Train dataset has {(kvs['metadata'].Progressor == 0).sum()} non-progressed knees") print(colored("==> ", 'green') + f"Train dataset has {(kvs['metadata'].Progressor > 0).sum()} progressed knees") print(colored("==> ", 'green') + f"Test dataset has {(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees") print(colored("==> ", 'green') + f"Test dataset has {(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
def init_progression_metadata(): # We get should rid of non-progressors from MOST because we can check # non-progressors only up to 84 months kvs = GlobalKVS() most_meta = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'MOST_progression.csv')) oai_meta = pd.read_csv( os.path.join(kvs['args'].metadata_root, 'OAI_progression.csv')) kvs.update('metadata', oai_meta) kvs.update('metadata_test', most_meta) gkf = GroupKFold(n_splits=5) cv_split = [ x for x in gkf.split(kvs['metadata'], kvs['metadata']['Progressor'], kvs['metadata']['ID'].astype(str)) ] kvs.update('cv_split_all_folds', cv_split) kvs.save_pkl( os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl')) print( colored("==> ", 'green') + f"Train dataset has " f"{(kvs['metadata'].Progressor == 0).sum()} non-progressed knees") print( colored("==> ", 'green') + f"Train dataset has " f"{(kvs['metadata'].Progressor > 0).sum()} progressed knees") print( colored("==> ", 'green') + f"Test dataset has " f"{(kvs['metadata_test'].Progressor == 0).sum()} non-progressed knees") print( colored("==> ", 'green') + f"Test dataset has " f"{(kvs['metadata_test'].Progressor > 0).sum()} progressed knees")
def init_session(): if not torch.cuda.is_available(): raise EnvironmentError('The code must be run on GPU.') kvs = GlobalKVS() # Getting the arguments args = parse_args() # Initializing the seeds torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) # Creating the snapshot snapshot_name = time.strftime('%Y_%m_%d_%H_%M') os.makedirs(os.path.join(args.snapshots, snapshot_name), exist_ok=True) res = git_info() if res is not None: kvs.update('git branch name', res[0]) kvs.update('git commit id', res[1]) else: kvs.update('git branch name', None) kvs.update('git commit id', None) kvs.update('pytorch_version', torch.__version__) if torch.cuda.is_available(): kvs.update('cuda', torch.version.cuda) kvs.update('gpus', torch.cuda.device_count()) else: kvs.update('cuda', None) kvs.update('gpus', None) kvs.update('snapshot_name', snapshot_name) kvs.update('args', args) kvs.save_pkl(os.path.join(args.snapshots, snapshot_name, 'session.pkl')) return args, snapshot_name
def log_metrics_age_sex_bmi(boardlogger, train_loss, val_res): kvs = GlobalKVS() res = dict() val_loss = val_res[0] res['val_loss'] = val_loss, res['epoch'] = kvs['cur_epoch'] print(colored('====> ', 'green') + f'Train loss: {train_loss:.5f}') print(colored('====> ', 'green') + f'Validation loss: {val_loss:.5f}') boardlogger.add_scalars('Losses', {'train': train_loss, 'val': val_loss}, kvs['cur_epoch']) if not kvs['args'].predict_age_sex_bmi: _, ids, gt, preds = val_res if kvs['args'].target_var == 'SEX': val_auc = roc_auc_score(gt.astype(int), preds) res['sex_auc'] = val_auc print(colored('====> ', 'green') + f'Validation AUC: {val_auc:.5f}') boardlogger.add_scalars('AUC sex', {'val': res['sex_auc']}, kvs['cur_epoch']) else: val_mse = mean_squared_error(gt, preds) val_mae = median_absolute_error(gt, preds) res[f"{kvs['args'].target_var}_mse"] = val_mse res[f"{kvs['args'].target_var}_mae"] = val_mae print(colored('====> ', 'green') + f'Validation mae: {val_mae:.5f}') print(colored('====> ', 'green') + f'Validation mse: {val_mse:.5f}') boardlogger.add_scalars(f"MSE [{kvs['args'].target_var}]", {'val': val_mse}, kvs['cur_epoch']) boardlogger.add_scalars(f"MAE [{kvs['args'].target_var}]", {'val': val_mae}, kvs['cur_epoch']) else: _, ids, gt_age, preds_age, gt_sex, preds_sex, gt_bmi, preds_bmi = val_res val_mse_age = mean_squared_error(gt_age, preds_age) val_mae_age = median_absolute_error(gt_age, preds_age) val_sex_auc = roc_auc_score(gt_sex.astype(int), preds_sex) val_mse_bmi = mean_squared_error(gt_bmi, preds_bmi) val_mae_bmi = median_absolute_error(gt_bmi, preds_bmi) res["AGE_mse"] = val_mse_age res["AGE_mae"] = val_mae_age res["BMI_mse"] = val_mse_bmi res["BMI_mae"] = val_mae_bmi res["SEX_auc"] = val_sex_auc print(colored('====> ', 'green') + f'Validation mae [Age]: {val_mae_age:.5f}') print(colored('====> ', 'green') + f'Validation mse [Age]: {val_mse_age:.5f}') print(colored('====> ', 'green') + f'Validation val_auc [Sex]: {val_sex_auc:.5f}') print(colored('====> ', 'green') + f'Validation mae [BMI]: {val_mae_bmi:.5f}') print(colored('====> ', 'green') + f'Validation mse [BMI]: {val_mse_bmi:.5f}') kvs.update(f'losses_fold_[{kvs["cur_fold"]}]', {'epoch': kvs['cur_epoch'], 'train_loss': train_loss, 'val_loss': val_loss}) kvs.update(f'val_metrics_fold_[{kvs["cur_fold"]}]', res) kvs.save_pkl(os.path.join(kvs['args'].snapshots, kvs['snapshot_name'], 'session.pkl'))