def test_sw_score(self):
        # Test sw_score from sparse representation.
        sw_score = alignment.sw_score(self.sw_params, self.alignments)
        self.assertAllEqual(sw_score, [95.0])
        # Test sw_score from dense representation.
        paths = alignment.alignments_to_paths(self.alignments, self.len_x,
                                              self.len_y +
                                              1)  # Testing padding too.
        sw_score = alignment.sw_score(self.sw_params, paths)
        self.assertAllEqual(sw_score, [95.0])

        # Test empty alignments / paths
        sw_score = alignment.sw_score(self.sw_params,
                                      tf.zeros_like(self.alignments))
        self.assertAllEqual(sw_score, [0.0])
        sw_score = alignment.sw_score(self.sw_params, tf.zeros_like(paths))
        self.assertAllEqual(sw_score, [0.0])
    def update_state(self,
                     alignments_true,
                     alignments_pred,
                     sample_weight=None):
        """Updates alignment scores for a batch of true and predicted alignments."""
        del sample_weight  # Logic in this metric controlled by process_negatives.

        vals_true = (self._split(alignments_pred[2], False) +
                     self._split(alignments_true, False))
        self._means[self._keys[0]].update_state(alignment.sw_score(*vals_true))

        vals_pred = self._split(alignments_pred[0])
        for k, tensor in zip(self._keys[1:], vals_pred):
            self._means[k].update_state(tensor)
Beispiel #3
0
  def call(self, true_alignments_or_paths,
           alignment_output):
    """Computes a loss associated with the Smith-Waterman DP.

    Args:
      true_alignments_or_paths: The ground-truth alignments for the batch. Both
        sparse and dense representations of the alignments are allowed. For the
        sparse case, true_alignments_or_paths is expected to be a
        tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
        enc_trans], 1) such that (pos_x[b][i], pos_y[b][i], enc_trans[b][i])
        represents the i-th transition in the ground-truth alignment for example
        b in the minibatch. Both pos_x and pos_y are assumed to use one-based
        indexing and enc_trans follows the (categorical) 9-state encoding of
        edge types used throughout alignment.py. For the dense case,
        true_alignments_or_paths is instead expected to be a
        tf.Tensor<float>[batch, len_x, len_y, 9] with binary entries,
        representing the trajectory of the indices along the predicted alignment
        paths, by having a one along the taken edges, with nine possible edges
        for each i,j.
      alignment_output: An AlignmentOutput, which is a tuple (solution_values,
        solution_paths, sw_params) such that + 'solution_values' contains a
        tf.Tensor<float>[batch] with the (soft) optimal Smith-Waterman scores
        for the batch. + 'solution_paths', which is not used by the loss,
        optionally contains a tf.Tensor<float>[batch, len1, len2, 9] that
        describes the optimal soft alignments, being None otherwise. +
        'sw_params' contains a tuple (sim_mat, gap_open, gap_extend) of
        tf.Tensor objects parameterizing the Smith-Waterman LP such that +
        sim_mat is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
        substitution values for pairs of sequences. + gap_open is a
        tf.Tensor<float>[], tf.Tensor<float>[batch] or tf.Tensor<float>[batch,
        len1, len2] (len1 <= len2) with the penalties for opening a gap. Must
        agree in rank with gap_extend.
          + gap_extend: a tf.Tensor<float>[], tf.Tensor<float>[batch] or
            tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
            penalties for with the penalties for extending a gap. Must agree in
            rank with gap_open.

    Returns:
      The loss value for each example in the batch.
    """
    solution_values, _, sw_params = alignment_output
    return (solution_values -
            alignment.sw_score(sw_params, true_alignments_or_paths))