def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]: c = self.in_type.size x = torch.randn(3, c, 10, 10) x = GeometricTensor(x, self.in_type) errors = [] for el in self.space.testing_elements: out1 = self(x).transform_fibers(el) out2 = self(x.transform_fibers(el)) errs = (out1.tensor - out2.tensor).detach().numpy() errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \ .format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors
def check_equivariance(self, atol: float = 2e-6, rtol: float = 1e-5, full_space_action: bool = True) -> List[Tuple[Any, float]]: if full_space_action: return super(MultipleModule, self).check_equivariance(atol=atol, rtol=rtol) else: c = self.in_type.size x = torch.randn(10, c, 9, 9) print(c, self.out_type.size) print([r.name for r in self.in_type.representations]) print([r.name for r in self.out_type.representations]) x = GeometricTensor(x, self.in_type) errors = [] for el in self.gspace.testing_elements: out1 = self(x).transform_fibers(el) out2 = self(x.transform_fibers(el)) errs = (out1.tensor - out2.tensor).detach().numpy() errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) if not torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol): tmp = np.abs((out1.tensor - out2.tensor).detach().numpy()) tmp = tmp.reshape(out1.tensor.shape[0], out1.tensor.shape[1], -1).max(axis=2)#.mean(axis=0) np.set_printoptions(precision=2, threshold=200000000, suppress=True, linewidth=500) print(tmp.shape) print(tmp) assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \ .format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors