Beispiel #1
0
    def test_switch(self, interpolation_input):
        """The inverse of almost identity."""

        eps = 1e-2  # we are predicting integers so the eps can be relatively large
        interpolation_method, interpolator_kwargs = interpolation_input
        shape = (30, 20)

        delta_x = np.zeros(shape)
        delta_y = np.zeros(shape)

        delta_x[5, 9] = 1
        delta_y[5, 9] = 1
        delta_x[6, 10] = -1
        delta_y[6, 10] = -1

        df = DisplacementField(delta_x, delta_y)

        assert TestInvert.eps_equal(
            df,
            df.pseudo_inverse(
                interpolation_method=interpolation_method,
                interpolator_kwargs=interpolator_kwargs,
            ),
            eps=eps,
        )
Beispiel #2
0
    def test_invert_compose(self, random_state):
        """Create a bijective mapping other than identity and make sure that compose and invert work.

        Notes
        -----
        To each grid element assign another grid element. Both pseudo_inverse and __call__ should
        work perfectly on grid elements since no interpolation is needed.

        """

        shape = (_, w) = (40, 30)

        n_pixels = np.prod(shape)
        np.random.seed(random_state)
        perm = np.random.permutation(n_pixels)

        delta_x = np.zeros(shape)
        delta_y = np.zeros(shape)

        for i, x in enumerate(perm):
            r_inp, c_inp = i // w, i % w
            r_out, c_out = x // w, x % w

            delta_x[r_inp, c_inp] = c_out - c_inp
            delta_y[r_inp, c_inp] = r_out - r_inp

        df = DisplacementField(delta_x, delta_y)

        df_id = DisplacementField(np.zeros(shape), np.zeros(shape))

        assert df(df.pseudo_inverse()) == df_id
Beispiel #3
0
    def test_cached_transormation(self, df_cached, interpolation_input):
        """Test inversion approaches on a relatively mild transform that is anchored in the corners."""

        eps = 0.1  # missing quarter of a pixel is not a big deal

        delta_x, delta_y, delta_x_inv, delta_y_inv = df_cached

        interpolation_method, interpolator_kwargs = interpolation_input

        df = DisplacementField(delta_x, delta_y)
        df_inv_true = DisplacementField(delta_x_inv, delta_y_inv)

        df_inv_numerical = df.pseudo_inverse(
            interpolation_method=interpolation_method,
            interpolator_kwargs=interpolator_kwargs,
        )

        assert TestInvert.eps_equal(df_inv_true, df_inv_numerical, eps=eps)
Beispiel #4
0
    def test_identity(self, interpolation_input):
        """The inverse of identity is identity."""

        eps = 1e-6
        interpolation_method, interpolator_kwargs = interpolation_input
        shape = (30, 20)

        delta_x = np.zeros(shape)
        delta_y = np.zeros(shape)

        df = DisplacementField(delta_x, delta_y)

        assert TestInvert.eps_equal(
            df,
            df.pseudo_inverse(
                interpolation_method=interpolation_method,
                interpolator_kwargs=interpolator_kwargs,
            ),
            eps=eps,
        )
Beispiel #5
0
    def test_off_grid(self):
        """Map one pixel into an offgrid elemenent."""

        shape = (30, 20)

        delta_x = np.zeros(shape)
        delta_y = np.zeros(shape)

        delta_x[5, 9] = 0.7
        delta_y[5, 9] = 0.7

        ix = np.ones(shape, dtype=bool)
        ix[5, 9] = False

        df = DisplacementField(delta_x, delta_y)
        df_inv = df.pseudo_inverse()

        assert np.allclose(df.delta_x[ix], df_inv.delta_x[ix]) and np.allclose(
            df.delta_y[ix], df_inv.delta_y[ix])

        # Interpolation will never give the precise value
        assert not np.isclose(df.delta_x[5, 9],
                              df_inv.delta_x[5, 9]) and not np.isclose(
                                  df.delta_y[5, 9], df_inv.delta_y[5, 9])
