Пример #1
0
def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file, remap_config, orientation,
                        prediction_path, data_id, device=0, logWriter=None, mode='eval'):
    log.info("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")

    #batch_size = 20 #BORIS: does not fit in memory
    batch_size = 10

    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    cuda_available = torch.cuda.is_available()
    # First, are we attempting to run on a GPU?
    if type(device) == int:
        # if CUDA available, follow through, else warn and fallback to CPU
        if cuda_available:
            model = torch.load(model_path)
            torch.cuda.empty_cache()
            model.cuda(device)
        else:
            log.warning(
                'CUDA is not available, trying with CPU.' + \
                'This can take much longer (> 1 hour). Cancel and ' + \
                'investigate if this behavior is not desired.'
            )
            # switch device to 'cpu'
            device = 'cpu'
    # If device is 'cpu' or CUDA not available
    if (type(device)==str) or not cuda_available:
        model = torch.load(
            model_path, 
            map_location=torch.device(device)
        )

    model.eval()

    common_utils.create_if_not(prediction_path)
    volume_dice_score_list = []
    log.info("Evaluating now...")
    file_paths = du.load_file_paths(data_dir, label_dir, data_id, volumes_txt_file)
    with torch.no_grad():
        for vol_idx, file_path in enumerate(file_paths):
            volume, labelmap, class_weights, weights, header = du.load_and_preprocess(file_path,
                                                                                      orientation=orientation,
                                                                                      remap_config=remap_config)

            volume = volume if len(volume.shape) == 4 else volume[:, np.newaxis, :, :]
            volume, labelmap = torch.tensor(volume).type(torch.FloatTensor), torch.tensor(labelmap).type(
                torch.LongTensor)

            volume_prediction = []
            for i in range(0, len(volume), batch_size):
                batch_x, batch_y = volume[i: i + batch_size], labelmap[i:i + batch_size]
                if cuda_available and (type(device)==int):
                    batch_x = batch_x.cuda(device)
                out = model(batch_x)
                _, batch_output = torch.max(out, dim=1)
                volume_prediction.append(batch_output)

            volume_prediction = torch.cat(volume_prediction)
            volume_dice_score = dice_score_perclass(volume_prediction, labelmap.cuda(device), num_classes, mode=mode)

            volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')

            #Copy header affine
            Mat = np.array([
                header['srow_x'], 
                header['srow_y'], 
                header['srow_z'],
                [0,0,0,1]
            ])

            volume_prediction = np.squeeze(volume_prediction)

            volume_prediction = preprocessor.remap_labels_back(volume_prediction, remap_config) #BORIS

            #BORIS
            if orientation == "COR":
                volume_prediction = volume_prediction.transpose((1, 2, 0))
            elif orientation == "AXI":
                volume_prediction = volume_prediction.transpose((2, 0, 1))

            # Apply original image affine to prediction volume
            #nifti_img = nib.MGHImage(np.squeeze(volume_prediction), Mat, header=header)
            nifti_img = nib.Nifti1Image(volume_prediction, Mat, header=header)  #BORIS
            #nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.mgz'))) #BORIS
            outputfilename = os.path.join(prediction_path, os.path.basename(file_path[0]).replace(".nii", "_seg1.nii")) #BORIS
            nib.save(nifti_img, outputfilename)

            if logWriter:
                logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], vol_idx)

            volume_dice_score = volume_dice_score.cpu().numpy()
            volume_dice_score_list.append(volume_dice_score)
            log.info(volume_dice_score, np.mean(volume_dice_score))
        dice_score_arr = np.asarray(volume_dice_score_list)
        avg_dice_score = np.mean(dice_score_arr)
        log.info("Mean of dice score : " + str(avg_dice_score))
        class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        if logWriter:
            logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    log.info("DONE")

    return avg_dice_score, class_dist
