def test_alignments_to_state_indices(self):
        exp_out = tf.convert_to_tensor(
            [[0, 0, 0], [0, 1, 1], [0, 2, 3], [0, 5, 4], [0, 6, 5], [0, 7, 6]],
            tf.int32)
        match_indices = alignment.alignments_to_state_indices(
            self.alignments, 'match')
        self.assertAllEqual(match_indices, exp_out)

        # Invariance to padding.
        padded_alignments = tf.concat(
            [self.alignments, tf.zeros([1, 3, 3], tf.int32)], 2)
        match_indices = alignment.alignments_to_state_indices(
            padded_alignments, 'match')

        # Deals with empty ground-truth alignments.
        exp_out = tf.zeros([0, 3], tf.int32)
        match_indices = alignment.alignments_to_state_indices(
            tf.zeros_like(self.alignments), 'match')
        self.assertAllEqual(match_indices, exp_out)

        # Gaps.
        self.assertAllEqual(match_indices, exp_out)
        exp_out = tf.convert_to_tensor([[0, 1, 2], [0, 3, 3]], tf.int32)
        gap_open_indices = alignment.alignments_to_state_indices(
            self.alignments, 'gap_open')
        self.assertAllEqual(gap_open_indices, exp_out)

        exp_out = tf.convert_to_tensor([[0, 4, 3]], tf.int32)
        gap_extend_indices = alignment.alignments_to_state_indices(
            self.alignments, 'gap_extend')
        self.assertAllEqual(gap_extend_indices, exp_out)
    def update_state(self,
                     alignments_true,
                     alignments_pred,
                     sample_weight=None):
        """Updates precision, recall for a batch of true, pred alignments."""
        if alignments_pred[1] is None:
            return

        _, match_indicators_pred, sw_params = alignments_pred
        sim_mat, _, _ = sw_params
        shape, dtype = sim_mat.shape, match_indicators_pred.dtype

        match_indices_true = alignment.alignments_to_state_indices(
            alignments_true, 'match')
        updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype)
        match_indicators_true = tf.scatter_nd(match_indices_true,
                                              updates_true,
                                              shape=shape)

        batch = tf.shape(sample_weight)[0]
        sample_weight = tf.reshape(sample_weight, [batch, 1, 1])
        mask = alignment.mask_from_similarities(sim_mat, dtype=dtype)

        self._precision.update_state(match_indicators_true,
                                     match_indicators_pred,
                                     sample_weight * mask)
        self._recall.update_state(match_indicators_true, match_indicators_pred,
                                  sample_weight * mask)
def _confusion_matrix(alignments_true, sol_paths_pred):
    """Computes true, predicted and actual positives for a batch of alignments."""
    batch_size = tf.shape(alignments_true)[0]

    # Computes the number of true positives per example as an (sparse) inner
    # product of two binary tensors of shape (batch_size, len_x, len_y) via
    # indexing. Entirely avoids materializing one of the two tensors explicitly.
    match_indices_true = alignment.alignments_to_state_indices(
        alignments_true, 'match')  # [n_aligned_chars_true, 3]
    match_indicators_pred = alignment.paths_to_state_indicators(
        sol_paths_pred, 'match')  # [batch, len_x, len_y]
    batch_indicators = match_indices_true[:, 0]  # [n_aligned_chars_true]
    matches_flat = tf.gather_nd(match_indicators_pred,
                                match_indices_true)  # [n_aligned_chars_true]
    true_positives = tf.math.unsorted_segment_sum(matches_flat,
                                                  batch_indicators,
                                                  batch_size)  # [batch]

    # Compute number of predicted and ground-truth positives per example.
    pred_positives = tf.reduce_sum(match_indicators_pred, axis=[1, 2])
    # Note(fllinares): tf.math.bincount unsupported in TPU :(
    cond_positives = tf.math.unsorted_segment_sum(
        tf.ones_like(batch_indicators, tf.float32), batch_indicators,
        batch_size)  # [batch]
    return true_positives, pred_positives, cond_positives
예제 #4
0
    def call(self, true_alignments, alignment_output):
        """Computes a brute-force BCE loss for pairwise sequence alignment.

    Args:
      true_alignments: The ground-truth alignments for the batch, given by a
        expected tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
        enc_trans], 1) such that (pos_x[b][i], pos_y[b][i], enc_trans[b][i])
        represents the i-th transition in the ground-truth alignment for example
        b in the minibatch. Both pos_x and pos_y are assumed to use one-based
        indexing and enc_trans follows the (categorical) 9-state encoding of
        edge types used throughout alignment.py.
      alignment_output: A NaiveAlignmentOutput, which is a 3-tuple made of:
        + The alignment scores: tf.Tensor<float>[batch].
        + The pairwise match probabilities: tf.Tensor<int>[batch, len, len].
        + A 3-tuple containing the Smith-Waterman parameters: similarities, gap
          open and gap extend. Similaries is tf.Tensor<float>[batch, len, len],
          the gap penalties can be either tf.Tensor<float>[batch] or
          tf.Tensor<float>[batch, len, len].

    Returns:
      The loss value for each example in the batch.
    """
        _, match_indicators_pred, sw_params = alignment_output
        sim_mat, _, _ = sw_params
        shape, dtype = sim_mat.shape, match_indicators_pred.dtype

        match_indices_true = alignment.alignments_to_state_indices(
            true_alignments, 'match')
        updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype)
        match_indicators_true = tf.scatter_nd(match_indices_true,
                                              updates_true,
                                              shape=shape)

        raw_losses = tf.losses.binary_crossentropy(
            match_indicators_true[Ellipsis, tf.newaxis],
            match_indicators_pred[Ellipsis, tf.newaxis])

        mask = alignment.mask_from_similarities(sim_mat,
                                                dtype=dtype,
                                                pad_penalty=self._pad_penalty)
        return tf.reduce_sum(mask * raw_losses, axis=[1, 2])