def chain_predict(model, inp, n_iterations=1):
    """Run alignment recursively.

    Parameters
    ----------
    model : keras.models.Model
        A trained model that whose inputs have shape (batch_size, h, w, 2) - last dimension represents
        stacking of atlas and input image. The outputs are of the same shape where the last dimension represents
        stacking of delta_x and delta_y of the displacement field.

    inp : np.ndarray
        An array of shape (h, w, 2) or (1, h, w, 2) representing the atlas and input image.

    Returns
    -------
    unwarped_img_list : list
        List of np.ndarrays of shape (h, w) representign the unwarped image at each iteration.

    """
    # Checks
    if inp.ndim == 3:
        inp_ = np.array([inp])

    elif inp.ndim == 4 and inp.shape[0] == 1:
        inp_ = inp

    else:
        raise ValueError("Input has incorrect shape of {}".format(inp.shape))

    shape = inp.shape[1:3]

    df = DisplacementField.generate(shape, approach="identity")

    img_atlas = inp_[0, :, :, 0]
    img_warped = inp_[0, :, :, 1]

    unwarped_img_list = [img_warped]

    for i in range(n_iterations):
        new_inputs = np.concatenate(
            (
                img_atlas[np.newaxis, :, :, np.newaxis],
                unwarped_img_list[-1][np.newaxis, :, :, np.newaxis],
            ),
            axis=3,
        )

        pred = model.predict(new_inputs)

        delta_x_pred = pred[0, ..., 0]
        delta_y_pred = pred[0, ..., 1]

        df_pred = DisplacementField(delta_x_pred, delta_y_pred)

        df_pred_inv = df_pred.pseudo_inverse(ds_f=8)

        df = df_pred_inv(df).adjust()
        img_unwarped_pred = df.warp(img_warped)

        unwarped_img_list.append(img_unwarped_pred)

    return unwarped_img_list
