Exemple #1
0
 def get_random_geometry(dict_of_g):
     g = SpatialInfo(
         shape=np.random.randint(0, 8, size=[3]),
         origin=np.random.uniform(-4, 6, size=[3]),
         spacing=np.random.uniform(0.5, 1.5, size=[3]),
     )
     return g
Exemple #2
0
    def test_transform_resampling_random_dependent(self):
        """
        Use a functors to randomly generate a resampling geometry.

        All volumes MUST use the same geometry.
        """
        batch = {
            'v1': torch.arange(10 * 9 * 8).reshape((1, 1, 10, 9, 8)),
            'v2': torch.arange(10 * 9 * 8).reshape((1, 1, 10, 9, 8)),
            'other': 'other_value'
        }

        get_spatial_info = functools.partial(
            TestTransformResample.get_spatial_info_generic,
            geometry=SpatialInfo(origin=(0, 0, 0),
                                 spacing=(1, 1, 1),
                                 shape=(10, 9, 8)))

        transform = trw.transforms.TransformResample(
            get_spatial_info_from_batch_name=get_spatial_info,
            resampling_geometry=TestTransformResample.get_random_geometry,
        )

        batch_transformed = transform(batch)
        assert len(batch_transformed) == 3
        assert batch_transformed['other'] == 'other_value'
        assert batch_transformed['v1'].shape == batch_transformed['v2'].shape
    def test_spatial_info_coordinate_mapping(self):
        origin = torch.tensor([10, 11, 12], dtype=torch.float32)
        spacing = torch.tensor([2, 3, 4], dtype=torch.float32)
        pst = affine_transformation_translation(origin).mm(
            affine_transformation_scale(spacing))
        si = SpatialInfo(shape=[20, 21, 22], patient_scale_transform=pst)

        origin_flipped = torch.flip(origin, (0, ))
        spacing_flipped = torch.flip(spacing, (0, ))
        si_2 = SpatialInfo(shape=[20, 21, 22],
                           origin=origin_flipped,
                           spacing=spacing_flipped)
        assert (si.patient_scale_transform -
                si_2.patient_scale_transform).abs().max() < 1e-5

        # index (0, 0, 0) is origin!
        p = si.index_to_position(index_zyx=torch.tensor([0, 0, 0]))
        assert len(p.shape) == 1
        assert p.shape[0] == 3

        assert (p - torch.tensor([12, 11, 10],
                                 dtype=torch.float32)).abs().max() < 1e-5
        assert (si.position_to_index(position_zyx=p) - torch.tensor(
            [0, 0, 0], dtype=torch.float32)).abs().max() < 1e-5

        # move in one direction from the origin
        for n in range(20):
            dp = torch.empty(3).uniform_(0, 10).type(torch.float32)
            p = origin_flipped + dp * spacing_flipped
            i = si.position_to_index(position_zyx=p)
            assert (i - dp).abs().max() < 1e-5

            p_back = si.index_to_position(index_zyx=i)
            assert (p - p_back).abs().max() < 1e-5
    def test_spatial_info_pst(self):
        pst = affine_transformation_translation([10, 11, 12]).mm(
            affine_transformation_rotation_3d_x(0.3)).mm(
                affine_transformation_scale([2, 3, 4]))

        si = SpatialInfo(shape=[20, 21, 22], patient_scale_transform=pst)

        assert (si.spacing == np.asarray([4, 3, 2])).all()
        assert (si.origin == np.asarray([12, 11, 10])).all()
    def test_sub_geometry(self):
        pst = affine_transformation_translation([10, 11, 12]).mm(
            affine_transformation_rotation_3d_x(0.3)).mm(
                affine_transformation_scale([2, 3, 4]))

        si = SpatialInfo(shape=[20, 21, 22], patient_scale_transform=pst)
        start_zyx = torch.tensor([5, 6, 7])
        end_zyx = torch.tensor([8, 10, 13])
        si_sub = si.sub_geometry(start_index_zyx=start_zyx,
                                 end_index_zyx_inclusive=end_zyx)

        o = si_sub.index_to_position(index_zyx=torch.tensor([0, 0, 0]))
        expected_o = si.index_to_position(index_zyx=start_zyx)
        assert (o - expected_o).abs().max() == 0

        index_e = end_zyx - start_zyx
        pos_e = si_sub.index_to_position(index_zyx=index_e)
        pos_e_expected = si.index_to_position(index_zyx=end_zyx)
        assert (pos_e - pos_e_expected).abs().max() <= 1e-5
Exemple #6
0
    def test_transform_background_volume_dependent(self):
        batch = {
            'v1': torch.arange(10 * 9 * 8).reshape((1, 1, 10, 9, 8)),
            'v2': torch.arange(10 * 9 * 8).reshape((1, 1, 10, 9, 8)),
            'other': 'other_value'
        }

        get_spatial_info = functools.partial(
            TestTransformResample.get_spatial_info_generic,
            geometry=SpatialInfo(origin=(0, 0, 0),
                                 spacing=(1, 1, 1),
                                 shape=(10, 9, 8)))

        transform = trw.transforms.TransformResample(
            get_spatial_info_from_batch_name=get_spatial_info,
            resampling_geometry=SpatialInfo(origin=(0, 0, 0),
                                            spacing=(1, 1, 1),
                                            shape=(1, 9, 18)),
            constant_background_value={
                'v1': 42,
                'v2': 43
            })

        batch_transformed = transform(batch)
        assert len(batch_transformed) == 3
        assert batch_transformed['other'] == 'other_value'
        assert batch_transformed['v1'].shape == (1, 1, 1, 9, 18)
        assert batch_transformed['v2'].shape == (1, 1, 1, 9, 18)

        # voxel within FoV of the volumes
        assert (batch_transformed['v1'][0,
                                        0, :, :, :8] == batch['v1'][0, 0,
                                                                    0:1, :, :8]
                ).all()
        assert (batch_transformed['v2'][0,
                                        0, :, :, :8] == batch['v2'][0, 0,
                                                                    0:1, :, :8]
                ).all()

        # voxels in the background
        assert (batch_transformed['v1'][0, 0, 0:1, :, 8:] == 42).all()
        assert (batch_transformed['v2'][0, 0, 0:1, :, 8:] == 43).all()
