Exemplo n.º 1
0
def test_resample():
    """
    Test resample by confirming that it generates appropriate
    resampling on two test cases with outputs within check_equal's
    tolerance level, and one which should fail (incompatible shapes).
    """
    # linear, vol has no feature channel - Pass
    interpolation = "linear"
    vol = tf.constant(np.array([[[0, 1, 2], [3, 4, 5]]],
                               dtype=np.float32))  # shape = [1,2,3]
    loc = tf.constant(
        np.array(
            [[
                [[0, 0], [0, 1], [0, 3]],  # outside frame
                [[0.4, 0], [0.5, 1], [0.6, 2]],
                [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]],
            ]],  # resampled = 3x+y
            dtype=np.float32,
        ))  # shape = [1,3,3,2]
    want = tf.constant(
        np.array([[[0, 1, 2], [1.2, 2.5, 3.8], [1.9, 2, 2.1]]],
                 dtype=np.float32))  # shape = [1,3,3]
    get = layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
    assert check_equal(want, get)

    # linear, vol has feature channel - Pass
    interpolation = "linear"
    vol = tf.constant(
        np.array([[[[0, 0], [1, 1], [2, 2]], [[3, 3], [4, 4], [5, 5]]]],
                 dtype=np.float32))  # shape = [1,2,3,2]
    loc = tf.constant(
        np.array(
            [[
                [[0, 0], [0, 1], [0, 3]],  # outside frame
                [[0.4, 0], [0.5, 1], [0.6, 2]],
                [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]],
            ]],  # resampled = 3x+y
            dtype=np.float32,
        ))  # shape = [1,3,3,2]
    want = tf.constant(
        np.array(
            [[
                [[0, 0], [1, 1], [2, 2]],
                [[1.2, 1.2], [2.5, 2.5], [3.8, 3.8]],
                [[1.9, 1.9], [2, 2], [2.1, 2.1]],
            ]],
            dtype=np.float32,
        ))  # shape = [1,3,3,2]
    get = layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
    assert check_equal(want, get)

    # Inconsistent shapes for resampling - Fail
    interpolation = "linear"
    vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
    loc = tf.constant(np.array([[0, 0], [0, 0]],
                               dtype=np.float32))  # shape = [2,2]
    with pytest.raises(ValueError) as execinfo:
        layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
    msg = " ".join(execinfo.value.args[0].split())
    assert "vol shape inconsistent with loc" in msg
Exemplo n.º 2
0
 def test_shape_error(self):
     vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
     loc = tf.constant(np.array([[0, 0], [0, 0]],
                                dtype=np.float32))  # shape = [2,2]
     with pytest.raises(ValueError) as err_info:
         layer_util.resample(vol=vol, loc=loc)
     assert "vol shape inconsistent with loc" in str(err_info.value)
Exemplo n.º 3
0
 def test_interpolation_error(self):
     interpolation = "nearest"
     vol = tf.constant(np.array([[0]], dtype=np.float32))  # shape = [1,1]
     loc = tf.constant(np.array([[0, 0], [0, 0]],
                                dtype=np.float32))  # shape = [2,2]
     with pytest.raises(ValueError) as err_info:
         layer_util.resample(vol=vol, loc=loc, interpolation=interpolation)
     assert "resample supports only linear interpolation" in str(
         err_info.value)
Exemplo n.º 4
0
    def call(self, inputs, **kwargs) -> tf.Tensor:
        """
        :param inputs: (ddf, image)

          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
        :param kwargs: additional arguments.
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
        """
        ddf, image = inputs
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
Exemplo n.º 5
0
    def transform(image: tf.Tensor, grid_ref: tf.Tensor,
                  params: tf.Tensor) -> tf.Tensor:
        """
        Transforms the reference grid and then resample the image.

        :param image: shape = (batch, dim1, dim2, dim3)
        :param grid_ref: shape = (dim1, dim2, dim3, 3)
        :param params: DDF, shape = (batch, dim1, dim2, dim3, 3)
        :return: shape = (batch, dim1, dim2, dim3)
        """
        return resample(vol=image, loc=grid_ref[None, ...] + params)
