def test_build_heatmap_no_exit():
    entries1 = [(2, 'a'), (4, 'b')]
    exits1 = [(6, 'a')]
    m1 = initialise_user_movements(1, entries1, exits1)
    heatmap_gen = HeatmapGenerator(1, m1)

    assert heatmap_gen.build_heat_map(100) == {'b': 1}
Exemple #2
0
def pre_processing(dict_images):
    MR = dict_images['MR']
    MR = np.clip(MR / 2048, a_min=0, a_max=1)
    Mask = dict_images['Mask']
    _, D, H, W = MR.shape

    heatmap_generator = HeatmapGenerator(image_size=(D, H, W),
                                         sigma=2.,
                                         scale_factor=1.,
                                         normalize=True,
                                         size_sigma_factor=8,
                                         sigma_scale_factor=2,
                                         dtype=np.float32)
    list_landmarks = dict_images['list_landmarks']

    index = random.randint(10, 18)
    while True in np.isnan(list_landmarks[index]):
        index = random.randint(10, 18)

    heatmap = heatmap_generator.generate_heatmap(landmark=list_landmarks[index])[np.newaxis, :, :, :]
    Mask = np.where(Mask == index + 1, 1, 0)  # just segment one IVD

    if D > 12:
        start = random.choice([i for i in range(D - 12 + 1)])
        MR = crop(MR, start=start, end=start + 12, axis='z')
        heatmap = crop(heatmap, start=start, end=start + 12, axis='z')
        Mask = crop(Mask, start=start, end=start + 12, axis='z')

    MR = crop_to_center(MR, list_landmarks[index], dsize=(12, 64, 96))
    heatmap = crop_to_center(heatmap, list_landmarks[index], dsize=(12, 64, 96))
    Mask = crop_to_center(Mask, list_landmarks[index], dsize=(12, 64, 96))

    return [np.concatenate((MR, heatmap)), Mask]
    def execute(self, event_id):
        state_package = self.state_data.get_task_state(self.task_id, event_id)
        if state_package is not None:
            # Get all movements since last movement, not including what has occurred in the current second.
            heatmap_gen, predictor = state_package
            last_movement = heatmap_gen.last_movement if heatmap_gen.last_movement else -1
            movements_since_last_update = self.log_source.retrieve_event_movements(
                event_id,
                time_start=last_movement + 1,
                time_end=math.floor(time.time()))
            heatmap_gen.append_movements(movements_since_last_update)
        else:
            event_movements = self.log_source.retrieve_event_movements(
                event_id)
            heatmap_gen = HeatmapGenerator(event_id, event_movements)
            predictor = RegionPopulationPredictor()

        heatmaps = heatmap_gen.build_heat_map_history(100)[1]
        for timestamp, heatmap in heatmaps.items():
            predictor.add_heat_table(int(timestamp), heatmap)

        result = predictor.calculate_average_region_population_by_30_min_buckets(
        )

        heatmap_gen.historical_heatmaps = {}
        heatmap_gen.heatmap_times = []
        state_package = (heatmap_gen, predictor)
        self.state_data.save_task_state(self.task_id, event_id, state_package)

        return result
Exemple #4
0
def pre_processing(dict_images):
    MR = dict_images['MR']
    MR = np.clip(MR / 2048, a_max=1, a_min=0)
    _, D, H, W = MR.shape

    heatmap_generator = HeatmapGenerator(image_size=(D, H, W),
                                         sigma=2.,
                                         spine_heatmap_sigma=20,
                                         scale_factor=3.,
                                         spine_heatmap_scale_factor=10.,
                                         normalize=True,
                                         size_sigma_factor=8,
                                         sigma_scale_factor=2,
                                         dtype=np.float32)
    spine_heatmap = heatmap_generator.generate_spine_heatmap(
        list_landmarks=dict_images['list_landmarks'])

    if D > 12:
        start = random.choice([i for i in range(D - 12 + 1)])
        MR = crop(MR, start=start, end=start + 12, axis='z')
        spine_heatmap = crop(spine_heatmap,
                             start=start,
                             end=start + 12,
                             axis='z')

    return [MR, spine_heatmap]
