def test_zeros(self): """Tests that two chains equal up to rotation have the same loss.""" loss = losses.ContactLoss(no_contact_fun=gauss_decrease, contact_fun=gauss_increase) batch_size = 8 num_embs = 100 dims = 17 sigma = 0.01 threshold = 1e-4 embs_true = embeddings_chain(batch_size, num_embs, dims, sigma) m_random = tf.random.uniform((batch_size, dims, dims)) _, u_r, _ = tf.linalg.svd(m_random) rotation_random = u_r embs_rotate = tf.transpose(tf.matmul(rotation_random, tf.transpose(embs_true, (0, 2, 1))), (0, 2, 1)) pairw_true = pairs.square_distances(embs_true, embs_true) pairw_rotate = pairs.square_distances(embs_rotate, embs_rotate) contact_true = tf.cast(pairw_true < threshold, dtype=tf.float32) loss_value = loss(contact_true, pairw_true) self.assertAllClose(loss_value, 0, atol=1e-4) self.assertEqual(loss_value, loss(contact_true, pairw_rotate))
def test_from_embs(self): """Tests that the loss yields the same results from embeddings.""" loss_embs = losses.ContactLoss(no_contact_fun=gauss_decrease, contact_fun=gauss_increase, from_embs=True) loss_mat = losses.ContactLoss(no_contact_fun=gauss_decrease, contact_fun=gauss_increase) batch_size = 16 num_embs = 50 dims = 53 sigma = 1e-4 threshold = 1e-4 embs_chain_1 = embeddings_chain(batch_size, num_embs, dims, sigma) embs_chain_2 = embeddings_chain(batch_size, num_embs, dims, sigma) pairw_true = pairs.square_distances(embs_chain_1, embs_chain_1) pairw_pred = pairs.square_distances(embs_chain_2, embs_chain_2) contact_true = tf.cast(pairw_true < threshold, dtype=tf.float32) loss_value_embs = loss_embs(contact_true, embs_chain_2) loss_value_mat = loss_mat(contact_true, pairw_pred) self.assertAllClose(loss_value_embs, loss_value_mat, atol=1e-5)
def call(self, contact_true, pred): """Computes the Contact loss between contact / distance matrices. Args: contact_true: a tf.Tensor<float>[batch_size, num_embs, num_embs], a batch of binary contact matrices for 'num_embs' embeddings. pred: a tf.Tensor<float> of shape either + [batch_size, num_embs, dims] if 'from_embs' is True (embeddings case) a batch of 'num_embs' embeddings in dimension 'dim'. + [batch_size, num_embs, num_embs] if 'from_embs' is False (matrix case) a batch of pairwise distances for 'num_embs' embeddings. Returns: The contact loss values between the contact matrices and predictions in the batch. This is computed for an instance matrix in the batch as: loss(y, p) = sum_ij w_|i-j| fun(y_ij, p_ij), where y is the ground truth contact matrix and p is the predicted pairwise distance matrix. + fun(y_ij, _) is no_contact_fun if y_ij = 0, contact_fun if y_ij = 1. + w_|i-j| is weights_fun(|i-j|), and just |i-j| if None. If from_embs is true, the predicted matrix is the pairwise distance of the predicted embeddings. """ if self._from_embs: pairw_dist_pred = pairs.square_distances(embs_1=pred, embs_2=pred) else: pairw_dist_pred = pred num_embs = tf.shape(pairw_dist_pred)[1] weights_range = tf.range(num_embs, dtype=tf.float32) weights_range_square = tf.abs(weights_range[tf.newaxis, :, tf.newaxis] - weights_range[tf.newaxis, tf.newaxis, :]) weights_batch_square = self._weights_fun(weights_range_square) contact_true = tf.cast(contact_true, dtype=pred.dtype) mat_losses = contact_true * self._contact_fun(pairw_dist_pred) + ( 1 - contact_true) * self._no_contact_fun(pairw_dist_pred) return weights_batch_square * mat_losses
def single_call(self, positions): """Expects positions to be a tf.Tensor<float>[n, 3].""" # Makes a batch of size 1 to be compatible with pairwise_square_dist. pos = tf.expand_dims(positions, 0) sq_dist = pairs.square_distances(pos, pos)[0] return tf.cast(sq_dist < self._threshold**2, dtype=tf.float32)