Exemplo n.º 6
0
    def _transform(image, grid_ref, transforms):
        """

        :param image: shape = [batch, dim1, dim2, dim3]
        :param grid_ref: shape = [dim1, dim2, dim3, 3]
        :param transforms: shape = [batch, 4, 3]
        :return: shape = [batch, dim1, dim2, dim3]
        """
        transformed = layer_util.resample(vol=image,
                                          loc=layer_util.warp_grid(grid_ref, transforms))
        return transformed
Exemplo n.º 7
0
    def test_repeat_extrapolation(self, channel):
        x = self.loc[..., 0]
        y = self.loc[..., 1]
        x = tf.clip_by_value(x, self.x_min, self.x_max)
        y = tf.clip_by_value(y, self.y_min, self.y_max)
        expected = 3 * x + y

        vol = self.vol
        if channel > 0:
            vol = tf.repeat(vol[..., None], channel, axis=-1)
            expected = tf.repeat(expected[..., None], channel, axis=-1)

        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=False)
        assert is_equal_tf(expected, got)
Exemplo n.º 8
0
    def _transform(image, grid_ref, transforms):
        """
        Resamples an input image from the reference grid by the series
        of input transforms.

        :param image: shape = (batch, dim1, dim2, dim3)
        :param grid_ref: shape = [dim1, dim2, dim3, 3]
        :param transforms: shape = [batch, 4, 3]
        :return: shape = (batch, dim1, dim2, dim3)
        """
        transformed = layer_util.resample(vol=image,
                                          loc=layer_util.warp_grid(
                                              grid_ref, transforms))
        return transformed
Exemplo n.º 9
0
    def test_repeat_zero_bound(self, channel):
        x = self.loc[..., 0]
        y = self.loc[..., 1]
        expected = 3 * x + y
        expected = (expected * tf.cast(x > self.x_min, tf.float32) *
                    tf.cast(x <= self.x_max, tf.float32))
        expected = (expected * tf.cast(y > self.y_min, tf.float32) *
                    tf.cast(y <= self.y_max, tf.float32))

        vol = self.vol
        if channel > 0:
            vol = tf.repeat(vol[..., None], channel, axis=-1)
            expected = tf.repeat(expected[..., None], channel, axis=-1)

        got = layer_util.resample(vol=vol, loc=self.loc, zero_boundary=True)
        assert is_equal_tf(expected, got)
Exemplo n.º 10
0
 def call(self, inputs, **kwargs):
     """
     wrap an image into a fixed size using ddf
     same functionality as transform of neuron
     https://github.com/adalca/neuron/blob/master/neuron/utils.py
     vol = image
     loc_shift = ddf
     :param inputs: [ddf, image]
                     ddf.shape = [batch, f_dim1, f_dim2, f_dim3, 3]
                     image.shape = [batch, m_dim1, m_dim2, m_dim3]
     :param kwargs:
     :return: shape = [batch, f_dim1, f_dim2, f_dim3]
     """
     grid_warped = self.grid_ref + inputs[0]  # [batch, f_dim1, f_dim2, f_dim3, 3]
     image_warped = layer_util.resample(
         vol=inputs[1], loc=grid_warped
     )  # [batch, f_dim1, f_dim2, f_dim3]
     return image_warped