Beispiel #7
0
    def augment(
        self,
        output_path,
        n_iter=10,
        anchor=True,
        p_reg=0.5,
        random_state=None,
        max_corrupted_pixels=500,
        ds_f=8,
        max_trials=5,
    ):
        """Augment the original dataset and create a new one.

        Note that this not modify the original dataset.

        Parameters
        ----------
        output_path : str
            Path to where the new h5 file stored.

        n_iter : int
            Number of augmented samples per each sample in the original dataset.

        anchor : bool
            If True, then dvf anchored before inverted.

        p_reg : bool
            Probability that we start from a registered image
            (rather than the moving).

        random_state : bool
            Random state

        max_corrupted_pixels : int
            Maximum numbr of corrupted pixels allowed for a dvf - the actual
            number is computed as np.sum(df.jacobian() < 0)

        ds_f : int
            Downsampling factor for inverses. 1 creates the least artifacts.

        max_trials : int
            Max number of attemps to augment before an identity displacement
            used as augmentation.
        """
        np.random.seed(random_state)

        n_new = n_iter * self.n_orig
        print(n_new)

        with h5py.File(self.original_path, "r") as f_orig:
            # extract
            dset_img_orig = f_orig["img"]
            dset_image_id_orig = f_orig["image_id"]
            dset_dataset_id_orig = f_orig["dataset_id"]
            dset_deltas_xy_orig = f_orig["deltas_xy"]
            dset_inv_deltas_xy_orig = f_orig["inv_deltas_xy"]
            dset_p_orig = f_orig["p"]

            with h5py.File(output_path, "w") as f_aug:
                dset_img_aug = f_aug.create_dataset(
                    "img", (n_new, 320, 456), dtype="uint8"
                )
                dset_image_id_aug = f_aug.create_dataset(
                    "image_id", (n_new,), dtype="int"
                )
                dset_dataset_id_aug = f_aug.create_dataset(
                    "dataset_id", (n_new,), dtype="int"
                )
                dset_p_aug = f_aug.create_dataset("p", (n_new,), dtype="int")
                dset_deltas_xy_aug = f_aug.create_dataset(
                    "deltas_xy", (n_new, 320, 456, 2), dtype=np.float16
                )
                dset_inv_deltas_xy_aug = f_aug.create_dataset(
                    "inv_deltas_xy", (n_new, 320, 456, 2), dtype=np.float16
                )

                for i in range(n_new):
                    print(i)
                    i_orig = i % self.n_orig

                    mov2reg = DisplacementField(
                        dset_deltas_xy_orig[i_orig, ..., 0],
                        dset_deltas_xy_orig[i_orig, ..., 1],
                    )

                    # copy
                    dset_image_id_aug[i] = dset_image_id_orig[i_orig]
                    dset_dataset_id_aug[i] = dset_dataset_id_orig[i_orig]
                    dset_p_aug[i] = dset_p_orig[i_orig]

                    use_reg = np.random.random() > p_reg
                    print("Using registered: {}".format(use_reg))

                    if not use_reg:
                        # mov != reg
                        img_mov = dset_img_orig[i_orig]
                    else:
                        # mov=reg
                        img_mov = mov2reg.warp(dset_img_orig[i_orig])
                        mov2reg = DisplacementField.generate(
                            (320, 456), approach="identity"
                        )

                    is_nice = False
                    n_trials = 0

                    while not is_nice:
                        n_trials += 1

                        if n_trials == max_trials:
                            print("Replicating original: out of trials")
                            dset_img_aug[i] = dset_img_orig[i_orig]
                            dset_deltas_xy_aug[i] = dset_deltas_xy_orig[i_orig]
                            dset_inv_deltas_xy_aug[i] = dset_inv_deltas_xy_orig[i_orig]
                            break

                        else:
                            mov2art = self.generate_mov2art(img_mov)

                        reg2mov = mov2reg.pseudo_inverse(ds_f=ds_f)
                        reg2art = reg2mov(mov2art)

                        # anchor
                        if anchor:
                            print("ANCHORING")
                            reg2art = reg2art.anchor(
                                ds_f=50, smooth=0, h_kept=0.9, w_kept=0.9
                            )

                        art2reg = reg2art.pseudo_inverse(ds_f=ds_f)

                        validity_check = np.all(
                            np.isfinite(reg2art.delta_x)
                        ) and np.all(np.isfinite(reg2art.delta_y))
                        validity_check &= np.all(
                            np.isfinite(art2reg.delta_x)
                        ) and np.all(np.isfinite(art2reg.delta_y))
                        jacobian_check = (
                            np.sum(reg2art.jacobian < 0) < max_corrupted_pixels
                        )
                        jacobian_check &= (
                            np.sum(art2reg.jacobian < 0) < max_corrupted_pixels
                        )

                        if validity_check and jacobian_check:
                            is_nice = True
                            print("Check passed")
                        else:
                            print("Check failed")

                    if n_trials != max_trials:
                        dset_img_aug[i] = mov2art.warp(img_mov)
                        dset_deltas_xy_aug[i] = np.stack(
                            [art2reg.delta_x, art2reg.delta_y], axis=-1
                        )
                        dset_inv_deltas_xy_aug[i] = np.stack(
                            [reg2art.delta_x, reg2art.delta_y], axis=-1
                        )