Пример #2
0
def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file, remap_config, orientation,
                        prediction_path, data_id, device=0, logWriter=None, mode='eval'):
    print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")

    batch_size = 20

    with open(volumes_txt_file) as file_handle:
        volumes_to_use = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()

    common_utils.create_if_not(prediction_path)
    volume_dice_score_list = []
    print("Evaluating now...")
    file_paths = du.load_file_paths(data_dir, label_dir, data_id, volumes_txt_file)
    with torch.no_grad():
        for vol_idx, file_path in enumerate(file_paths):
            volume, labelmap, class_weights, weights, header = du.load_and_preprocess(file_path,
                                                                                      orientation=orientation,
                                                                                      remap_config=remap_config)

            volume = volume if len(volume.shape) == 4 else volume[:, np.newaxis, :, :]
            volume, labelmap = torch.tensor(volume).type(torch.FloatTensor), torch.tensor(labelmap).type(
                torch.LongTensor)

            volume_prediction = []
            for i in range(0, len(volume), batch_size):
                batch_x, batch_y = volume[i: i + batch_size], labelmap[i:i + batch_size]
                if cuda_available:
                    batch_x = batch_x.cuda(device)
                out = model(batch_x)
                _, batch_output = torch.max(out, dim=1)
                volume_prediction.append(batch_output)

            volume_prediction = torch.cat(volume_prediction)
            volume_dice_score = dice_score_perclass(volume_prediction, labelmap.cuda(device), num_classes, mode=mode)

            volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')
            nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4), header=header)
            nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.mgz')))
            if logWriter:
                logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], vol_idx)

            volume_dice_score = volume_dice_score.cpu().numpy()
            volume_dice_score_list.append(volume_dice_score)
            print(volume_dice_score, np.mean(volume_dice_score))
        dice_score_arr = np.asarray(volume_dice_score_list)
        avg_dice_score = np.mean(dice_score_arr)
        print("Mean of dice score : " + str(avg_dice_score))
        class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        if logWriter:
            logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    print("DONE")

    return avg_dice_score, class_dist
