def test_build(self): batch_size = 32 dim = 5 inputs = tf.random.uniform((batch_size * 2, dim)) indices = tf.reshape(tf.range(batch_size *2), (-1, 2)) result, = pairs_lib.build(indices, inputs) self.assertAllEqual(result[4], tf.stack([inputs[8], inputs[9]])) # Shift by one the second column. indices = tf.stack( [indices[:, 0], tf.roll(indices[:, 1], shift=1, axis=0)], axis=1) result, = pairs_lib.build(indices, inputs) self.assertAllEqual(result[4], tf.stack([inputs[8], inputs[7]]))
def call(self, inputs, training = False, embeddings_only = False): embeddings = self.encoder(inputs, training=training) if embeddings_only: return embeddings # Recomputes padding mask, this time ensuring special tokens such as EOS or # MASK are zeroed out as well, regardless of the value of the flag # `self.encoder._mask_special_tokens`. masks = self.encoder.compute_mask(inputs, mask_special_tokens=True) indices = pairs_lib.consecutive_indices(inputs) embedding_pairs, mask_pairs = pairs_lib.build(indices, embeddings, masks) alignments = self.aligner( embedding_pairs, mask=mask_pairs, training=training) # Computes homology scores from SW scores and sequence lengths. homology_scores = self.homology_head( alignments, mask=mask_pairs, training=training) # Removes "dummy" trailing dimension. homology_scores = tf.squeeze(homology_scores, axis=-1) return { 'sw_params': alignments[2], 'sw_scores': alignments[0], 'paths': alignments[1], 'homology_logits': homology_scores, }
def call(self, inputs, training=False, embeddings_only=False): embeddings = self.encoder(inputs, training=training) if embeddings_only: return embeddings masks = self.encoder.compute_mask(inputs) indices = pairs_lib.consecutive_indices(inputs) embedding_pairs, mask_pairs = pairs_lib.build(indices, embeddings, masks) alignments = self.aligner(embedding_pairs, mask=mask_pairs, training=training) return alignments
def forward(self, inputs, selector=None, training=True): """Run the models on a single input and potentially selects some heads only. Args: inputs: a Tensor<int32>[batch, seq_len] representing protein sequences. selector: If set a multi_task.Backbone[bool] to specify which head to apply. For non selected heads, a None will replace the output. If not set, all the heads will be output. training: whether to run in training mode or eval mode. Returns: A multi_task.Backbone of tensor corresponding to the output of the different heads of the model. """ selector = self.heads.constant_copy( True) if selector is None else selector embeddings = self.encoder(inputs, training=training) masks = self.encoder.compute_mask(inputs) result = multi_task.Backbone() for head, on, backprop, in zip(self.heads.embeddings, selector.embeddings, self.backprop.embeddings): head_output = self.head_output(head, on, backprop, embeddings, mask=masks, training=training) result.embeddings.append(head_output) if not self.heads.alignments or not any(selector.alignments): # Ensures structure of result matches self.heads even when method skips # alignment phase due to selector. for _ in selector.alignments: result.alignments.append(tf.constant([])) return result # For each head, we compute the output of positive pairs and negative ones, # then concatenate to obtain an output batch where the first half is # positive and the second half is negative. outputs = [] pos_indices = pairs_lib.consecutive_indices(inputs) neg_indices = (pairs_lib.roll_indices(pos_indices) if self.process_negatives else None) num_alignment_calls = 1 + int(self.process_negatives) for indices in (pos_indices, neg_indices)[:num_alignment_calls]: curr = [] embedding_pairs, mask_pairs = pairs_lib.build( indices, embeddings, masks) alignments = self.aligner(embedding_pairs, mask=mask_pairs, training=training) for head, on, backprop, in zip(self.heads.alignments, selector.alignments, self.backprop.alignments): head_output = self.head_output(head, on, backprop, alignments, mask=mask_pairs, training=training) curr.append(head_output) outputs.append(curr) for output in merge(*outputs): result.alignments.append(output) return result