def check_equivariance(self, atol: float = 1e-7, rtol: float = 1e-5) -> List[Tuple[Any, float]]: _, parent_mapping, _ = self.in_type.gspace.restrict(self._id) c = self.in_type.size x = torch.randn(3, c, 10, 10) x = GeometricTensor(x, self.in_type) errors = [] for el in self.out_type.testing_elements: print(el) out1 = self(x).transform(el).tensor.detach().numpy() out2 = self(x.transform( parent_mapping(el))).tensor.detach().numpy() errs = out1 - out2 errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) assert np.allclose(out1, out2, atol=atol, rtol=rtol), \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \ .format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors
def check_equivariance(self, atol: float = 1e-7, rtol: float = 1e-5) -> List[Tuple[Any, float]]: r""" Method that automatically tests the equivariance of the current module. The default implementation of this method relies on :meth:`e2cnn.nn.GeometricTensor.transform` and uses the the group elements in :attr:`~e2cnn.nn.FieldType.testing_elements`. This method can be overwritten for custom tests. Returns: a list containing containing for each testing element a pair with that element and the corresponding equivariance error """ c = self.in_type.size x = torch.randn(3, c, 10, 10) x = GeometricTensor(x, self.in_type) errors = [] for el in self.out_type.testing_elements: print(el) out1 = self(x).transform(el).tensor.detach().numpy() out2 = self(x.transform(el)).tensor.detach().numpy() errs = out1 - out2 errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) assert np.allclose(out1, out2, atol=atol, rtol=rtol), \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'\ .format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors
def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1, assertion: bool = True, verbose: bool = True): # np.set_printoptions(precision=5, threshold=30 *self.in_type.size**2, suppress=False, linewidth=30 *self.in_type.size**2) feature_map_size = 33 last_downsampling = 5 first_downsampling = 5 initial_size = (feature_map_size * last_downsampling - 1 + self.kernel_size) * first_downsampling c = self.in_type.size import matplotlib.image as mpimg from skimage.measure import block_reduce from skimage.transform import resize x = mpimg.imread('../group/testimage.jpeg').transpose( (2, 0, 1))[np.newaxis, 0:c, :, :] x = resize(x, (x.shape[0], x.shape[1], initial_size, initial_size), anti_aliasing=True) x = x / 255.0 - 0.5 if x.shape[1] < c: to_stack = [x for i in range(c // x.shape[1])] if c % x.shape[1] > 0: to_stack += [x[:, :(c % x.shape[1]), ...]] x = np.concatenate(to_stack, axis=1) x = GeometricTensor(torch.FloatTensor(x), self.in_type) def shrink(t: GeometricTensor, s) -> GeometricTensor: return GeometricTensor( torch.FloatTensor( block_reduce(t.tensor.detach().numpy(), s, func=np.mean)), t.type) errors = [] for el in self.space.testing_elements: out1 = self(shrink( x, (1, 1, 5, 5))).transform(el).tensor.detach().numpy() out2 = self(shrink(x.transform(el), (1, 1, 5, 5))).tensor.detach().numpy() out1 = block_reduce(out1, (1, 1, 5, 5), func=np.mean) out2 = block_reduce(out2, (1, 1, 5, 5), func=np.mean) b, c, h, w = out2.shape center_mask = np.zeros((2, h, w)) center_mask[1, :, :] = np.arange(0, w) - w / 2 center_mask[0, :, :] = np.arange(0, h) - h / 2 center_mask[0, :, :] = center_mask[0, :, :].T center_mask = center_mask[0, :, :]**2 + center_mask[1, :, :]**2 < ( h / 4)**2 out1 = out1[..., center_mask] out2 = out2[..., center_mask] out1 = out1.reshape(-1) out2 = out2.reshape(-1) errs = np.abs(out1 - out2) esum = np.maximum(np.abs(out1), np.abs(out2)) esum[esum == 0.0] = 1 relerr = errs / esum if verbose: print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var()) tol = rtol * esum + atol if np.any(errs > tol) and verbose: print(out1[errs > tol]) print(out2[errs > tol]) print(tol[errs > tol]) if assertion: assert np.all( errs < tol ), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format( el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors
def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1): initial_size = 55 c = self.in_type.size # x = torch.randn(3, c, initial_size, initial_size) import matplotlib.image as mpimg from skimage.transform import resize x = mpimg.imread('../group/testimage.jpeg').transpose( (2, 0, 1))[np.newaxis, 0:c, :, :] x = x / 255.0 x = resize(x, (x.shape[0], x.shape[1], initial_size, initial_size), anti_aliasing=True) if x.shape[1] < c: to_stack = [x for i in range(c // x.shape[1])] if c % x.shape[1] > 0: to_stack += [x[:, :(c % x.shape[1]), ...]] x = np.concatenate(to_stack, axis=1) x = GeometricTensor(torch.FloatTensor(x), self.in_type) errors = [] for el in self.space.testing_elements: out1 = self(x).transform(el).tensor.detach().numpy() out2 = self(x.transform(el)).tensor.detach().numpy() b, c, h, w = out2.shape center_mask = np.zeros((2, h, w)) center_mask[1, :, :] = np.arange(0, w) - w / 2 center_mask[0, :, :] = np.arange(0, h) - h / 2 center_mask[0, :, :] = center_mask[0, :, :].T center_mask = center_mask[0, :, :]**2 + center_mask[1, :, :]**2 < ( h * 0.4)**2 out1 = out1[..., center_mask] out2 = out2[..., center_mask] out1 = out1.reshape(-1) out2 = out2.reshape(-1) errs = np.abs(out1 - out2) esum = np.maximum(np.abs(out1), np.abs(out2)) esum[esum == 0.0] = 1 relerr = errs / esum # print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var()) # tol = rtol*(np.abs(out1) + np.abs(out2)) + atol tol = rtol * esum + atol if np.any(errs > tol): print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var()) # print(errs[errs > tol]) print(out1[errs > tol]) print(out2[errs > tol]) print(tol[errs > tol]) # assert np.all(np.abs(out1 - out2) < tol), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format(el, errs.max(), errs.mean(), errs.var()) assert np.all( errs < tol ), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format( el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors