def test_compute_centroid_d(): """ Testing compute centroid distance between equal tensors returns 0s. """ array_ones = np.ones((2, 2)) tensor_mask = np.zeros((3, 2, 2, 2)) tensor_mask[0, :, :, :] = array_ones tensor_mask = tf.convert_to_tensor(tensor_mask, dtype=tf.float32) tensor_grid = np.zeros((2, 2, 2, 3)) tensor_grid[:, :, :, 0] = array_ones tensor_grid = tf.convert_to_tensor(tensor_grid, dtype=tf.float32) get = label.compute_centroid_distance(tensor_mask, tensor_mask, tensor_grid) expect = np.zeros((3)) assert is_equal_tf(get, expect)
def calculate_metrics( fixed_image: tf.Tensor, fixed_label: (tf.Tensor, None), pred_fixed_image: (tf.Tensor, None), pred_fixed_label: (tf.Tensor, None), fixed_grid_ref: tf.Tensor, sample_index: int, ) -> dict: """ Calculate image/label based metrics. :param fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3) :param fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None :param pred_fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3) :param pred_fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) :param sample_index: int, :return: dictionary of metrics """ if pred_fixed_image is not None: y_true = fixed_image[sample_index:(sample_index + 1), :, :, :] y_pred = pred_fixed_image[sample_index:(sample_index + 1), :, :, :] y_true = tf.expand_dims(y_true, axis=4) y_pred = tf.expand_dims(y_pred, axis=4) ssd = image_loss.SumSquaredDifference()(y_true=y_true, y_pred=y_pred).numpy() else: ssd = None if fixed_label is not None and pred_fixed_label is not None: y_true = fixed_label[sample_index:(sample_index + 1), :, :, :] y_pred = pred_fixed_label[sample_index:(sample_index + 1), :, :, :] dice = label_loss.DiceScore(binary=True)(y_true=y_true, y_pred=y_pred).numpy() tre = label_loss.compute_centroid_distance( y_true=y_true, y_pred=y_pred, grid=fixed_grid_ref[0, :, :, :, :]).numpy()[0] else: dice = None tre = None return dict(image_ssd=ssd, label_binary_dice=dice, label_tre=tre)
def build_loss(self): """Build losses according to configs.""" # input metrics fixed_image = self._inputs["fixed_image"] moving_image = self._inputs["moving_image"] self.log_tensor_stats(tensor=moving_image, name="moving_image") self.log_tensor_stats(tensor=fixed_image, name="fixed_image") # image loss, conditional model does not have this if "pred_fixed_image" in self._outputs: pred_fixed_image = self._outputs["pred_fixed_image"] self._build_loss( name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image), ) if self.labeled: # input metrics fixed_label = self._inputs["fixed_label"] moving_label = self._inputs["moving_label"] self.log_tensor_stats(tensor=moving_label, name="moving_label") self.log_tensor_stats(tensor=fixed_label, name="fixed_label") # label loss pred_fixed_label = self._outputs["pred_fixed_label"] self._build_loss( name="label", inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label), ) # additional label metrics tre = compute_centroid_distance(y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref) self._model.add_metric(tre, name="metric/TRE", aggregation="mean")