Exemplo n.º 11
0
def train_step(grid, weights, optimizer, mov, fix):
    """
    Train step function for backprop using gradient tape

    :param grid: reference grid return from layer_util.get_reference_grid
    :param weights: trainable affine parameters [1, 4, 3]
    :param optimizer: tf.optimizers
    :param mov: moving image [1, m_dim1, m_dim2, m_dim3]
    :param fix: fixed image [1, f_dim1, f_dim2, f_dim3]
    :return loss: image dissimilarity to minimise
    """
    with tf.GradientTape() as tape:
        pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights))
        loss = image_loss.dissimilarity_fn(
            y_true=fix, y_pred=pred, name=image_loss_name
        )
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss
Exemplo n.º 12
0
def predict_on_dataset(
    dataset: tf.data.Dataset,
    fixed_grid_ref: tf.Tensor,
    model: tf.keras.Model,
    model_method: str,
    save_dir: str,
    save_nifti: bool,
    save_png: bool,
):
    """
    Function to predict results from a dataset from some model

    :param dataset: where data is stored
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param model: model to be used for prediction
    :param model_method: str, ddf / dvf / affine / conditional
    :param save_dir: str, path to store dir
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    """
    # remove the save_dir in case it exists
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)

    sample_index_strs = []
    metric_lists = []
    for _, inputs_dict in enumerate(dataset):
        outputs_dict = model.predict(x=inputs_dict)

        # moving image/label
        # (batch, m_dim1, m_dim2, m_dim3)
        moving_image = inputs_dict["moving_image"]
        moving_label = inputs_dict.get("moving_label", None)
        # fixed image/labelimage_index
        # (batch, f_dim1, f_dim2, f_dim3)
        fixed_image = inputs_dict["fixed_image"]
        fixed_label = inputs_dict.get("fixed_label", None)

        # indices to identify the pair
        # (batch, num_indices) last indice is for label, -1 means unlabeled data
        indices = inputs_dict.get("indices")
        # ddf / dvf
        # (batch, f_dim1, f_dim2, f_dim3, 3)
        ddf = outputs_dict.get("ddf", None)
        dvf = outputs_dict.get("dvf", None)
        affine = outputs_dict.get("affine", None)  # (batch, 4, 3)

        # prediction
        # (batch, f_dim1, f_dim2, f_dim3)
        pred_fixed_label = outputs_dict.get("pred_fixed_label", None)
        pred_fixed_image = (layer_util.resample(
            vol=moving_image, loc=fixed_grid_ref +
            ddf) if ddf is not None else None)

        # save images of inputs and outputs
        for sample_index in range(moving_image.shape[0]):
            # save moving/fixed image under pair_dir
            # save moving/fixed label, pred fixed image/label, ddf/dvf under label dir
            # if labeled, label dir is a sub dir of pair_dir, otherwise = pair_dir

            # init output path
            indices_i = indices[sample_index, :].numpy().astype(int).tolist()
            pair_dir, label_dir = build_pair_output_path(indices=indices_i,
                                                         save_dir=save_dir)

            # save image/label
            # if model is conditional, the pred_fixed_image depends on the input label
            conditional = model_method == "conditional"
            arr_save_dirs = [
                pair_dir,
                pair_dir,
                label_dir if conditional else pair_dir,
                label_dir,
                label_dir,
                label_dir,
            ]
            arrs = [
                moving_image,
                fixed_image,
                pred_fixed_image,
                moving_label,
                fixed_label,
                pred_fixed_label,
            ]
            names = [
                "moving_image",
                "fixed_image",
                "pred_fixed_image",  # or warped moving image
                "moving_label",
                "fixed_label",
                "pred_fixed_label",  # or warped moving label
            ]
            for arr_save_dir, arr, name in zip(arr_save_dirs, arrs, names):
                if arr is not None:
                    # for files under pair_dir, do not overwrite
                    save_array(
                        save_dir=arr_save_dir,
                        arr=arr[sample_index, :, :, :],
                        name=name,
                        gray=True,
                        save_nifti=save_nifti,
                        save_png=save_png,
                        overwrite=arr_save_dir == label_dir,
                    )

            # save ddf / dvf
            arrs = [ddf, dvf]
            names = ["ddf", "dvf"]
            for arr, name in zip(arrs, names):
                if arr is not None:
                    arr = normalize_array(arr=arr[sample_index, :, :, :])
                    save_array(
                        save_dir=label_dir if conditional else pair_dir,
                        arr=arr,
                        name=name,
                        gray=False,
                        save_nifti=save_nifti,
                        save_png=save_png,
                    )

            # save affine
            if affine is not None:
                np.savetxt(
                    fname=os.path.join(label_dir if conditional else pair_dir,
                                       "affine.txt"),
                    x=affine[sample_index, :, :].numpy(),
                    delimiter=",",
                )

            # calculate metric
            sample_index_str = "_".join([str(x) for x in indices_i])
            if sample_index_str in sample_index_strs:
                raise ValueError(
                    "Sample is repeated, maybe the dataset has been repeated.")
            sample_index_strs.append(sample_index_str)

            metric = calculate_metrics(
                fixed_image=fixed_image,
                fixed_label=fixed_label,
                pred_fixed_image=pred_fixed_image,
                pred_fixed_label=pred_fixed_label,
                fixed_grid_ref=fixed_grid_ref,
                sample_index=sample_index,
            )
            metric["pair_index"] = indices_i[:-1]
            metric["label_index"] = indices_i[-1]
            metric_lists.append(metric)

    # save metric
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
Exemplo n.º 13
0
    raise ("Download the data using demo_data.py script")
