def test(config, fold, model, loader, dir_to_load, dir_confusion): """ free all GPU memory """ torch.cuda.empty_cache() # criterion_cls = nn.CrossEntropyLoss() # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha) criterion_cls = nn.BCELoss() test_loader = loader """ load the model """ model_dir = ut.model_dir_to_load(fold, dir_to_load) if model_dir != None: model.load_state_dict(torch.load(model_dir)) model.eval() if fst.flag_eval_cropping == True: dict_result = ut.eval_classification_model_cropped_input( config, fold, test_loader, model, criterion_cls) elif fst.flag_eval_translation == True: dict_result = ut.eval_classification_model_esemble( config, fold, test_loader, model, criterion_cls) elif fst.flag_MC_dropout == True: dict_result = ut.eval_classification_model_MC_dropout( config, fold, test_loader, model, criterion_cls) elif fst.flag_bayesian == True: dict_result = ut.eval_classification_model_bayesian( config, fold, test_loader, model, criterion_cls) else: dict_result = ut.eval_classification_model(config, fold, test_loader, model, criterion_cls, flag_heatmap=False) return dict_result
def test(config, fold, model, loader, dir_to_load, dir_confusion): """ free all GPU memory """ torch.cuda.empty_cache() # criterion_cls = nn.CrossEntropyLoss() # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha) criterion_cls = nn.BCELoss() criterion_L1 = nn.L1Loss(reduction='sum').cuda() test_loader = loader """ load the model """ model_dir = ut.model_dir_to_load(fold, dir_to_load) if model_dir != None: model.load_state_dict(torch.load(model_dir)) model.eval() dict_result = ut.eval_multi_task_model(config, fold, test_loader, model, criterion_cls, criterion_L1) return dict_result
def test(config, fold, model_1, model_2, loader, dir_to_load, dir_confusion): """ free all GPU memory """ torch.cuda.empty_cache() # criterion_cls = nn.CrossEntropyLoss() # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha) criterion_cls = nn.BCELoss() # criterion = nn.MSELoss(reduction='mean').cuda() test_loader = loader """ load the model """ model_dir = ut.model_dir_to_load(fold, dir_to_load) if model_dir != None: model_1.load_state_dict(torch.load(model_dir)) model_1.eval() dict_result = ut.eval_classification_model_2(config, fold, test_loader, model_1, model_2, criterion_cls) return dict_result
def get_heatmap_1class(config, fold, model, dir_to_load, dir_heatmap): """ free all GPU memory """ torch.cuda.empty_cache() """ loss """ criterion = nn.L1Loss() """ load the fold list for test """ list_test_data = DL.concat_class_of_interest( config, fold, list_class=st.list_class_for_test, flag_tr_val_te='test') test_loader = DL.convert_Dloader_3(config.v_batch_size, list_test_data[0], list_test_data[1], list_test_data[2], list_test_data[3], is_training=False, num_workers=0, shuffle=False) """ load the model """ model_dir = ut.model_dir_to_load(fold, dir_to_load) model.load_state_dict(torch.load(model_dir)) model.eval() """ param for accuracy """ test_batchnum = 0 """ eval """ patch_size = st.patch_size stride_between_patches = st.patch_stride count = 0 with torch.no_grad(): for datas, labels, alabels, mlabel in test_loader: count += 1 if count < 3: datas_backup = datas """ padding """ pred_Map = np.zeros_like(datas) """ padding """ m = nn.ConstantPad3d(patch_size // 2, 0) datas = m(datas) """ loop as much as the size of strides """ for i in range(stride_between_patches): print("i : {0}".format(i)) for j in range(stride_between_patches): # print("j : {0}".format(j)) for k in range(stride_between_patches): test_batchnum += 1 """ input""" data = Variable(datas[:, :, i:, j:, k:]).cuda() labels = Variable(labels.long()).cuda() alabels = Variable(alabels.float()).cuda() """ run classification model """ dict_result = model(data) predMap = dict_result['predMap'] shape = predMap.shape for a in range(shape[-3]): for b in range(shape[-2]): for c in range(shape[-1]): if predMap is not None: pred_Map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ predMap[:, 0, a, b, c].data.cpu().numpy() # logit_list.append(test_output_1.cpu().numpy()) torch.cuda.empty_cache() print("finished a sample!") for sample in range(pred_Map.shape[0]): tmp_save_dir = dir_heatmap + '/fold_{0}'.format(fold) ut.make_dir(dir=tmp_save_dir, flag_rm=False) ut.save_featureMap_tensor(datas_backup[sample][0], tmp_save_dir, 'input_{}'.format(count)) if predMap is not None: ut.save_featureMap_numpy(pred_Map[sample][0], tmp_save_dir, 'pred_map_{}'.format(count)) orig_img = datas_backup[sample][0] heatmap_img = pred_Map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/1_logit_map', fig_title='Original Logit Map', thresh=0.2, percentile=1)
def get_multi_heatmap_2class(config, fold, model, dir_to_load, dir_heatmap): """ free all GPU memory """ torch.cuda.empty_cache() """ loss """ criterion = nn.L1Loss() """ load the fold list for test """ list_test_data = DL.concat_class_of_interest( config, fold, list_class=st.list_class_for_test, flag_tr_val_te='test') test_loader = DL.convert_Dloader_3(config.v_batch_size, list_test_data[0], list_test_data[1], list_test_data[2], list_test_data[3], is_training=False, num_workers=0, shuffle=False) """ load the model """ model_dir = ut.model_dir_to_load(fold, dir_to_load) if model_dir != None: model.load_state_dict(torch.load(model_dir)) model.eval() """ param for accuracy """ test_batchnum = 0 """ eval """ patch_size_1 = 9 patch_size_2 = 17 patch_size_3 = 33 stride_between_patches = st.patch_stride count = 0 with torch.no_grad(): for datas, labels, alabel, mlabel in test_loader: datas_backup = datas count += 1 datas_pred = Variable(datas).cuda() dict_result = model(datas_pred) output_logit = dict_result['logits'] prob = nn.Softmax(dim=1)(output_logit) pred = prob.argmax(dim=1, keepdim=True) # batch, 1 """ padding """ logit_map = np.zeros_like(datas) attn_1_map = np.zeros_like(datas) attn_2_map = np.zeros_like(datas) attn_3_map = np.zeros_like(datas) final_evidence_map_1 = np.zeros_like(datas) final_evidence_map_2 = np.zeros_like(datas) final_evidence_map_3 = np.zeros_like(datas) """ padding """ m_1 = nn.ConstantPad3d(patch_size_1 // 2, 0) datas_1 = m_1(datas) m_2 = nn.ConstantPad3d(patch_size_2 // 2, 0) datas_2 = m_2(datas) m_3 = nn.ConstantPad3d(patch_size_3 // 2, 0) datas_3 = m_3(datas) """ loop as much as the size of strides """ for i in range(stride_between_patches): print("i : {0}".format(i)) for j in range(stride_between_patches): # print("j : {0}".format(j)) for k in range(stride_between_patches): test_batchnum += 1 """ input""" data_1 = Variable(datas_1[:, :, i:, j:, k:]).cuda() dict_result = model(data_1) attn_1 = dict_result['attn_1'] # 1, 1, 25, 30, 24 final_evidence_a = dict_result['final_evidence_a'] data_2 = Variable(datas_2[:, :, i:, j:, k:]).cuda() dict_result = model(data_2) attn_2 = dict_result['attn_2'] # 1, 1, 24, 29, 23 final_evidence_b = dict_result['final_evidence_b'] data_3 = Variable(datas_3[:, :, i:, j:, k:]).cuda() dict_result = model(data_3) attn_3 = dict_result['attn_3'] # 1, 1, 22, 27, 21 final_evidence_c = dict_result['final_evidence_c'] shape = final_evidence_a.shape for a in range(shape[-3]): for b in range(shape[-2]): for c in range(shape[-1]): if attn_1 is not None: tmp_index_2 = [a, b, c] attn_1_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ attn_1[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy() if final_evidence_a is not None: tmp_index_2 = [a, b, c] final_evidence_map_1[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ final_evidence_a[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy() shape = final_evidence_b.shape for a in range(shape[-3]): for b in range(shape[-2]): for c in range(shape[-1]): if attn_2 is not None: tmp_index_2 = [a, b, c] attn_2_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ attn_2[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy() if final_evidence_b is not None: tmp_index_2 = [a, b, c] final_evidence_map_2[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ final_evidence_b[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy() shape = final_evidence_c.shape for a in range(shape[-3]): for b in range(shape[-2]): for c in range(shape[-1]): if attn_3 is not None: tmp_index_2 = [a, b, c] attn_3_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ attn_3[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy() if final_evidence_c is not None: tmp_index_2 = [a, b, c] final_evidence_map_3[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ final_evidence_c[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy() torch.cuda.empty_cache() print("finished a sample!") for sample in range(logit_map.shape[0]): tmp_save_dir = dir_heatmap + '/fold_{0}'.format(fold) ut.make_dir(dir=tmp_save_dir, flag_rm=False) ut.save_featureMap_tensor( datas_backup[sample][0], tmp_save_dir, 'input_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) if attn_1 is not None: ut.save_featureMap_numpy( attn_1_map[sample][0], tmp_save_dir, 'attn_1_map_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = attn_1_map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/2_attn_map_1_{}'.format(count), fig_title='Attention Map 1', thresh=0.2, percentile=1) if attn_2 is not None: ut.save_featureMap_numpy( attn_2_map[sample][0], tmp_save_dir, 'attn_2_map_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = attn_2_map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/2_attn_map_2_{}'.format(count), fig_title='Attention Map 2', thresh=0.2, percentile=1) if attn_3 is not None: ut.save_featureMap_numpy( attn_3_map[sample][0], tmp_save_dir, 'attn_3_map_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = attn_3_map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/2_attn_map_3_{}'.format(count), fig_title='Attention Map 3', thresh=0.2, percentile=1) if final_evidence_a is not None: ut.save_featureMap_numpy( final_evidence_map_1[sample][0], tmp_save_dir, 'final_evidence_map_1_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = final_evidence_map_1[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/3_final_evidence_1_{}'.format(count), fig_title='Final Evidence 1', thresh=0.2, percentile=1) if final_evidence_b is not None: ut.save_featureMap_numpy( final_evidence_map_2[sample][0], tmp_save_dir, 'final_evidence_map_2_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = final_evidence_map_2[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/3_final_evidence_2_{}'.format(count), fig_title='Final Evidence 2', thresh=0.2, percentile=1) if final_evidence_c is not None: ut.save_featureMap_numpy( final_evidence_map_3[sample][0], tmp_save_dir, 'final_evidence_map_3_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = final_evidence_map_3[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/3_final_evidence_3_{}'.format(count), fig_title='Final Evidence 3', thresh=0.2, percentile=1)
def get_heatmap_2class(config, fold, model, dir_to_load, dir_heatmap): """ free all GPU memory """ torch.cuda.empty_cache() """ loss """ criterion = nn.L1Loss() """ load the fold list for test """ list_test_data = DL.concat_class_of_interest( config, fold, list_class=st.list_class_for_test, flag_tr_val_te='test') test_loader = DL.convert_Dloader_3(config.v_batch_size, list_test_data[0], list_test_data[1], list_test_data[2], list_test_data[3], is_training=False, num_workers=0, shuffle=False) """ load the model """ model_dir = ut.model_dir_to_load(fold, dir_to_load) if model_dir != None: model.load_state_dict(torch.load(model_dir)) model.eval() """ param for accuracy """ test_batchnum = 0 """ eval """ patch_size = st.patch_size stride_between_patches = st.patch_stride count = 0 with torch.no_grad(): for datas, labels, alabel, mlabel in test_loader: datas_backup = datas count += 1 """ prediction """ datas_pred = Variable(datas).cuda() dict_result = model(datas_pred) output_logit = dict_result['logits'] prob = nn.Softmax(dim=1)(output_logit) pred = prob.argmax(dim=1, keepdim=True) # batch, 1 """ padding """ logit_map = np.zeros_like(datas) attn_1_map = np.zeros_like(datas) attn_2_map = np.zeros_like(datas) final_evidence_map = np.zeros_like(datas) """ padding """ m = nn.ConstantPad3d(patch_size // 2, 0) datas = m(datas) """ loop as much as the size of strides """ for i in range(stride_between_patches): print("i : {0}".format(i)) for j in range(stride_between_patches): # print("j : {0}".format(j)) for k in range(stride_between_patches): """ input""" data = Variable(datas[:, :, i:, j:, k:]).cuda() labels = Variable(labels.long()).cuda() """ run classification model """ dict_result = model(data) logitMap = dict_result['logitMap'] attn_1 = dict_result['attn_1'] attn_2 = dict_result['attn_2'] final_evidence = dict_result['final_evidence'] shape = logitMap.shape for a in range(shape[-3]): for b in range(shape[-2]): for c in range(shape[-1]): if logitMap is not None: logit_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ logitMap[:, pred[0].data.cpu().numpy()[0], a, b, c].data.cpu().numpy() if attn_1 is not None: attn_1_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ attn_1[:, 0, a, b, c].data.cpu().numpy() if attn_2 is not None: attn_2_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ attn_2[:, 0, a, b, c].data.cpu().numpy() if final_evidence is not None: final_evidence_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \ final_evidence[:, pred[0].data.cpu().numpy()[0], a, b, c].data.cpu().numpy() # logit_list.append(test_output_1.cpu().numpy()) torch.cuda.empty_cache() print("finished a sample!") for sample in range(logit_map.shape[0]): tmp_save_dir = dir_heatmap + '/fold_{0}'.format(fold, ) ut.make_dir(dir=tmp_save_dir, flag_rm=False) ut.save_featureMap_tensor( datas_backup[sample][0], tmp_save_dir, 'input_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) if logitMap is not None: ut.save_featureMap_numpy( logit_map[sample][0], tmp_save_dir, 'logit_map_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = logit_map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/1_logit_map_{}'.format(count), fig_title='Original Logit Map', thresh=0.2, percentile=1) if attn_1 is not None: ut.save_featureMap_numpy( attn_1_map[sample][0], tmp_save_dir, 'attn_1_map_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = attn_1_map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/2_attn_map_1_{}'.format(count), fig_title='Attention Map 1', thresh=0.2, percentile=1) if attn_2 is not None: ut.save_featureMap_numpy( attn_2_map[sample][0], tmp_save_dir, 'attn_2_map_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = attn_2_map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/2_attn_map_2_{}'.format(count), fig_title='Attention Map 2', thresh=0.2, percentile=1) if final_evidence is not None: ut.save_featureMap_numpy( final_evidence_map[sample][0], tmp_save_dir, 'final_evidence_map_{0}_gt_{1}_pred_{2}'.format( count, st.list_selected_for_test[ labels[sample].data.cpu().numpy()], st.list_selected_for_test[ pred[sample].data.cpu().numpy()[0]])) orig_img = datas_backup[sample][0] heatmap_img = final_evidence_map[sample][0] ut.plot_heatmap_with_overlay( orig_img=orig_img, heatmap_img=heatmap_img, save_dir=tmp_save_dir + '/3_final_evidence_{}'.format(count), fig_title='Final Evidence', thresh=0.2, percentile=1)
def main(config): """ 1. data process """ if fst.flag_orig_npy == True: print('preparation of the numpy') if os.path.exists(st.orig_npy_dir) == False: os.makedirs(st.orig_npy_dir) """ processing """ if st.list_data_type[st.data_type_num] == 'Density': cDL.Prepare_data_GM_AGE_MMSE() elif st.list_data_type[st.data_type_num] == 'ADNI_JSY': jDL.Prepare_data_1() elif st.list_data_type[st.data_type_num] == 'ADNI_Jacob_256': jcDL.Prepare_data_GM_WM_CSF() elif 'ADNI_Jacob' in st.list_data_type[st.data_type_num]: jcDL.Prepare_data_GM() elif 'ADNI_AAL_256' in st.list_data_type[st.data_type_num]: aDL.Prepare_data_GM() if fst.flag_orig_npy_other_dataset == True: cDL.Prepare_data_GM_age_others(dataset='ABIDE') cDL.Prepare_data_GM_age_others(dataset='ICBM') cDL.Prepare_data_GM_age_others(dataset='Cam') cDL.Prepare_data_GM_age_others(dataset='IXI') cDL.Prepare_data_GM_age_others(dataset='PPMI') """ 2. fold index processing """ if fst.flag_fold_index == True: print('preparation of the fold index') if os.path.exists(st.fold_index_dir) == False: os.makedirs(st.fold_index_dir) """ save the fold index """ ut.preparation_fold_index(config) """ fold selection """ start_fold = st.start_fold end_fold = st.end_fold """ workbook """ list_dir_result = [] list_wb = [] list_ws = [] for i in range(len(st.list_standard_eval_dir)): list_dir_result.append(st.dir_to_save_1 + st.list_standard_eval_dir[i]) ut.make_dir(dir=list_dir_result[i], flag_rm=False) out = ut.excel_setting(start_fold=start_fold, end_fold=end_fold, result_dir=list_dir_result[i], f_name='results') list_wb.append(out[0]) list_ws.append(out[1]) """ fold """ list_eval_metric = st.list_eval_metric metric_avg = [[[] for j in range(len(st.list_eval_metric))] for i in range(len(st.list_standard_eval_dir))] for fold in range(start_fold, end_fold + 1): print("FOLD : {}".format(fold)) ## TODO : Directory preparation print('-' * 10 + 'Directory preparation' + '-' * 10) list_dir_save_model = [] list_dir_save_model_2 = [] list_dir_confusion = [] list_dir_age_pred = [] list_dir_heatmap = [] for i in range(len(st.list_standard_eval_dir)): """ dir to save model """ list_dir_save_model.append(st.dir_to_save_1 + st.list_standard_eval_dir[i] + '/weights/fold_{}'.format(fold)) ut.make_dir(dir=list_dir_save_model[i], flag_rm=False) list_dir_save_model_2.append(st.dir_to_save_1 + st.list_standard_eval_dir[i] + '/weights_2/fold_{}'.format(fold)) ut.make_dir(dir=list_dir_save_model_2[i], flag_rm=False) """ dir to save confusion matrix """ list_dir_confusion.append(st.dir_to_save_1 + st.list_standard_eval_dir[i] + '/confusion') ut.make_dir(dir=list_dir_confusion[i], flag_rm=False) """ dir to save age pred """ list_dir_age_pred.append(st.dir_to_save_1 + st.list_standard_eval_dir[i] + '/age_pred') ut.make_dir(dir=list_dir_age_pred[i], flag_rm=False) list_dir_heatmap.append(st.dir_to_save_1 + st.list_standard_eval_dir[i] + '/heatmap') ut.make_dir(dir=list_dir_heatmap[i], flag_rm=False) """ dir to save pyplot """ dir_pyplot = st.dir_to_save_1 + '/pyplot/fold_{}'.format(fold) ut.make_dir(dir=dir_pyplot, flag_rm=False) """ dir to save MMSE dist """ dir_MMSE_dist = st.dir_to_save_1 + '/MMSE_dist' ut.make_dir(dir=dir_MMSE_dist, flag_rm=False) ##TODO : model construction print('-' * 10 + 'Model construction' + '-' * 10) model_1 = construct_model.construct_model(config, flag_model_num=0) model_1 = nn.DataParallel(model_1) if fst.flag_classification_fine_tune == True: dir_to_load = st.dir_preTrain_1 dir_load_model = dir_to_load + '/weights/fold_{}'.format(fold) model_dir = ut.model_dir_to_load(fold, dir_load_model) pretrained_dict = torch.load(model_dir) model_dict = model_1.state_dict() for k, v in pretrained_dict.items(): if k in model_dict: print(k) pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_1.load_state_dict(model_dict) elif fst.flag_classification_using_pretrained == True: model_2 = construct_model.construct_model(config, flag_model_num=1) model_2 = nn.DataParallel(model_2) dir_to_load = st.dir_preTrain_1 dir_load_model = dir_to_load + '/weights/fold_{}'.format(fold) model_dir = ut.model_dir_to_load(fold, dir_load_model) pretrained_dict = torch.load(model_dir) model_dict = model_2.state_dict() for k, v in pretrained_dict.items(): if k in model_dict: print(k) pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model_2.load_state_dict(model_dict) model_2.eval() """ optimizer """ # optimizer = torch.optim.SGD(model_1.parameters(), lr=config.lr, momentum=0.9, weight_decay=st.weight_decay) optimizer = torch.optim.Adam(model_1.parameters(), lr=config.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=st.weight_decay) # optimizer = AdamP(model_1.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=st.weight_decay) # optimizer = RAdam(model_1.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=st.weight_decay) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1) # params_dict = [] # params_dict.append({'params': model.parameters(), 'lr': config.lr}) # optimizer = ut.BayesianSGD(params=params_dict) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1) # scheduler_expo = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1) # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5 , after_scheduler=scheduler_expo) scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, st.epoch) scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=scheduler_cosine) # scheduler_cosine_restart = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=50) # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=scheduler_cosine_restart) # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=50) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.2, patience=10) """ data loader """ print('-' * 10 + 'data loader' + '-' * 10) train_loader = DL.convert_Dloader_3(fold, list_class=st.list_class_for_train, flag_tr_val_te='train', batch_size=config.batch_size, num_workers=0, shuffle=True, drop_last=True) val_loader = DL.convert_Dloader_3(fold, list_class=st.list_class_for_test, flag_tr_val_te='val', batch_size=config.batch_size, num_workers=0, shuffle=False, drop_last=False) test_loader = DL.convert_Dloader_3(fold, list_class=st.list_class_for_test, flag_tr_val_te='test', batch_size=config.batch_size, num_workers=0, shuffle=False, drop_last=False) dict_data_loader = { 'train': train_loader, 'val': val_loader, 'test': test_loader } """ normal classification tasks """ list_test_result = [] print('-' * 10 + 'start training' + '-' * 10) if fst.flag_classification == True or fst.flag_classification_fine_tune == True: train.train(config, fold, model_1, dict_data_loader, optimizer, scheduler, list_dir_save_model, dir_pyplot, Validation=True, Test_flag=True) for i_tmp in range(len(st.list_standard_eval_dir)): dict_test_output = test.test(config, fold, model_1, dict_data_loader['test'], list_dir_save_model[i_tmp], list_dir_confusion[i_tmp]) list_test_result.append(dict_test_output) # if len(st.list_selected_for_train) == 2 and fold == 1 and st.list_standard_eval_dir[i_tmp] == '/val_auc': # generate_heatmap.get_multi_heatmap_2class(config, fold, model, list_dir_save_model[i_tmp], list_dir_heatmap[i_tmp]) elif fst.flag_classification_using_pretrained == True: """ using pretrained patch level model """ train_using_pretrained.train(config, fold, model_1, model_2, dict_data_loader, optimizer, scheduler, list_dir_save_model, dir_pyplot, Validation=True, Test_flag=True) for i_tmp in range(len(st.list_standard_eval_dir)): dict_test_output = test_using_pretrained.test( config, fold, model_1, model_2, dict_data_loader['test'], list_dir_save_model[i_tmp], list_dir_confusion[i_tmp]) list_test_result.append(dict_test_output) elif fst.flag_multi_task == True: train_multi_task.train(config, fold, model_1, dict_data_loader, optimizer, scheduler, list_dir_save_model, dir_pyplot, Validation=True, Test_flag=True) for i_tmp in range(len(st.list_standard_eval_dir)): dict_test_output = test_multi_task.test( config, fold, model_1, dict_data_loader['test'], list_dir_save_model[i_tmp], list_dir_confusion[i_tmp]) list_test_result.append(dict_test_output) """ fill out the results on the excel sheet """ for i_standard in range(len(st.list_standard_eval_dir)): for i in range(len(list_eval_metric)): if list_eval_metric[i] in list_test_result[i_standard]: list_ws[i_standard].cell( row=2 + i + st.push_start_row, column=fold + 1, value="%.4f" % (list_test_result[i_standard][list_eval_metric[i]])) metric_avg[i_standard][i].append( list_test_result[i_standard][list_eval_metric[i]]) for i in range(len(list_eval_metric)): if metric_avg[i_standard][i]: avg = round(np.mean(metric_avg[i_standard][i]), 4) std = round(np.std(metric_avg[i_standard][i]), 4) tmp = "%.4f \u00B1 %.4f" % (avg, std) list_ws[i_standard].cell(row=2 + st.push_start_row + i, column=end_fold + 2, value=tmp) list_wb[i_standard].save(list_dir_result[i_standard] + "/results.xlsx") for i_standard in range(len(st.list_standard_eval_dir)): n_row = list_ws[i_standard].max_row n_col = list_ws[i_standard].max_column for i_row in range(1, n_row + 1): for i_col in range(1, n_col + 1): ca1 = list_ws[i_standard].cell(row=i_row, column=i_col) ca1.alignment = Alignment(horizontal='center', vertical='center') list_wb[i_standard].save(list_dir_result[i_standard] + "/results.xlsx") list_wb[i_standard].close() print("finished")
def test(config, fold, model, dir_to_load, dir_age_pred): """ free all GPU memory """ torch.cuda.empty_cache() """ load the fold list for test """ list_train_data = DL.concat_class_of_interest( config, fold, list_class=st.list_class_for_total, flag_tr_val_te='train') list_val_data = DL.concat_class_of_interest( config, fold, list_class=st.list_class_for_total, flag_tr_val_te='val') list_test_data = DL.concat_class_of_interest( config, fold, list_class=st.list_class_for_total, flag_tr_val_te='test') train_loader = DL.convert_Dloader_3(config.v_batch_size, list_train_data[0], list_train_data[1], list_train_data[2], list_train_data[3], is_training=False, num_workers=1, shuffle=False) val_loader = DL.convert_Dloader_3(config.v_batch_size, list_val_data[0], list_val_data[1], list_val_data[2], list_val_data[3], is_training=False, num_workers=1, shuffle=False) test_loader = DL.convert_Dloader_3(config.v_batch_size, list_test_data[0], list_test_data[1], list_test_data[2], list_test_data[3], is_training=False, num_workers=1, shuffle=False) del list_train_data, list_val_data, list_test_data """ load the model """ model_dir = ut.model_dir_to_load(fold, dir_to_load) model.load_state_dict(torch.load(model_dir)) model.eval() fig = plt.figure(figsize=(len(st.list_selected_for_total) * 12, 25)) plt.rcParams.update({'font.size': 22}) if fst.flag_estimate_age == True: fig.suptitle( 'Comparing between labeled and predicted ages in fold{0} ({1})'. format(fold, st.list_age_estimating_function[st.selected_function]), fontsize=50) else: fig.suptitle( 'Comparing between labeled and predicted ages in fold{0}'.format( fold), fontsize=50) # plt.xticks([]) heights = [10, 2, 10, 2, 10, 2] widths = [] for i_tmp in range(len(st.list_selected_for_total)): widths.append(10) widths.append(3) gs = gridspec.GridSpec( nrows=6, # row ncols=len(st.list_selected_for_total) * 2, # col height_ratios=heights, width_ratios=widths) age_left = 50 age_right = 110 pred_left = 50 pred_right = 110 gap_1 = 4 gap_2 = 10 text_fontsize = 15 """ eval """ list_loader = ['train', 'val', 'test'] for i_loader, dataset in enumerate(list_loader): """ param for accuracy """ if dataset == 'train': loader = train_loader elif dataset == 'val': loader = val_loader elif dataset == 'test': loader = test_loader """ param for accuracy """ list_age = [] list_lbl = [] list_pred = [] with torch.no_grad(): for datas, labels, alabel, mlabel in loader: """ input""" datas = Variable(datas).cuda() labels = Variable(labels.long()).cuda() alabel = Variable(alabel.float()).cuda() """ forward propagation """ dict_result = model(datas) pred_age = dict_result['preds'] """ appending to the list """ list_pred.append(pred_age.data.cpu().numpy().squeeze()) list_lbl.append(labels.data.cpu().numpy().squeeze()) list_age.append(alabel.data.cpu().numpy().squeeze()) np_age = np.vstack(list_age).squeeze() np_lbl = np.vstack(list_lbl).squeeze() np_pred = np.vstack(list_pred).squeeze() for j_disease in range(len(st.list_selected_for_total)): row = i_loader * 2 col = j_disease * 2 ax1 = fig.add_subplot(gs[row, col]) ax1.scatter(np_age[(np_lbl == j_disease)], np_pred[(np_lbl == j_disease)]) ax1.plot(range(age_left, age_right), range(age_left, age_right)) ax1.set_title('{} {}'.format( dataset, st.list_selected_for_total[j_disease]), fontsize=25) # title of plot ax1.set_xlim([age_left, age_right]) ax1.set_ylim([pred_left, pred_right]) ax1.grid(True) ax1.set_ylabel('predicted age') ax1.set_xlabel('labeled age') # if col == 0: # ax1.set_ylabel('Labeled MMSE') # else: # ax1.set_yticks([]) # # if row == 2: # ax1.set_xlabel('Labeled age') # else: # ax1.set_xticks([]) ax1.text(age_right + 1, pred_right, 'labeled age', fontsize=text_fontsize + 5) ax1.text(age_right + 1, pred_right - (1 * gap_1), 'min: {:.2f}'.format(np_age[(np_lbl == j_disease)].min()), fontsize=text_fontsize) ax1.text(age_right + 1, pred_right - (2 * gap_1), 'max: {:.2f}'.format(np_age[(np_lbl == j_disease)].max()), fontsize=text_fontsize) ax1.text(age_right + 1, pred_right - (3 * gap_1), 'mean: {:.2f}'.format( np_age[(np_lbl == j_disease)].mean()), fontsize=text_fontsize) ax1.text(age_right + 1, pred_right - (4 * gap_1), 'std: {:.2f}'.format(np_age[(np_lbl == j_disease)].std()), fontsize=text_fontsize) ax1.text(age_right + 1, pred_right - (4 * gap_1) - (1 * gap_1) - gap_2, 'pred age', fontsize=text_fontsize + 5) ax1.text(age_right + 1, pred_right - (4 * gap_1) - (2 * gap_1) - gap_2, 'min: {:.2f}'.format( np_pred[(np_lbl == j_disease)].min()), fontsize=text_fontsize) ax1.text(age_right + 1, pred_right - (4 * gap_1) - (3 * gap_1) - gap_2, 'max: {:.2f}'.format( np_pred[(np_lbl == j_disease)].max()), fontsize=text_fontsize) ax1.text(age_right + 1, pred_right - (4 * gap_1) - (4 * gap_1) - gap_2, 'mean: {:.2f}'.format( np_pred[(np_lbl == j_disease)].mean()), fontsize=text_fontsize) ax1.text(age_right + 1, pred_right - (4 * gap_1) - (5 * gap_1) - gap_2, 'std: {:.2f}'.format( np_pred[(np_lbl == j_disease)].std()), fontsize=text_fontsize) """ save the figure """ plt.savefig(dir_age_pred + '/fold{}_age_prediction.png'.format(fold)) """ close all plot """ plt.close('all')