Exemplo n.º 1
0
 def length_fn(x):
     masks = loop.model.encoder.compute_mask(x)
     seq_lens = tf.reduce_sum(tf.cast(masks, tf.int32), 1)
     pos_indices = pairs_lib.consecutive_indices(x)
     seq_lens_pos = tf.gather(seq_lens, pos_indices)
     seq_lens_neg = tf.gather(seq_lens, pairs_lib.roll_indices(pos_indices))
     seq_lens = tf.concat([seq_lens_pos, seq_lens_neg], 0)
     return seq_lens
    def call(self, values):
        def get_vals(indices):
            vals = tf.gather(values, indices)
            return tf.cast(vals[:, 0] == vals[:, 1], tf.int32)

        pos_indices = pairs.consecutive_indices(values)
        neg_indices = pairs.roll_indices(pos_indices)
        targets = [get_vals(pos_indices)]
        if self._process_negatives:
            targets.append(get_vals(neg_indices))
        return tf.concat(targets, 0)[:, tf.newaxis]  # [batch, 1]
Exemplo n.º 3
0
    def forward(self, inputs, selector=None, training=True):
        """Run the models on a single input and potentially selects some heads only.

    Args:
      inputs: a Tensor<int32>[batch, seq_len] representing protein sequences.
      selector: If set a multi_task.Backbone[bool] to specify which head to
        apply. For non selected heads, a None will replace the output. If not
        set, all the heads will be output.
      training: whether to run in training mode or eval mode.

    Returns:
      A multi_task.Backbone of tensor corresponding to the output of the
      different heads of the model.
    """
        selector = self.heads.constant_copy(
            True) if selector is None else selector

        embeddings = self.encoder(inputs, training=training)
        masks = self.encoder.compute_mask(inputs)

        result = multi_task.Backbone()
        for head, on, backprop, in zip(self.heads.embeddings,
                                       selector.embeddings,
                                       self.backprop.embeddings):
            head_output = self.head_output(head,
                                           on,
                                           backprop,
                                           embeddings,
                                           mask=masks,
                                           training=training)
            result.embeddings.append(head_output)

        if not self.heads.alignments or not any(selector.alignments):
            # Ensures structure of result matches self.heads even when method skips
            # alignment phase due to selector.
            for _ in selector.alignments:
                result.alignments.append(tf.constant([]))
            return result

        # For each head, we compute the output of positive pairs and negative ones,
        # then concatenate to obtain an output batch where the first half is
        # positive and the second half is negative.
        outputs = []
        pos_indices = pairs_lib.consecutive_indices(inputs)
        neg_indices = (pairs_lib.roll_indices(pos_indices)
                       if self.process_negatives else None)
        num_alignment_calls = 1 + int(self.process_negatives)
        for indices in (pos_indices, neg_indices)[:num_alignment_calls]:
            curr = []
            embedding_pairs, mask_pairs = pairs_lib.build(
                indices, embeddings, masks)
            alignments = self.aligner(embedding_pairs,
                                      mask=mask_pairs,
                                      training=training)
            for head, on, backprop, in zip(self.heads.alignments,
                                           selector.alignments,
                                           self.backprop.alignments):
                head_output = self.head_output(head,
                                               on,
                                               backprop,
                                               alignments,
                                               mask=mask_pairs,
                                               training=training)
                curr.append(head_output)
            outputs.append(curr)

        for output in merge(*outputs):
            result.alignments.append(output)
        return result