if not os.path.exists(FILE_PATH):
    raise ("Download the data using demo_data.py script")

fid = h5py.File(FILE_PATH, "r")
fixed_image = tf.cast(tf.expand_dims(fid["image"], axis=0), dtype=tf.float32)
fixed_image = (fixed_image - tf.reduce_min(fixed_image)) / (
    tf.reduce_max(fixed_image) - tf.reduce_min(fixed_image)
)  # normalisation to [0,1]

# generate a radomly-affine-transformed moving image
fixed_image_size = fixed_image.shape
transform_random = layer_util.random_transform_generator(batch_size=1, scale=0.2)
grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size[1:4])
grid_random = layer_util.warp_grid(grid_ref, transform_random)
moving_image = layer_util.resample(vol=fixed_image, loc=grid_random)
# warp the labels to get ground-truth using the same random affine, for validation
fixed_labels = tf.cast(tf.expand_dims(fid["label"], axis=0), dtype=tf.float32)
moving_labels = tf.stack(
    [
        layer_util.resample(vol=fixed_labels[..., idx], loc=grid_random)
        for idx in range(fixed_labels.shape[4])
    ],
    axis=4,
)


## optimisation
@tf.function
def train_step(grid, weights, optimizer, mov, fix):
    """
Exemplo n.º 14
0
    def test_resample(self):
        # linear, vol has no feature channel
        interpolation = "linear"
        vol = tf.constant(
            np.array([[
                [0, 1, 2],
                [3, 4, 5],
            ]], dtype=np.float32))  # shape = [1,2,3]
        loc = tf.constant(
            np.array(
                [[
                    [[0, 0], [0, 1], [0, 3]],  # outside frame
                    [[0.4, 0], [0.5, 1], [0.6, 2]],
                    [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]],  # resampled = 3x+y
                ]],
                dtype=np.float32))  # shape = [1,3,3,2]
        want = tf.constant(
            np.array([[
                [0, 1, 2],
                [1.2, 2.5, 3.8],
                [1.9, 2, 2.1],
            ]],
                     dtype=np.float32))  # shape = [1,3,3]
        get = layer_util.resample(vol=vol,
                                  loc=loc,
                                  interpolation=interpolation)
        self.check_equal(want, get)

        # linear, vol has feature channel
        interpolation = "linear"
        vol = tf.constant(
            np.array([[
                [
                    [0, 0],
                    [1, 1],
                    [2, 2],
                ],
                [
                    [3, 3],
                    [4, 4],
                    [5, 5],
                ],
            ]],
                     dtype=np.float32))  # shape = [1,2,3,2]
        loc = tf.constant(
            np.array(
                [[
                    [[0, 0], [0, 1], [0, 3]],  # outside frame
                    [[0.4, 0], [0.5, 1], [0.6, 2]],
                    [[0.4, 0.7], [0.5, 0.5], [0.6, 0.3]],  # resampled = 3x+y
                ]],
                dtype=np.float32))  # shape = [1,3,3,2]
        want = tf.constant(
            np.array([[
                [[0, 0], [1, 1], [2, 2]],
                [[1.2, 1.2], [2.5, 2.5], [3.8, 3.8]],
                [[1.9, 1.9], [2, 2], [2.1, 2.1]],
            ]],
                     dtype=np.float32))  # shape = [1,3,3,2]
        get = layer_util.resample(vol=vol,
                                  loc=loc,
                                  interpolation=interpolation)
        self.check_equal(want, get)