Example #1
0
    def check_equivariance(self,
                           atol: float = 1e-6,
                           rtol: float = 1e-5) -> List[Tuple[Any, float]]:

        c = self.in_type.size

        x = torch.randn(3, c, 10, 10)

        x = GeometricTensor(x, self.in_type)

        errors = []

        for el in self.space.testing_elements:
            out1 = self(x).transform_fibers(el)
            out2 = self(x.transform_fibers(el))

            errs = (out1.tensor - out2.tensor).detach().numpy()
            errs = np.abs(errs).reshape(-1)
            print(el, errs.max(), errs.mean(), errs.var())

            assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \
                'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \
                    .format(el, errs.max(), errs.mean(), errs.var())

            errors.append((el, errs.mean()))

        return errors
Example #2
0
 def check_equivariance(self, atol: float = 2e-6, rtol: float = 1e-5, full_space_action: bool = True) -> List[Tuple[Any, float]]:
     
     if full_space_action:
         
         return super(MultipleModule, self).check_equivariance(atol=atol, rtol=rtol)
     
     else:
         c = self.in_type.size
     
         x = torch.randn(10, c, 9, 9)
         print(c, self.out_type.size)
         print([r.name for r in self.in_type.representations])
         print([r.name for r in self.out_type.representations])
         x = GeometricTensor(x, self.in_type)
     
         errors = []
     
         for el in self.gspace.testing_elements:
             out1 = self(x).transform_fibers(el)
             out2 = self(x.transform_fibers(el))
         
             errs = (out1.tensor - out2.tensor).detach().numpy()
             errs = np.abs(errs).reshape(-1)
             print(el, errs.max(), errs.mean(), errs.var())
             
             if not torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol):
                 tmp = np.abs((out1.tensor - out2.tensor).detach().numpy())
                 tmp = tmp.reshape(out1.tensor.shape[0], out1.tensor.shape[1], -1).max(axis=2)#.mean(axis=0)
                 
                 np.set_printoptions(precision=2, threshold=200000000, suppress=True, linewidth=500)
                 print(tmp.shape)
                 print(tmp)
         
             assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \
                 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \
                     .format(el, errs.max(), errs.mean(), errs.var())
         
             errors.append((el, errs.mean()))
     
         return errors