def get_spatial_info_type(batch: Batch, name: str) -> SpatialInfo:
    v = batch[name]
    assert len(v.shape) == 5
    assert v.shape[0] == 1
    assert v.shape[1] == 1

    affine_matrix = batch[name.replace('_voxels', '_affine')]
    assert len(affine_matrix.shape) == 3 and affine_matrix.shape[0] == 1
    affine_matrix = affine_matrix[0]

    spacing = get_spacing_from_4x4(affine_matrix)
    origin = get_translation_from_4x4(affine_matrix)
    return SpatialInfo(origin=origin, spacing=spacing, shape=v.shape[2:])
Exemple #8
0
    def test_transform_resampling_random_fixed(self):
        batch = {
            'v1': torch.arange(10 * 9 * 8).reshape((1, 1, 10, 9, 8)),
            'v2': torch.arange(1 * 1 * 1).reshape((1, 1, 1, 1, 1)),
        }

        get_spatial_info = functools.partial(
            TestTransformResample.get_spatial_info_generic,
            geometry=SpatialInfo(origin=(0, 0, 0),
                                 spacing=(1, 1, 1),
                                 shape=(10, 9, 8)))

        transform = trw.transforms.TransformResample(
            get_spatial_info_from_batch_name=get_spatial_info,
            resampling_geometry=functools.partial(
                random_fixed_geometry_within_geometries,
                fixed_geometry_shape=(3, 3, 3),
                fixed_geometry_spacing=(1, 1, 1)),
        )

        batch_transformed = transform(batch)
        assert len(batch_transformed) == 2
        assert batch_transformed['v1'].shape == batch_transformed['v2'].shape
    def test_random_volumes(self):
        def fill_volume(v, nb):
            for n in range(nb):
                min_index = torch.tensor([
                    torch.randint(0, v.shape[0] - 5, size=(1, )),
                    torch.randint(0, v.shape[1] - 5, size=(1, )),
                    torch.randint(0, v.shape[2] - 5, size=(1, )),
                ])
                shape = torch.randint(10, 15, size=(3, ))
                value = torch.randint(255, size=(1, ))

                max_index = min_index + shape
                max_index = trw.utils.clamp_n(max_index,
                                              torch.tensor([0, 0, 0]),
                                              torch.tensor(v.shape)).squeeze(0)
                trw.utils.sub_tensor(v, min_index, max_index)[:] = value

        np.random.seed(0)
        torch.manual_seed(0)
        nb_points = 10000
        all_avg_errors = []
        for e in range(5):
            tfm = torch.from_numpy(
                np.asarray([
                    [0.95, 0, 0, -0.5],
                    [0, 0.9, 0, -5],
                    [0, 0, 1, 1],
                    [0, 0, 0, 1],
                ],
                           dtype=np.float32))

            shape_moving_zyx = (42, 30, 40)
            shape_fixed_zyx = (37, 31, 42)

            moving_geometry = SpatialInfo(shape=shape_moving_zyx,
                                          origin=(10, 15, 20),
                                          spacing=(0.9, 1.1, 1.3))
            fixed_geometry = SpatialInfo(shape=shape_fixed_zyx,
                                         origin=(15, 5, 15),
                                         spacing=(1.3, 1.4, 1.1))

            moving = torch.zeros(shape_moving_zyx,
                                 dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            fill_volume(moving[0, 0], 5)

            fixed = resample_spatial_info(geometry_moving=moving_geometry,
                                          moving_volume=moving,
                                          geometry_fixed=fixed_geometry,
                                          tfm=tfm)

            error = 0.0
            nb_voxels = 0

            shape_moving_zyx = torch.tensor(shape_fixed_zyx)

            for _ in range(nb_points):
                index_fixed = torch.tensor([
                    torch.randint(shape_fixed_zyx[0], size=(1, )),
                    torch.randint(shape_fixed_zyx[1], size=(1, )),
                    torch.randint(shape_fixed_zyx[2], size=(1, ))
                ],
                                           dtype=torch.float32)

                index_fixed_i = index_fixed.type(torch.long)

                # transform: index (fixed) -> position in world space
                p = fixed_geometry.index_to_position(index_zyx=index_fixed)

                # apply moving transform
                p = apply_homogeneous_affine_transform_zyx(tfm, p)

                # transform position -> index (moving)
                index_moving = moving_geometry.position_to_index(
                    position_zyx=p)

                index_moving_rounded = index_moving.round().type(torch.long)
                if (index_moving_rounded >= 0).all() and \
                        ((index_moving_rounded - shape_moving_zyx) < 0).all():
                    value_resampled = fixed[0, 0][index_fixed_i[0],
                                                  index_fixed_i[1],
                                                  index_fixed_i[2]]
                    if value_resampled <= 2:
                        # discard background values
                        continue
                    value_fixed = moving[0, 0][index_moving_rounded[0],
                                               index_moving_rounded[1],
                                               index_moving_rounded[2]]
                    error += (value_resampled - value_fixed).abs()
                    nb_voxels += 1

            avg_error = error / nb_voxels
            print('avg=', error / nb_voxels, 'nb_voxels=', nb_voxels)
            # empirical value
            assert avg_error <= 25.0
            all_avg_errors.append(avg_error)
        assert np.mean(all_avg_errors) < 15