Beispiel #8
0
def evaluate_single(
        deltas_true,
        deltas_pred,
        img_mov,
        p=None,
        avol=None,
        collapsing_labels=None,
        deltas_pred_inv=None,
        deltas_true_inv=None,
        ds_f=4,
        depths=(),
):
    """Evaluate a single sample.

    Parameters
    ----------
    deltas_true : DisplacementField or np.ndarray
        If np.ndarray then of shape (height, width, 2) representing deltas_xy of ground truth.

    deltas_pred : DisplacementField or np.ndarray
        If np.ndarray then of shape (height, width, 2) representing deltas_xy of prediction.

    img_mov : np.ndarray
        Moving image.

    p : int
        Coronal section in microns.

    avol : np.ndarray or None
        Annotation volume of shape (528, 320, 456). If None then loaded via `annotation_volume`.

    collapsing_labels : dict or None
        Dictionary for segmentation collapsing. If None then loaded via `segmentation_collapsing_labels`

    deltas_pred_inv : None or np.ndarray
        If np.ndarray then of shape (height, width, 2) representing inv_deltas_xy of prediction. If not provided
        computed from `df_pred`.

    deltas_true_inv : None or np.ndarray
        If np.ndarray then of shape (height, width, 2) representing inv_deltas_xy of truth. If not provided
        computed from `df_true`.

    ds_f : int
        Downsampling factor for numerical inversses.

    depths : tuple
        Tuple of integers representing all depths to compute IOU for. If empty no IOU computation takes places.

    Returns
    -------
    results : pd.Series
        Relevant metrics.
    """
    n_pixels = 320 * 456
    if not (deltas_true.shape == (320, 456, 2)
            and deltas_pred.shape == (320, 456, 2)):
        raise ValueError("Incorrect shape of input")

    df_true = DisplacementField(deltas_true[..., 0], deltas_true[..., 1])
    df_pred = DisplacementField(deltas_pred[..., 0], deltas_pred[..., 1])

    img_reg_true = df_true.warp(img_mov)
    img_reg_pred = df_pred.warp(img_mov)

    all_metrics = {
        "mse_img": mse_img(img_reg_true, img_reg_pred),
        "mae_img": mae_img(img_reg_true, img_reg_pred),
        "psnr_img": psnr_img(img_reg_true, img_reg_pred),
        "ssmi_img": ssmi_img(img_reg_true, img_reg_pred),
        "mi_img": mi_img(img_reg_true, img_reg_pred),
        "cc_img": cross_correlation_img(img_reg_true, img_reg_pred),
        # 'perceptual_img': perceptual_loss_img(img_reg_true, img_reg_pred),
        "norm": df_pred.norm.mean(),
        "corrupted_pixels": np.sum(df_pred.jacobian < 0) / n_pixels,
        "euclidean_distance": vector_distance_combined([df_true],
                                                       [df_pred])[0],
        "angular_error": angular_error_of([df_true], [df_pred],
                                          weighted=True)[0],
    }

    # segmentations metrics
    if not depths:
        return all_metrics
    else:
        # checks
        if avol is None or avol.shape != (528, 320, 456):
            raise ValueError(
                "Incorrectly shaped annotation volume or not provided")

        # Prepare inverses
        if deltas_true_inv is not None:
            df_true_inv = DisplacementField(deltas_true_inv[..., 0],
                                            deltas_true_inv[..., 1])
        else:
            df_true_inv = df_true.pseudo_inverse(ds_f=ds_f)

        if deltas_pred_inv is not None:
            df_pred_inv = DisplacementField(deltas_pred_inv[..., 0],
                                            deltas_pred_inv[..., 1])
        else:
            df_pred_inv = df_pred.pseudo_inverse(ds_f=ds_f)

        # Extract data
        avol_ = annotation_volume() if avol is None else avol
        collapsing_labels_ = (segmentation_collapsing_labels() if
                              collapsing_labels is None else collapsing_labels)

        # Compute
        images = {}
        for depth in depths:
            segm_ref = find_labels_dic(avol_[p // 25], collapsing_labels_,
                                       depth)
            segm_true = df_true_inv.warp_annotation(segm_ref)
            segm_pred = df_pred_inv.warp_annotation(segm_ref)

            images[depth] = (segm_true, segm_pred)

            all_metrics["iou_{}".format(depth)] = iou_score(
                np.array([segm_true]),
                np.array([segm_pred]),
                k=None,
                excluded_labels=[0],
            )[0]

            all_metrics["dice_{}".format(depth)] = dice_score(
                np.array([segm_true]),
                np.array([segm_pred]),
                k=None,
                excluded_labels=[0],
            )[0]

        return all_metrics, images