def test_alignments_to_paths(self):
        exp_out = tf.convert_to_tensor(
            [[[1., 0., 0., 0., 0., 0., 0.], [0., 2., 5., 0., 0., 0., 0.],
              [0., 0., 0., 3., 0., 0., 0.], [0., 0., 0., 7., 0., 0., 0.],
              [0., 0., 0., 9., 0., 0., 0.], [0., 0., 0., 0., 4., 0., 0.],
              [0., 0., 0., 0., 0., 2., 0.], [0., 0., 0., 0., 0., 0., 2.]]],
            tf.float32)

        paths = alignment.alignments_to_paths(self.alignments, self.len_x,
                                              self.len_y)
        sq_paths = alignment.path_label_squeeze(paths)
        self.assertAllEqual(sq_paths, exp_out)

        # Invariance to padding.
        padded_alignments = tf.concat(
            [self.alignments, tf.zeros([1, 3, 3], tf.int32)], 2)
        paths = alignment.alignments_to_paths(padded_alignments, self.len_x,
                                              self.len_y)
        sq_paths = alignment.path_label_squeeze(paths)
        self.assertAllEqual(sq_paths, exp_out)

        # Pads correctly via length arguments.
        paths = alignment.alignments_to_paths(self.alignments, self.len_x + 3,
                                              self.len_y + 3)
        sq_paths = alignment.path_label_squeeze(paths)
        self.assertAllEqual(sq_paths[:, :self.len_x, :self.len_y], exp_out)
        self.assertAllEqual(sq_paths[:, self.len_x:, :],
                            tf.zeros([1, 3, self.len_y + 3], tf.float32))
        self.assertAllEqual(sq_paths[Ellipsis, self.len_y:],
                            tf.zeros([1, self.len_x + 3, 3], tf.float32))

        # Deals with empty ground-truth alignments.
        paths = alignment.alignments_to_paths(tf.zeros_like(self.alignments),
                                              self.len_x, self.len_y)
        sq_paths = alignment.path_label_squeeze(paths)
        self.assertAllEqual(sq_paths, tf.zeros_like(exp_out))
    def update_state(self,
                     alignments_true,
                     alignments_pred,
                     sample_weight=None):
        """Updates mean squared error for a batch of true vs pred alignments."""
        if alignments_pred[1] is None:
            return

        sol_paths_pred = alignments_pred[1]
        len_x, len_y = tf.shape(sol_paths_pred)[1], tf.shape(sol_paths_pred)[2]
        sol_paths_true = alignment.alignments_to_paths(alignments_true, len_x,
                                                       len_y)
        mse = tf.reduce_sum((sol_paths_pred - sol_paths_true)**2,
                            axis=[1, 2, 3])
        super().update_state(mse, sample_weight)
    def test_sw_score(self):
        # Test sw_score from sparse representation.
        sw_score = alignment.sw_score(self.sw_params, self.alignments)
        self.assertAllEqual(sw_score, [95.0])
        # Test sw_score from dense representation.
        paths = alignment.alignments_to_paths(self.alignments, self.len_x,
                                              self.len_y +
                                              1)  # Testing padding too.
        sw_score = alignment.sw_score(self.sw_params, paths)
        self.assertAllEqual(sw_score, [95.0])

        # Test empty alignments / paths
        sw_score = alignment.sw_score(self.sw_params,
                                      tf.zeros_like(self.alignments))
        self.assertAllEqual(sw_score, [0.0])
        sw_score = alignment.sw_score(self.sw_params, tf.zeros_like(paths))
        self.assertAllEqual(sw_score, [0.0])