Ejemplo n.º 1
0
def test(config, fold, model, loader, dir_to_load, dir_confusion):
    """ free all GPU memory """
    torch.cuda.empty_cache()
    # criterion_cls = nn.CrossEntropyLoss()
    # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha)
    criterion_cls = nn.BCELoss()
    test_loader = loader
    """ load the model """
    model_dir = ut.model_dir_to_load(fold, dir_to_load)
    if model_dir != None:
        model.load_state_dict(torch.load(model_dir))
    model.eval()

    if fst.flag_eval_cropping == True:
        dict_result = ut.eval_classification_model_cropped_input(
            config, fold, test_loader, model, criterion_cls)
    elif fst.flag_eval_translation == True:
        dict_result = ut.eval_classification_model_esemble(
            config, fold, test_loader, model, criterion_cls)
    elif fst.flag_MC_dropout == True:
        dict_result = ut.eval_classification_model_MC_dropout(
            config, fold, test_loader, model, criterion_cls)
    elif fst.flag_bayesian == True:
        dict_result = ut.eval_classification_model_bayesian(
            config, fold, test_loader, model, criterion_cls)
    else:
        dict_result = ut.eval_classification_model(config,
                                                   fold,
                                                   test_loader,
                                                   model,
                                                   criterion_cls,
                                                   flag_heatmap=False)

    return dict_result
Ejemplo n.º 2
0
def test(config, fold, model, loader, dir_to_load, dir_confusion):
    """ free all GPU memory """
    torch.cuda.empty_cache()
    # criterion_cls = nn.CrossEntropyLoss()
    # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha)
    criterion_cls = nn.BCELoss()
    criterion_L1 = nn.L1Loss(reduction='sum').cuda()
    test_loader = loader
    """ load the model """
    model_dir = ut.model_dir_to_load(fold, dir_to_load)
    if model_dir != None:
        model.load_state_dict(torch.load(model_dir))
    model.eval()

    dict_result = ut.eval_multi_task_model(config, fold, test_loader, model,
                                           criterion_cls, criterion_L1)

    return dict_result
def test(config, fold, model_1, model_2, loader, dir_to_load, dir_confusion):
    """ free all GPU memory """
    torch.cuda.empty_cache()
    # criterion_cls = nn.CrossEntropyLoss()
    # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha)
    criterion_cls = nn.BCELoss()
    # criterion = nn.MSELoss(reduction='mean').cuda()
    test_loader = loader
    """ load the model """
    model_dir = ut.model_dir_to_load(fold, dir_to_load)
    if model_dir != None:
        model_1.load_state_dict(torch.load(model_dir))
    model_1.eval()

    dict_result = ut.eval_classification_model_2(config, fold, test_loader,
                                                 model_1, model_2,
                                                 criterion_cls)

    return dict_result
Ejemplo n.º 4
0
def get_heatmap_1class(config, fold, model, dir_to_load, dir_heatmap):
    """ free all GPU memory """
    torch.cuda.empty_cache()
    """ loss """
    criterion = nn.L1Loss()
    """ load the fold list for test """
    list_test_data = DL.concat_class_of_interest(
        config, fold, list_class=st.list_class_for_test, flag_tr_val_te='test')
    test_loader = DL.convert_Dloader_3(config.v_batch_size,
                                       list_test_data[0],
                                       list_test_data[1],
                                       list_test_data[2],
                                       list_test_data[3],
                                       is_training=False,
                                       num_workers=0,
                                       shuffle=False)
    """ load the model """
    model_dir = ut.model_dir_to_load(fold, dir_to_load)
    model.load_state_dict(torch.load(model_dir))
    model.eval()
    """ param for accuracy """
    test_batchnum = 0
    """ eval """
    patch_size = st.patch_size
    stride_between_patches = st.patch_stride
    count = 0
    with torch.no_grad():
        for datas, labels, alabels, mlabel in test_loader:
            count += 1
            if count < 3:
                datas_backup = datas
                """ padding """
                pred_Map = np.zeros_like(datas)
                """ padding """
                m = nn.ConstantPad3d(patch_size // 2, 0)
                datas = m(datas)
                """ loop as much as the size of strides """
                for i in range(stride_between_patches):
                    print("i : {0}".format(i))
                    for j in range(stride_between_patches):
                        # print("j : {0}".format(j))
                        for k in range(stride_between_patches):
                            test_batchnum += 1
                            """ input"""
                            data = Variable(datas[:, :, i:, j:, k:]).cuda()
                            labels = Variable(labels.long()).cuda()
                            alabels = Variable(alabels.float()).cuda()
                            """ run classification model """
                            dict_result = model(data)
                            predMap = dict_result['predMap']
                            shape = predMap.shape
                            for a in range(shape[-3]):
                                for b in range(shape[-2]):
                                    for c in range(shape[-1]):
                                        if predMap is not None:
                                            pred_Map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                                predMap[:, 0, a, b, c].data.cpu().numpy()
                            # logit_list.append(test_output_1.cpu().numpy())

                torch.cuda.empty_cache()
                print("finished a sample!")
                for sample in range(pred_Map.shape[0]):
                    tmp_save_dir = dir_heatmap + '/fold_{0}'.format(fold)
                    ut.make_dir(dir=tmp_save_dir, flag_rm=False)
                    ut.save_featureMap_tensor(datas_backup[sample][0],
                                              tmp_save_dir,
                                              'input_{}'.format(count))

                    if predMap is not None:
                        ut.save_featureMap_numpy(pred_Map[sample][0],
                                                 tmp_save_dir,
                                                 'pred_map_{}'.format(count))
                        orig_img = datas_backup[sample][0]
                        heatmap_img = pred_Map[sample][0]
                        ut.plot_heatmap_with_overlay(
                            orig_img=orig_img,
                            heatmap_img=heatmap_img,
                            save_dir=tmp_save_dir + '/1_logit_map',
                            fig_title='Original Logit Map',
                            thresh=0.2,
                            percentile=1)
Ejemplo n.º 5
0
def get_multi_heatmap_2class(config, fold, model, dir_to_load, dir_heatmap):
    """ free all GPU memory """
    torch.cuda.empty_cache()
    """ loss """
    criterion = nn.L1Loss()
    """ load the fold list for test """
    list_test_data = DL.concat_class_of_interest(
        config, fold, list_class=st.list_class_for_test, flag_tr_val_te='test')
    test_loader = DL.convert_Dloader_3(config.v_batch_size,
                                       list_test_data[0],
                                       list_test_data[1],
                                       list_test_data[2],
                                       list_test_data[3],
                                       is_training=False,
                                       num_workers=0,
                                       shuffle=False)
    """ load the model """
    model_dir = ut.model_dir_to_load(fold, dir_to_load)
    if model_dir != None:
        model.load_state_dict(torch.load(model_dir))
    model.eval()
    """ param for accuracy """
    test_batchnum = 0
    """ eval """
    patch_size_1 = 9
    patch_size_2 = 17
    patch_size_3 = 33
    stride_between_patches = st.patch_stride
    count = 0

    with torch.no_grad():
        for datas, labels, alabel, mlabel in test_loader:
            datas_backup = datas
            count += 1
            datas_pred = Variable(datas).cuda()
            dict_result = model(datas_pred)
            output_logit = dict_result['logits']
            prob = nn.Softmax(dim=1)(output_logit)
            pred = prob.argmax(dim=1, keepdim=True)  # batch, 1
            """ padding """
            logit_map = np.zeros_like(datas)
            attn_1_map = np.zeros_like(datas)
            attn_2_map = np.zeros_like(datas)
            attn_3_map = np.zeros_like(datas)
            final_evidence_map_1 = np.zeros_like(datas)
            final_evidence_map_2 = np.zeros_like(datas)
            final_evidence_map_3 = np.zeros_like(datas)
            """ padding """
            m_1 = nn.ConstantPad3d(patch_size_1 // 2, 0)
            datas_1 = m_1(datas)
            m_2 = nn.ConstantPad3d(patch_size_2 // 2, 0)
            datas_2 = m_2(datas)
            m_3 = nn.ConstantPad3d(patch_size_3 // 2, 0)
            datas_3 = m_3(datas)
            """ loop as much as the size of strides """
            for i in range(stride_between_patches):
                print("i : {0}".format(i))
                for j in range(stride_between_patches):
                    # print("j : {0}".format(j))
                    for k in range(stride_between_patches):
                        test_batchnum += 1
                        """ input"""
                        data_1 = Variable(datas_1[:, :, i:, j:, k:]).cuda()
                        dict_result = model(data_1)
                        attn_1 = dict_result['attn_1']  # 1, 1, 25, 30, 24
                        final_evidence_a = dict_result['final_evidence_a']

                        data_2 = Variable(datas_2[:, :, i:, j:, k:]).cuda()
                        dict_result = model(data_2)
                        attn_2 = dict_result['attn_2']  # 1, 1, 24, 29, 23
                        final_evidence_b = dict_result['final_evidence_b']

                        data_3 = Variable(datas_3[:, :, i:, j:, k:]).cuda()
                        dict_result = model(data_3)
                        attn_3 = dict_result['attn_3']  # 1, 1, 22, 27, 21
                        final_evidence_c = dict_result['final_evidence_c']

                        shape = final_evidence_a.shape
                        for a in range(shape[-3]):
                            for b in range(shape[-2]):
                                for c in range(shape[-1]):
                                    if attn_1 is not None:
                                        tmp_index_2 = [a, b, c]
                                        attn_1_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            attn_1[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy()

                                    if final_evidence_a is not None:
                                        tmp_index_2 = [a, b, c]
                                        final_evidence_map_1[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            final_evidence_a[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy()

                        shape = final_evidence_b.shape
                        for a in range(shape[-3]):
                            for b in range(shape[-2]):
                                for c in range(shape[-1]):
                                    if attn_2 is not None:
                                        tmp_index_2 = [a, b, c]
                                        attn_2_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            attn_2[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy()

                                    if final_evidence_b is not None:
                                        tmp_index_2 = [a, b, c]
                                        final_evidence_map_2[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            final_evidence_b[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy()

                        shape = final_evidence_c.shape
                        for a in range(shape[-3]):
                            for b in range(shape[-2]):
                                for c in range(shape[-1]):
                                    if attn_3 is not None:
                                        tmp_index_2 = [a, b, c]
                                        attn_3_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            attn_3[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy()
                                    if final_evidence_c is not None:
                                        tmp_index_2 = [a, b, c]
                                        final_evidence_map_3[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            final_evidence_c[:, 0, tmp_index_2[0], tmp_index_2[1], tmp_index_2[2]].data.cpu().numpy()

            torch.cuda.empty_cache()
            print("finished a sample!")
            for sample in range(logit_map.shape[0]):
                tmp_save_dir = dir_heatmap + '/fold_{0}'.format(fold)
                ut.make_dir(dir=tmp_save_dir, flag_rm=False)
                ut.save_featureMap_tensor(
                    datas_backup[sample][0], tmp_save_dir,
                    'input_{0}_gt_{1}_pred_{2}'.format(
                        count, st.list_selected_for_test[
                            labels[sample].data.cpu().numpy()],
                        st.list_selected_for_test[
                            pred[sample].data.cpu().numpy()[0]]))

                if attn_1 is not None:
                    ut.save_featureMap_numpy(
                        attn_1_map[sample][0], tmp_save_dir,
                        'attn_1_map_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = attn_1_map[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/2_attn_map_1_{}'.format(count),
                        fig_title='Attention Map 1',
                        thresh=0.2,
                        percentile=1)

                if attn_2 is not None:
                    ut.save_featureMap_numpy(
                        attn_2_map[sample][0], tmp_save_dir,
                        'attn_2_map_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = attn_2_map[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/2_attn_map_2_{}'.format(count),
                        fig_title='Attention Map 2',
                        thresh=0.2,
                        percentile=1)

                if attn_3 is not None:
                    ut.save_featureMap_numpy(
                        attn_3_map[sample][0], tmp_save_dir,
                        'attn_3_map_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = attn_3_map[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/2_attn_map_3_{}'.format(count),
                        fig_title='Attention Map 3',
                        thresh=0.2,
                        percentile=1)

                if final_evidence_a is not None:
                    ut.save_featureMap_numpy(
                        final_evidence_map_1[sample][0], tmp_save_dir,
                        'final_evidence_map_1_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = final_evidence_map_1[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/3_final_evidence_1_{}'.format(count),
                        fig_title='Final Evidence 1',
                        thresh=0.2,
                        percentile=1)

                if final_evidence_b is not None:
                    ut.save_featureMap_numpy(
                        final_evidence_map_2[sample][0], tmp_save_dir,
                        'final_evidence_map_2_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = final_evidence_map_2[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/3_final_evidence_2_{}'.format(count),
                        fig_title='Final Evidence 2',
                        thresh=0.2,
                        percentile=1)

                if final_evidence_c is not None:
                    ut.save_featureMap_numpy(
                        final_evidence_map_3[sample][0], tmp_save_dir,
                        'final_evidence_map_3_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = final_evidence_map_3[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/3_final_evidence_3_{}'.format(count),
                        fig_title='Final Evidence 3',
                        thresh=0.2,
                        percentile=1)
Ejemplo n.º 6
0
def get_heatmap_2class(config, fold, model, dir_to_load, dir_heatmap):
    """ free all GPU memory """
    torch.cuda.empty_cache()
    """ loss """
    criterion = nn.L1Loss()
    """ load the fold list for test """
    list_test_data = DL.concat_class_of_interest(
        config, fold, list_class=st.list_class_for_test, flag_tr_val_te='test')
    test_loader = DL.convert_Dloader_3(config.v_batch_size,
                                       list_test_data[0],
                                       list_test_data[1],
                                       list_test_data[2],
                                       list_test_data[3],
                                       is_training=False,
                                       num_workers=0,
                                       shuffle=False)
    """ load the model """
    model_dir = ut.model_dir_to_load(fold, dir_to_load)
    if model_dir != None:
        model.load_state_dict(torch.load(model_dir))
    model.eval()
    """ param for accuracy """
    test_batchnum = 0
    """ eval """
    patch_size = st.patch_size
    stride_between_patches = st.patch_stride
    count = 0

    with torch.no_grad():
        for datas, labels, alabel, mlabel in test_loader:
            datas_backup = datas
            count += 1
            """ prediction """
            datas_pred = Variable(datas).cuda()
            dict_result = model(datas_pred)
            output_logit = dict_result['logits']
            prob = nn.Softmax(dim=1)(output_logit)
            pred = prob.argmax(dim=1, keepdim=True)  # batch, 1
            """ padding """
            logit_map = np.zeros_like(datas)
            attn_1_map = np.zeros_like(datas)
            attn_2_map = np.zeros_like(datas)
            final_evidence_map = np.zeros_like(datas)
            """ padding """
            m = nn.ConstantPad3d(patch_size // 2, 0)
            datas = m(datas)
            """ loop as much as the size of strides """
            for i in range(stride_between_patches):
                print("i : {0}".format(i))
                for j in range(stride_between_patches):
                    # print("j : {0}".format(j))
                    for k in range(stride_between_patches):
                        """ input"""
                        data = Variable(datas[:, :, i:, j:, k:]).cuda()
                        labels = Variable(labels.long()).cuda()
                        """ run classification model """
                        dict_result = model(data)
                        logitMap = dict_result['logitMap']
                        attn_1 = dict_result['attn_1']
                        attn_2 = dict_result['attn_2']
                        final_evidence = dict_result['final_evidence']

                        shape = logitMap.shape
                        for a in range(shape[-3]):
                            for b in range(shape[-2]):
                                for c in range(shape[-1]):
                                    if logitMap is not None:
                                        logit_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            logitMap[:, pred[0].data.cpu().numpy()[0], a, b, c].data.cpu().numpy()

                                    if attn_1 is not None:
                                        attn_1_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            attn_1[:, 0, a, b, c].data.cpu().numpy()
                                    if attn_2 is not None:
                                        attn_2_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            attn_2[:, 0, a, b, c].data.cpu().numpy()
                                    if final_evidence is not None:
                                        final_evidence_map[:, 0, a * stride_between_patches + i, b * stride_between_patches + j, c * stride_between_patches + k] = \
                                            final_evidence[:, pred[0].data.cpu().numpy()[0], a, b, c].data.cpu().numpy()
                            # logit_list.append(test_output_1.cpu().numpy())

            torch.cuda.empty_cache()
            print("finished a sample!")
            for sample in range(logit_map.shape[0]):
                tmp_save_dir = dir_heatmap + '/fold_{0}'.format(fold, )
                ut.make_dir(dir=tmp_save_dir, flag_rm=False)
                ut.save_featureMap_tensor(
                    datas_backup[sample][0], tmp_save_dir,
                    'input_{0}_gt_{1}_pred_{2}'.format(
                        count, st.list_selected_for_test[
                            labels[sample].data.cpu().numpy()],
                        st.list_selected_for_test[
                            pred[sample].data.cpu().numpy()[0]]))

                if logitMap is not None:
                    ut.save_featureMap_numpy(
                        logit_map[sample][0], tmp_save_dir,
                        'logit_map_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = logit_map[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/1_logit_map_{}'.format(count),
                        fig_title='Original Logit Map',
                        thresh=0.2,
                        percentile=1)

                if attn_1 is not None:
                    ut.save_featureMap_numpy(
                        attn_1_map[sample][0], tmp_save_dir,
                        'attn_1_map_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = attn_1_map[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/2_attn_map_1_{}'.format(count),
                        fig_title='Attention Map 1',
                        thresh=0.2,
                        percentile=1)

                if attn_2 is not None:
                    ut.save_featureMap_numpy(
                        attn_2_map[sample][0], tmp_save_dir,
                        'attn_2_map_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = attn_2_map[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/2_attn_map_2_{}'.format(count),
                        fig_title='Attention Map 2',
                        thresh=0.2,
                        percentile=1)

                if final_evidence is not None:
                    ut.save_featureMap_numpy(
                        final_evidence_map[sample][0], tmp_save_dir,
                        'final_evidence_map_{0}_gt_{1}_pred_{2}'.format(
                            count, st.list_selected_for_test[
                                labels[sample].data.cpu().numpy()],
                            st.list_selected_for_test[
                                pred[sample].data.cpu().numpy()[0]]))
                    orig_img = datas_backup[sample][0]
                    heatmap_img = final_evidence_map[sample][0]
                    ut.plot_heatmap_with_overlay(
                        orig_img=orig_img,
                        heatmap_img=heatmap_img,
                        save_dir=tmp_save_dir +
                        '/3_final_evidence_{}'.format(count),
                        fig_title='Final Evidence',
                        thresh=0.2,
                        percentile=1)
Ejemplo n.º 7
0
def main(config):
    """ 1. data process """
    if fst.flag_orig_npy == True:
        print('preparation of the numpy')
        if os.path.exists(st.orig_npy_dir) == False:
            os.makedirs(st.orig_npy_dir)
        """ processing """
        if st.list_data_type[st.data_type_num] == 'Density':
            cDL.Prepare_data_GM_AGE_MMSE()
        elif st.list_data_type[st.data_type_num] == 'ADNI_JSY':
            jDL.Prepare_data_1()
        elif st.list_data_type[st.data_type_num] == 'ADNI_Jacob_256':
            jcDL.Prepare_data_GM_WM_CSF()
        elif 'ADNI_Jacob' in st.list_data_type[st.data_type_num]:
            jcDL.Prepare_data_GM()
        elif 'ADNI_AAL_256' in st.list_data_type[st.data_type_num]:
            aDL.Prepare_data_GM()

    if fst.flag_orig_npy_other_dataset == True:
        cDL.Prepare_data_GM_age_others(dataset='ABIDE')
        cDL.Prepare_data_GM_age_others(dataset='ICBM')
        cDL.Prepare_data_GM_age_others(dataset='Cam')
        cDL.Prepare_data_GM_age_others(dataset='IXI')
        cDL.Prepare_data_GM_age_others(dataset='PPMI')
    """ 2. fold index processing """
    if fst.flag_fold_index == True:
        print('preparation of the fold index')
        if os.path.exists(st.fold_index_dir) == False:
            os.makedirs(st.fold_index_dir)
        """ save the fold index """
        ut.preparation_fold_index(config)
    """ fold selection """
    start_fold = st.start_fold
    end_fold = st.end_fold
    """ workbook """
    list_dir_result = []
    list_wb = []
    list_ws = []
    for i in range(len(st.list_standard_eval_dir)):
        list_dir_result.append(st.dir_to_save_1 + st.list_standard_eval_dir[i])
        ut.make_dir(dir=list_dir_result[i], flag_rm=False)
        out = ut.excel_setting(start_fold=start_fold,
                               end_fold=end_fold,
                               result_dir=list_dir_result[i],
                               f_name='results')
        list_wb.append(out[0])
        list_ws.append(out[1])
    """ fold """
    list_eval_metric = st.list_eval_metric
    metric_avg = [[[] for j in range(len(st.list_eval_metric))]
                  for i in range(len(st.list_standard_eval_dir))]
    for fold in range(start_fold, end_fold + 1):
        print("FOLD : {}".format(fold))

        ## TODO : Directory preparation
        print('-' * 10 + 'Directory preparation' + '-' * 10)
        list_dir_save_model = []
        list_dir_save_model_2 = []
        list_dir_confusion = []
        list_dir_age_pred = []
        list_dir_heatmap = []
        for i in range(len(st.list_standard_eval_dir)):
            """ dir to save model """
            list_dir_save_model.append(st.dir_to_save_1 +
                                       st.list_standard_eval_dir[i] +
                                       '/weights/fold_{}'.format(fold))
            ut.make_dir(dir=list_dir_save_model[i], flag_rm=False)

            list_dir_save_model_2.append(st.dir_to_save_1 +
                                         st.list_standard_eval_dir[i] +
                                         '/weights_2/fold_{}'.format(fold))
            ut.make_dir(dir=list_dir_save_model_2[i], flag_rm=False)
            """ dir to save confusion matrix  """
            list_dir_confusion.append(st.dir_to_save_1 +
                                      st.list_standard_eval_dir[i] +
                                      '/confusion')
            ut.make_dir(dir=list_dir_confusion[i], flag_rm=False)
            """ dir to save age pred """
            list_dir_age_pred.append(st.dir_to_save_1 +
                                     st.list_standard_eval_dir[i] +
                                     '/age_pred')
            ut.make_dir(dir=list_dir_age_pred[i], flag_rm=False)

            list_dir_heatmap.append(st.dir_to_save_1 +
                                    st.list_standard_eval_dir[i] + '/heatmap')
            ut.make_dir(dir=list_dir_heatmap[i], flag_rm=False)
        """ dir to save pyplot """
        dir_pyplot = st.dir_to_save_1 + '/pyplot/fold_{}'.format(fold)
        ut.make_dir(dir=dir_pyplot, flag_rm=False)
        """ dir to save MMSE dist """
        dir_MMSE_dist = st.dir_to_save_1 + '/MMSE_dist'
        ut.make_dir(dir=dir_MMSE_dist, flag_rm=False)

        ##TODO : model construction
        print('-' * 10 + 'Model construction' + '-' * 10)
        model_1 = construct_model.construct_model(config, flag_model_num=0)
        model_1 = nn.DataParallel(model_1)
        if fst.flag_classification_fine_tune == True:
            dir_to_load = st.dir_preTrain_1
            dir_load_model = dir_to_load + '/weights/fold_{}'.format(fold)
            model_dir = ut.model_dir_to_load(fold, dir_load_model)

            pretrained_dict = torch.load(model_dir)
            model_dict = model_1.state_dict()
            for k, v in pretrained_dict.items():
                if k in model_dict:
                    print(k)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            model_1.load_state_dict(model_dict)

        elif fst.flag_classification_using_pretrained == True:
            model_2 = construct_model.construct_model(config, flag_model_num=1)
            model_2 = nn.DataParallel(model_2)
            dir_to_load = st.dir_preTrain_1
            dir_load_model = dir_to_load + '/weights/fold_{}'.format(fold)
            model_dir = ut.model_dir_to_load(fold, dir_load_model)
            pretrained_dict = torch.load(model_dir)
            model_dict = model_2.state_dict()
            for k, v in pretrained_dict.items():
                if k in model_dict:
                    print(k)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            model_2.load_state_dict(model_dict)
            model_2.eval()
        """ optimizer """
        # optimizer = torch.optim.SGD(model_1.parameters(), lr=config.lr, momentum=0.9, weight_decay=st.weight_decay)

        optimizer = torch.optim.Adam(model_1.parameters(),
                                     lr=config.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-8,
                                     weight_decay=st.weight_decay)
        # optimizer = AdamP(model_1.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=st.weight_decay)
        # optimizer = RAdam(model_1.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=st.weight_decay)

        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1)

        # params_dict = []
        # params_dict.append({'params': model.parameters(), 'lr': config.lr})
        # optimizer = ut.BayesianSGD(params=params_dict)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1)

        # scheduler_expo = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1)
        # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5 , after_scheduler=scheduler_expo)

        scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, st.epoch)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=1,
                                           total_epoch=5,
                                           after_scheduler=scheduler_cosine)

        # scheduler_cosine_restart = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=50)
        # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=scheduler_cosine_restart)

        # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=50)

        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.2, patience=10)
        """ data loader """
        print('-' * 10 + 'data loader' + '-' * 10)
        train_loader = DL.convert_Dloader_3(fold,
                                            list_class=st.list_class_for_train,
                                            flag_tr_val_te='train',
                                            batch_size=config.batch_size,
                                            num_workers=0,
                                            shuffle=True,
                                            drop_last=True)
        val_loader = DL.convert_Dloader_3(fold,
                                          list_class=st.list_class_for_test,
                                          flag_tr_val_te='val',
                                          batch_size=config.batch_size,
                                          num_workers=0,
                                          shuffle=False,
                                          drop_last=False)
        test_loader = DL.convert_Dloader_3(fold,
                                           list_class=st.list_class_for_test,
                                           flag_tr_val_te='test',
                                           batch_size=config.batch_size,
                                           num_workers=0,
                                           shuffle=False,
                                           drop_last=False)

        dict_data_loader = {
            'train': train_loader,
            'val': val_loader,
            'test': test_loader
        }
        """ normal classification tasks """
        list_test_result = []
        print('-' * 10 + 'start training' + '-' * 10)
        if fst.flag_classification == True or fst.flag_classification_fine_tune == True:
            train.train(config,
                        fold,
                        model_1,
                        dict_data_loader,
                        optimizer,
                        scheduler,
                        list_dir_save_model,
                        dir_pyplot,
                        Validation=True,
                        Test_flag=True)
            for i_tmp in range(len(st.list_standard_eval_dir)):
                dict_test_output = test.test(config, fold, model_1,
                                             dict_data_loader['test'],
                                             list_dir_save_model[i_tmp],
                                             list_dir_confusion[i_tmp])
                list_test_result.append(dict_test_output)
                # if len(st.list_selected_for_train) == 2 and fold == 1 and st.list_standard_eval_dir[i_tmp] == '/val_auc':
                #     generate_heatmap.get_multi_heatmap_2class(config, fold, model, list_dir_save_model[i_tmp], list_dir_heatmap[i_tmp])

        elif fst.flag_classification_using_pretrained == True:
            """ using pretrained patch level model """
            train_using_pretrained.train(config,
                                         fold,
                                         model_1,
                                         model_2,
                                         dict_data_loader,
                                         optimizer,
                                         scheduler,
                                         list_dir_save_model,
                                         dir_pyplot,
                                         Validation=True,
                                         Test_flag=True)
            for i_tmp in range(len(st.list_standard_eval_dir)):
                dict_test_output = test_using_pretrained.test(
                    config, fold, model_1, model_2, dict_data_loader['test'],
                    list_dir_save_model[i_tmp], list_dir_confusion[i_tmp])
                list_test_result.append(dict_test_output)

        elif fst.flag_multi_task == True:
            train_multi_task.train(config,
                                   fold,
                                   model_1,
                                   dict_data_loader,
                                   optimizer,
                                   scheduler,
                                   list_dir_save_model,
                                   dir_pyplot,
                                   Validation=True,
                                   Test_flag=True)
            for i_tmp in range(len(st.list_standard_eval_dir)):
                dict_test_output = test_multi_task.test(
                    config, fold, model_1, dict_data_loader['test'],
                    list_dir_save_model[i_tmp], list_dir_confusion[i_tmp])
                list_test_result.append(dict_test_output)
        """ fill out the results on the excel sheet """
        for i_standard in range(len(st.list_standard_eval_dir)):
            for i in range(len(list_eval_metric)):
                if list_eval_metric[i] in list_test_result[i_standard]:
                    list_ws[i_standard].cell(
                        row=2 + i + st.push_start_row,
                        column=fold + 1,
                        value="%.4f" %
                        (list_test_result[i_standard][list_eval_metric[i]]))
                    metric_avg[i_standard][i].append(
                        list_test_result[i_standard][list_eval_metric[i]])

            for i in range(len(list_eval_metric)):
                if metric_avg[i_standard][i]:
                    avg = round(np.mean(metric_avg[i_standard][i]), 4)
                    std = round(np.std(metric_avg[i_standard][i]), 4)
                    tmp = "%.4f \u00B1 %.4f" % (avg, std)
                    list_ws[i_standard].cell(row=2 + st.push_start_row + i,
                                             column=end_fold + 2,
                                             value=tmp)

            list_wb[i_standard].save(list_dir_result[i_standard] +
                                     "/results.xlsx")

    for i_standard in range(len(st.list_standard_eval_dir)):
        n_row = list_ws[i_standard].max_row
        n_col = list_ws[i_standard].max_column
        for i_row in range(1, n_row + 1):
            for i_col in range(1, n_col + 1):
                ca1 = list_ws[i_standard].cell(row=i_row, column=i_col)
                ca1.alignment = Alignment(horizontal='center',
                                          vertical='center')
        list_wb[i_standard].save(list_dir_result[i_standard] + "/results.xlsx")
        list_wb[i_standard].close()

    print("finished")
def test(config, fold, model, dir_to_load, dir_age_pred):
    """ free all GPU memory """
    torch.cuda.empty_cache()
    """ load the fold list for test """
    list_train_data = DL.concat_class_of_interest(
        config,
        fold,
        list_class=st.list_class_for_total,
        flag_tr_val_te='train')
    list_val_data = DL.concat_class_of_interest(
        config, fold, list_class=st.list_class_for_total, flag_tr_val_te='val')
    list_test_data = DL.concat_class_of_interest(
        config,
        fold,
        list_class=st.list_class_for_total,
        flag_tr_val_te='test')

    train_loader = DL.convert_Dloader_3(config.v_batch_size,
                                        list_train_data[0],
                                        list_train_data[1],
                                        list_train_data[2],
                                        list_train_data[3],
                                        is_training=False,
                                        num_workers=1,
                                        shuffle=False)
    val_loader = DL.convert_Dloader_3(config.v_batch_size,
                                      list_val_data[0],
                                      list_val_data[1],
                                      list_val_data[2],
                                      list_val_data[3],
                                      is_training=False,
                                      num_workers=1,
                                      shuffle=False)
    test_loader = DL.convert_Dloader_3(config.v_batch_size,
                                       list_test_data[0],
                                       list_test_data[1],
                                       list_test_data[2],
                                       list_test_data[3],
                                       is_training=False,
                                       num_workers=1,
                                       shuffle=False)

    del list_train_data, list_val_data, list_test_data
    """ load the model """
    model_dir = ut.model_dir_to_load(fold, dir_to_load)
    model.load_state_dict(torch.load(model_dir))
    model.eval()

    fig = plt.figure(figsize=(len(st.list_selected_for_total) * 12, 25))
    plt.rcParams.update({'font.size': 22})

    if fst.flag_estimate_age == True:
        fig.suptitle(
            'Comparing between labeled and predicted ages in fold{0} ({1})'.
            format(fold,
                   st.list_age_estimating_function[st.selected_function]),
            fontsize=50)
    else:
        fig.suptitle(
            'Comparing between labeled and predicted ages in fold{0}'.format(
                fold),
            fontsize=50)

    # plt.xticks([])

    heights = [10, 2, 10, 2, 10, 2]
    widths = []
    for i_tmp in range(len(st.list_selected_for_total)):
        widths.append(10)
        widths.append(3)
    gs = gridspec.GridSpec(
        nrows=6,  # row
        ncols=len(st.list_selected_for_total) * 2,  # col
        height_ratios=heights,
        width_ratios=widths)
    age_left = 50
    age_right = 110
    pred_left = 50
    pred_right = 110
    gap_1 = 4
    gap_2 = 10
    text_fontsize = 15
    """ eval """
    list_loader = ['train', 'val', 'test']
    for i_loader, dataset in enumerate(list_loader):
        """ param for accuracy """
        if dataset == 'train':
            loader = train_loader
        elif dataset == 'val':
            loader = val_loader
        elif dataset == 'test':
            loader = test_loader
        """ param for accuracy """
        list_age = []
        list_lbl = []
        list_pred = []
        with torch.no_grad():
            for datas, labels, alabel, mlabel in loader:
                """ input"""
                datas = Variable(datas).cuda()
                labels = Variable(labels.long()).cuda()
                alabel = Variable(alabel.float()).cuda()
                """ forward propagation """
                dict_result = model(datas)
                pred_age = dict_result['preds']
                """ appending to the list """
                list_pred.append(pred_age.data.cpu().numpy().squeeze())
                list_lbl.append(labels.data.cpu().numpy().squeeze())
                list_age.append(alabel.data.cpu().numpy().squeeze())

        np_age = np.vstack(list_age).squeeze()
        np_lbl = np.vstack(list_lbl).squeeze()
        np_pred = np.vstack(list_pred).squeeze()

        for j_disease in range(len(st.list_selected_for_total)):
            row = i_loader * 2
            col = j_disease * 2
            ax1 = fig.add_subplot(gs[row, col])
            ax1.scatter(np_age[(np_lbl == j_disease)],
                        np_pred[(np_lbl == j_disease)])
            ax1.plot(range(age_left, age_right), range(age_left, age_right))
            ax1.set_title('{}  {}'.format(
                dataset, st.list_selected_for_total[j_disease]),
                          fontsize=25)  # title of plot

            ax1.set_xlim([age_left, age_right])
            ax1.set_ylim([pred_left, pred_right])
            ax1.grid(True)

            ax1.set_ylabel('predicted age')
            ax1.set_xlabel('labeled age')

            # if col == 0:
            #     ax1.set_ylabel('Labeled MMSE')
            # else:
            #     ax1.set_yticks([])
            #
            # if row == 2:
            #     ax1.set_xlabel('Labeled age')
            # else:
            #     ax1.set_xticks([])

            ax1.text(age_right + 1,
                     pred_right,
                     'labeled age',
                     fontsize=text_fontsize + 5)
            ax1.text(age_right + 1,
                     pred_right - (1 * gap_1),
                     'min: {:.2f}'.format(np_age[(np_lbl == j_disease)].min()),
                     fontsize=text_fontsize)
            ax1.text(age_right + 1,
                     pred_right - (2 * gap_1),
                     'max: {:.2f}'.format(np_age[(np_lbl == j_disease)].max()),
                     fontsize=text_fontsize)
            ax1.text(age_right + 1,
                     pred_right - (3 * gap_1),
                     'mean: {:.2f}'.format(
                         np_age[(np_lbl == j_disease)].mean()),
                     fontsize=text_fontsize)
            ax1.text(age_right + 1,
                     pred_right - (4 * gap_1),
                     'std: {:.2f}'.format(np_age[(np_lbl == j_disease)].std()),
                     fontsize=text_fontsize)

            ax1.text(age_right + 1,
                     pred_right - (4 * gap_1) - (1 * gap_1) - gap_2,
                     'pred age',
                     fontsize=text_fontsize + 5)
            ax1.text(age_right + 1,
                     pred_right - (4 * gap_1) - (2 * gap_1) - gap_2,
                     'min: {:.2f}'.format(
                         np_pred[(np_lbl == j_disease)].min()),
                     fontsize=text_fontsize)
            ax1.text(age_right + 1,
                     pred_right - (4 * gap_1) - (3 * gap_1) - gap_2,
                     'max: {:.2f}'.format(
                         np_pred[(np_lbl == j_disease)].max()),
                     fontsize=text_fontsize)
            ax1.text(age_right + 1,
                     pred_right - (4 * gap_1) - (4 * gap_1) - gap_2,
                     'mean: {:.2f}'.format(
                         np_pred[(np_lbl == j_disease)].mean()),
                     fontsize=text_fontsize)
            ax1.text(age_right + 1,
                     pred_right - (4 * gap_1) - (5 * gap_1) - gap_2,
                     'std: {:.2f}'.format(
                         np_pred[(np_lbl == j_disease)].std()),
                     fontsize=text_fontsize)
            """ save the figure """
            plt.savefig(dir_age_pred +
                        '/fold{}_age_prediction.png'.format(fold))
    """ close all plot """
    plt.close('all')