def test_norm(irreps_in, squared): m = o3.Norm(irreps_in, squared=squared) m(torch.randn(m.irreps_in.dim)) if m.irreps_in.dim == 0: return assert_equivariant(m) assert_auto_jitable(m)
def test_jit(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein): """Test the JIT. This test is seperate from test_optimizations to ensure that just jitting a model has minimal error if any. """ orig_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=special_code, _optimize_einsums=opt_ein) opt_tp = assert_auto_jitable(orig_tp) # Confirm equivariance of optimized model assert_equivariant(opt_tp, irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2], irreps_out=orig_tp.irreps_out) # Confirm that it gives same results x1 = orig_tp.irreps_in1.randn(2, -1) x2 = orig_tp.irreps_in2.randn(2, -1) # TorchScript should casue very little if any numerical error assert torch.allclose( orig_tp(x1, x2), opt_tp(x1, x2), ) assert torch.allclose( orig_tp.right(x2), opt_tp.right(x2), )
def test_bias(): irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o") irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o") m = o3.Linear(irreps_in, irreps_out, biases=[True, False, False, True, False]) with torch.no_grad(): m.bias[:].fill_(1.0) x = m(torch.zeros(irreps_in.dim)) assert torch.allclose( x, torch.tensor([ 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0 ])) assert_equivariant(m) assert_auto_jitable(m) m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5, weights=[m.weight])
def test_full(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2x2e + 2x3o") m = FullTensorProduct(irreps_in1, irreps_in2) print(m) assert_equivariant(m) assert_auto_jitable(m)
def test_equivariance(act, lmax): m = SO3Activation(lmax, lmax, act, 6) assert_equivariant(m, ntrials=10, tolerance=0.04, irreps_in=so3_irreps(lmax), irreps_out=so3_irreps(lmax))
def test_module(normalization, normalize): l = o3.Irreps("0e + 1o + 3o") sp = o3.SphericalHarmonics(l, normalize, normalization) sp_jit = assert_auto_jitable(sp) xyz = torch.randn(11, 3) assert torch.allclose( sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization)) assert_equivariant(sp)
def main(): data, labels = tetris() f = Network() print("Built a model:") print(f) optim = torch.optim.Adam(f.parameters(), lr=1e-3) # == Training == for step in range(200): pred = f(data) loss = (pred - labels).pow(2).sum() optim.zero_grad() loss.backward() optim.step() if step % 10 == 0: accuracy = pred.round().eq(labels).all(dim=1).double().mean( dim=0).item() print( f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy" ) # == Check equivariance == # Because the model outputs (psuedo)scalars, we can easily directly # check its equivariance to the same data with new rotations: print("Testing equivariance directly...") rotated_data, _ = tetris() error = f(rotated_data) - f(data) print(f"Equivariance error = {error.abs().max().item():.1e}") print("Testing equivariance using `assert_equivariance`...") # We can also use the library's `assert_equivariant` helper # `assert_equivariant` also tests parity and translation, and # can handle non-(psuedo)scalar outputs. # To "interpret" between it and torch_geometric, we use a small wrapper: def wrapper(pos, batch): return f(Data(pos=pos, batch=batch)) # `assert_equivariant` uses logging to print a summary of the equivariance error, # so we enable logging logging.basicConfig(level=logging.INFO) assert_equivariant( wrapper, # We provide the original data that `assert_equivariant` will transform... args_in=[data.pos, data.batch], # ...in accordance with these irreps... irreps_in=[ "cartesian_points", # pos has vector 1o irreps, but is also translation equivariant None, # `None` indicates invariant, possibly non-floating-point data ], # ...and confirm that the outputs transform correspondingly for these irreps: irreps_out=[f.irreps_out], )
def test_equivariant(): irreps = o3.Irreps("3x0e + 3x0o + 4x1e") m = BatchNorm(irreps) m(irreps.randn(16, -1)) m(irreps.randn(16, -1)) m.train() assert_equivariant(m, irreps_in=[irreps], irreps_out=[irreps]) m.eval() assert_equivariant(m, irreps_in=[irreps], irreps_out=[irreps])
def test_linear(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = o3.Linear(irreps_in, irreps_out) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
def test_id(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = Identity(irreps_in, irreps_out) print(m) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m, strict_shapes=False)
def test_fully_connected(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) print(m) m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim)) assert_equivariant(m) assert_auto_jitable(m)
def test_assert_equivariant(): def not_equivariant(x1, x2): return x1*x2 not_equivariant.irreps_in1 = o3.Irreps("2x0e + 1x1e + 3x2o + 1x4e") not_equivariant.irreps_in2 = o3.Irreps("2x0o + 3x0o + 3x2e + 1x4o") not_equivariant.irreps_out = o3.Irreps("1x1e + 2x0o + 3x2e + 1x4o") assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_in2.dim assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_out.dim with pytest.raises(AssertionError): assert_equivariant(not_equivariant)
def test_optimizations(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein, jit, float_tolerance): orig_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=False, _optimize_einsums=False) opt_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=special_code, _optimize_einsums=opt_ein) # We don't use state_dict here since that contains things like wigners that can differ between optimized and unoptimized TPs with torch.no_grad(): opt_tp.weight[:] = orig_tp.weight assert opt_tp._specialized_code == special_code assert opt_tp._optimize_einsums == opt_ein if jit: opt_tp = assert_auto_jitable(opt_tp) # Confirm equivariance of optimized model assert_equivariant(opt_tp, irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2], irreps_out=orig_tp.irreps_out) # Confirm that it gives same results x1 = orig_tp.irreps_in1.randn(2, -1) x2 = orig_tp.irreps_in2.randn(2, -1) assert torch.allclose( orig_tp(x1, x2), opt_tp(x1, x2), atol= float_tolerance # numerical optimizations can cause meaningful numerical error by changing operations ) assert torch.allclose(orig_tp.right(x2), opt_tp.right(x2), atol=float_tolerance) # We also test .to(), even if only with a dtype, to ensure that various optimizations still always store constants in correct ways other_dtype = next(d for d in [torch.float32, torch.float64] if d != torch.get_default_dtype()) x1, x2 = x1.to(other_dtype), x2.to(other_dtype) opt_tp = opt_tp.to(other_dtype) assert opt_tp(x1, x2).dtype == other_dtype
def test_extract_single(squeeze): c = Extract('1e + 0e + 0e', ['0e'], [(1, )], squeeze_out=squeeze) out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) if squeeze: assert isinstance(out, torch.Tensor) else: assert len(out) == 1 out = out[0] assert out == torch.Tensor([1.]) assert_auto_jitable(c) assert_equivariant(c, irreps_out=list(c.irreps_outs))
def test_equivariance(float_tolerance, act, normalization, p_val, p_arg): irreps = io.SphericalTensor(3, p_val, p_arg) m = S2Activation(irreps, act, 120, normalization=normalization, lmax_out=6, random_rot=True) assert_equivariant(m, ntrials=10, tolerance=torch.sqrt(float_tolerance))
def test_equivariance(lmax, res_b, res_a): m = FromS2Grid((res_b, res_a), lmax) k = ToS2Grid(lmax, (res_b, res_a)) def f(x): y = k(x) y = y.exp() return m(y) f.irreps_in = f.irreps_out = Irreps.spherical_harmonics(lmax) assert_equivariant(f)
def test_antisymmetric_matrix(float_tolerance): tp = o3.ReducedTensorProducts('ij=-ji', i='5x0e + 1e') assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(2, 5 + 3) assert (tp(*x) - torch.einsum('xij,i,j', Q, *x)).abs().max() < float_tolerance assert (Q + torch.einsum("xij->xji", Q)).abs().max() < float_tolerance
def test_equivariant(): # Confirm that a compiled tensorproduct is still equivariant irreps_in = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") mod = Linear(irreps_in, irreps_out) mod_script = compile(mod) assert_equivariant( mod_script, # we provide explicit irreps because infering on a script module is not reliable irreps_in=irreps_in, irreps_out=irreps_out )
def test_f(): m = o3.Linear("0e + 1e + 2e", "0e + 2x1e + 2e", f_in=44, f_out=5, _optimize_einsums=False) assert_equivariant(m, args_in=[torch.randn(10, 44, 9)]) m = assert_auto_jitable(m) y = m(torch.randn(10, 44, 9)) assert m.weight_numel == 4 assert m.weight.numel() == 44 * 5 * 4 assert 0.7 < y.pow(2).mean() < 1.4
def test_gate(): irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps( "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh ], Irreps("16x1e+16x1o") sc = _Sortcut(irreps_scalars, irreps_gates) assert_auto_jitable(sc) g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated) assert_equivariant(g) assert_auto_jitable(g) assert_normalized(g)
def test_reduce_tensor_antisymmetric_L2(float_tolerance): tp = o3.ReducedTensorProducts('ijk=-ikj=-jik', i='2e') assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(3, 5) assert (tp(*x) - torch.einsum('xijk,i,j,k', Q, *x)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xikj", Q)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xjik", Q)).abs().max() < float_tolerance
def test_gate_points_2101_equivariant(network): f, random_graph = network # -- Test equivariance: -- def wrapper(pos, x, z): data = Data(pos=pos, x=x, z=z, batch=torch.zeros(pos.shape[0], dtype=torch.long)) return f(data) assert_equivariant( wrapper, irreps_in=['cartesian_points', f.irreps_in, f.irreps_node_attr], irreps_out=[f.irreps_out], )
def test_norm(): irreps_in = o3.Irreps("3x0e + 5x1o") scalars = torch.randn(3) vecs = torch.randn(5, 3) norm = Norm(irreps_in=irreps_in) out_norms = norm( torch.cat((scalars.reshape(1, -1), vecs.reshape(1, -1)), dim=-1)) true_scalar_norms = torch.abs(scalars) true_vec_norms = torch.linalg.norm(vecs, dim=-1) assert torch.allclose(out_norms[0, :3], true_scalar_norms) assert torch.allclose(out_norms[0, 3:], true_vec_norms) assert_equivariant(norm) assert_auto_jitable(norm)
def test_reduce_tensor_elasticity_tensor(float_tolerance): tp = o3.ReducedTensorProducts('ijkl=jikl=klij', i='1e') assert tp.irreps_out.dim == 21 assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(4, 3) assert (tp(*x) - torch.einsum('xijkl,i,j,k,l', Q, *x)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xjikl", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xijlk", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xklij", Q)).abs().max() < float_tolerance
def test_activation(irreps_in, acts): irreps_in = o3.Irreps(irreps_in) a = Activation(irreps_in, acts) assert_auto_jitable(a) assert_equivariant(a) inp = irreps_in.randn(13, -1) out = a(inp) for ir_slice, act in zip(irreps_in.slices(), acts): this_out = out[:, ir_slice] true_up_to_factor = act(inp[:, ir_slice]) factors = this_out / true_up_to_factor assert torch.allclose(factors, factors[0]) assert_normalized(a)
def test_dropout(): c = Dropout(irreps='10x1e + 10x0e', p=0.75) x = c.irreps.randn(5, 2, -1) for c in [c, assert_auto_jitable(c)]: c.eval() assert c(x).eq(x).all() c.train() y = c(x) assert (y.eq(x / 0.25) | y.eq(0)).all() def wrap(x): torch.manual_seed(0) return c(x) assert_equivariant(wrap, args_in=[x], irreps_in=[c.irreps], irreps_out=[c.irreps])
def test(): from torch_cluster import radius_graph from e3nn.util.test import assert_equivariant, assert_auto_jitable mp = MessagePassing( irreps_node_input="0e", irreps_node_hidden="0e + 1e", irreps_node_output="1e", irreps_node_attr="0e + 1e", irreps_edge_attr="1e", layers=3, fc_neurons=[2, 100], num_neighbors=3.0, ) num_nodes = 4 node_pos = torch.randn(num_nodes, 3) edge_index = radius_graph(node_pos, 3.0) edge_src, edge_dst = edge_index num_edges = edge_index.shape[1] edge_attr = node_pos[edge_index[0]] - node_pos[edge_index[1]] node_features = torch.randn(num_nodes, 1) node_attr = torch.randn(num_nodes, 4) edge_scalars = torch.randn(num_edges, 2) assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars).shape == (num_nodes, 3) assert_equivariant( mp, irreps_in=[ mp.irreps_node_input, mp.irreps_node_attr, None, None, mp.irreps_edge_attr, None ], args_in=[ node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars ], irreps_out=[mp.irreps_node_output], ) assert_auto_jitable(mp.layers[0].first)
def test_bilinear_right_variance_equivariance(float_tolerance, l1, p1, l2, p2, lo, po, mode, weight): eps = float_tolerance n = 1_500 tol = 3.0 m = make_tp(l1, p1, l2, p2, lo, po, mode, weight) # bilinear x1 = torch.randn(2, m.irreps_in1.dim) x2 = torch.randn(2, m.irreps_in1.dim) y1 = torch.randn(2, m.irreps_in2.dim) y2 = torch.randn(2, m.irreps_in2.dim) z1 = m(x1 + 1.7 * x2, y1 - y2) z2 = m(x1, y1 - y2) + 1.7 * m(x2, y1 - y2) z3 = m(x1 + 1.7 * x2, y1) - m(x1 + 1.7 * x2, y2) assert (z1 - z2).abs().max() < eps assert (z1 - z3).abs().max() < eps # right z1 = m(x1, y1) z2 = torch.einsum('zi,zij->zj', x1, m.right(y1)) assert (z1 - z2).abs().max() < eps # variance x1 = torch.randn(n, m.irreps_in1.dim) y1 = torch.randn(n, m.irreps_in2.dim) z1 = m(x1, y1).var(0) assert z1.mean().log10().abs() < torch.tensor(tol).log10() # equivariance assert_equivariant(m, irreps_in=[m.irreps_in1, m.irreps_in2], irreps_out=m.irreps_out) if weight: # linear in weights w1 = m.weight.clone().normal_() w2 = m.weight.clone().normal_() z1 = m(x1, y1, weight=w1) + 1.5 * m(x1, y1, weight=w2) z2 = m(x1, y1, weight=w1 + 1.5 * w2) assert (z1 - z2).abs().max() < eps
def test_norm_activation_equivariant(do_bias, nonlin): irreps_in = e3nn.o3.Irreps( # test lots of different irreps "2x0e + 3x0o + 5x1o + 1x1e + 2x2e + 1x2o + 1x3e + 1x3o + 1x5e + 1x6o" ) norm_act = NormActivation( irreps_in=irreps_in, scalar_nonlinearity=nonlin, bias=do_bias ) if do_bias: # Set up some nonzero biases assert len(list(norm_act.parameters())) == 1 with torch.no_grad(): norm_act.biases[:] = torch.randn(norm_act.biases.shape) assert_equivariant(norm_act) assert_auto_jitable(norm_act)
def test_extract_ir(): c = ExtractIr('1e + 0e + 0e', '0e') out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) assert torch.all(out == torch.Tensor([1., 2.])) assert_auto_jitable(c) assert_equivariant(c)