Пример #3
0
def evaluate_dice_score(model_path,
                        num_classes,
                        query_labels,
                        data_dir,
                        query_txt_file,
                        support_txt_file,
                        remap_config,
                        orientation,
                        prediction_path, device=0, logWriter=None, mode='eval', fold=None):
    print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")
    print("Loading model => " + model_path)
    batch_size = 20
    Num_support = 10

    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()

    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []
        for query_label in query_labels:
            volume_dice_score_list = []
            #
            # support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0],
            #                                                                 orientation=orientation,
            #                                                                 remap_config=remap_config)
            #
            # support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :]
            #
            # support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), torch.tensor(
            #     support_labelmap).type(torch.LongTensor)
            # support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label)
            # support_volume = support_volume[range_index[0]: range_index[1]]

            # Loading support
            support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0],
                                                                            orientation=orientation,
                                                                            remap_config=remap_config)
            support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :,
                                                                                   :]
            support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                               torch.tensor(support_labelmap).type(torch.LongTensor)

            support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label)

            support_slice_indexes = np.round(np.linspace(0, len(support_volume) - 1, Num_support + 1)).astype(int)
            support_slice_indexes += (len(support_volume) // Num_support) // 2
            support_slice_indexes = support_slice_indexes[:-1]
            # support_slice_indexes[0] += (len(support_volume) // Num_support) // 2

            # if len(support_slice_indexes) > 1:
            #     support_slice_indexes[-1] -= (len(support_volume) // Num_support) // 2

            if len(support_slice_indexes) < Num_support:
                support_slice_indexes.append(len(support_volume) - 1)

            # batch_needed = Num_support < 5

            for vol_idx, file_path in enumerate(query_file_paths):

                query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path,
                                                                            orientation=orientation,
                                                                            remap_config=remap_config)

                query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :]
                query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \
                                               torch.tensor(query_labelmap).type(torch.LongTensor)

                query_labelmap = query_labelmap == query_label
                range_query = get_range(query_labelmap)
                query_volume = query_volume[range_query[0]: range_query[1] + 1]
                query_labelmap = query_labelmap[range_query[0]: range_query[1] + 1]

                query_slice_indexes = np.round(np.linspace(0, len(query_volume) - 1, Num_support)).astype(int)
                if len(query_slice_indexes) < Num_support:
                    query_slice_indexes.append(len(query_volume) - 1)

                volume_prediction = []

                # for i in range(0, len(query_volume), batch_size):
                # support_current_slice = 0
                # query_current_slice = 0

                for i, query_start_slice in enumerate(query_slice_indexes):
                    if query_start_slice == query_slice_indexes[-1]:
                        query_batch_x = query_volume[query_slice_indexes[i]:]
                    else:
                        query_batch_x = query_volume[query_slice_indexes[i]:query_slice_indexes[i + 1]]

                    support_batch_x = support_volume[support_slice_indexes[i]]

                    # Running larger blocks in smaller batches
                    # if batch_needed:
                    volume_prediction_10 = []
                    for b in range(0, len(query_batch_x), 10):
                        query_batch_x_10 = query_batch_x[b:b + 10]
                        support_batch_x_10 = support_batch_x.repeat(len(query_batch_x_10), 1, 1, 1)
                        if cuda_available:
                            query_batch_x_10 = query_batch_x_10.cuda(device)
                            support_batch_x_10 = support_batch_x_10.cuda(device)

                        weights_10 = model.conditioner(support_batch_x_10)
                        out_10 = model.segmentor(query_batch_x_10, weights_10)

                        # For shaban et al
                        # batch_output_10 = out_10 > 0.5
                        # batch_output_10 = batch_output_10.squeeze()
                        # For others
                        _, batch_output_10 = torch.max(F.softmax(out_10, dim=1), dim=1)

                        volume_prediction_10.append(batch_output_10)
                    volume_prediction.extend(volume_prediction_10)

                    # else:
                    #     support_batch_x = support_batch_x.repeat(len(query_batch_x), 1, 1, 1)
                    #     if cuda_available:
                    #         query_batch_x = query_batch_x.cuda(device)
                    #         support_batch_x = support_batch_x.cuda(device)
                    #
                    #     weights = model.conditioner(support_batch_x)
                    #     out = model.segmentor(query_batch_x, weights)
                    #
                    #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                    #     volume_prediction.append(batch_output)

                    # query_current_slice += slice_gap_query
                    # support_current_slice += slice_gap_support

                # query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, orientation=orientation,
                #                                                             remap_config=remap_config)
                # query_labelmap = query_labelmap == query_label
                # range_query = get_range(query_labelmap)
                # query_volume = query_volume[range_query[0]: range_query[1]]
                #
                # query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :]
                # query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), torch.tensor(
                #     query_labelmap).type(torch.LongTensor)
                #
                # support_batch_x = []
                #
                # volume_prediction = []
                #
                # support_current_slice = 0
                # query_current_slice = 0
                # support_slice_left = support_volume[range_index[0]]

                # for i in range(0, range_index[0], batch_size):
                #     end_index_query = query_current_slice + batch_size
                #     end_index_query = end_index_query if end_index_query < range_index[0] else range_index[0]
                #
                #     query_batch_x = query_volume[i: end_index_query]
                #
                #     support_batch_x = support_slice_left.repeat(query_batch_x.size()[0], 1, 1, 1)
                #
                #     if cuda_available:
                #         query_batch_x = query_batch_x.cuda(device)
                #         support_batch_x = support_batch_x.cuda(device)
                #
                #     weights = model.conditioner(support_batch_x)
                #     out = model.segmentor(query_batch_x, weights)
                #
                #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #     volume_prediction.append(batch_output)
                #     query_current_slice = end_index_query
                #     support_current_slice = query_current_slice
                #
                # for i in range(range_index[0], range_index[1] + 1, batch_size):
                #     end_index_query = query_current_slice + batch_size
                #     end_index_query = end_index_query if end_index_query < range_index[1] + 1 else range_index[1] + 1
                #
                #     query_batch_x = query_volume[i: end_index_query]
                #
                #     # end_index_support = support_current_slice + batch_size
                #     # end_index_support = end_index_support if end_index_support < len(range_index[1] + 1) else len(
                #     #     range_index[1] + 1)
                #     # print(len(support_volume))
                #     # print(support_current_slice, end_index_query)
                #     support_batch_x = support_volume[support_current_slice: end_index_query]
                #
                #     query_current_slice = end_index_query
                #     support_current_slice = query_current_slice
                #
                #     support_batch_x = support_batch_x[0].repeat(query_batch_x.size()[0], 1, 1, 1)
                #
                #     # k += 1
                #     if cuda_available:
                #         query_batch_x = query_batch_x.cuda(device)
                #         support_batch_x = support_batch_x.cuda(device)
                #
                #     weights = model.conditioner(support_batch_x)
                #     out = model.segmentor(query_batch_x, weights)
                #
                #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #     volume_prediction.append(batch_output)
                #
                # support_slice_right = support_volume[range_index[1]]
                # for i in range(range_index[1] + 1, len(support_volume), batch_size):
                #     end_index_query = query_current_slice + batch_size
                #     end_index_query = end_index_query if end_index_query < len(support_volume) else len(support_volume)
                #
                #     query_batch_x = query_volume[i: end_index_query]
                #
                #     support_batch_x = support_slice_right.repeat(query_batch_x.size()[0], 1, 1, 1)
                #
                #     if cuda_available:
                #         query_batch_x = query_batch_x.cuda(device)
                #         support_batch_x = support_batch_x.cuda(device)
                #
                #     weights = model.conditioner(support_batch_x)
                #     out = model.segmentor(query_batch_x, weights)
                #
                #     _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #     volume_prediction.append(batch_output)
                #     query_current_slice = end_index_query
                #     support_current_slice = query_current_slice

                volume_prediction = torch.cat(volume_prediction)
                # volume_prediction = volume_prediction.squeeze()
                # batch, _, _ = query_labelmap.size()
                # slice_with_class = torch.sum(query_labelmap.view(batch, -1), dim=1) > 10
                # index = slice_with_class[:-1] - slice_with_class[1:] > 0
                # seq = torch.Tensor(range(batch - 1))
                # range_index_gt = seq[index].type(torch.LongTensor)

                volume_dice_score = dice_score_binary(volume_prediction[:len(query_labelmap)],
                                                      query_labelmap.cuda(device), phase=mode)

                volume_prediction = (volume_prediction.cpu().numpy()).astype('float32')
                nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz')))
                #
                # # # Save Input
                nifti_img = nib.MGHImage(np.squeeze(query_volume.cpu().numpy()), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz')))

                # # # Condition Input
                nifti_img = nib.MGHImage(np.squeeze(support_volume.cpu().numpy()), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz')))
                # Cond GT
                nifti_img = nib.MGHImage(np.squeeze(support_labelmap.cpu().numpy()).astype('float32'), np.eye(4))
                nib.save(nifti_img,
                         os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz')))

                # # # Save Ground Truth
                nifti_img = nib.MGHImage(np.squeeze(query_labelmap.cpu().numpy()), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + fold
                                                 + str('.mgz')))

                # if logWriter:
                #     logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx],
                #                               vol_idx)
                volume_dice_score = volume_dice_score.item()
                volume_dice_score_list.append(volume_dice_score)

                print(volume_dice_score)

            print(volume_dice_score_list)
            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)
        # class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        # if logWriter:
        #     logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    print("DONE")

    return np.mean(all_query_dice_score_list)
Пример #4
0
def evaluate_dice_score_3view(model1_path,
                              model2_path,
                              model3_path,
                              num_classes,
                              query_labels,
                              data_dir,
                              query_txt_file,
                              support_txt_file,
                              remap_config,
                              orientation1,
                              prediction_path, device=0, logWriter=None, mode='eval', fold=None):
    print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**")
    print("Loading model => " + model1_path + " and " + model2_path)
    batch_size = 10

    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model1 = torch.load(model1_path)
    model2 = torch.load(model2_path)
    model3 = torch.load(model3_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model1.cuda(device)
        model2.cuda(device)
        model3.cuda(device)

    model1.eval()
    model2.eval()
    model3.eval()

    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []
        for query_label in query_labels:
            volume_dice_score_list = []
            for vol_idx, file_path in enumerate(support_file_paths):
                # Loading support
                support_volume1, support_labelmap1, _, _ = du.load_and_preprocess(file_path,
                                                                                  orientation=orientation1,
                                                                                  remap_config=remap_config)
                support_volume2, support_labelmap2 = support_volume1.transpose((1, 2, 0)), support_labelmap1.transpose(
                    (1, 2, 0))

                support_volume3, support_labelmap3 = support_volume1.transpose((2, 0, 1)), support_labelmap1.transpose(
                    (2, 0, 1))

                support_volume1 = support_volume1 if len(support_volume1.shape) == 4 else support_volume1[:, np.newaxis,
                                                                                          :, :]
                support_volume2 = support_volume2 if len(support_volume2.shape) == 4 else support_volume2[:, np.newaxis,
                                                                                          :, :]

                support_volume3 = support_volume3 if len(support_volume3.shape) == 4 else support_volume3[:, np.newaxis,
                                                                                          :, :]

                support_volume1, support_labelmap1 = torch.tensor(support_volume1).type(
                    torch.FloatTensor), torch.tensor(
                    support_labelmap1).type(torch.LongTensor)
                support_volume2, support_labelmap2 = torch.tensor(support_volume2).type(
                    torch.FloatTensor), torch.tensor(
                    support_labelmap2).type(torch.LongTensor)
                support_volume3, support_labelmap3 = torch.tensor(support_volume3).type(
                    torch.FloatTensor), torch.tensor(
                    support_labelmap3).type(torch.LongTensor)

                support_volume1 = binarize_label(support_volume1, support_labelmap1, query_label)
                support_volume2 = binarize_label(support_volume2, support_labelmap2, query_label)
                support_volume3 = binarize_label(support_volume3, support_labelmap3, query_label)

            for vol_idx, file_path in enumerate(query_file_paths):
                query_volume1, query_labelmap1, _, _ = du.load_and_preprocess(file_path,
                                                                              orientation=orientation1,
                                                                              remap_config=remap_config)
                query_volume2, query_labelmap2 = query_volume1.transpose((1, 2, 0)), query_labelmap1.transpose(
                    (1, 2, 0))
                query_volume3, query_labelmap3 = query_volume1.transpose((2, 0, 1)), query_labelmap1.transpose(
                    (2, 0, 1))

                query_volume1 = query_volume1 if len(query_volume1.shape) == 4 else query_volume1[:, np.newaxis, :, :]
                query_volume2 = query_volume2 if len(query_volume2.shape) == 4 else query_volume2[:, np.newaxis, :, :]
                query_volume3 = query_volume3 if len(query_volume3.shape) == 4 else query_volume3[:, np.newaxis, :, :]

                query_volume1, query_labelmap1 = torch.tensor(query_volume1).type(torch.FloatTensor), torch.tensor(
                    query_labelmap1).type(torch.LongTensor)
                query_volume2, query_labelmap2 = torch.tensor(query_volume2).type(torch.FloatTensor), torch.tensor(
                    query_labelmap2).type(torch.LongTensor)
                query_volume3, query_labelmap3 = torch.tensor(query_volume3).type(torch.FloatTensor), torch.tensor(
                    query_labelmap3).type(torch.LongTensor)

                query_labelmap1 = query_labelmap1 == query_label
                # query_labelmap2 = query_labelmap2 == query_label
                # query_labelmap3 = query_labelmap3 == query_label

                # Evaluate for orientation 1
                support_batch_x = []
                k = 2
                volume_prediction1 = []
                for i in range(0, len(query_volume1), batch_size):
                    query_batch_x = query_volume1[i: i + batch_size]
                    if k % 2 == 0:
                        support_batch_x = support_volume1[i: i + batch_size]
                    sz = query_batch_x.size()
                    support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1)
                    k += 1
                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                    weights = model1.conditioner(support_batch_x)
                    out = model1.segmentor(query_batch_x, weights)

                    # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                    volume_prediction1.append(out)

                # Evaluate for orientation 2
                support_batch_x = []
                k = 2
                volume_prediction2 = []
                for i in range(0, len(query_volume2), batch_size):
                    query_batch_x = query_volume2[i: i + batch_size]
                    if k % 2 == 0:
                        support_batch_x = support_volume2[i: i + batch_size]
                    sz = query_batch_x.size()
                    support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1)
                    k += 1
                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                    weights = model2.conditioner(support_batch_x)
                    out = model2.segmentor(query_batch_x, weights)
                    volume_prediction2.append(out)

                # Evaluate for orientation 3
                support_batch_x = []
                k = 2
                volume_prediction3 = []
                for i in range(0, len(query_volume3), batch_size):
                    query_batch_x = query_volume3[i: i + batch_size]
                    if k % 2 == 0:
                        support_batch_x = support_volume3[i: i + batch_size]
                    sz = query_batch_x.size()
                    support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1)
                    k += 1
                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                    weights = model3.conditioner(support_batch_x)
                    out = model3.segmentor(query_batch_x, weights)
                    volume_prediction3.append(out)

                volume_prediction1 = torch.cat(volume_prediction1)
                volume_prediction2 = torch.cat(volume_prediction2)
                volume_prediction3 = torch.cat(volume_prediction3)
                volume_prediction = 0.33 * F.softmax(volume_prediction1, dim=1) + 0.33 * F.softmax(
                    volume_prediction2.permute(3, 1, 0, 2), dim=1) + 0.33 * F.softmax(
                    volume_prediction3.permute(2, 1, 3, 0), dim=1)
                _, batch_output = torch.max(volume_prediction, dim=1)
                volume_dice_score = dice_score_binary(batch_output, query_labelmap1.cuda(device), phase=mode)

                batch_output = (batch_output.cpu().numpy()).astype('float32')
                nifti_img = nib.MGHImage(np.squeeze(batch_output), np.eye(4))
                nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz')))

                # # Save Input
                # nifti_img = nib.MGHImage(np.squeeze(query_volume1.cpu().numpy()), np.eye(4))
                # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz')))
                # # # Condition Input
                # nifti_img = nib.MGHImage(np.squeeze(support_volume1.cpu().numpy()), np.eye(4))
                # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz')))
                # # # Cond GT
                # nifti_img = nib.MGHImage(np.squeeze(support_labelmap1.cpu().numpy()).astype('float32'), np.eye(4))
                # nib.save(nifti_img,
                #          os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz')))
                # # # # Save Ground Truth
                # nifti_img = nib.MGHImage(np.squeeze(query_labelmap1.cpu().numpy()), np.eye(4))
                # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + str('.mgz')))

                # if logWriter:
                #     logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx],
                #                               vol_idx)
                volume_dice_score = volume_dice_score.cpu().numpy()
                volume_dice_score_list.append(volume_dice_score)

                print(volume_dice_score)

            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)
        # class_dist = [dice_score_arr[:, c] for c in range(num_classes)]

        # if logWriter:
        #     logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score')
    print("DONE")

    return np.mean(all_query_dice_score_list)
