コード例 #1
0
def test_compute_centroid_d():
    """
    Testing compute centroid distance between equal
    tensors returns 0s.
    """
    array_ones = np.ones((2, 2))
    tensor_mask = np.zeros((3, 2, 2, 2))
    tensor_mask[0, :, :, :] = array_ones
    tensor_mask = tf.convert_to_tensor(tensor_mask, dtype=tf.float32)

    tensor_grid = np.zeros((2, 2, 2, 3))
    tensor_grid[:, :, :, 0] = array_ones
    tensor_grid = tf.convert_to_tensor(tensor_grid, dtype=tf.float32)

    get = label.compute_centroid_distance(tensor_mask, tensor_mask,
                                          tensor_grid)
    expect = np.zeros((3))
    assert is_equal_tf(get, expect)
コード例 #2
0
def calculate_metrics(
    fixed_image: tf.Tensor,
    fixed_label: (tf.Tensor, None),
    pred_fixed_image: (tf.Tensor, None),
    pred_fixed_label: (tf.Tensor, None),
    fixed_grid_ref: tf.Tensor,
    sample_index: int,
) -> dict:
    """
    Calculate image/label based metrics.
    :param fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3)
    :param fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None
    :param pred_fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3)
    :param pred_fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
    :param sample_index: int,
    :return: dictionary of metrics
    """

    if pred_fixed_image is not None:
        y_true = fixed_image[sample_index:(sample_index + 1), :, :, :]
        y_pred = pred_fixed_image[sample_index:(sample_index + 1), :, :, :]
        y_true = tf.expand_dims(y_true, axis=4)
        y_pred = tf.expand_dims(y_pred, axis=4)
        ssd = image_loss.SumSquaredDifference()(y_true=y_true,
                                                y_pred=y_pred).numpy()
    else:
        ssd = None

    if fixed_label is not None and pred_fixed_label is not None:
        y_true = fixed_label[sample_index:(sample_index + 1), :, :, :]
        y_pred = pred_fixed_label[sample_index:(sample_index + 1), :, :, :]
        dice = label_loss.DiceScore(binary=True)(y_true=y_true,
                                                 y_pred=y_pred).numpy()
        tre = label_loss.compute_centroid_distance(
            y_true=y_true, y_pred=y_pred,
            grid=fixed_grid_ref[0, :, :, :, :]).numpy()[0]
    else:
        dice = None
        tre = None

    return dict(image_ssd=ssd, label_binary_dice=dice, label_tre=tre)
コード例 #3
0
ファイル: network.py プロジェクト: mianasbat/DeepReg
    def build_loss(self):
        """Build losses according to configs."""

        # input metrics
        fixed_image = self._inputs["fixed_image"]
        moving_image = self._inputs["moving_image"]
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")

        # image loss, conditional model does not have this
        if "pred_fixed_image" in self._outputs:
            pred_fixed_image = self._outputs["pred_fixed_image"]
            self._build_loss(
                name="image",
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
            )

        if self.labeled:
            # input metrics
            fixed_label = self._inputs["fixed_label"]
            moving_label = self._inputs["moving_label"]
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")

            # label loss
            pred_fixed_label = self._outputs["pred_fixed_label"]
            self._build_loss(
                name="label",
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
            )

            # additional label metrics
            tre = compute_centroid_distance(y_true=fixed_label,
                                            y_pred=pred_fixed_label,
                                            grid=self.grid_ref)
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")