Exemple #5
0
def evaluate(prediction_dir, gt_dir):

    list_errors = []
    list_case_ids = os.listdir(prediction_dir)
    for case_id in list_case_ids:
        pred = sitk.ReadImage(os.path.join(prediction_dir, case_id, 'pred_heatmap.nii.gz'))
        pred = sitk.GetArrayFromImage(pred)
        D, H, W = pred.shape

        landmarks = pd.read_csv(os.path.join(gt_dir, case_id, 'landmarks_512.csv'))
        list_landmarks = landmark_extractor(landmarks)
        heatmap_generator = HeatmapGenerator(image_size=(D, H, W),
                                             sigma=2.,
                                             spine_heatmap_sigma=20,
                                             scale_factor=1.,
                                             normalize=True,
                                             size_sigma_factor=6,
                                             sigma_scale_factor=1,
                                             dtype=np.float32)
        gt = heatmap_generator.generate_spine_heatmap(list_landmarks=list_landmarks)  # 4D

        pred_centroid = ndimage.center_of_mass(pred)
        gt_centroid = ndimage.center_of_mass(gt[0])
        error = [abs(pred_centroid[i] - gt_centroid[i]) for i in range(3)]
        print(f"\n{case_id}:,\npred: {pred_centroid},\ngt: {gt_centroid},\nerror: {error}\n")
        list_errors.append(error)

    return np.mean(list_errors, axis=0)
def test_build_heatmap_basic():
    entries1 = [(2, 'a'), (6, 'b')]
    exits1 = [(4, 'a'), (8, 'b')]
    m1 = initialise_user_movements(1, entries1, exits1)

    entries2 = [(1, 'b')]
    exits2 = [(9, 'b')]
    m2 = initialise_user_movements(2, entries2, exits2)

    heatmap_gen = HeatmapGenerator(1, merge_user_movements(m1, m2))

    expected_at_2 = {'a': 1, 'b': 1}
    expected_at_10 = {}
    assert expected_at_2 == heatmap_gen.build_heat_map(2)
    assert expected_at_10 == heatmap_gen.build_heat_map(10)
Exemple #7
0
def pre_processing(dict_images):

    MR = dict_images['MR']
    MR = np.clip(MR / 2048, a_min=0, a_max=1)
    _, D, H, W = MR.shape

    heatmap_generator = HeatmapGenerator(image_size=(D, H, W),
                                         sigma=2.,
                                         spine_heatmap_sigma=20,
                                         scale_factor=1.,
                                         normalize=True,
                                         size_sigma_factor=6,
                                         sigma_scale_factor=1,
                                         dtype=np.float32)
    spine_heatmap = heatmap_generator.generate_spine_heatmap(list_landmarks=dict_images['list_landmarks'])

    return [MR, spine_heatmap]
