コード例 #1
0
  def test_build(self):
    batch_size = 32
    dim = 5
    inputs = tf.random.uniform((batch_size * 2, dim))
    indices = tf.reshape(tf.range(batch_size *2), (-1, 2))

    result, = pairs_lib.build(indices, inputs)
    self.assertAllEqual(result[4], tf.stack([inputs[8], inputs[9]]))

    # Shift by one the second column.
    indices = tf.stack(
        [indices[:, 0], tf.roll(indices[:, 1], shift=1, axis=0)], axis=1)
    result, = pairs_lib.build(indices, inputs)
    self.assertAllEqual(result[4], tf.stack([inputs[8], inputs[7]]))
コード例 #2
0
  def call(self, inputs, training = False, embeddings_only = False):
    embeddings = self.encoder(inputs, training=training)
    if embeddings_only:
      return embeddings

    # Recomputes padding mask, this time ensuring special tokens such as EOS or
    # MASK are zeroed out as well, regardless of the value of the flag
    # `self.encoder._mask_special_tokens`.
    masks = self.encoder.compute_mask(inputs, mask_special_tokens=True)
    indices = pairs_lib.consecutive_indices(inputs)
    embedding_pairs, mask_pairs = pairs_lib.build(indices, embeddings, masks)
    alignments = self.aligner(
        embedding_pairs, mask=mask_pairs, training=training)

    # Computes homology scores from SW scores and sequence lengths.
    homology_scores = self.homology_head(
        alignments, mask=mask_pairs, training=training)
    # Removes "dummy" trailing dimension.
    homology_scores = tf.squeeze(homology_scores, axis=-1)

    return {
        'sw_params': alignments[2],
        'sw_scores': alignments[0],
        'paths': alignments[1],
        'homology_logits': homology_scores,
    }
コード例 #3
0
 def call(self, inputs, training=False, embeddings_only=False):
     embeddings = self.encoder(inputs, training=training)
     if embeddings_only:
         return embeddings
     masks = self.encoder.compute_mask(inputs)
     indices = pairs_lib.consecutive_indices(inputs)
     embedding_pairs, mask_pairs = pairs_lib.build(indices, embeddings,
                                                   masks)
     alignments = self.aligner(embedding_pairs,
                               mask=mask_pairs,
                               training=training)
     return alignments
コード例 #4
0
    def forward(self, inputs, selector=None, training=True):
        """Run the models on a single input and potentially selects some heads only.

    Args:
      inputs: a Tensor<int32>[batch, seq_len] representing protein sequences.
      selector: If set a multi_task.Backbone[bool] to specify which head to
        apply. For non selected heads, a None will replace the output. If not
        set, all the heads will be output.
      training: whether to run in training mode or eval mode.

    Returns:
      A multi_task.Backbone of tensor corresponding to the output of the
      different heads of the model.
    """
        selector = self.heads.constant_copy(
            True) if selector is None else selector

        embeddings = self.encoder(inputs, training=training)
        masks = self.encoder.compute_mask(inputs)

        result = multi_task.Backbone()
        for head, on, backprop, in zip(self.heads.embeddings,
                                       selector.embeddings,
                                       self.backprop.embeddings):
            head_output = self.head_output(head,
                                           on,
                                           backprop,
                                           embeddings,
                                           mask=masks,
                                           training=training)
            result.embeddings.append(head_output)

        if not self.heads.alignments or not any(selector.alignments):
            # Ensures structure of result matches self.heads even when method skips
            # alignment phase due to selector.
            for _ in selector.alignments:
                result.alignments.append(tf.constant([]))
            return result

        # For each head, we compute the output of positive pairs and negative ones,
        # then concatenate to obtain an output batch where the first half is
        # positive and the second half is negative.
        outputs = []
        pos_indices = pairs_lib.consecutive_indices(inputs)
        neg_indices = (pairs_lib.roll_indices(pos_indices)
                       if self.process_negatives else None)
        num_alignment_calls = 1 + int(self.process_negatives)
        for indices in (pos_indices, neg_indices)[:num_alignment_calls]:
            curr = []
            embedding_pairs, mask_pairs = pairs_lib.build(
                indices, embeddings, masks)
            alignments = self.aligner(embedding_pairs,
                                      mask=mask_pairs,
                                      training=training)
            for head, on, backprop, in zip(self.heads.alignments,
                                           selector.alignments,
                                           self.backprop.alignments):
                head_output = self.head_output(head,
                                               on,
                                               backprop,
                                               alignments,
                                               mask=mask_pairs,
                                               training=training)
                curr.append(head_output)
            outputs.append(curr)

        for output in merge(*outputs):
            result.alignments.append(output)
        return result