def test_norm(irreps_in, squared): m = o3.Norm(irreps_in, squared=squared) m(torch.randn(m.irreps_in.dim)) if m.irreps_in.dim == 0: return assert_equivariant(m) assert_auto_jitable(m)
def test_bias(): irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o") irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o") m = o3.Linear(irreps_in, irreps_out, biases=[True, False, False, True, False]) with torch.no_grad(): m.bias[:].fill_(1.0) x = m(torch.zeros(irreps_in.dim)) assert torch.allclose( x, torch.tensor([ 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0 ])) assert_equivariant(m) assert_auto_jitable(m) m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5, weights=[m.weight])
def test_full(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2x2e + 2x3o") m = FullTensorProduct(irreps_in1, irreps_in2) print(m) assert_equivariant(m) assert_auto_jitable(m)
def test_id(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = Identity(irreps_in, irreps_out) print(m) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m, strict_shapes=False)
def test_input_weights_jit(): irreps_in1 = Irreps("1e + 2e + 3x3o") irreps_in2 = Irreps("1e + 2e + 3x3o") irreps_out = Irreps("1e + 2e + 3x3o") # - shared_weights = False - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=False) traced = assert_auto_jitable(m) x1 = irreps_in1.randn(2, -1) x2 = irreps_in2.randn(2, -1) w = torch.randn(2, m.weight_numel) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, w[0]) # it should reject insufficient weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w)) # Confirm that weird batch dimensions give the same results for f in (m, traced): x1 = irreps_in1.randn(2, 1, 4, -1) x2 = irreps_in2.randn(2, 3, 1, -1) w = torch.randn(3, 4, f.weight_numel) assert torch.allclose( f(x1, x2, w).reshape(24, -1), f( x1.expand(2, 3, 4, -1).reshape(24, -1), x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1))) assert torch.allclose( f.right(x2, w).reshape(24, -1), f.right( x2.expand(2, 3, 4, -1).reshape(24, -1), w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1)) # - shared_weights = True - m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, internal_weights=False, shared_weights=True) traced = assert_auto_jitable(m) w = torch.randn(m.weight_numel) with pytest.raises((RuntimeError, torch.jit.Error)): m(x1, x2) # it should require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2) # it should also require weights with pytest.raises((RuntimeError, torch.jit.Error)): traced(x1, x2, torch.randn( 2, m.weight_numel)) # it should reject too many weights # Does the trace give right results? assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))
def test_linear(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = o3.Linear(irreps_in, irreps_out) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
def test_fully_connected(): irreps_in1 = o3.Irreps("1e + 2e + 3x3o") irreps_in2 = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out) print(m) m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim)) assert_equivariant(m) assert_auto_jitable(m)
def test_extract_single(squeeze): c = Extract('1e + 0e + 0e', ['0e'], [(1, )], squeeze_out=squeeze) out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) if squeeze: assert isinstance(out, torch.Tensor) else: assert len(out) == 1 out = out[0] assert out == torch.Tensor([1.]) assert_auto_jitable(c) assert_equivariant(c, irreps_out=list(c.irreps_outs))
def test_trace_dtypes(): # FullyConnectedNet is traced fc = FullyConnectedNet([8, 16, 8]) # compile in a dtype other than the default target_dtype = { torch.float32: torch.float64, torch.float64: torch.float32 }[torch.get_default_dtype()] fc = fc.to(dtype=target_dtype) for weight in fc.parameters(): assert weight.dtype == target_dtype assert_auto_jitable(fc)
def test_antisymmetric_matrix(float_tolerance): tp = o3.ReducedTensorProducts('ij=-ji', i='5x0e + 1e') assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(2, 5 + 3) assert (tp(*x) - torch.einsum('xij,i,j', Q, *x)).abs().max() < float_tolerance assert (Q + torch.einsum("xij->xji", Q)).abs().max() < float_tolerance
def test_gate(): irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps( "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh ], Irreps("16x1e+16x1o") sc = _Sortcut(irreps_scalars, irreps_gates) assert_auto_jitable(sc) g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated) assert_equivariant(g) assert_auto_jitable(g) assert_normalized(g)
def test_reduce_tensor_antisymmetric_L2(float_tolerance): tp = o3.ReducedTensorProducts('ijk=-ikj=-jik', i='2e') assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(3, 5) assert (tp(*x) - torch.einsum('xijk,i,j,k', Q, *x)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xikj", Q)).abs().max() < float_tolerance assert (Q + torch.einsum("xijk->xjik", Q)).abs().max() < float_tolerance
def test_norm(): irreps_in = o3.Irreps("3x0e + 5x1o") scalars = torch.randn(3) vecs = torch.randn(5, 3) norm = Norm(irreps_in=irreps_in) out_norms = norm( torch.cat((scalars.reshape(1, -1), vecs.reshape(1, -1)), dim=-1)) true_scalar_norms = torch.abs(scalars) true_vec_norms = torch.linalg.norm(vecs, dim=-1) assert torch.allclose(out_norms[0, :3], true_scalar_norms) assert torch.allclose(out_norms[0, 3:], true_vec_norms) assert_equivariant(norm) assert_auto_jitable(norm)
def test_jit_trace(): @compile_mode('trace') class NotTracable(torch.nn.Module): def forward(self, param): if param.shape[0] == 7: return torch.ones(8) else: return torch.randn(8, 3) not_tracable = NotTracable() not_tracable.irreps_in = o3.Irreps("2x0e") not_tracable.irreps_out = o3.Irreps("1x1o") # TorchScript returns some weird exceptions... with pytest.raises(Exception): assert_auto_jitable(not_tracable)
def test_reduce_tensor_elasticity_tensor(float_tolerance): tp = o3.ReducedTensorProducts('ijkl=jikl=klij', i='1e') assert tp.irreps_out.dim == 21 assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out) assert_auto_jitable(tp) Q = tp.change_of_basis x = torch.randn(4, 3) assert (tp(*x) - torch.einsum('xijkl,i,j,k,l', Q, *x)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xjikl", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xijlk", Q)).abs().max() < float_tolerance assert (Q - torch.einsum("xijkl->xklij", Q)).abs().max() < float_tolerance
def test_activation(irreps_in, acts): irreps_in = o3.Irreps(irreps_in) a = Activation(irreps_in, acts) assert_auto_jitable(a) assert_equivariant(a) inp = irreps_in.randn(13, -1) out = a(inp) for ir_slice, act in zip(irreps_in.slices(), acts): this_out = out[:, ir_slice] true_up_to_factor = act(inp[:, ir_slice]) factors = this_out / true_up_to_factor assert torch.allclose(factors, factors[0]) assert_normalized(a)
def test_jit(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein): """Test the JIT. This test is seperate from test_optimizations to ensure that just jitting a model has minimal error if any. """ orig_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=special_code, _optimize_einsums=opt_ein) opt_tp = assert_auto_jitable(orig_tp) # Confirm equivariance of optimized model assert_equivariant(opt_tp, irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2], irreps_out=orig_tp.irreps_out) # Confirm that it gives same results x1 = orig_tp.irreps_in1.randn(2, -1) x2 = orig_tp.irreps_in2.randn(2, -1) # TorchScript should casue very little if any numerical error assert torch.allclose( orig_tp(x1, x2), opt_tp(x1, x2), ) assert torch.allclose( orig_tp.right(x2), opt_tp.right(x2), )
def test_jit(float_tolerance): sh = o3.SphericalHarmonicsAlphaBeta([0, 1, 2]) jited = assert_auto_jitable(sh) a = torch.randn(5, 4) b = torch.randn(5, 4) return (sh(a, b) - jited(a, b)).abs().max() < float_tolerance
def test_module(normalization, normalize): l = o3.Irreps("0e + 1o + 3o") sp = o3.SphericalHarmonics(l, normalize, normalization) sp_jit = assert_auto_jitable(sp) xyz = torch.randn(11, 3) assert torch.allclose( sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization)) assert_equivariant(sp)
def test(): from torch_cluster import radius_graph from e3nn.util.test import assert_equivariant, assert_auto_jitable mp = MessagePassing( irreps_node_input="0e", irreps_node_hidden="0e + 1e", irreps_node_output="1e", irreps_node_attr="0e + 1e", irreps_edge_attr="1e", layers=3, fc_neurons=[2, 100], num_neighbors=3.0, ) num_nodes = 4 node_pos = torch.randn(num_nodes, 3) edge_index = radius_graph(node_pos, 3.0) edge_src, edge_dst = edge_index num_edges = edge_index.shape[1] edge_attr = node_pos[edge_index[0]] - node_pos[edge_index[1]] node_features = torch.randn(num_nodes, 1) node_attr = torch.randn(num_nodes, 4) edge_scalars = torch.randn(num_edges, 2) assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars).shape == (num_nodes, 3) assert_equivariant( mp, irreps_in=[ mp.irreps_node_input, mp.irreps_node_attr, None, None, mp.irreps_edge_attr, None ], args_in=[ node_features, node_attr, edge_src, edge_dst, edge_attr, edge_scalars ], irreps_out=[mp.irreps_node_output], ) assert_auto_jitable(mp.layers[0].first)
def test_optimizations(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein, jit, float_tolerance): orig_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=False, _optimize_einsums=False) opt_tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight, _specialized_code=special_code, _optimize_einsums=opt_ein) # We don't use state_dict here since that contains things like wigners that can differ between optimized and unoptimized TPs with torch.no_grad(): opt_tp.weight[:] = orig_tp.weight assert opt_tp._specialized_code == special_code assert opt_tp._optimize_einsums == opt_ein if jit: opt_tp = assert_auto_jitable(opt_tp) # Confirm equivariance of optimized model assert_equivariant(opt_tp, irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2], irreps_out=orig_tp.irreps_out) # Confirm that it gives same results x1 = orig_tp.irreps_in1.randn(2, -1) x2 = orig_tp.irreps_in2.randn(2, -1) assert torch.allclose( orig_tp(x1, x2), opt_tp(x1, x2), atol= float_tolerance # numerical optimizations can cause meaningful numerical error by changing operations ) assert torch.allclose(orig_tp.right(x2), opt_tp.right(x2), atol=float_tolerance) # We also test .to(), even if only with a dtype, to ensure that various optimizations still always store constants in correct ways other_dtype = next(d for d in [torch.float32, torch.float64] if d != torch.get_default_dtype()) x1, x2 = x1.to(other_dtype), x2.to(other_dtype) opt_tp = opt_tp.to(other_dtype) assert opt_tp(x1, x2).dtype == other_dtype
def test_norm_activation_equivariant(do_bias, nonlin): irreps_in = e3nn.o3.Irreps( # test lots of different irreps "2x0e + 3x0o + 5x1o + 1x1e + 2x2e + 1x2o + 1x3e + 1x3o + 1x5e + 1x6o" ) norm_act = NormActivation( irreps_in=irreps_in, scalar_nonlinearity=nonlin, bias=do_bias ) if do_bias: # Set up some nonzero biases assert len(list(norm_act.parameters())) == 1 with torch.no_grad(): norm_act.biases[:] = torch.randn(norm_act.biases.shape) assert_equivariant(norm_act) assert_auto_jitable(norm_act)
def test_f(): m = o3.Linear("0e + 1e + 2e", "0e + 2x1e + 2e", f_in=44, f_out=5, _optimize_einsums=False) assert_equivariant(m, args_in=[torch.randn(10, 44, 9)]) m = assert_auto_jitable(m) y = m(torch.randn(10, 44, 9)) assert m.weight_numel == 4 assert m.weight.numel() == 44 * 5 * 4 assert 0.7 < y.pow(2).mean() < 1.4
def test_variance(act, var_in, var_out, out_act): hs = (1000, 500, 1500, 4) f = FullyConnectedNet(hs, act, var_in, var_out, out_act) x = torch.randn(2000, hs[0]) * var_in**0.5 y = f(x) / var_out**0.5 if not out_act: assert y.mean().abs() < 0.5 assert y.pow(2).mean().log10().abs() < torch.tensor(1.5).log10() f = assert_auto_jitable(f) f(x)
def test_dropout(): c = Dropout(irreps='10x1e + 10x0e', p=0.75) x = c.irreps.randn(5, 2, -1) for c in [c, assert_auto_jitable(c)]: c.eval() assert c(x).eq(x).all() c.train() y = c(x) assert (y.eq(x / 0.25) | y.eq(0)).all() def wrap(x): torch.manual_seed(0) return c(x) assert_equivariant(wrap, args_in=[x], irreps_in=[c.irreps], irreps_out=[c.irreps])
def test_save(l1, p1, l2, p2, lo, po, mode, weight): tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight) # Saved TP with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: torch.save(tp, tmp.name) tp2 = torch.load(tmp.name) # JITed, saved TP with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: tp_jit = assert_auto_jitable(tp) tp_jit.save(tmp.name) tp3 = torch.jit.load(tmp.name) # Double-saved TP with tempfile.NamedTemporaryFile(suffix=".pth") as tmp: torch.save(tp2, tmp.name) tp4 = torch.load(tmp.name) x1 = torch.randn(2, tp.irreps_in1.dim) x2 = torch.randn(2, tp.irreps_in2.dim) res1 = tp(x1, x2) res2 = tp2(x1, x2) res3 = tp3(x1, x2) res4 = tp4(x1, x2) assert torch.allclose(res1, res2) assert torch.allclose(res1, res3) assert torch.allclose(res1, res4)
def test_extract_ir(): c = ExtractIr('1e + 0e + 0e', '0e') out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) assert torch.all(out == torch.Tensor([1., 2.])) assert_auto_jitable(c) assert_equivariant(c)
def test_convolution_jit(network): f, _ = network # Get a convolution from the network assert_auto_jitable(f.layers[0].first)
def test_extract(): c = Extract('1e + 0e + 0e', ['0e', '0e'], [(1, ), (2, )]) out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])) assert out == (torch.Tensor([1.]), torch.Tensor([2.])) assert_auto_jitable(c) assert_equivariant(c, irreps_out=list(c.irreps_outs))