Пример #5
0
def evaluate_dice_score(model_path,
                        num_classes,
                        query_labels,
                        data_dir,
                        query_txt_file,
                        support_txt_file,
                        remap_config,
                        orientation,
                        prediction_path,
                        device=0,
                        logWriter=None,
                        mode='eval',
                        fold=None):
    print(
        "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**"
    )
    print("Loading model => " + model_path)
    batch_size = 20
    Num_support = 15
    MC_samples = 10
    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()
    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir,
                                            support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []

        for query_label in query_labels:
            volume_dice_score_list = []

            support_slices = []

            for i, file_path in enumerate(support_file_paths):
                # Loading support
                support_volume, support_labelmap, _, _ = du.load_and_preprocess(
                    file_path,
                    orientation=orientation,
                    remap_config=remap_config)

                support_volume = support_volume if len(
                    support_volume.shape
                ) == 4 else support_volume[:, np.newaxis, :, :]
                support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                                   torch.tensor(support_labelmap).type(torch.LongTensor)

                support_volume, range_index = binarize_label(
                    support_volume, support_labelmap, query_label)

                slice_gap_support = int(
                    np.ceil(len(support_volume) / Num_support))

                support_slice_indexes = [
                    i for i in range(0, len(support_volume), slice_gap_support)
                ]

                if len(support_slice_indexes) < Num_support:
                    support_slice_indexes.append(len(support_volume) - 1)

                support_slices.extend(
                    [support_volume[idx] for idx in support_slice_indexes])

            for vol_idx, file_path in enumerate(query_file_paths):

                query_volume, query_labelmap, _, _ = du.load_and_preprocess(
                    file_path,
                    orientation=orientation,
                    remap_config=remap_config)

                query_volume = query_volume if len(
                    query_volume.shape) == 4 else query_volume[:, np.
                                                               newaxis, :, :]
                query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \
                                               torch.tensor(query_labelmap).type(torch.LongTensor)

                query_labelmap = query_labelmap == query_label
                range_query = get_range(query_labelmap)
                query_volume = query_volume[range_query[0]:range_query[1] + 1]
                query_labelmap = query_labelmap[range_query[0]:range_query[1] +
                                                1]

                slice_gap_query = int(np.ceil(len(query_volume) / Num_support))
                dice_per_batch = []
                batch_output_arr = []
                for support_slice_idx, i in enumerate(
                        range(0, len(query_volume), slice_gap_query)):
                    query_batch_x = query_volume[i:i + slice_gap_query]
                    support_batch_x = support_volume[support_slice_idx].repeat(
                        query_batch_x.size()[0], 1, 1, 1)

                    if cuda_available:
                        query_batch_x = query_batch_x.cuda(device)
                        support_batch_x = support_batch_x.cuda(device)

                        weights = model.conditioner(support_batch_x)
                        out = model.segmentor(query_batch_x, weights)

                        _, batch_output = torch.max(F.softmax(out, dim=1),
                                                    dim=1)
                        batch_output_arr.append(batch_output)

                volume_output = torch.cat(batch_output_arr)
                volume_dice_score = dice_score_binary(
                    volume_output, query_labelmap.cuda(device), phase=mode)
                volume_dice_score_list.append(volume_dice_score.item())
                print(str(file_path), volume_dice_score)

            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print(volume_dice_score_list)
            print('Query Label -> ' + str(query_label) + ' ' +
                  str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)

    print("DONE")

    return np.mean(all_query_dice_score_list)
