def TestMicroCalcificationReconstruction(args):
    prediction_saving_dir = os.path.join(
        args.model_saving_dir,
        'reconstruction_results_dataset_{}_epoch_{}'.format(
            args.dataset_type, args.epoch_idx))
    visualization_saving_dir = os.path.join(prediction_saving_dir,
                                            'qualitative_results')
    visualization_TP_saving_dir = os.path.join(visualization_saving_dir,
                                               'TPs_only')
    visualization_FP_saving_dir = os.path.join(visualization_saving_dir,
                                               'FPs_only')
    visualization_FN_saving_dir = os.path.join(visualization_saving_dir,
                                               'FNs_only')
    visualization_FP_FN_saving_dir = os.path.join(visualization_saving_dir,
                                                  'FPs_FNs_both')

    # remove existing dir which has the same name and create clean dir
    if os.path.exists(prediction_saving_dir):
        shutil.rmtree(prediction_saving_dir)
    os.mkdir(prediction_saving_dir)
    os.mkdir(visualization_saving_dir)
    os.mkdir(visualization_TP_saving_dir)
    os.mkdir(visualization_FP_saving_dir)
    os.mkdir(visualization_FN_saving_dir)
    os.mkdir(visualization_FP_FN_saving_dir)

    # initialize logger
    logger = Logger(prediction_saving_dir, 'quantitative_results.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))

    # define the network
    network = VNet2d(num_in_channels=cfg.net.in_channels,
                     num_out_channels=cfg.net.out_channels)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir,
                                 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [
            best_ckpt_filename for best_ckpt_filename in saved_ckpt_list
            if 'net_best_on_validation_set' in best_ckpt_filename
        ][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = copy.deepcopy(network)
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    logger.write_and_print(
        'Load ckpt: {0} for evaluating...'.format(ckpt_path))

    # get calculate_uncertainty global variance
    calculate_uncertainty = True if len(args.mc_epoch_indexes) > 0 else False

    # get net list for imitating MC dropout process
    net_list = None
    if calculate_uncertainty:
        net_list = get_net_list(network, ckpt_dir, args.mc_epoch_indexes,
                                logger)

    # create dataset
    dataset = MicroCalcificationDataset(
        data_root_dir=args.data_root_dir,
        mode=args.dataset_type,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=args.dilation_radius,
        load_uncertainty_map=False,
        calculate_micro_calcification_number=cfg.dataset.
        calculate_micro_calcification_number,
        enable_data_augmentation=False)

    # create data loader
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=cfg.train.num_threads)

    metrics = MetricsReconstruction(args.prob_threshold, args.area_threshold,
                                    args.distance_threshold,
                                    args.slack_for_recall)

    calcification_num = 0
    recall_num = 0
    FP_num = 0

    for batch_idx, (images_tensor, pixel_level_labels_tensor,
                    pixel_level_labels_dilated_tensor, _,
                    image_level_labels_tensor, _,
                    filenames) in enumerate(data_loader):
        logger.write_and_print('Evaluating batch: {}'.format(batch_idx))

        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # network forward
        reconstructed_images_tensor, prediction_residues_tensor = net(
            images_tensor)

        # MC dropout
        uncertainty_maps_np = generate_uncertainty_maps(
            net_list, images_tensor) if calculate_uncertainty else None

        # evaluation
        post_process_preds_np, calcification_num_batch_level, recall_num_batch_level, FP_num_batch_level, \
        result_flag_list = metrics.metric_batch_level(prediction_residues_tensor, pixel_level_labels_tensor)

        calcification_num += calcification_num_batch_level
        recall_num += recall_num_batch_level
        FP_num += FP_num_batch_level

        # print logging information
        logger.write_and_print(
            'The number of the annotated calcifications of this batch = {}'.
            format(calcification_num_batch_level))
        logger.write_and_print(
            'The number of the recalled calcifications of this batch = {}'.
            format(recall_num_batch_level))
        logger.write_and_print(
            'The number of the false positive calcifications of this batch = {}'
            .format(FP_num_batch_level))
        logger.write_and_print(
            'Consuming time: {:.4f}s'.format(time() - start_time_for_batch))
        logger.write_and_print(
            '--------------------------------------------------------------------------------------'
        )

        save_tensor_in_png_and_nii_format(images_tensor,
                                          reconstructed_images_tensor,
                                          prediction_residues_tensor,
                                          post_process_preds_np,
                                          pixel_level_labels_tensor,
                                          pixel_level_labels_dilated_tensor,
                                          uncertainty_maps_np,
                                          filenames,
                                          result_flag_list,
                                          visualization_saving_dir,
                                          save_nii=args.save_nii)

        logger.flush()

    logger.write_and_print(
        'The number of the annotated calcifications of this dataset = {}'.
        format(calcification_num))
    logger.write_and_print(
        'The number of the recalled calcifications of this dataset = {}'.
        format(recall_num))
    logger.write_and_print(
        'The number of the false positive calcifications of this dataset = {}'.
        format(FP_num))

    return
def UncertaintySTA(args):
    prediction_saving_dir = os.path.join(args.model_saving_dir,
                                         'reconstruction_results_dataset_{}_epoch_{}'.format(args.dataset_type,
                                                                                             args.epoch_idx))
    # initialize logger
    if os.path.exists(args.sta_save_dir):
        shutil.rmtree(args.sta_save_dir)
    os.mkdir(args.sta_save_dir)
    logger = Logger(args.sta_save_dir, 'uncertainty_distribution_sta.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))
    # define the network
    network = VNet2d(num_in_channels=cfg.net.in_channels, num_out_channels=cfg.net.out_channels)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir, 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [best_ckpt_filename for best_ckpt_filename in saved_ckpt_list if
                              'net_best_on_validation_set' in best_ckpt_filename][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = copy.deepcopy(network)
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    # get calculate_uncertainty global variance
    calculate_uncertainty = True if len(args.mc_epoch_indexes) > 0 else False

    # get net list for imitating MC dropout process
    net_list = None
    if calculate_uncertainty:
        net_list = get_net_list(network, ckpt_dir, args.mc_epoch_indexes, logger)

    # create dataset

    dataset = MicroCalcificationDataset(data_root_dir=args.data_root_dir,
                                        mode=args.dataset_type,
                                        enable_random_sampling=False,
                                        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
                                        image_channels=cfg.dataset.image_channels,
                                        cropping_size=cfg.dataset.cropping_size,
                                        dilation_radius=args.dilation_radius,
                                        load_uncertainty_map=False,
                                        calculate_micro_calcification_number=cfg.dataset.calculate_micro_calcification_number,
                                        enable_data_augmentation=False)

    # create data loader
    data_loader = DataLoader(dataset, batch_size=args.batch_size,
                             shuffle=False, num_workers=cfg.train.num_threads)

    metrics = MetricsReconstruction(args.prob_threshold, args.area_threshold, args.distance_threshold,
                                    args.slack_for_recall)

    all_positive_uncertainty_in_dataset = np.zeros(args.bins)
    tp_uncertainty_in_dataset = np.zeros(args.bins)
    fn_uncertainty_in_dataset = np.zeros(args.bins)
    fp_uncertainty_in_dataset = np.zeros(args.bins)
    uncertainty_max=0

    for batch_idx, (images_tensor, pixel_level_labels_tensor, pixel_level_labels_dilated_tensor, _,
                    image_level_labels_tensor, _, filenames) in enumerate(data_loader):
        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # network forward
        reconstructed_images_tensor, prediction_residues_tensor = net(images_tensor)

        # MC dropout
        uncertainty_maps_np = generate_uncertainty_maps(net_list, images_tensor) if calculate_uncertainty else None

        # in tp, fn label area  uncertainty value distribution
        post_process_preds_np, calcification_num_batch_level, recall_num_batch_level, FP_num_batch_level, \
        result_flag_list = metrics.metric_batch_level(prediction_residues_tensor, pixel_level_labels_tensor)

        pixel_level_labels_dilated = pixel_level_labels_dilated_tensor.view(-1).numpy()
        preds_positive = post_process_preds_np.reshape(-1)
        uncertainty_maps = uncertainty_maps_np.reshape(-1)
        if uncertainty_max< np.amax(uncertainty_maps):
            uncertainty_max = np.amax(uncertainty_maps)

        all_positive_uncertainty_batch = uncertainty_maps[pixel_level_labels_dilated == 1]
        all_positive_uncertainty_distr_batch, _ = np.histogram(all_positive_uncertainty_batch,
                                                               bins=args.bins, range=(0, args.bin_range))
        all_positive_uncertainty_in_dataset += all_positive_uncertainty_distr_batch

        pixel_level_unlabels_dilated = np.subtract(np.ones_like(pixel_level_labels_dilated), pixel_level_labels_dilated)
        fp_location = np.multiply(pixel_level_unlabels_dilated, preds_positive)
        tp_location = np.multiply(pixel_level_labels_dilated, preds_positive)
        fn_location = np.zeros_like(preds_positive)
        fn_location[pixel_level_labels_dilated == 1] = 1
        fn_location[preds_positive == 1] = 0

        tp_uncertainty_batch = uncertainty_maps[tp_location == 1]
        tp_uncertainty_distr_batch, _ = np.histogram(tp_uncertainty_batch, bins=args.bins, range=(0, args.bin_range))
        tp_uncertainty_in_dataset += tp_uncertainty_distr_batch

        fn_uncertainty_batch = uncertainty_maps[fn_location == 1]
        fn_uncertainty_distr_batch, _ = np.histogram(fn_uncertainty_batch, bins=args.bins, range=(0, args.bin_range))
        fn_uncertainty_in_dataset += fn_uncertainty_distr_batch

        fp_uncertainty_batch = uncertainty_maps[fp_location == 1]
        fp_uncertainty_distr_batch, _ = np.histogram(fp_uncertainty_batch, bins=args.bins, range=(0, args.bin_range))
        fp_uncertainty_in_dataset += fp_uncertainty_distr_batch

    # debug only
    # print(all_positive_uncertainty_in_dataset[0:5])
    # print(tp_uncertainty_in_dataset[0:5])
    # print(fn_uncertainty_in_dataset[0:5])

    all_positive_uncertainty_in_dataset[all_positive_uncertainty_in_dataset > 1000] = 1000
    tp_uncertainty_in_dataset[tp_uncertainty_in_dataset > 1000] = 1000
    fn_uncertainty_in_dataset[fn_uncertainty_in_dataset > 1000] = 1000
    fp_uncertainty_in_dataset[fp_uncertainty_in_dataset > 1000] = 1000

    pltsave(all_positive_uncertainty_in_dataset, dir=args.sta_save_dir, name='all positive uncertainty')
    pltsave(tp_uncertainty_in_dataset, dir=args.sta_save_dir, name='True Positive uncertainty')
    pltsave(fn_uncertainty_in_dataset, dir=args.sta_save_dir, name='False Negative uncertainty')
    pltsave(fp_uncertainty_in_dataset, dir=args.sta_save_dir, name='False Positive uncertainty')

    fp_uncertainty_in_dataset_filtered = gaussian_filter1d(fp_uncertainty_in_dataset, sigma=3)
    tp_uncertainty_in_dataset_filtered = gaussian_filter1d(tp_uncertainty_in_dataset, sigma=3)
    fn_uncertainty_in_dataset_filtered = gaussian_filter1d(fn_uncertainty_in_dataset, sigma=3)
    fp_and_fn = fp_uncertainty_in_dataset_filtered + fn_uncertainty_in_dataset_filtered

    pltsave(fp_and_fn, dir=args.sta_save_dir, name='FP & FN uncertainty filtered')
    pltsave(tp_uncertainty_in_dataset_filtered, dir=args.sta_save_dir, name='True Positive uncertainty filtered')
    pltsave(fn_uncertainty_in_dataset_filtered, dir=args.sta_save_dir, name='False Negative uncertainty filtered')
    pltsave(fp_uncertainty_in_dataset_filtered, dir=args.sta_save_dir, name='False Positive uncertainty filtered')

    fp_mean, fp_var = find_mean_and_var(fp_uncertainty_in_dataset_filtered, start=args.bins / 10, end=args.bins,
                                        bins=args.bins, bin_range=args.bin_range)
    tp_mean, tp_var = find_mean_and_var(tp_uncertainty_in_dataset_filtered, start=args.bins / 10, end=args.bins,
                                        bins=args.bins,
                                        bin_range=args.bin_range)
    fn_mean, fn_var = find_mean_and_var(fn_uncertainty_in_dataset_filtered, start=args.bins / 10, end=args.bins,
                                        bins=args.bins,
                                        bin_range=args.bin_range)
    logger.write_and_print('max uncertainty is {0}  '.format(uncertainty_max))
    logger.write_and_print('fp uncertainty mean is {0}  variance is {1}'.format(fp_mean, fp_var))
    logger.write_and_print('tp uncertainty mean is {0}  variance is {1}'.format(tp_mean, tp_var))
    logger.write_and_print('fn uncertainty mean is {0}  variance is {1}'.format(fn_mean, fn_var))

    return
def generate_coordinate_and_score_list(images_tensor,
                                       classification_net,
                                       pixel_level_label_np,
                                       raw_residue_radiograph_np,
                                       processed_residue_radiograph_np,
                                       filename,
                                       saving_dir,
                                       crop_patch_size,
                                       upsampled_patch_size,
                                       net_list,
                                       mode='detected'):
    # mode must be either 'detected' or 'annotated'
    assert mode in ['detected', 'annotated']

    if saving_dir is not None:
        # make the related dirs
        patch_level_root_saving_dir = os.path.join(saving_dir, filename[:-4])
        patch_visualization_dir = os.path.join(patch_level_root_saving_dir,
                                               mode)
        if not os.path.exists(patch_level_root_saving_dir):
            os.mkdir(patch_level_root_saving_dir)
        os.mkdir(patch_visualization_dir)

    height, width = processed_residue_radiograph_np.shape

    if mode == 'detected':
        # mode: detected -> iterate each connected component on processed_residue_radiograph_np
        mask_np = copy.copy(processed_residue_radiograph_np)
        mask_np[processed_residue_radiograph_np > 0] = 1
    else:
        # mode: annotated -> iterate each connected component on pixel_level_label_np
        mask_np = copy.copy(pixel_level_label_np)
        # remain micro calcifications and normal tissue label only
        mask_np[mask_np > 1] = 0

    # generate information of each connected component
    connected_components = measure.label(mask_np)
    props = measure.regionprops(connected_components)

    # created for saving the coordinates and the detected score for this connected component
    coordinate_list = list()
    score_list = list()

    connected_idx = 0
    if len(props) > 0:
        for prop in props:
            connected_idx += 1

            # generate logical indexes for this connected component
            indexes = connected_components == connected_idx

            # record the centroid of this connected component
            coordinate_list.append(np.array(prop.centroid))

            # generate legal start and end idx for row and column
            centroid_row_idx = prop.centroid[0]
            centroid_column_idx = prop.centroid[1]
            #
            centroid_row_idx = np.clip(centroid_row_idx,
                                       crop_patch_size[0] / 2,
                                       height - crop_patch_size[0] / 2)
            centroid_column_idx = np.clip(centroid_column_idx,
                                          crop_patch_size[1] / 2,
                                          width - crop_patch_size[1] / 2)
            #
            start_row_idx = int(centroid_row_idx - crop_patch_size[0] / 2)
            end_row_idx = int(centroid_row_idx + crop_patch_size[0] / 2)
            start_column_idx = int(centroid_column_idx -
                                   crop_patch_size[1] / 2)
            end_column_idx = int(centroid_column_idx + crop_patch_size[1] / 2)

            # crop this patch for model inference
            patch_image_tensor = images_tensor[:, :, start_row_idx:end_row_idx,
                                               start_column_idx:end_column_idx]
            upsampled_patch_image_tensor = \
                torch.nn.functional.interpolate(patch_image_tensor, size=(upsampled_patch_size[0],
                                                                          upsampled_patch_size[1]),
                                                scale_factor=None, mode='bilinear', align_corners=False)

            # generate the positive class prediction probability
            classification_preds_tensor = classification_net(
                upsampled_patch_image_tensor)
            classification_preds_tensor = torch.softmax(
                classification_preds_tensor, dim=1)
            positive_prob = classification_preds_tensor.cpu().detach().numpy(
            ).squeeze()[1]

            # MC dropout
            uncertainty_maps_np = generate_uncertainty_maps(
                net_list, upsampled_patch_image_tensor)
            uncertainty_map_np = uncertainty_maps_np.squeeze()

            # calculate the mean value of this connected component on the residue
            residue_mean = (processed_residue_radiograph_np[indexes]).mean()

            # calculate and record the score of this connected component
            score = positive_prob * residue_mean
            score_list.append(score)

            if saving_dir is not None:
                # process the visualization results
                image_patch_np = copy.copy(
                    patch_image_tensor.cpu().detach().numpy().squeeze())
                #
                pixel_level_label_patch_np = copy.copy(
                    pixel_level_label_np[start_row_idx:end_row_idx,
                                         start_column_idx:end_column_idx])
                #
                raw_residue_patch_np = copy.copy(
                    raw_residue_radiograph_np[start_row_idx:end_row_idx,
                                              start_column_idx:end_column_idx])
                #
                processed_residue_patch_np = copy.copy(
                    processed_residue_radiograph_np[
                        start_row_idx:end_row_idx,
                        start_column_idx:end_column_idx])
                #
                stacked_np = np.concatenate(
                    (np.expand_dims(image_patch_np, axis=0),
                     np.expand_dims(pixel_level_label_patch_np, axis=0),
                     np.expand_dims(raw_residue_patch_np, axis=0),
                     np.expand_dims(processed_residue_patch_np, axis=0)),
                    axis=0)
                stacked_image = sitk.GetImageFromArray(stacked_np)
                #
                image_patch_np *= 255
                raw_residue_patch_np *= 255
                processed_residue_patch_np *= 255
                uncertainty_map_np *= 4 * 255
                #
                pixel_level_label_patch_np[pixel_level_label_patch_np ==
                                           1] = 255
                pixel_level_label_patch_np[pixel_level_label_patch_np ==
                                           2] = 165
                pixel_level_label_patch_np[pixel_level_label_patch_np ==
                                           3] = 85
                #
                image_patch_np = image_patch_np.astype(np.uint8)
                raw_residue_patch_np = raw_residue_patch_np.astype(np.uint8)
                processed_residue_patch_np = processed_residue_patch_np.astype(
                    np.uint8)
                pixel_level_label_patch_np = pixel_level_label_patch_np.astype(
                    np.uint8)
                uncertainty_map_np = uncertainty_map_np.astype(np.uint8)
                uncertainty_map_np = cv2.applyColorMap(uncertainty_map_np,
                                                       cv2.COLORMAP_JET)
                #
                prob_saving_image = np.zeros(
                    (crop_patch_size[0], crop_patch_size[1], 3), np.uint8)
                mean_residue_saving_image = np.zeros(
                    (crop_patch_size[0], crop_patch_size[1], 3), np.uint8)
                score_saving_image = np.zeros(
                    (crop_patch_size[0], crop_patch_size[1], 3), np.uint8)
                font = cv2.FONT_HERSHEY_SIMPLEX
                cv2.putText(prob_saving_image, '{:.4f}'.format(positive_prob),
                            (0, 64), font, 1, (0, 255, 255), 2)
                cv2.putText(mean_residue_saving_image,
                            '{:.4f}'.format(residue_mean), (0, 64), font, 1,
                            (255, 0, 255), 2)
                cv2.putText(score_saving_image,
                            '{:.4f}'.format(positive_prob * residue_mean),
                            (0, 64), font, 1, (255, 255, 0), 2)

                # saving
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png', '_patch_{:0>3d}_{}_{}_image.png'.format(
                                connected_idx, int(centroid_row_idx),
                                int(centroid_column_idx)))), image_patch_np)
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png',
                            '_patch_{:0>3d}_mask.png'.format(connected_idx))),
                    pixel_level_label_patch_np)
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png', '_patch_{:0>3d}_raw_residue.png'.format(
                                connected_idx))), raw_residue_patch_np)
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png',
                            '_patch_{:0>3d}_processed_residue.png'.format(
                                connected_idx))), processed_residue_patch_np)
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png', '_patch_{:0>3d}_uncertainty.png'.format(
                                connected_idx))), uncertainty_map_np)
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png', '_patch_{:0>3d}_positive_prob.png'.format(
                                connected_idx))), prob_saving_image)
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png', '_patch_{:0>3d}_mean_residue.png'.format(
                                connected_idx))), mean_residue_saving_image)
                cv2.imwrite(
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png',
                            '_patch_{:0>3d}_score.png'.format(connected_idx))),
                    score_saving_image)
                sitk.WriteImage(
                    stacked_image,
                    os.path.join(
                        patch_visualization_dir,
                        filename.replace(
                            '.png',
                            '_patch_{:0>3d}.nii'.format(connected_idx))))

    return coordinate_list, score_list
