Exemplo n.º 1
0
 def test_equivariance(self, tol=1e-4):
     self.model.eval()
     mb = self.mb
     outs = self.model(mb).cpu().data.numpy()
     #print('first done')
     outs2 = self.model(mb).cpu().data.numpy()
     #print('second done')
     bs = mb['positions'].shape[0]
     q = torch.randn(bs,
                     1,
                     4,
                     device=mb['positions'].device,
                     dtype=mb['positions'].dtype)
     q /= norm(q, dim=-1).unsqueeze(-1)
     theta_2 = torch.atan2(norm(q[..., 1:], dim=-1), q[...,
                                                       0]).unsqueeze(-1)
     so3_elem = theta_2 * q[..., 1:]
     Rs = SO3.exp(so3_elem)
     #print(Rs.shape)
     #print(mb['positions'].shape)
     mb['positions'] = (Rs @ mb['positions'].unsqueeze(-1)).squeeze(-1)
     outs3 = self.model(mb).cpu().data.numpy()
     diff = np.abs(outs2 - outs).mean() / np.abs(outs).mean()
     print('run through twice rel err:', diff)
     diff = np.abs(outs2 - outs3).mean() / np.abs(outs2).mean()
     print('rotation equivariance rel err:', diff)
     self.assertTrue(diff < tol)
Exemplo n.º 2
0
 def forward(self, inp, withquery=False):
     abq_pairs, vals, mask = inp
     dist = self.group.distance if self.group else lambda ab: norm(ab,
                                                                   dim=-1)
     if self.ds_frac != 1:
         if self.cache and self.cached_indices is None:
             query_idx = self.cached_indices = FPSindices(
                 dist(abq_pairs), self.ds_frac, mask).detach()
         elif self.cache:
             query_idx = self.cached_indices
         else:
             query_idx = FPSindices(dist(abq_pairs), self.ds_frac, mask)
         B = torch.arange(query_idx.shape[0],
                          device=query_idx.device).long()[:, None]
         subsampled_abq_pairs = abq_pairs[B, query_idx][B, :, query_idx]
         subsampled_values = vals[B, query_idx]
         subsampled_mask = mask[B, query_idx]
     else:
         subsampled_abq_pairs = abq_pairs
         subsampled_values = vals
         subsampled_mask = mask
         query_idx = None
     if withquery:
         return (subsampled_abq_pairs, subsampled_values, subsampled_mask,
                 query_idx)
     return (subsampled_abq_pairs, subsampled_values, subsampled_mask)