Пример #6
0
def evaluate_dice_score(model_path,
                        num_classes,
                        query_labels,
                        data_dir,
                        query_txt_file,
                        support_txt_file,
                        remap_config,
                        orientation,
                        prediction_path,
                        device=0,
                        logWriter=None,
                        mode='eval',
                        fold=None):
    print(
        "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**"
    )
    print("Loading model => " + model_path)
    batch_size = 20
    Num_support = 10
    with open(query_txt_file) as file_handle:
        volumes_query = file_handle.read().splitlines()

    # with open(support_txt_file) as file_handle:
    #     volumes_support = file_handle.read().splitlines()

    model = torch.load(model_path)
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.empty_cache()
        model.cuda(device)

    model.eval()

    common_utils.create_if_not(prediction_path)

    print("Evaluating now... " + fold)
    query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file)
    support_file_paths = du.load_file_paths(data_dir, data_dir,
                                            support_txt_file)

    with torch.no_grad():
        all_query_dice_score_list = []
        for query_label in query_labels:
            volume_dice_score_list = []

            # Loading support
            support_volume, support_labelmap, _, _ = du.load_and_preprocess(
                support_file_paths[0],
                orientation=orientation,
                remap_config=remap_config)
            support_volume = support_volume if len(
                support_volume.shape) == 4 else support_volume[:, np.
                                                               newaxis, :, :]
            support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                               torch.tensor(support_labelmap).type(torch.LongTensor)

            support_volume, range_index = binarize_label(
                support_volume, support_labelmap, query_label)

            # # Save Input
            nifti_img = nib.MGHImage(
                np.squeeze(support_volume[:, 0, :, :].cpu().numpy()),
                np.eye(4))
            nib.save(
                nifti_img,
                os.path.join(prediction_path, 'SupportInput_' + str('.mgz')))

            nifti_img = nib.MGHImage(
                np.squeeze(support_volume[:, 1, :, :].cpu().numpy()),
                np.eye(4))
            nib.save(nifti_img,
                     os.path.join(prediction_path, 'SupportGT_' + str('.mgz')))

            print("Saved")

            slice_gap_support = int(np.ceil(len(support_volume) / Num_support))

            support_slice_indexes = [
                i for i in range(0, len(support_volume), slice_gap_support)
            ]

            if len(support_slice_indexes) < Num_support:
                support_slice_indexes.append(len(support_volume) - 1)

            for vol_idx, file_path in enumerate(query_file_paths):

                query_volume, query_labelmap, _, _ = du.load_and_preprocess(
                    file_path,
                    orientation=orientation,
                    remap_config=remap_config)

                query_volume = query_volume if len(
                    query_volume.shape) == 4 else query_volume[:, np.
                                                               newaxis, :, :]
                query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \
                                               torch.tensor(query_labelmap).type(torch.LongTensor)

                query_labelmap = query_labelmap == query_label
                range_query = get_range(query_labelmap)
                query_volume = query_volume[range_query[0]:range_query[1] + 1]
                query_labelmap = query_labelmap[range_query[0]:range_query[1] +
                                                1]

                dice_per_slice = []
                vol_output = []
                for support_slice_idx in support_slice_indexes:
                    batch_output = []
                    for i in range(0, len(query_volume), batch_size):
                        query_batch_x = query_volume[i:i + batch_size]
                        support_batch_x = support_volume[
                            support_slice_idx].repeat(query_batch_x.size()[0],
                                                      1, 1, 1)
                        if cuda_available:
                            query_batch_x = query_batch_x.cuda(device)
                            support_batch_x = support_batch_x.cuda(device)
                        weights = model.conditioner(support_batch_x)
                        out = model.segmentor(query_batch_x, weights)

                        # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                        batch_output.append(out)
                    batch_output = torch.cat(batch_output)
                    vol_output.append(batch_output)
                vol_output = torch.stack(vol_output)
                vol_output = torch.mean(vol_output, dim=0)
                _, vol_output = torch.max(F.softmax(vol_output, dim=1), dim=1)

                # for i, query_slice in enumerate(query_volume):
                #     query_batch_x = query_slice.unsqueeze(0)
                #     max_dice = -1.0
                #     max_output = None
                #     for j in range(0, len(support_volume), 5):
                #         support_slice = support_volume[j]
                #
                #         support_batch_x = support_slice.unsqueeze(0)
                #         if cuda_available:
                #             query_batch_x = query_batch_x.cuda(device)
                #             support_batch_x = support_batch_x.cuda(device)
                #
                #         weights = model.conditioner(support_batch_x)
                #         out = model.segmentor(query_batch_x, weights)
                #
                #         _, batch_output = torch.max(F.softmax(out, dim=1), dim=1)
                #         slice_dice_score = dice_score_binary(batch_output,
                #                                              query_labelmap[i].cuda(device), phase=mode)
                #         if slice_dice_score.item() >= max_dice:
                #             max_dice = slice_dice_score.item()
                #             max_output = batch_output
                #     # dice_per_slice.append(max_dice)
                #     vol_output.append(max_output)
                #
                # vol_output = torch.cat(vol_output)
                # volume_dice_score = np.mean(np.asarray(dice_per_slice))
                volume_dice_score = dice_score_binary(
                    vol_output, query_labelmap.cuda(device), phase=mode)
                volume_dice_score_list.append(volume_dice_score)

                print(volume_dice_score)

            dice_score_arr = np.asarray(volume_dice_score_list)
            avg_dice_score = np.median(dice_score_arr)
            print('Query Label -> ' + str(query_label) + ' ' +
                  str(avg_dice_score))
            all_query_dice_score_list.append(avg_dice_score)

    print("DONE")

    return np.mean(all_query_dice_score_list)
