def TestMicroCalcificationPatchLevelQuantityRegression(args):
    start_time_for_epoch = time()

    prediction_saving_dir = os.path.join(args.model_saving_dir,
                                         'patch_level_quantity_regression_results_dataset_{}_epoch_{}'.format(
                                             args.dataset_type, args.epoch_idx))
    visualization_saving_dir = os.path.join(prediction_saving_dir, 'qualitative_results')

    over_preds_dir = os.path.join(visualization_saving_dir, 'over_preds')
    correct_preds_dir = os.path.join(visualization_saving_dir, 'correct_preds')
    under_preds_dir = os.path.join(visualization_saving_dir, 'under_preds')

    # remove existing dir which has the same name and create clean dir
    if os.path.exists(prediction_saving_dir):
        shutil.rmtree(prediction_saving_dir)
    os.mkdir(prediction_saving_dir)
    os.mkdir(visualization_saving_dir)
    os.mkdir(over_preds_dir)
    os.mkdir(correct_preds_dir)
    os.mkdir(under_preds_dir)

    # initialize logger
    logger = Logger(prediction_saving_dir, 'quantitative_results.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))

    # define the network
    net = ResNet18(in_channels=cfg.net.in_channels, num_classes=cfg.net.num_classes)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir, 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [best_ckpt_filename for best_ckpt_filename in saved_ckpt_list if
                              'net_best_on_validation_set' in best_ckpt_filename][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    logger.write_and_print('Load ckpt: {0}...'.format(ckpt_path))

    # create dataset and data loader
    dataset = MicroCalcificationDataset(data_root_dir=args.data_root_dir,
                                        mode=args.dataset_type,
                                        enable_random_sampling=False,
                                        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
                                        image_channels=cfg.dataset.image_channels,
                                        cropping_size=cfg.dataset.cropping_size,
                                        dilation_radius=cfg.dataset.dilation_radius,
                                        load_uncertainty_map=False,
                                        calculate_micro_calcification_number=cfg.dataset.calculate_micro_calcification_number,
                                        enable_data_augmentation=False)

    data_loader = DataLoader(dataset, batch_size=args.batch_size,
                             shuffle=False, num_workers=cfg.train.num_threads)

    metrics = MetricsImageLEvelQuantityRegression(cfg.dataset.cropping_size)

    pred_num_epoch_level = 0
    distance_epoch_level = 0
    over_pred_epoch_level = 0
    correct_pred_epoch_level = 0
    under_pred_epoch_level = 0

    for batch_idx, (
            images_tensor, pixel_level_labels_tensor, _, _, image_level_labels_tensor,
            micro_calcification_number_label_tensor,
            filenames) in enumerate(
        data_loader):
        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()
        pixel_level_labels_tensor = pixel_level_labels_tensor.cuda()
        micro_calcification_number_label_tensor = micro_calcification_number_label_tensor.type(torch.FloatTensor)
        micro_calcification_number_label_tensor = micro_calcification_number_label_tensor.cuda()

        # network forward
        preds_tensor = net(images_tensor)  # the shape of preds_tensor: [B*1]

        # metrics
        classification_flag_np, visual_preds_np, visual_labels_np, distance_batch_level, over_preds_batch_level, \
        correct_preds_batch_level, under_preds_batch_level = \
            metrics.metric_batch_level(preds_tensor, micro_calcification_number_label_tensor)
        pred_num_epoch_level += preds_tensor.shape[0]
        distance_epoch_level += distance_batch_level
        over_pred_epoch_level += over_preds_batch_level
        correct_pred_epoch_level += correct_preds_batch_level
        under_pred_epoch_level += under_preds_batch_level

        # print logging information
        logger.write_and_print(
            'The number of the over predicted patches of this batch = {}'.format(over_preds_batch_level))
        logger.write_and_print(
            'The number of the correct predicted patches of this batch = {}'.format(correct_preds_batch_level))
        logger.write_and_print(
            'The number of the under predicted patches of this batch = {}'.format(under_preds_batch_level))
        logger.write_and_print('The value of the MSE of this batch = {}'.format(distance_batch_level))
        logger.write_and_print('batch: {}, batch_size: {}, consuming time: {:.4f}s'.format(batch_idx, args.batch_size,
                                                                                           time() - start_time_for_batch))
        logger.write_and_print('--------------------------------------------------------------------------------------')

        images_np = images_tensor.cpu().numpy()
        pixel_level_labels_np = pixel_level_labels_tensor.cpu().numpy()
        for patch_idx in range(images_tensor.shape[0]):
            image_np = images_np[patch_idx, 0, :, :]
            visual_pred_np = visual_preds_np[patch_idx, :, :]
            visual_label_np = visual_labels_np[patch_idx, :, :]
            pixel_level_label_np = pixel_level_labels_np[patch_idx, :, :]
            filename = filenames[patch_idx]
            classification_flag = classification_flag_np[patch_idx]

            assert image_np.shape == pixel_level_label_np.shape
            assert len(image_np.shape) == 2

            image_np *= 255
            image_np = image_np.astype(np.uint8)

            pixel_level_label_np *= 255
            pixel_level_label_np = pixel_level_label_np.astype(np.uint8)

            flag_2_dir_mapping = {0: 'over_preds', 1: 'correct_preds', 2: 'under_preds'}
            saving_dir_of_this_patch = os.path.join(visualization_saving_dir, flag_2_dir_mapping[classification_flag])
            cv2.imwrite(os.path.join(saving_dir_of_this_patch, filename.replace('.png', '_image.png')), image_np)
            cv2.imwrite(os.path.join(saving_dir_of_this_patch, filename.replace('.png', '_pixel_level_label.png')),
                        pixel_level_label_np)
            cv2.imwrite(os.path.join(saving_dir_of_this_patch, filename.replace('.png', '_mask_num.png')),
                        visual_label_np)
            cv2.imwrite(os.path.join(saving_dir_of_this_patch, filename.replace('.png', '_pred_num.png')),
                        visual_pred_np)

    # print logging information
    logger.write_and_print('##########################################################################################')
    logger.write_and_print('The number of the patches of this dataset = {}'.format(pred_num_epoch_level))
    logger.write_and_print(
        'The number of the over predicted patches of this dataset = {}'.format(over_pred_epoch_level))
    logger.write_and_print(
        'The number of the correct predicted patches of this dataset = {}'.format(correct_pred_epoch_level))
    logger.write_and_print(
        'The number of the under predicted patches of this dataset = {}'.format(under_pred_epoch_level))
    logger.write_and_print(
        'The value of the MSE of this dataset = {}'.format(distance_epoch_level / pred_num_epoch_level))
    logger.write_and_print('consuming time: {:.4f}s'.format(time() - start_time_for_epoch))
    logger.write_and_print('##########################################################################################')

    return
def TestMicroCalcificationReconstruction(args):
    prediction_saving_dir = os.path.join(
        args.model_saving_dir,
        'reconstruction_results_dataset_{}_epoch_{}'.format(
            args.dataset_type, args.epoch_idx))
    visualization_saving_dir = os.path.join(prediction_saving_dir,
                                            'qualitative_results')
    visualization_TP_saving_dir = os.path.join(visualization_saving_dir,
                                               'TPs_only')
    visualization_FP_saving_dir = os.path.join(visualization_saving_dir,
                                               'FPs_only')
    visualization_FN_saving_dir = os.path.join(visualization_saving_dir,
                                               'FNs_only')
    visualization_FP_FN_saving_dir = os.path.join(visualization_saving_dir,
                                                  'FPs_FNs_both')

    # remove existing dir which has the same name and create clean dir
    if os.path.exists(prediction_saving_dir):
        shutil.rmtree(prediction_saving_dir)
    os.mkdir(prediction_saving_dir)
    os.mkdir(visualization_saving_dir)
    os.mkdir(visualization_TP_saving_dir)
    os.mkdir(visualization_FP_saving_dir)
    os.mkdir(visualization_FN_saving_dir)
    os.mkdir(visualization_FP_FN_saving_dir)

    # initialize logger
    logger = Logger(prediction_saving_dir, 'quantitative_results.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))

    # define the network
    network = VNet2d(num_in_channels=cfg.net.in_channels,
                     num_out_channels=cfg.net.out_channels)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir,
                                 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [
            best_ckpt_filename for best_ckpt_filename in saved_ckpt_list
            if 'net_best_on_validation_set' in best_ckpt_filename
        ][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = copy.deepcopy(network)
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    logger.write_and_print(
        'Load ckpt: {0} for evaluating...'.format(ckpt_path))

    # get calculate_uncertainty global variance
    calculate_uncertainty = True if len(args.mc_epoch_indexes) > 0 else False

    # get net list for imitating MC dropout process
    net_list = None
    if calculate_uncertainty:
        net_list = get_net_list(network, ckpt_dir, args.mc_epoch_indexes,
                                logger)

    # create dataset
    dataset = MicroCalcificationDataset(
        data_root_dir=args.data_root_dir,
        mode=args.dataset_type,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=args.dilation_radius,
        load_uncertainty_map=False,
        calculate_micro_calcification_number=cfg.dataset.
        calculate_micro_calcification_number,
        enable_data_augmentation=False)

    # create data loader
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=cfg.train.num_threads)

    metrics = MetricsReconstruction(args.prob_threshold, args.area_threshold,
                                    args.distance_threshold,
                                    args.slack_for_recall)

    calcification_num = 0
    recall_num = 0
    FP_num = 0

    for batch_idx, (images_tensor, pixel_level_labels_tensor,
                    pixel_level_labels_dilated_tensor, _,
                    image_level_labels_tensor, _,
                    filenames) in enumerate(data_loader):
        logger.write_and_print('Evaluating batch: {}'.format(batch_idx))

        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # network forward
        reconstructed_images_tensor, prediction_residues_tensor = net(
            images_tensor)

        # MC dropout
        uncertainty_maps_np = generate_uncertainty_maps(
            net_list, images_tensor) if calculate_uncertainty else None

        # evaluation
        post_process_preds_np, calcification_num_batch_level, recall_num_batch_level, FP_num_batch_level, \
        result_flag_list = metrics.metric_batch_level(prediction_residues_tensor, pixel_level_labels_tensor)

        calcification_num += calcification_num_batch_level
        recall_num += recall_num_batch_level
        FP_num += FP_num_batch_level

        # print logging information
        logger.write_and_print(
            'The number of the annotated calcifications of this batch = {}'.
            format(calcification_num_batch_level))
        logger.write_and_print(
            'The number of the recalled calcifications of this batch = {}'.
            format(recall_num_batch_level))
        logger.write_and_print(
            'The number of the false positive calcifications of this batch = {}'
            .format(FP_num_batch_level))
        logger.write_and_print(
            'Consuming time: {:.4f}s'.format(time() - start_time_for_batch))
        logger.write_and_print(
            '--------------------------------------------------------------------------------------'
        )

        save_tensor_in_png_and_nii_format(images_tensor,
                                          reconstructed_images_tensor,
                                          prediction_residues_tensor,
                                          post_process_preds_np,
                                          pixel_level_labels_tensor,
                                          pixel_level_labels_dilated_tensor,
                                          uncertainty_maps_np,
                                          filenames,
                                          result_flag_list,
                                          visualization_saving_dir,
                                          save_nii=args.save_nii)

        logger.flush()

    logger.write_and_print(
        'The number of the annotated calcifications of this dataset = {}'.
        format(calcification_num))
    logger.write_and_print(
        'The number of the recalled calcifications of this dataset = {}'.
        format(recall_num))
    logger.write_and_print(
        'The number of the false positive calcifications of this dataset = {}'.
        format(FP_num))

    return
def RdsidueDistributionSTA(args):
    # define the network
    net = VNet2d(num_in_channels=cfg.net.in_channels, num_out_channels=cfg.net.out_channels)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir, 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [best_ckpt_filename for best_ckpt_filename in saved_ckpt_list if
                              'net_best_on_validation_set' in best_ckpt_filename][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    # create dataset
    dataset = MicroCalcificationDataset(data_root_dir=args.data_root_dir,
                                        mode=args.dataset_type,
                                        enable_random_sampling=False,
                                        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
                                        image_channels=cfg.dataset.image_channels,
                                        cropping_size=cfg.dataset.cropping_size,
                                        dilation_radius=args.dilation_radius,
                                        calculate_micro_calcification_number=cfg.dataset.calculate_micro_calcification_number,
                                        enable_data_augmentation=False)

    # create data loader
    data_loader = DataLoader(dataset, batch_size=args.batch_size,
                             shuffle=False, num_workers=cfg.train.num_threads)

    metrics = MetricsReconstruction(args.prob_threshold, args.area_threshold, args.distance_threshold,
                                    args.slack_for_recall)

    residue_in_dataset = np.zeros(args.histogram_bins)
    mask_positive_residue_in_dataset = np.zeros(args.histogram_bins)
    mask_negative_residue_in_dataset = np.zeros(args.histogram_bins)
    recon_positive_residue_in_dataset = np.zeros(args.histogram_bins)
    recon_negative_residue_in_dataset = np.zeros(args.histogram_bins)

    for batch_idx, (images_tensor, pixel_level_labels_tensor, pixel_level_labels_dilated_tensor,
                    image_level_labels_tensor, _, filenames) in enumerate(data_loader):
        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # network forward
        reconstructed_images_tensor, prediction_residues_tensor = net(images_tensor)

        # evaluation

        post_process_preds_np, calcification_num_batch_level, recall_num_batch_level, FP_num_batch_level = metrics.metric_batch_level(
            prediction_residues_tensor, pixel_level_labels_tensor)

        # dilated label  , predict label
        pixel_level_labels_dilated = pixel_level_labels_dilated_tensor.cpu().view(-1).numpy()
        process_preds = post_process_preds_np.reshape(-1)

        residues = prediction_residues_tensor.cpu().view(-1).detach().numpy()
        residues_hist, _ = np.histogram(residues, bins=args.histogram_bins, range=(0, 1))
        residue_in_dataset += residues_hist

        assert residues.shape == pixel_level_labels_dilated.shape
        assert residues.shape == process_preds.shape

        mask_positive_residue = residues[pixel_level_labels_dilated == 1]
        mask_positive_residue_hist, _ = np.histogram(mask_positive_residue, bins=args.histogram_bins, range=(0, 1))
        mask_positive_residue_in_dataset += mask_positive_residue_hist

        mask_negative_residue = residues[pixel_level_labels_dilated == 0]
        mask_negative_residue_hist, _ = np.histogram(mask_negative_residue, bins=args.histogram_bins, range=(0, 1))
        mask_negative_residue_in_dataset += mask_negative_residue_hist

        process_positive_residue = residues[process_preds == 1]
        process_positive_residue_hist, _ = np.histogram(process_positive_residue, bins=args.histogram_bins,
                                                        range=(0, 1))
        recon_positive_residue_in_dataset += process_positive_residue_hist

        process_negative_residue = residues[process_preds == 0]
        process_negative_residue_hist, _ = np.histogram(process_negative_residue, bins=args.histogram_bins,
                                                        range=(0, 1))
        recon_negative_residue_in_dataset += process_negative_residue_hist

    residue_in_dataset[residue_in_dataset > 15000] = 15000
    mask_negative_residue_in_dataset[mask_negative_residue_in_dataset > 15000] = 15000
    recon_negative_residue_in_dataset[recon_negative_residue_in_dataset > 15000] = 15000
    pltsave(residue_in_dataset, args.data_save_dir, 'total residues')
    pltsave(mask_positive_residue_in_dataset, args.data_save_dir, 'mask positive residues')
    pltsave(mask_negative_residue_in_dataset, args.data_save_dir, 'mask negative residues')
    pltsave(recon_positive_residue_in_dataset, args.data_save_dir, 'predict positive residues')
    pltsave(recon_negative_residue_in_dataset, args.data_save_dir, 'predict negative residues')

    print('on dataset {0} with {1} dilation {2} histogram bins'.format(args.dataset_type, args.dilation_radius,
                                                                       args.histogram_bins))
    print('the whole residues distribution is {}'.format(np.around(residue_in_dataset, 3)))
    print('in dilated mask label, the positive residues distribution is {0}\n'
          'the negative residues distribution is {1}.'.format(np.around(mask_positive_residue_hist, 3),
                                                              np.around(mask_negative_residue_hist, 3)))
    print('in predicted label, the positive residues distribution is {0}\n'
          'the negative residues distribution is {1}'.format(np.around(recon_positive_residue_in_dataset, 3),
                                                             np.around(recon_negative_residue_in_dataset, 3)))

    return
Beispiel #4
0
    metrics = MetricsImageLEvelQuantityRegression(
        image_size=cfg.dataset.cropping_size)

    # setup Visualizer
    visdom_display_name = cfg.general.saving_dir.split('/')[-2]
    visdom_obj = visdom.Visdom(env=visdom_display_name, port=cfg.visdom.port)

    # create dataset and data loader for training
    training_dataset = MicroCalcificationDataset(
        data_root_dir=cfg.general.data_root_dir,
        mode='training',
        enable_random_sampling=cfg.dataset.enable_random_sampling,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=cfg.dataset.dilation_radius,
        load_uncertainty_map=cfg.dataset.load_uncertainty_map,
        calculate_micro_calcification_number=cfg.dataset.
        calculate_micro_calcification_number,
        enable_data_augmentation=cfg.dataset.augmentation.
        enable_data_augmentation,
        enable_vertical_flip=cfg.dataset.augmentation.enable_vertical_flip,
        enable_horizontal_flip=cfg.dataset.augmentation.enable_horizontal_flip)

    training_data_loader = DataLoader(training_dataset,
                                      batch_size=cfg.train.batch_size,
                                      shuffle=True,
                                      num_workers=cfg.train.num_threads)

    # create dataset and data loader for validation
    validation_dataset = MicroCalcificationDataset(
Beispiel #5
0
def TestUncertaintyMapLabelWeightsGeneration(args):
    positive_patch_results_saving_dir = os.path.join(args.dst_data_root_dir,
                                                     'positive_patches',
                                                     args.dataset_type,
                                                     'uncertainty-maps')
    negative_patch_results_saving_dir = os.path.join(args.dst_data_root_dir,
                                                     'negative_patches',
                                                     args.dataset_type,
                                                     'uncertainty-maps')

    # create dir when it does not exist
    if not os.path.exists(positive_patch_results_saving_dir):
        os.mkdir(positive_patch_results_saving_dir)
    if not os.path.exists(negative_patch_results_saving_dir):
        os.mkdir(negative_patch_results_saving_dir)

    # initialize logger
    logger = Logger(args.src_data_root_dir, 'uncertainty.txt')
    logger.write_and_print('Dataset: {}'.format(args.src_data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))

    # define the network
    network = VNet2d(num_in_channels=cfg.net.in_channels,
                     num_out_channels=cfg.net.out_channels)

    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')

    # get net list for imitating MC dropout process
    net_list = get_net_list(network, ckpt_dir, args.mc_epoch_indexes, logger)

    # create dataset
    dataset = MicroCalcificationDataset(
        data_root_dir=args.src_data_root_dir,
        mode=args.dataset_type,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=0,
        load_uncertainty_map=False,
        calculate_micro_calcification_number=False,
        enable_data_augmentation=False)

    # create data loader
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=cfg.train.num_threads)

    for batch_idx, (images_tensor, _, _, _, _, _,
                    filenames) in enumerate(data_loader):
        logger.write_and_print('Evaluating batch: {}'.format(batch_idx))

        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # imitating MC dropout
        uncertainty_maps_np = generate_uncertainty_maps(
            net_list, images_tensor)
        save_uncertainty_maps(uncertainty_maps_np, filenames,
                              positive_patch_results_saving_dir,
                              negative_patch_results_saving_dir, logger)

        logger.write_and_print(
            'Finished evaluating, consuming time = {:.4f}s'.format(
                time() - start_time_for_batch))
        logger.write_and_print(
            '--------------------------------------------------------------------------------------'
        )

        logger.flush()

    return
def UncertaintySTA(args):
    prediction_saving_dir = os.path.join(args.model_saving_dir,
                                         'reconstruction_results_dataset_{}_epoch_{}'.format(args.dataset_type,
                                                                                             args.epoch_idx))
    # initialize logger
    if os.path.exists(args.sta_save_dir):
        shutil.rmtree(args.sta_save_dir)
    os.mkdir(args.sta_save_dir)
    logger = Logger(args.sta_save_dir, 'uncertainty_distribution_sta.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))
    # define the network
    network = VNet2d(num_in_channels=cfg.net.in_channels, num_out_channels=cfg.net.out_channels)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir, 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [best_ckpt_filename for best_ckpt_filename in saved_ckpt_list if
                              'net_best_on_validation_set' in best_ckpt_filename][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = copy.deepcopy(network)
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    # get calculate_uncertainty global variance
    calculate_uncertainty = True if len(args.mc_epoch_indexes) > 0 else False

    # get net list for imitating MC dropout process
    net_list = None
    if calculate_uncertainty:
        net_list = get_net_list(network, ckpt_dir, args.mc_epoch_indexes, logger)

    # create dataset

    dataset = MicroCalcificationDataset(data_root_dir=args.data_root_dir,
                                        mode=args.dataset_type,
                                        enable_random_sampling=False,
                                        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
                                        image_channels=cfg.dataset.image_channels,
                                        cropping_size=cfg.dataset.cropping_size,
                                        dilation_radius=args.dilation_radius,
                                        load_uncertainty_map=False,
                                        calculate_micro_calcification_number=cfg.dataset.calculate_micro_calcification_number,
                                        enable_data_augmentation=False)

    # create data loader
    data_loader = DataLoader(dataset, batch_size=args.batch_size,
                             shuffle=False, num_workers=cfg.train.num_threads)

    metrics = MetricsReconstruction(args.prob_threshold, args.area_threshold, args.distance_threshold,
                                    args.slack_for_recall)

    all_positive_uncertainty_in_dataset = np.zeros(args.bins)
    tp_uncertainty_in_dataset = np.zeros(args.bins)
    fn_uncertainty_in_dataset = np.zeros(args.bins)
    fp_uncertainty_in_dataset = np.zeros(args.bins)
    uncertainty_max=0

    for batch_idx, (images_tensor, pixel_level_labels_tensor, pixel_level_labels_dilated_tensor, _,
                    image_level_labels_tensor, _, filenames) in enumerate(data_loader):
        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # network forward
        reconstructed_images_tensor, prediction_residues_tensor = net(images_tensor)

        # MC dropout
        uncertainty_maps_np = generate_uncertainty_maps(net_list, images_tensor) if calculate_uncertainty else None

        # in tp, fn label area  uncertainty value distribution
        post_process_preds_np, calcification_num_batch_level, recall_num_batch_level, FP_num_batch_level, \
        result_flag_list = metrics.metric_batch_level(prediction_residues_tensor, pixel_level_labels_tensor)

        pixel_level_labels_dilated = pixel_level_labels_dilated_tensor.view(-1).numpy()
        preds_positive = post_process_preds_np.reshape(-1)
        uncertainty_maps = uncertainty_maps_np.reshape(-1)
        if uncertainty_max< np.amax(uncertainty_maps):
            uncertainty_max = np.amax(uncertainty_maps)

        all_positive_uncertainty_batch = uncertainty_maps[pixel_level_labels_dilated == 1]
        all_positive_uncertainty_distr_batch, _ = np.histogram(all_positive_uncertainty_batch,
                                                               bins=args.bins, range=(0, args.bin_range))
        all_positive_uncertainty_in_dataset += all_positive_uncertainty_distr_batch

        pixel_level_unlabels_dilated = np.subtract(np.ones_like(pixel_level_labels_dilated), pixel_level_labels_dilated)
        fp_location = np.multiply(pixel_level_unlabels_dilated, preds_positive)
        tp_location = np.multiply(pixel_level_labels_dilated, preds_positive)
        fn_location = np.zeros_like(preds_positive)
        fn_location[pixel_level_labels_dilated == 1] = 1
        fn_location[preds_positive == 1] = 0

        tp_uncertainty_batch = uncertainty_maps[tp_location == 1]
        tp_uncertainty_distr_batch, _ = np.histogram(tp_uncertainty_batch, bins=args.bins, range=(0, args.bin_range))
        tp_uncertainty_in_dataset += tp_uncertainty_distr_batch

        fn_uncertainty_batch = uncertainty_maps[fn_location == 1]
        fn_uncertainty_distr_batch, _ = np.histogram(fn_uncertainty_batch, bins=args.bins, range=(0, args.bin_range))
        fn_uncertainty_in_dataset += fn_uncertainty_distr_batch

        fp_uncertainty_batch = uncertainty_maps[fp_location == 1]
        fp_uncertainty_distr_batch, _ = np.histogram(fp_uncertainty_batch, bins=args.bins, range=(0, args.bin_range))
        fp_uncertainty_in_dataset += fp_uncertainty_distr_batch

    # debug only
    # print(all_positive_uncertainty_in_dataset[0:5])
    # print(tp_uncertainty_in_dataset[0:5])
    # print(fn_uncertainty_in_dataset[0:5])

    all_positive_uncertainty_in_dataset[all_positive_uncertainty_in_dataset > 1000] = 1000
    tp_uncertainty_in_dataset[tp_uncertainty_in_dataset > 1000] = 1000
    fn_uncertainty_in_dataset[fn_uncertainty_in_dataset > 1000] = 1000
    fp_uncertainty_in_dataset[fp_uncertainty_in_dataset > 1000] = 1000

    pltsave(all_positive_uncertainty_in_dataset, dir=args.sta_save_dir, name='all positive uncertainty')
    pltsave(tp_uncertainty_in_dataset, dir=args.sta_save_dir, name='True Positive uncertainty')
    pltsave(fn_uncertainty_in_dataset, dir=args.sta_save_dir, name='False Negative uncertainty')
    pltsave(fp_uncertainty_in_dataset, dir=args.sta_save_dir, name='False Positive uncertainty')

    fp_uncertainty_in_dataset_filtered = gaussian_filter1d(fp_uncertainty_in_dataset, sigma=3)
    tp_uncertainty_in_dataset_filtered = gaussian_filter1d(tp_uncertainty_in_dataset, sigma=3)
    fn_uncertainty_in_dataset_filtered = gaussian_filter1d(fn_uncertainty_in_dataset, sigma=3)
    fp_and_fn = fp_uncertainty_in_dataset_filtered + fn_uncertainty_in_dataset_filtered

    pltsave(fp_and_fn, dir=args.sta_save_dir, name='FP & FN uncertainty filtered')
    pltsave(tp_uncertainty_in_dataset_filtered, dir=args.sta_save_dir, name='True Positive uncertainty filtered')
    pltsave(fn_uncertainty_in_dataset_filtered, dir=args.sta_save_dir, name='False Negative uncertainty filtered')
    pltsave(fp_uncertainty_in_dataset_filtered, dir=args.sta_save_dir, name='False Positive uncertainty filtered')

    fp_mean, fp_var = find_mean_and_var(fp_uncertainty_in_dataset_filtered, start=args.bins / 10, end=args.bins,
                                        bins=args.bins, bin_range=args.bin_range)
    tp_mean, tp_var = find_mean_and_var(tp_uncertainty_in_dataset_filtered, start=args.bins / 10, end=args.bins,
                                        bins=args.bins,
                                        bin_range=args.bin_range)
    fn_mean, fn_var = find_mean_and_var(fn_uncertainty_in_dataset_filtered, start=args.bins / 10, end=args.bins,
                                        bins=args.bins,
                                        bin_range=args.bin_range)
    logger.write_and_print('max uncertainty is {0}  '.format(uncertainty_max))
    logger.write_and_print('fp uncertainty mean is {0}  variance is {1}'.format(fp_mean, fp_var))
    logger.write_and_print('tp uncertainty mean is {0}  variance is {1}'.format(tp_mean, tp_var))
    logger.write_and_print('fn uncertainty mean is {0}  variance is {1}'.format(fn_mean, fn_var))

    return
def TestMicroCalcificationImageLevelClassification(args):
    start_time_for_epoch = time()

    prediction_saving_dir = os.path.join(
        args.model_saving_dir,
        'image_level_classification_results_dataset_{}_epoch_{}'.format(
            args.dataset_type, args.epoch_idx))
    visualization_saving_dir = os.path.join(prediction_saving_dir,
                                            'qualitative_results')

    TPs_saving_dir = os.path.join(visualization_saving_dir, 'TPs')
    TNs_saving_dir = os.path.join(visualization_saving_dir, 'TNs')
    FPs_saving_dir = os.path.join(visualization_saving_dir, 'FPs')
    FNs_saving_dir = os.path.join(visualization_saving_dir, 'FNs')

    # remove existing dir which has the same name and create clean dir
    if os.path.exists(prediction_saving_dir):
        shutil.rmtree(prediction_saving_dir)
    os.mkdir(prediction_saving_dir)
    os.mkdir(visualization_saving_dir)
    os.mkdir(TPs_saving_dir)
    os.mkdir(TNs_saving_dir)
    os.mkdir(FPs_saving_dir)
    os.mkdir(FNs_saving_dir)

    # initialize logger
    logger = Logger(prediction_saving_dir, 'quantitative_results.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))

    # define the network
    net = ResNet18(in_channels=cfg.net.in_channels,
                   num_classes=cfg.net.num_classes)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir,
                                 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [
            best_ckpt_filename for best_ckpt_filename in saved_ckpt_list
            if 'net_best_on_validation_set' in best_ckpt_filename
        ][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    logger.write_and_print('Load ckpt: {0}...'.format(ckpt_path))

    # create dataset and data loader
    dataset = MicroCalcificationDataset(
        data_root_dir=args.data_root_dir,
        mode=args.dataset_type,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=cfg.dataset.dilation_radius,
        load_uncertainty_map=False,
        calculate_micro_calcification_number=cfg.dataset.
        calculate_micro_calcification_number,
        enable_data_augmentation=False)

    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=cfg.train.num_threads)

    metrics = MetricsImageLevelClassification(cfg.dataset.cropping_size)

    TPs_epoch_level = 0
    TNs_epoch_level = 0
    FPs_epoch_level = 0
    FNs_epoch_level = 0

    for batch_idx, (images_tensor, pixel_level_labels_tensor, _, _,
                    image_level_labels_tensor, _,
                    filenames) in enumerate(data_loader):
        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # reshape the label to meet the requirement of CrossEntropy
        image_level_labels_tensor = image_level_labels_tensor.view(
            -1)  # [B, C] -> [B]

        # network forward
        preds_tensor = net(images_tensor)

        # evaluation
        _, classification_flag_np, TPs_batch_level, TNs_batch_level, FPs_batch_level, FNs_batch_level = \
            metrics.metric_batch_level(preds_tensor, image_level_labels_tensor)

        TPs_epoch_level += TPs_batch_level
        TNs_epoch_level += TNs_batch_level
        FPs_epoch_level += FPs_batch_level
        FNs_epoch_level += FNs_batch_level

        # print logging information
        logger.write_and_print(
            'The number of the TPs of this batch = {}'.format(TPs_batch_level))
        logger.write_and_print(
            'The number of the TNs of this batch = {}'.format(TNs_batch_level))
        logger.write_and_print(
            'The number of the FPs of this batch = {}'.format(FPs_batch_level))
        logger.write_and_print(
            'The number of the FNs of this batch = {}'.format(FNs_batch_level))
        logger.write_and_print(
            'batch: {}, batch_size: {}, consuming time: {:.4f}s'.format(
                batch_idx, args.batch_size,
                time() - start_time_for_batch))
        logger.write_and_print(
            '--------------------------------------------------------------------------------------'
        )

        images_np = images_tensor.cpu().numpy()
        pixel_level_labels_np = pixel_level_labels_tensor.numpy()
        for patch_idx in range(images_tensor.shape[0]):
            image_np = images_np[patch_idx, 0, :, :]
            pixel_level_label_np = pixel_level_labels_np[patch_idx, :, :]
            filename = filenames[patch_idx]
            classification_flag = classification_flag_np[patch_idx]

            assert image_np.shape == pixel_level_label_np.shape
            assert len(image_np.shape) == 2

            image_np *= 255
            image_np = image_np.astype(np.uint8)

            pixel_level_label_np *= 255
            pixel_level_label_np = pixel_level_label_np.astype(np.uint8)

            flag_2_dir_mapping = {0: 'TPs', 1: 'TNs', 2: 'FPs', 3: 'FNs'}
            saving_dir_of_this_patch = os.path.join(
                visualization_saving_dir,
                flag_2_dir_mapping[classification_flag])

            cv2.imwrite(
                os.path.join(saving_dir_of_this_patch,
                             filename.replace('.png', '_image.png')), image_np)
            cv2.imwrite(
                os.path.join(
                    saving_dir_of_this_patch,
                    filename.replace('.png', '_pixel_level_label.png')),
                pixel_level_label_np)

            if args.enable_CAM:
                result = generateCAM(net, image_np, "layer3")
                cv2.imwrite(
                    os.path.join(saving_dir_of_this_patch,
                                 filename.replace('.png', '_cam.png')), result)

    # print logging information
    logger.write_and_print(
        '##########################################################################################'
    )
    logger.write_and_print(
        'The number of the TPs of this dataset = {}'.format(TPs_epoch_level))
    logger.write_and_print(
        'The number of the TNs of this dataset = {}'.format(TNs_epoch_level))
    logger.write_and_print(
        'The number of the FPs of this dataset = {}'.format(FPs_epoch_level))
    logger.write_and_print(
        'The number of the FNs of this dataset = {}'.format(FNs_epoch_level))
    logger.write_and_print(
        'consuming time: {:.4f}s'.format(time() - start_time_for_epoch))
    logger.write_and_print(
        '##########################################################################################'
    )

    return
def MicroCalcificationPatchLevelDatasetTest(args):
    # create dataset
    dataset = MicroCalcificationDataset(
        data_root_dir=cfg.general.data_root_dir,
        mode=args.mode,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=cfg.dataset.dilation_radius,
        load_uncertainty_map=args.load_uncertainty_map,
        calculate_micro_calcification_number=cfg.dataset.
        calculate_micro_calcification_number,
        enable_data_augmentation=cfg.dataset.augmentation.
        enable_data_augmentation,
        enable_vertical_flip=cfg.dataset.augmentation.enable_vertical_flip,
        enable_horizontal_flip=cfg.dataset.augmentation.enable_horizontal_flip)

    # create data loader for training
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # enumerating
    for epoch_idx in range(args.num_epoch):
        # create folder for this epoch
        output_dir_epoch = os.path.join(args.dst_data_root_dir,
                                        'epoch_{0}'.format(epoch_idx))
        os.mkdir(output_dir_epoch)

        print(
            '-------------------------------------------------------------------------------------------------------'
        )
        print('Loading epoch {0}...'.format(epoch_idx))

        # the following two variables are used for counting positive and negative patch number in an epoch
        positive_patch_num_for_this_epoch = 0
        negative_patch_num_for_this_epoch = 0

        for batch_idx, (images_tensor, pixel_level_labels_tensor,
                        pixel_level_labels_dilated_tensor,
                        uncertainty_maps_tensor, image_level_labels_tensor, _,
                        filenames) in enumerate(data_loader):
            # create folder for this batch
            output_dir_batch = os.path.join(output_dir_epoch,
                                            'batch_{0}'.format(batch_idx))
            os.mkdir(output_dir_batch)

            # create folder for saving positive patches
            output_dir_positive = os.path.join(output_dir_batch, 'positive')
            os.mkdir(output_dir_positive)

            # create folder for saving negative patches
            output_dir_negative = os.path.join(output_dir_batch, 'negative')
            os.mkdir(output_dir_negative)

            # the following two variables are used for counting positive and negative patch number in a batch
            positive_patch_num_for_this_batch = 0
            negative_patch_num_for_this_batch = 0

            images_np = images_tensor.cpu().numpy()
            pixel_level_labels_np = pixel_level_labels_tensor.cpu().numpy()
            pixel_level_labels_dilated_np = pixel_level_labels_dilated_tensor.cpu(
            ).numpy()
            uncertainty_maps_np = uncertainty_maps_tensor.cpu().numpy()
            image_level_labels_np = image_level_labels_tensor.cpu().numpy()

            for image_idx in range(images_np.shape[0]):
                image_np = images_np[image_idx, 0, :, :]
                pixel_level_label_np = pixel_level_labels_np[image_idx, :, :]
                pixel_level_label_dilated_np = pixel_level_labels_dilated_np[
                    image_idx, :, :]
                uncertainty_map_np = uncertainty_maps_np[image_idx, :, :]
                image_level_label = image_level_labels_np[image_idx, 0]
                filename = filenames[image_idx]

                image_np *= 255
                image_np = image_np.astype(np.uint8)

                pixel_level_label_np *= 255
                pixel_level_label_np = pixel_level_label_np.astype(np.uint8)

                pixel_level_label_dilated_np *= 255
                pixel_level_label_dilated_np = pixel_level_label_dilated_np.astype(
                    np.uint8)

                uncertainty_map_np *= 255
                uncertainty_map_np = uncertainty_map_np.astype(np.uint8)
                uncertainty_map_np = cv2.applyColorMap(uncertainty_map_np,
                                                       cv2.COLORMAP_JET)

                # image_level_label is either 0 or 1
                assert image_level_label in [0, 1]

                if image_level_label == 1:
                    cv2.imwrite(os.path.join(output_dir_positive, filename),
                                image_np)
                    cv2.imwrite(
                        os.path.join(output_dir_positive,
                                     filename.replace('.png', '_mask.png')),
                        pixel_level_label_np)
                    cv2.imwrite(
                        os.path.join(
                            output_dir_positive,
                            filename.replace('.png', '_dilated_mask.png')),
                        pixel_level_label_dilated_np)
                    cv2.imwrite(
                        os.path.join(
                            output_dir_positive,
                            filename.replace('.png', '_uncertainty_map.png')),
                        uncertainty_map_np)
                    positive_patch_num_for_this_epoch += 1
                    positive_patch_num_for_this_batch += 1
                elif image_level_label == 0:
                    cv2.imwrite(os.path.join(output_dir_negative, filename),
                                image_np)
                    cv2.imwrite(
                        os.path.join(output_dir_negative,
                                     filename.replace('.png', '_mask.png')),
                        pixel_level_label_np)
                    cv2.imwrite(
                        os.path.join(
                            output_dir_negative,
                            filename.replace('.png', '_dilated_mask.png')),
                        pixel_level_label_dilated_np)
                    cv2.imwrite(
                        os.path.join(
                            output_dir_negative,
                            filename.replace('.png', '_uncertainty_map.png')),
                        uncertainty_map_np)
                    negative_patch_num_for_this_epoch += 1
                    negative_patch_num_for_this_batch += 1

            print('----batch {0} loading finished; '
                  'positive patches: {1}, negative patches: {2}'.format(
                      batch_idx, positive_patch_num_for_this_batch,
                      negative_patch_num_for_this_batch))

        print('epoch {0} loading finished; '
              'positive patches: {1}, negative patches: {2}'.format(
                  epoch_idx, positive_patch_num_for_this_epoch,
                  negative_patch_num_for_this_epoch))

    return
Beispiel #9
0
def TestMicroCalcificationDetectionPatchLevel(args):
    visualization_saving_dir = os.path.join(args.prediction_saving_dir, 'qualitative_results')

    # remove existing dir which has the same name and create clean dir
    if os.path.exists(args.prediction_saving_dir):
        shutil.rmtree(args.prediction_saving_dir)
    os.mkdir(args.prediction_saving_dir)
    os.mkdir(visualization_saving_dir)

    # initialize logger
    logger = Logger(args.prediction_saving_dir, 'quantitative_results.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))
    logger.write_and_print('Reconstruction model saving dir: {}'.format(args.reconstruction_model_saving_dir))
    logger.write_and_print('Reconstruction ckpt index: {}'.format(args.reconstruction_epoch_idx))
    logger.write_and_print('Classification model saving dir: {}'.format(args.classification_model_saving_dir))
    logger.write_and_print('Classification ckpt index: {}'.format(args.classification_epoch_idx))

    # define the reconstruction network
    reconstruction_net = VNet2d(num_in_channels=r_cfg.net.in_channels, num_out_channels=r_cfg.net.out_channels)
    #
    # get the reconstruction absolute ckpt path
    reconstruction_ckpt_path = get_ckpt_path(args.reconstruction_model_saving_dir, args.reconstruction_epoch_idx)
    #
    # load ckpt and transfer net into gpu devices
    reconstruction_net = torch.nn.DataParallel(reconstruction_net).cuda()
    reconstruction_net.load_state_dict(torch.load(reconstruction_ckpt_path))
    reconstruction_net = reconstruction_net.eval()
    #
    logger.write_and_print('Load ckpt: {0}...'.format(reconstruction_ckpt_path))

    # define the classification network
    classification_net = ResNet18(in_channels=c_cfg.net.in_channels, num_classes=c_cfg.net.num_classes)
    #
    # get the classification absolute ckpt path
    classification_ckpt_path = get_ckpt_path(args.classification_model_saving_dir, args.classification_epoch_idx)
    #
    # load ckpt and transfer net into gpu devices
    classification_net = torch.nn.DataParallel(classification_net).cuda()
    classification_net.load_state_dict(torch.load(classification_ckpt_path))
    classification_net = classification_net.eval()
    #
    logger.write_and_print('Load ckpt: {0}...'.format(classification_ckpt_path))

    # create dataset and data loader
    dataset = MicroCalcificationDataset(data_root_dir=args.data_root_dir,
                                        mode=args.dataset_type,
                                        enable_random_sampling=False,
                                        pos_to_neg_ratio=r_cfg.dataset.pos_to_neg_ratio,
                                        image_channels=r_cfg.dataset.image_channels,
                                        cropping_size=r_cfg.dataset.cropping_size,
                                        dilation_radius=args.dilation_radius,
                                        load_uncertainty_map=False,
                                        calculate_micro_calcification_number=r_cfg.dataset.calculate_micro_calcification_number,
                                        enable_data_augmentation=False)
    #
    data_loader = DataLoader(dataset, batch_size=args.batch_size,
                             shuffle=False, num_workers=r_cfg.train.num_threads)

    metrics = MetricsReconstruction(args.prob_threshold, args.area_threshold, args.distance_threshold)

    calcification_num = 0
    recall_num = 0
    FP_num = 0

    for batch_idx, (images_tensor, pixel_level_labels_tensor, pixel_level_labels_dilated_tensor, _,
                    image_level_labels_tensor, _, filenames) in enumerate(data_loader):
        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # reconstruction network forward
        reconstructed_images_tensor, prediction_residues_tensor = reconstruction_net(images_tensor)

        # classification network forward
        classification_preds_tensor = classification_net(images_tensor)

        # merge the reconstruction and the classification results
        detection_results_np = micro_calcification_detection_batch_level(prediction_residues_tensor,
                                                                         classification_preds_tensor)

        # evaluation
        post_process_preds_np, calcification_num_batch_level, recall_num_batch_level, FP_num_batch_level = \
            metrics.metric_batch_level(detection_results_np, pixel_level_labels_tensor)

        calcification_num += calcification_num_batch_level
        recall_num += recall_num_batch_level
        FP_num += FP_num_batch_level

        # print logging information
        logger.write_and_print(
            'The number of the annotated calcifications of this batch = {}'.format(calcification_num_batch_level))
        logger.write_and_print(
            'The number of the recalled calcifications of this batch = {}'.format(recall_num_batch_level))
        logger.write_and_print(
            'The number of the false positive calcifications of this batch = {}'.format(FP_num_batch_level))
        logger.write_and_print('batch: {}, consuming time: {:.4f}s'.format(batch_idx, time() - start_time_for_batch))
        logger.write_and_print('--------------------------------------------------------------------------------------')

        save_tensor_in_png_and_nii_format(images_tensor, reconstructed_images_tensor, prediction_residues_tensor,
                                          post_process_preds_np, pixel_level_labels_tensor,
                                          pixel_level_labels_dilated_tensor, filenames, visualization_saving_dir)

        logger.flush()

    logger.write_and_print('The number of the annotated calcifications of this dataset = {}'.format(calcification_num))
    logger.write_and_print('The number of the recalled calcifications of this dataset = {}'.format(recall_num))
    logger.write_and_print('The number of the false positive calcifications of this dataset = {}'.format(FP_num))

    return
def TestMicroCalcificationReconstruction(args):
    prediction_saving_dir = os.path.join(
        args.model_saving_dir,
        'pixel_level_classification_results_dataset_{}_epoch_{}'.format(
            args.dataset_type, args.epoch_idx))
    visualization_saving_dir = os.path.join(prediction_saving_dir,
                                            'qualitative_results')

    # remove existing dir which has the same name and create clean dir
    if os.path.exists(prediction_saving_dir):
        shutil.rmtree(prediction_saving_dir)
    os.mkdir(prediction_saving_dir)
    os.mkdir(visualization_saving_dir)

    # initialize logger
    logger = Logger(prediction_saving_dir, 'quantitative_results.txt')
    logger.write_and_print('Dataset: {}'.format(args.data_root_dir))
    logger.write_and_print('Dataset type: {}'.format(args.dataset_type))

    # define the network
    net = VNet2d(num_in_channels=cfg.net.in_channels,
                 num_out_channels=cfg.net.out_channels)

    # load the specified ckpt
    ckpt_dir = os.path.join(args.model_saving_dir, 'ckpt')
    # epoch_idx is specified -> load the specified ckpt
    if args.epoch_idx >= 0:
        ckpt_path = os.path.join(ckpt_dir,
                                 'net_epoch_{}.pth'.format(args.epoch_idx))
    # epoch_idx is not specified -> load the best ckpt
    else:
        saved_ckpt_list = os.listdir(ckpt_dir)
        best_ckpt_filename = [
            best_ckpt_filename for best_ckpt_filename in saved_ckpt_list
            if 'net_best_on_validation_set' in best_ckpt_filename
        ][0]
        ckpt_path = os.path.join(ckpt_dir, best_ckpt_filename)

    # transfer net into gpu devices
    net = torch.nn.DataParallel(net).cuda()
    net.load_state_dict(torch.load(ckpt_path))
    net = net.eval()

    logger.write_and_print('Load ckpt: {0}...'.format(ckpt_path))

    # create dataset and data loader
    dataset = MicroCalcificationDataset(
        data_root_dir=args.data_root_dir,
        mode=args.dataset_type,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=args.dilation_radius,
        load_uncertainty_map=False,
        calculate_micro_calcification_number=cfg.dataset.
        calculate_micro_calcification_number,
        enable_data_augmentation=False)
    #
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=cfg.train.num_threads)

    metrics = MetricsPixelLevelClassification(args.prob_threshold,
                                              args.area_threshold,
                                              args.distance_threshold)

    calcification_num = 0
    recall_num = 0
    FP_num = 0

    for batch_idx, (images_tensor, pixel_level_labels_tensor,
                    pixel_level_labels_dilated_tensor, _,
                    image_level_labels_tensor, _,
                    filenames) in enumerate(data_loader):
        # start time of this batch
        start_time_for_batch = time()

        # transfer the tensor into gpu device
        images_tensor = images_tensor.cuda()

        # network forward
        predictions_tensor = net(images_tensor)

        # extract the 1-st channel from classification results
        predictions_tensor = extract_classification_preds_channel(
            predictions_tensor, 1)

        # evaluation
        post_process_preds_np, calcification_num_batch_level, recall_num_batch_level, FP_num_batch_level = \
            metrics.metric_batch_level(predictions_tensor, pixel_level_labels_tensor)

        calcification_num += calcification_num_batch_level
        recall_num += recall_num_batch_level
        FP_num += FP_num_batch_level

        # print logging information
        logger.write_and_print(
            'The number of the annotated calcifications of this batch = {}'.
            format(calcification_num_batch_level))
        logger.write_and_print(
            'The number of the recalled calcifications of this batch = {}'.
            format(recall_num_batch_level))
        logger.write_and_print(
            'The number of the false positive calcifications of this batch = {}'
            .format(FP_num_batch_level))
        logger.write_and_print('batch: {}, consuming time: {:.4f}s'.format(
            batch_idx,
            time() - start_time_for_batch))
        logger.write_and_print(
            '--------------------------------------------------------------------------------------'
        )

        save_tensor_in_png_and_nii_format(images_tensor, predictions_tensor,
                                          post_process_preds_np,
                                          pixel_level_labels_tensor,
                                          pixel_level_labels_dilated_tensor,
                                          filenames, visualization_saving_dir)

        logger.flush()

    logger.write_and_print(
        'The number of the annotated calcifications of this dataset = {}'.
        format(calcification_num))
    logger.write_and_print(
        'The number of the recalled calcifications of this dataset = {}'.
        format(recall_num))
    logger.write_and_print(
        'The number of the false positive calcifications of this dataset = {}'.
        format(FP_num))

    return
def SoftPositiveSTA(args):
    dataset = MicroCalcificationDataset(
        data_root_dir=cfg.general.data_root_dir,
        mode=args.mode,
        enable_random_sampling=False,
        pos_to_neg_ratio=cfg.dataset.pos_to_neg_ratio,
        image_channels=cfg.dataset.image_channels,
        cropping_size=cfg.dataset.cropping_size,
        dilation_radius=cfg.dataset.dilation_radius,
        calculate_micro_calcification_number=cfg.dataset.
        calculate_micro_calcification_number,
        enable_data_augmentation=cfg.dataset.augmentation.
        enable_data_augmentation,
        enable_vertical_flip=cfg.dataset.augmentation.enable_vertical_flip,
        enable_horizontal_flip=cfg.dataset.augmentation.enable_horizontal_flip)
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers)
    # enumerating

    print(
        '-------------------------------------------------------------------------------------------------------'
    )
    print('Loading on {0} dataset...'.format(args.mode))

    # the following two variables are used for counting positive and negative patch number in an epoch
    positive_patch_num_total = 0
    negative_patch_num_total = 0
    pixels_total = 0
    label_pixels_total = 0
    dilated_label_pixel_total = 0

    for batch_idx, (images_tensor, pixel_level_labels_tensor,
                    pixel_level_labels_dilated_tensor,
                    image_level_labels_tensor, _,
                    filenames) in enumerate(data_loader):
        images_np = images_tensor.cpu().view(-1).numpy()
        pixel_level_labels_np = pixel_level_labels_tensor.cpu().view(
            -1).numpy()
        pixel_level_labels_dilated_np = pixel_level_labels_dilated_tensor.cpu(
        ).view(-1).numpy()
        image_level_labels_np = image_level_labels_tensor.cpu().view(
            -1).numpy()

        pixels_total += images_np.shape[0]
        positve_patches_num = sum(
            image_level_labels_np[image_level_labels_np == 1])
        positive_patch_num_total += positve_patches_num
        negative_patch_num = (image_level_labels_np.shape[0] -
                              positve_patches_num)
        negative_patch_num_total += negative_patch_num
        labels = pixel_level_labels_np
        labed_pixels = sum(labels[labels == 1])
        label_pixels_total += labed_pixels
        dilated_labels = pixel_level_labels_dilated_np
        dilated_label_pixels = sum(dilated_labels[dilated_labels == 1])
        dilated_label_pixel_total += dilated_label_pixels

        print(
            '----batch {0} loading finished; '
            'positive patches: {1}, negative patches: {2} '
            'labeled pixels number is {3}, dilated labeled pixels number is {4}, total pixels number is {5}'
            .format(batch_idx, positve_patches_num, negative_patch_num,
                    labed_pixels, dilated_label_pixels, images_np.shape[0]))

        print()

    soft_negative_ratio = (dilated_label_pixel_total - label_pixels_total) / (
        pixels_total - label_pixels_total)
    print('on dataset {0} with {1} dilation loading finished; \n'
          'total patches: {2}, positive patches: {3}, negative patches: {4}\n'
          'total pixels number is {5} \n'
          'masked label pixels number is {6}\n'
          'dilated pixels number is {7}\n'
          'soft label pixels number is {8}\n'
          'the ratio of soft positive and negative pixels is {9}'.format(
              args.mode, args.dilation_radius,
              positive_patch_num_total + negative_patch_num_total,
              positive_patch_num_total, negative_patch_num_total, pixels_total,
              label_pixels_total, dilated_label_pixel_total,
              dilated_label_pixel_total - label_pixels_total,
              soft_negative_ratio))

    return