Beispiel #1
0
  def test_zeros(self):
    """Tests that two chains equal up to rotation have the same loss."""
    loss = losses.ContactLoss(no_contact_fun=gauss_decrease,
                              contact_fun=gauss_increase)
    batch_size = 8
    num_embs = 100
    dims = 17
    sigma = 0.01
    threshold = 1e-4
    embs_true = embeddings_chain(batch_size,
                                 num_embs,
                                 dims,
                                 sigma)

    m_random = tf.random.uniform((batch_size, dims, dims))
    _, u_r, _ = tf.linalg.svd(m_random)
    rotation_random = u_r

    embs_rotate = tf.transpose(tf.matmul(rotation_random,
                                         tf.transpose(embs_true, (0, 2, 1))),
                               (0, 2, 1))
    pairw_true = pairs.square_distances(embs_true, embs_true)
    pairw_rotate = pairs.square_distances(embs_rotate, embs_rotate)
    contact_true = tf.cast(pairw_true < threshold, dtype=tf.float32)
    loss_value = loss(contact_true, pairw_true)
    self.assertAllClose(loss_value, 0, atol=1e-4)
    self.assertEqual(loss_value, loss(contact_true, pairw_rotate))
 def test_from_embs(self):
     """Tests that the loss yields the same results from embeddings."""
     loss_embs = losses.ContactLoss(no_contact_fun=gauss_decrease,
                                    contact_fun=gauss_increase,
                                    from_embs=True)
     loss_mat = losses.ContactLoss(no_contact_fun=gauss_decrease,
                                   contact_fun=gauss_increase)
     batch_size = 16
     num_embs = 50
     dims = 53
     sigma = 1e-4
     threshold = 1e-4
     embs_chain_1 = embeddings_chain(batch_size, num_embs, dims, sigma)
     embs_chain_2 = embeddings_chain(batch_size, num_embs, dims, sigma)
     pairw_true = pairs.square_distances(embs_chain_1, embs_chain_1)
     pairw_pred = pairs.square_distances(embs_chain_2, embs_chain_2)
     contact_true = tf.cast(pairw_true < threshold, dtype=tf.float32)
     loss_value_embs = loss_embs(contact_true, embs_chain_2)
     loss_value_mat = loss_mat(contact_true, pairw_pred)
     self.assertAllClose(loss_value_embs, loss_value_mat, atol=1e-5)
    def call(self, contact_true, pred):
        """Computes the Contact loss between contact / distance matrices.

    Args:
      contact_true: a tf.Tensor<float>[batch_size, num_embs, num_embs], a batch
        of binary contact matrices for 'num_embs' embeddings.
      pred: a tf.Tensor<float> of shape either + [batch_size, num_embs, dims] if
        'from_embs' is True (embeddings case) a batch of 'num_embs' embeddings
        in dimension 'dim'. + [batch_size, num_embs, num_embs] if 'from_embs' is
        False (matrix case) a batch of pairwise distances for 'num_embs'
        embeddings.

    Returns:
      The contact loss values between the contact matrices and predictions
      in the batch. This is computed for an instance matrix in the batch as:
        loss(y, p) = sum_ij w_|i-j| fun(y_ij, p_ij),
      where y is the ground truth contact matrix and p is the predicted
      pairwise distance matrix.
        + fun(y_ij, _) is no_contact_fun if y_ij = 0, contact_fun if y_ij = 1.
        + w_|i-j| is weights_fun(|i-j|), and just |i-j| if None.
      If from_embs is true, the predicted matrix is the pairwise distance of the
      predicted embeddings.
    """
        if self._from_embs:
            pairw_dist_pred = pairs.square_distances(embs_1=pred, embs_2=pred)
        else:
            pairw_dist_pred = pred
        num_embs = tf.shape(pairw_dist_pred)[1]
        weights_range = tf.range(num_embs, dtype=tf.float32)
        weights_range_square = tf.abs(weights_range[tf.newaxis, :,
                                                    tf.newaxis] -
                                      weights_range[tf.newaxis, tf.newaxis, :])
        weights_batch_square = self._weights_fun(weights_range_square)
        contact_true = tf.cast(contact_true, dtype=pred.dtype)
        mat_losses = contact_true * self._contact_fun(pairw_dist_pred) + (
            1 - contact_true) * self._no_contact_fun(pairw_dist_pred)
        return weights_batch_square * mat_losses
Beispiel #4
0
 def single_call(self, positions):
     """Expects positions to be a tf.Tensor<float>[n, 3]."""
     # Makes a batch of size 1 to be compatible with pairwise_square_dist.
     pos = tf.expand_dims(positions, 0)
     sq_dist = pairs.square_distances(pos, pos)[0]
     return tf.cast(sq_dist < self._threshold**2, dtype=tf.float32)