def tetris(): pos = [ [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1 [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2 [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)], # zigzag ] pos = torch.tensor(pos, dtype=torch.get_default_dtype()) # Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them labels = torch.tensor( [ [+1, 0, 0, 0, 0, 0, 0], # chiral_shape_1 [-1, 0, 0, 0, 0, 0, 0], # chiral_shape_2 [0, 1, 0, 0, 0, 0, 0], # square [0, 0, 1, 0, 0, 0, 0], # line [0, 0, 0, 1, 0, 0, 0], # corner [0, 0, 0, 0, 1, 0, 0], # L [0, 0, 0, 0, 0, 1, 0], # T [0, 0, 0, 0, 0, 0, 1], # zigzag ], dtype=torch.get_default_dtype()) # apply random rotation pos = torch.einsum('zij,zaj->zai', o3.rand_matrix(len(pos)), pos) # put in torch_geometric format dataset = [Data(pos=pos, x=torch.ones(4, 1)) for pos in pos] data = next(iter(DataLoader(dataset, batch_size=len(dataset)))) return data, labels
def test_xyz(float_tolerance): R = o3.rand_matrix(10) assert (R @ R.transpose(-1, -2) - torch.eye(3)).abs().max() < float_tolerance a, b, c = o3.matrix_to_angles(R) pos1 = o3.angles_to_xyz(a, b) pos2 = R @ torch.tensor([0, 1.0, 0]) assert torch.allclose(pos1, pos2, atol=float_tolerance) a2, b2 = o3.xyz_to_angles(pos2) assert (a - a2).abs().max() < float_tolerance assert (b - b2).abs().max() < float_tolerance
def test_conversions(float_tolerance): def wrap(f): def g(x): if isinstance(x, tuple): return f(*x) else: return f(x) return g def identity(x): return x conv = [ [ identity, wrap(o3.angles_to_matrix), wrap(o3.angles_to_axis_angle), wrap(o3.angles_to_quaternion) ], [ wrap(o3.matrix_to_angles), identity, wrap(o3.matrix_to_axis_angle), wrap(o3.matrix_to_quaternion) ], [ wrap(o3.axis_angle_to_angles), wrap(o3.axis_angle_to_matrix), identity, wrap(o3.axis_angle_to_quaternion) ], [ wrap(o3.quaternion_to_angles), wrap(o3.quaternion_to_matrix), wrap(o3.quaternion_to_axis_angle), identity ], ] R1 = o3.rand_matrix(100) path = [1, 2, 3, 0, 2, 0, 3, 1, 3, 2, 1, 0, 1] g = R1 for i, j in zip(path, path[1:]): g = conv[i][j](g) R2 = g assert (R1 - R2).abs().median() < float_tolerance
def test(): torch.set_default_dtype(torch.float64) pos = torch.tensor([ [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.5], ]) # atom type z = torch.tensor([0, 1, 2, 2]) dataset = [Data(pos=pos @ R.T, z=z) for R in o3.rand_matrix(10)] data = next(iter(DataLoader(dataset, batch_size=len(dataset)))) f = InvariantPolynomial("0e+0o", num_z=3, lmax=3) out = f(data) # expect invariant output assert out.std(0).max() < 1e-5
def equivariance_error(func, args_in, irreps_in=None, irreps_out=None, ntrials=1, do_parity=True, do_translation=True): r"""Get the maximum equivariance error for ``func`` over ``ntrials`` Each trial randomizes the equivariant transformation tested. Parameters ---------- func : callable the function to test args_in : list the original inputs to pass to ``func``. irreps_in : list of `e3nn.o3.Irreps` or `e3nn.o3.Irreps` the input irreps for each of the arguments in ``args_in``. If left as the default of ``None``, ``get_io_irreps`` will be used to try to infer them. If a sequence is provided, valid elements are also the string ``'cartesian'``, which denotes that the corresponding input should be dealt with as cartesian points in 3D, and ``None``, which indicates that the argument should not be transformed. irreps_out : list of `e3nn.o3.Irreps` or `e3nn.o3.Irreps` the out irreps for each of the return values of ``func``. Accepts similar values to ``irreps_in``. ntrials : int run this many trials with random transforms do_parity : bool whether to test parity do_translation : bool whether to test translation for ``'cartesian'`` inputs Returns ------- dictionary mapping tuples ``(parity_k, did_translate)`` to errors """ irreps_in, irreps_out = _get_io_irreps(func, irreps_in=irreps_in, irreps_out=irreps_out) if do_parity: parity_ks = [0, 1] else: parity_ks = [0] if 'cartesian_points' not in irreps_in: # There's nothing to translate do_translation = False if do_translation: do_translation = [False, True] else: do_translation = [False] tests = itertools.product(parity_ks, do_translation) neg_inf = -float("Inf") biggest_errs = {} for trial in range(ntrials): for this_test in tests: parity_k, this_do_translate = this_test # Build a rotation matrix for point data rot_mat = o3.rand_matrix() # add parity rot_mat *= (-1)**parity_k # build translation translation = 10 * torch.randn( 1, 3, dtype=rot_mat.dtype) if this_do_translate else 0. # Evaluate the function on rotated arguments: rot_args = _transform(args_in, irreps_in, rot_mat, translation) x1 = func(*rot_args) # Evaluate the function on the arguments, then apply group action: x2 = func(*args_in) # Deal with output shapes assert type(x1) == type( x2), f"Inconsistant return types {type(x1)} and {type(x2)}" # pylint: disable=unidiomatic-typecheck if isinstance(x1, torch.Tensor): # Make sequences x1 = [x1] x2 = [x2] elif isinstance(x1, (list, tuple)): # They're already tuples x1 = list(x1) x2 = list(x2) else: raise TypeError( f"equivariance_error cannot handle output type {type(x1)}") assert len(x1) == len(x2) assert len(x1) == len(irreps_out) # apply the group action to x2 x2 = _transform(x2, irreps_out, rot_mat, translation) error = max((a - b).abs().max() for a, b in zip(x1, x2)) if error > biggest_errs.get(this_test, neg_inf): biggest_errs[this_test] = error return biggest_errs