def test_alignments_to_state_indices(self): exp_out = tf.convert_to_tensor( [[0, 0, 0], [0, 1, 1], [0, 2, 3], [0, 5, 4], [0, 6, 5], [0, 7, 6]], tf.int32) match_indices = alignment.alignments_to_state_indices( self.alignments, 'match') self.assertAllEqual(match_indices, exp_out) # Invariance to padding. padded_alignments = tf.concat( [self.alignments, tf.zeros([1, 3, 3], tf.int32)], 2) match_indices = alignment.alignments_to_state_indices( padded_alignments, 'match') # Deals with empty ground-truth alignments. exp_out = tf.zeros([0, 3], tf.int32) match_indices = alignment.alignments_to_state_indices( tf.zeros_like(self.alignments), 'match') self.assertAllEqual(match_indices, exp_out) # Gaps. self.assertAllEqual(match_indices, exp_out) exp_out = tf.convert_to_tensor([[0, 1, 2], [0, 3, 3]], tf.int32) gap_open_indices = alignment.alignments_to_state_indices( self.alignments, 'gap_open') self.assertAllEqual(gap_open_indices, exp_out) exp_out = tf.convert_to_tensor([[0, 4, 3]], tf.int32) gap_extend_indices = alignment.alignments_to_state_indices( self.alignments, 'gap_extend') self.assertAllEqual(gap_extend_indices, exp_out)
def update_state(self, alignments_true, alignments_pred, sample_weight=None): """Updates precision, recall for a batch of true, pred alignments.""" if alignments_pred[1] is None: return _, match_indicators_pred, sw_params = alignments_pred sim_mat, _, _ = sw_params shape, dtype = sim_mat.shape, match_indicators_pred.dtype match_indices_true = alignment.alignments_to_state_indices( alignments_true, 'match') updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype) match_indicators_true = tf.scatter_nd(match_indices_true, updates_true, shape=shape) batch = tf.shape(sample_weight)[0] sample_weight = tf.reshape(sample_weight, [batch, 1, 1]) mask = alignment.mask_from_similarities(sim_mat, dtype=dtype) self._precision.update_state(match_indicators_true, match_indicators_pred, sample_weight * mask) self._recall.update_state(match_indicators_true, match_indicators_pred, sample_weight * mask)
def _confusion_matrix(alignments_true, sol_paths_pred): """Computes true, predicted and actual positives for a batch of alignments.""" batch_size = tf.shape(alignments_true)[0] # Computes the number of true positives per example as an (sparse) inner # product of two binary tensors of shape (batch_size, len_x, len_y) via # indexing. Entirely avoids materializing one of the two tensors explicitly. match_indices_true = alignment.alignments_to_state_indices( alignments_true, 'match') # [n_aligned_chars_true, 3] match_indicators_pred = alignment.paths_to_state_indicators( sol_paths_pred, 'match') # [batch, len_x, len_y] batch_indicators = match_indices_true[:, 0] # [n_aligned_chars_true] matches_flat = tf.gather_nd(match_indicators_pred, match_indices_true) # [n_aligned_chars_true] true_positives = tf.math.unsorted_segment_sum(matches_flat, batch_indicators, batch_size) # [batch] # Compute number of predicted and ground-truth positives per example. pred_positives = tf.reduce_sum(match_indicators_pred, axis=[1, 2]) # Note(fllinares): tf.math.bincount unsupported in TPU :( cond_positives = tf.math.unsorted_segment_sum( tf.ones_like(batch_indicators, tf.float32), batch_indicators, batch_size) # [batch] return true_positives, pred_positives, cond_positives
def call(self, true_alignments, alignment_output): """Computes a brute-force BCE loss for pairwise sequence alignment. Args: true_alignments: The ground-truth alignments for the batch, given by a expected tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y, enc_trans], 1) such that (pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th transition in the ground-truth alignment for example b in the minibatch. Both pos_x and pos_y are assumed to use one-based indexing and enc_trans follows the (categorical) 9-state encoding of edge types used throughout alignment.py. alignment_output: A NaiveAlignmentOutput, which is a 3-tuple made of: + The alignment scores: tf.Tensor<float>[batch]. + The pairwise match probabilities: tf.Tensor<int>[batch, len, len]. + A 3-tuple containing the Smith-Waterman parameters: similarities, gap open and gap extend. Similaries is tf.Tensor<float>[batch, len, len], the gap penalties can be either tf.Tensor<float>[batch] or tf.Tensor<float>[batch, len, len]. Returns: The loss value for each example in the batch. """ _, match_indicators_pred, sw_params = alignment_output sim_mat, _, _ = sw_params shape, dtype = sim_mat.shape, match_indicators_pred.dtype match_indices_true = alignment.alignments_to_state_indices( true_alignments, 'match') updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype) match_indicators_true = tf.scatter_nd(match_indices_true, updates_true, shape=shape) raw_losses = tf.losses.binary_crossentropy( match_indicators_true[Ellipsis, tf.newaxis], match_indicators_pred[Ellipsis, tf.newaxis]) mask = alignment.mask_from_similarities(sim_mat, dtype=dtype, pad_penalty=self._pad_penalty) return tf.reduce_sum(mask * raw_losses, axis=[1, 2])