Exemple #8
0
def pre_processing(dict_images):
    MR = dict_images['MR']
    MR = np.clip(MR / 2048, a_max=1, a_min=0)
    _, D, H, W = MR.shape

    heatmap_generator = HeatmapGenerator(image_size=(D, H, W),
                                         sigma=2.,
                                         spine_heatmap_sigma=20,
                                         scale_factor=1.,
                                         normalize=True,
                                         size_sigma_factor=6,
                                         sigma_scale_factor=1,
                                         dtype=np.float32)
    spine_heatmap = heatmap_generator.generate_spine_heatmap(
        list_landmarks=dict_images['list_landmarks'])
    heatmaps = heatmap_generator.generate_heatmaps(
        list_landmarks=dict_images['list_landmarks'])
    centroid_coordinate = [
        round(i) for i in ndimage.center_of_mass(spine_heatmap)
    ]  # (0, z, y, x)

    start_x = centroid_coordinate[-1] - W // 4
    end_x = centroid_coordinate[-1] + W // 4
    MR = crop(MR, start=start_x, end=end_x, axis='x')
    spine_heatmap = crop(spine_heatmap, start=start_x, end=end_x, axis='x')
    heatmaps = crop(heatmaps, start_x, end=end_x, axis='x')

    if D > 12:
        start_z = random.choice([i for i in range(D - 12 + 1)])
        MR = crop(MR, start=start_z, end=start_z + 12, axis='z')
        spine_heatmap = crop(spine_heatmap,
                             start=start_z,
                             end=start_z + 12,
                             axis='z')
        heatmaps = crop(heatmaps, start_z, end=start_z + 12, axis='z')

    # FIXME crop patches
    start_y = random.choice((0, H // 4, H // 2))
    end_y = start_y + H // 2
    MR = crop(MR, start=start_y, end=end_y, axis='y')
    spine_heatmap = crop(spine_heatmap, start=start_y, end=end_y, axis='y')
    heatmaps = crop(heatmaps, start_y, end=end_y, axis='y')

    return [MR, spine_heatmap, heatmaps
            ]  # (1, 12, 256, 256), (1, 12, 256, 256), (19, 12, 256, 256)
def test_build_heatmap_advanced():
    entries1 = [(2, 'a'), (4, 'b')]
    exits1 = [(6, 'a'), (8, 'b')]
    m1 = initialise_user_movements(1, entries1, exits1)

    entries2 = [(1, 'b')]
    exits2 = [(19, 'b')]
    m2 = initialise_user_movements(2, entries2, exits2)

    entries3 = [(1, 'a'), (3, 'a')]
    exits3 = [(2, 'a'), (7, 'a')]
    m3 = initialise_user_movements(3, entries3, exits3)

    heatmap_gen = HeatmapGenerator(1, merge_user_movements(m1, m2, m3))

    expected_at_5 = {'a': 1, 'b': 2}
    expected_at_19 = {}
    assert expected_at_5 == heatmap_gen.build_heat_map(5)
    assert expected_at_19 == heatmap_gen.build_heat_map(19)
Exemple #10
0
    def execute(self, event_id):
        heatmap_gen = self.state_data.get_task_state(self.task_id, event_id)
        if heatmap_gen is not None:
            # Get all movements since last movement, not including what has occurred in the current second.
            last_movement = heatmap_gen.last_movement if heatmap_gen.last_movement else -1
            movements_since_last_update = self.log_source.retrieve_event_movements(
                event_id,
                time_start=last_movement + 1,
                time_end=math.floor(time.time()))
            heatmap_gen.append_movements(movements_since_last_update)
        else:
            event_movements = self.log_source.retrieve_event_movements(
                event_id)
            heatmap_gen = HeatmapGenerator(event_id, event_movements)

        times, heatmaps = heatmap_gen.build_heat_map_history(
            time_interval=DEFAULT_TIME_INTERVAL)
        result = {"timestamps": times, "data": heatmaps}

        # Save new heatmap generator state.
        self.state_data.save_task_state(self.task_id, event_id, heatmap_gen)
        return result
def test_build_heatmap_history_basic_auto_duration():
    entries1 = [(2, 'a'), (6, 'b')]
    exits1 = [(4, 'a'), (8, 'b')]
    m1 = initialise_user_movements(1, entries1, exits1)

    entries2 = [(1, 'b')]
    exits2 = [(9, 'b')]
    m2 = initialise_user_movements(2, entries2, exits2)

    heatmap_gen = HeatmapGenerator(1, merge_user_movements(m1, m2))

    expected_at_3 = {'a': 1, 'b': 1}
    expected_at_5 = {'b': 1}
    expected_at_7 = {'b': 2}
    expected_at_9 = {}

    expected = {
        '3': expected_at_3,
        '5': expected_at_5,
        '7': expected_at_7,
        '9': expected_at_9
    }
    assert expected == heatmap_gen.build_heat_map_history(2)[1]
def test_heatmap_history_with_appended_movements_halfway_1():
    entries1 = [(2, 'a'), (6, 'b')]
    exits1 = [(4, 'a')]
    m1 = initialise_user_movements(1, entries1, exits1)

    entries2 = [(1, 'b')]
    exits2 = []
    m2 = initialise_user_movements(2, entries2, exits2)

    heatmap_gen = HeatmapGenerator(1, merge_user_movements(m1, m2))

    expected_at_3 = {'a': 1, 'b': 1}
    expected_at_5 = {'b': 1}

    expected = {'3': expected_at_3, '5': expected_at_5}
    assert expected == heatmap_gen.build_heat_map_history(2)[1]

    entries1 = [(10, 'c')]
    exits1 = [(12, 'c'), (13, 'c')]
    m1 = initialise_user_movements(1, entries1, exits1)

    start_time = heatmap_gen.last_movement
    heatmap_gen.append_movements(m1)

    expected_at_7 = {'b': 2}
    expected_at_9 = {'b': 2}
    expected_at_11 = {'b': 1, 'c': 1}
    expected_at_13 = {'b': 2}

    expected = {
        '3': expected_at_3,
        '5': expected_at_5,
        '7': expected_at_7,
        '9': expected_at_9,
        '11': expected_at_11,
        '13': expected_at_13
    }
    assert expected == heatmap_gen.build_heat_map_history(2)[1]
Exemple #13
0
def inference(trainer, list_case_dirs, save_path, do_TTA=False):
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    if do_TTA:
        TTA_mode = [[], [2], [4], [2, 4]]
    else:
        TTA_mode = [[]]

    with torch.no_grad():
        trainer.setting.network.eval()
        for case_dir in tqdm(list_case_dirs):
            assert os.path.exists(case_dir), case_dir + 'does not exist!'
            case_id = case_dir.split('/')[-1]

            dict_images = read_data(case_dir)
            list_images = pre_processing(dict_images)
            MR = list_images[0]
            # MR = torch.from_numpy(MR)
            list_IVD_landmarks = list_images[1]

            C, D, H, W = MR.shape
            dsize = (12, 64, 96)
            # all pred_IVDMask will be insert into this tensor
            pred_Mask = torch.zeros(C, D, H, W).to(trainer.setting.device)
            heatmap_generator = HeatmapGenerator(image_size=(D, H, W),
                                                 sigma=2.,
                                                 scale_factor=1.,
                                                 normalize=True,
                                                 size_sigma_factor=8,
                                                 sigma_scale_factor=2,
                                                 dtype=np.float32)

            for index, landmark in enumerate(list_IVD_landmarks):
                if True in np.isnan(landmark):
                    continue

                temp = torch.zeros(C, D, H, W).to(trainer.setting.device)
                heatmap = heatmap_generator.generate_heatmap(landmark)[
                    np.newaxis, :, :, :]  # (1, D, H, W)
                # heatmap = torch.from_numpy(heatmap)
                input_ = np.concatenate((MR, heatmap), axis=0)  # (2, D, H, W)

                if D > 12:
                    input_, patch, pad = crop_to_center(input_,
                                                        landmark=landmark,
                                                        dsize=dsize)
                    input_ = np.stack(
                        (input_[:, :12, :, :], input_[:, -12:, :, :]),
                        axis=0)  # (2, 2, 12, H, W)

                    input_ = torch.from_numpy(input_).to(
                        trainer.setting.device)
                    # pred_IVDMask = trainer.setting.network(input_)  # (2, 2, 12, 128, 128)
                    pred_IVDMask = test_time_augmentation(
                        trainer, input_, TTA_mode)
                    pred_IVDMask = post_processing(
                        pred_IVDMask, D,
                        device=trainer.setting.device)  # (1, 2, D, 128, 128)
                    pred_IVDMask = nn.Softmax(dim=1)(pred_IVDMask)
                    pred_IVDMask = torch.argmax(pred_IVDMask,
                                                dim=1)  # (1, D, 128, 128)

                else:
                    input_, patch, pad = crop_to_center(input_,
                                                        landmark=landmark,
                                                        dsize=dsize)
                    input_ = torch.from_numpy(input_).unsqueeze(0).to(
                        trainer.setting.device)
                    # pred_IVDMask = trainer.setting.network(input_)  # (1, 2, 12, 128, 128)
                    pred_IVDMask = test_time_augmentation(
                        trainer, input_, TTA_mode)
                    pred_IVDMask = nn.Softmax(dim=1)(pred_IVDMask)
                    pred_IVDMask = torch.argmax(pred_IVDMask,
                                                dim=1)  # (1, 12, 128, 128)

                bh, eh, bw, ew = patch
                pad_h_1, pad_h_2, pad_w_1, pad_w_2 = pad
                if pad_h_1 > 0:
                    pred_IVDMask = pred_IVDMask[:, :, pad_h_1:, :]
                if pad_h_2 > 0:
                    pred_IVDMask = pred_IVDMask[:, :, :-pad_h_2, :]
                if pad_w_1 > 0:
                    pred_IVDMask = pred_IVDMask[:, :, :, pad_w_1:]
                if pad_w_2 > 0:
                    pred_IVDMask = pred_IVDMask[:, :, :, :-pad_w_2]

                pred_IVDMask = torch.where(pred_IVDMask > 0, index + 11, 0)
                temp[:, :, bh:eh, bw:ew] = pred_IVDMask
                pred_Mask += temp

            pred_Mask = pred_Mask.cpu().numpy()  # (1, 12, 128, 128)

            # Save prediction to nii image
            template_nii = sitk.ReadImage(case_dir + '/MR_512.nii.gz')

            prediction_nii = sitk.GetImageFromArray(pred_Mask[0])
            prediction_nii = copy_sitk_imageinfo(template_nii, prediction_nii)
            if not os.path.exists(save_path + '/' + case_id):
                os.mkdir(save_path + '/' + case_id)
            sitk.WriteImage(prediction_nii,
                            save_path + '/' + case_id + '/pred_IVDMask.nii.gz')