Esempio n. 4
0
def TestUncertaintyMapLabelWeightsGeneration(args):
    positive_patch_results_saving_dir = os.path.join(args.dst_data_root_dir,
                                                     'positive_patches',
                                                     args.dataset_type,
                                                     'uncertainty-maps')
    negative_patch_results_saving_dir = os.path.join(args.dst_data_root_dir,
                                                     'negative_patches',
                                                     args.dataset_type,
                                                     'uncertainty-maps')

    # create dir when it does not exist
    if not os.path.exists(positive_patch_results_saving_dir):
        os.mkdir(positive_patch_results_saving_dir)
    if not os.path.exists(negative_patch_results_saving_dir):
        os.mkdir(negative_patch_results_saving_dir)

    # initialize logger
    logger = Logger(args.src_data_root_dir, 'uncertainty.txt')
    logger.write_and_print('Dataset: {}'.format(args.src_data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))

    # define the network
    network = VNet2d(num_in_channels=cfg.net.in_channels,
                     num_out_channels=cfg.net.out_channels)

    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')

    # get net list for imitating MC dropout process
    net_list = get_net_list(network, ckpt_dir, args.mc_epoch_indexes, logger)

    # create dataset
    dataset = MicroCalcificationDataset(
        data_root_dir=args.src_data_root_dir,
        mode=args.dataset_type,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=0,
        load_uncertainty_map=False,
        calculate_micro_calcification_number=False,
        enable_data_augmentation=False)

    # create data loader
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=cfg.train.num_threads)

    for batch_idx, (images_tensor, _, _, _, _, _,
                    filenames) in enumerate(data_loader):
        logger.write_and_print('Evaluating batch: {}'.format(batch_idx))

        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # imitating MC dropout
        uncertainty_maps_np = generate_uncertainty_maps(
            net_list, images_tensor)
        save_uncertainty_maps(uncertainty_maps_np, filenames,
                              positive_patch_results_saving_dir,
                              negative_patch_results_saving_dir, logger)

        logger.write_and_print(
            'Finished evaluating, consuming time = {:.4f}s'.format(
                time() - start_time_for_batch))
        logger.write_and_print(
            '--------------------------------------------------------------------------------------'
        )

        logger.flush()

    return