def train(train_params, common_params, data_params, net_params):
    query_label = 8
    Num_support = 10

    # train_data, test_data = load_data(data_params)

    support_volume, support_labelmap, _, _ = du.load_and_preprocess(
        "/home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral/10000132_1_CTce_ThAb.mat",
        orientation='AXI',
        remap_config="WholeBody")
    support_volume = support_volume if len(
        support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :]
    support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \
                                       torch.tensor(support_labelmap).type(torch.LongTensor)

    support_labelmap = (support_labelmap == query_label).type(
        torch.FloatTensor)
    batch, _, _ = support_labelmap.size()
    slice_with_class = torch.sum(support_labelmap.view(batch, -1), dim=1) > 10
    support_labelmap = support_labelmap[slice_with_class]
    support_volume = support_volume[slice_with_class]

    support_slice_indexes = np.round(
        np.linspace(0,
                    len(support_volume) - 1, Num_support + 1)).astype(int)
    support_slice_indexes += (len(support_volume) // Num_support) // 2
    support_slice_indexes = support_slice_indexes[:-1]
    support_volume = support_volume[support_slice_indexes]
    support_labelmap = support_labelmap[support_slice_indexes]

    train_data = ImdbData(support_volume.numpy(), support_labelmap.numpy(),
                          np.ones_like(support_labelmap.numpy()))

    # Removing unused classes
    # train_data.y[train_data.y == 3] = 0
    # test_data.y[test_data.y == 3] = 0
    #
    # train_data.y[train_data.y == 4] = 0
    # test_data.y[test_data.y == 4] = 0
    #
    # train_data.y[train_data.y == 5] = 0
    # test_data.y[test_data.y == 5] = 0
    #
    # train_data.y[train_data.y == 6] = 3
    # test_data.y[test_data.y == 6] = 3
    #
    # train_data.y[train_data.y == 7] = 4
    # test_data.y[test_data.y == 7] = 4
    #
    # train_data.y[train_data.y == 8] = 0
    # test_data.y[test_data.y == 8] = 0
    #
    # train_data.y[train_data.y == 9] = 0
    # test_data.y[test_data.y == 9] = 0

    # batch_size = len(train_data.y)
    # non_black_slices = np.sum(train_data.y.reshape(batch_size, -1), axis=1) > 10
    # train_data.X = train_data.X[non_black_slices]
    # train_data.y = train_data.y[non_black_slices]

    # batch_size = len(test_data.y)
    # non_black_slices = np.sum(test_data.y.reshape(batch_size, -1), axis=1) > 10
    # test_data.X = test_data.X[non_black_slices]
    # test_data.y = test_data.y[non_black_slices]

    model_prefix = 'finetuned_segmentor_'
    folds = ['fold4']
    for fold in folds:
        final_model_path = os.path.join(common_params['save_model_dir'],
                                        model_prefix + fold + '.pth.tar')

        train_params['exp_name'] = model_prefix + fold

        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=train_params['train_batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True)
        val_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=train_params['val_batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True)

        # conditioner_pretrained = torch.load(train_params['pre_trained_path'])
        # segmentor_pretrained.classifier = Identity()
        # segmentor_pretrained.sigmoid = Identity()

        # for param in segmentor_pretrained.parameters():
        #     param.requires_grad = False

        # few_shot_model = fs.SDnetSegmentor(net_params)
        few_shot_model = torch.load(train_params['pre_trained_path'])
        for param in few_shot_model.parameters():
            param.requires_grad = False
        net_params['num_channels'] = 64
        few_shot_model.classifier = sm.ClassifierBlock(net_params)
        for param in few_shot_model.classifier.parameters():
            param.requires_grad = True

        # few_shot_model = segmentor_pretrained
        # few_shot_model.conditioner = conditioner_pretrained

        solver = Solver(
            few_shot_model,
            device=common_params['device'],
            num_class=net_params['num_class'],
            optim_args={
                "lr": train_params['learning_rate'],
                # "betas": train_params['optim_betas'],
                # "eps": train_params['optim_eps'],
                "weight_decay": train_params['optim_weight_decay'],
                "momentum": train_params['momentum']
            },
            model_name=common_params['model_name'],
            exp_name=train_params['exp_name'],
            labels=data_params['labels'],
            log_nth=train_params['log_nth'],
            num_epochs=train_params['num_epochs'],
            lr_scheduler_step_size=train_params['lr_scheduler_step_size'],
            lr_scheduler_gamma=train_params['lr_scheduler_gamma'],
            use_last_checkpoint=train_params['use_last_checkpoint'],
            log_dir=common_params['log_dir'],
            exp_dir=common_params['exp_dir'])

        solver.train(train_loader, val_loader)

        # few_shot_model.save(final_model_path)
        # final_model_path = os.path.join(common_params['save_model_dir'], )
        solver.save_best_model(final_model_path)
        print("final model saved @ " + str(final_model_path))