def call(self, embeddings, mask=None, training=None): """Computes aligments and scores ala Bepler et al. ICLR 2019. Args: embeddings: a tf.Tensor<float>[batch, 2, len, dim] containing pairs of sequence embeddings (with the sequence lengths). mask: An optional token mask to account for padding. training: whether to run the layer for train (True), eval (False) or let the Keras backend decide (None). Returns: A NaiveAlignmentOutput which is a 3-tuple made of: - The alignment scores: tf.Tensor<float>[batch]. - The pairwise match probabilities: tf.Tensor<int>[batch, len, len]. - A 3-tuple containing the similarities, gap open and gap extend. Here similaries is tf.Tensor<float>[batch, len, len] that simply encodes the padding mask, taking value 0.0 for "real" tokens or 1e9 for padding / special tokens. The gap penalties are tf.Tensor<float>[batch] of zeroes, present for consistency in the output signature. """ batch, dtype = tf.shape(embeddings)[0], embeddings.dtype scores, match_indicators_pred = self._similarity(embeddings, mask=mask) # Here sim_mat has no real purpose other than passing the padding mask to # the loss and metrics for the corresponding output head. sim_mat = tf.where(pairs_lib.pair_masks(mask[:, 0], mask[:, 1]), 0.0, 1e9) sw_params = (sim_mat, tf.zeros([batch], dtype), tf.zeros([batch], dtype)) return scores, match_indicators_pred, sw_params
def call(self, embeddings, mask=None): """Computes the forward pass for the soft symmetric alignment layer. Args: embeddings: A tf.Tensor[batch, 2, len, dim] with the embeddings of the two sequences. mask: A tf.Tensor[batch, 2, len] with the paddings masks of the two sequences. Returns: The soft symmetric alignment similarity scores, as defined by the paper Bepler et al. - Learning protein sequence embeddings using information from structure. ICLR 2019, represented by a 1D tf.Tensor of dimension batch_size. If return_att_weights is True, it will additionally return the soft symmetric alignments weights as a tf.Tensor<float>[batch, len, len] with entries in [0, 1]. """ if self._proj: embeddings = self.dense(embeddings) pair_dist = self.pairwise_distance(embeddings[:, 0], embeddings[:, 1]) pair_mask = pairs_lib.pair_masks(mask[:, 0], mask[:, 1]) a = self.softmax_a(-pair_dist, pair_mask) b = self.softmax_b(-pair_dist, pair_mask) att_weights = tf.where(pair_mask, (a + b - a * b), 0.0) scores = -tf.reduce_sum(att_weights * pair_dist, (1, 2)) scores /= tf.reduce_sum(att_weights, (1, 2)) return (scores, att_weights) if self._return_att_weights else scores
def call(self, inputs, mask=None, training=None): """Evaluates bilinear form for (batched) sets of vector pairs. Args: inputs: a tf.Tensor<float>[batch, 2, len, dim] representing two inputs. mask: a tf.Tensor<float>[batch, 2, len] to account for padding. training: whether to run the layer for train (True), eval (False) or let the Keras backend decide (None). Returns: A tf.Tensor<float>[batch, len, len] s.t. out[n][i][j] := activation( (x[n][i]^{T} W y[n][j]) / norm_factor + b), where the bilinear form matrix W can optionally be set to be the identity matrix (use_kernel = False) or optionally frozen to its initialization value (trainable_kernel = False) and the scalar bias b can be optionally set to zero (use_bias = False) or likewise optionally frozen to its initialization value (trainable_bias=False). If sqrt_norm is True, the scalar norm_factor above is set to sqrt(d), following dot-product attention. Otherwise, norm_factor = 1.0. Finally, if either masks_x[n][i] = 0 or masks_y[n][j] = 0 and mask_penalty is not None, then out[n][i][j] = mask_penalty instead. """ inputs = self.dropout(inputs, training=training) x, y = inputs[:, 0], inputs[:, 1] if not self._use_kernel: output = tf.einsum('ijk,ilk->ijl', x, y) else: w = self.kernel if self._symmetric_kernel: w = 0.5 * (w + tf.transpose(w)) output = tf.einsum('nir,rs,njs->nij', x, w, y) if self._sqrt_norm: dim_x, dim_y = tf.shape(x)[-1], tf.shape(y)[-1] dim = tf.sqrt(tf.cast(dim_x * dim_y, output.dtype)) output /= tf.sqrt(dim) if self._use_bias: output += self.bias if self._activation is not None: output = self._activation(output) if self._mask_penalty is not None and mask is not None: paired_masks = pairs_lib.pair_masks(mask[:, 0], mask[:, 1]) output = tf.where(paired_masks, output, self._mask_penalty) return output