def save_images_labels_uncertainty_maps(coord_list, image_tensor, image_np,
                                        pixel_level_label_np, net_list,
                                        filename, positive_dataset_type_dir,
                                        negative_dataset_type_dir,
                                        reconstruction_patch_size,
                                        saving_patch_size):
    height, width = image_np.shape

    positive_patch_idx = 0
    negative_patch_idx = 0

    for coord in coord_list:
        # generate legal start and end idx for row and column
        saving_crop_indexes = generate_legal_indexes(coord, saving_patch_size,
                                                     height, width)
        reconstruction_crop_indexes = generate_legal_indexes(
            coord, reconstruction_patch_size, height, width)

        # crop this patch from image and label
        image_patch_np = copy.copy(
            image_np[saving_crop_indexes[0]:saving_crop_indexes[1],
                     saving_crop_indexes[2]:saving_crop_indexes[3]])
        pixel_level_label_patch_np = copy.copy(pixel_level_label_np[
            saving_crop_indexes[0]:saving_crop_indexes[1],
            saving_crop_indexes[2]:saving_crop_indexes[3]])
        image_level_label_patch_bool = True if (pixel_level_label_patch_np
                                                == 1).sum() > 0 else False

        # MC dropout for uncertainty map
        image_patch_tensor = copy.copy(
            image_tensor[:, :, reconstruction_crop_indexes[0]:
                         reconstruction_crop_indexes[1],
                         reconstruction_crop_indexes[2]:
                         reconstruction_crop_indexes[3]])
        uncertainty_map_np = generate_uncertainty_maps(net_list,
                                                       image_patch_tensor)
        uncertainty_map_np = uncertainty_map_np.squeeze()
        #
        # uncertainty map size 112*112 -> 56*56
        center_coord = [
            int(reconstruction_patch_size[0] / 2),
            int(reconstruction_patch_size[1] / 2)
        ]
        center_crop_indexes = generate_legal_indexes(
            center_coord, saving_patch_size, reconstruction_patch_size[0],
            reconstruction_patch_size[1])
        uncertainty_map_np = uncertainty_map_np[
            center_crop_indexes[0]:center_crop_indexes[1],
            center_crop_indexes[2]:center_crop_indexes[3]]
        uncertainty_map_image = sitk.GetImageFromArray(uncertainty_map_np)

        # transformed into png format
        image_patch_np *= 255
        #
        pixel_level_label_patch_np[pixel_level_label_patch_np == 1] = 255
        pixel_level_label_patch_np[pixel_level_label_patch_np == 2] = 165
        pixel_level_label_patch_np[pixel_level_label_patch_np == 3] = 85
        #
        image_patch_np = image_patch_np.astype(np.uint8)
        pixel_level_label_patch_np = pixel_level_label_patch_np.astype(
            np.uint8)

        if image_level_label_patch_bool:
            positive_patch_idx += 1
            absolute_image_saving_path = os.path.join(
                positive_dataset_type_dir, 'positive_' +
                filename.split('.')[0] + '_{}.png'.format(positive_patch_idx))
            absolute_label_saving_path = absolute_image_saving_path.replace(
                'images', 'labels')
            absolute_uncertainty_map_saving_path = absolute_image_saving_path.replace(
                'images', 'uncertainty-maps')
            absolute_uncertainty_map_saving_path = absolute_uncertainty_map_saving_path.replace(
                '.png', '.nii')
        else:
            negative_patch_idx += 1
            absolute_image_saving_path = os.path.join(
                negative_dataset_type_dir, 'negative_' +
                filename.split('.')[0] + '_{}.png'.format(negative_patch_idx))
            absolute_label_saving_path = absolute_image_saving_path.replace(
                'images', 'labels')

            absolute_uncertainty_map_saving_path = absolute_image_saving_path.replace(
                'images', 'uncertainty-maps')
            absolute_uncertainty_map_saving_path = absolute_uncertainty_map_saving_path.replace(
                '.png', '.nii')

        # saving
        cv2.imwrite(absolute_image_saving_path, image_patch_np)
        cv2.imwrite(absolute_label_saving_path, pixel_level_label_patch_np)
        sitk.WriteImage(uncertainty_map_image,
                        absolute_uncertainty_map_saving_path)

    return positive_patch_idx, negative_patch_idx