def test_bad_normalize(): def not_normal(x1): return 870.0 * x1.square().relu() not_normal.irreps_in = random_irreps(clean=True, allow_empty=False) not_normal.irreps_out = not_normal.irreps_in with pytest.raises(AssertionError): assert_normalized(not_normal)
def test_bias(): irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o") irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o") m = o3.Linear(irreps_in, irreps_out, biases=[True, False, False, True, False]) with torch.no_grad(): m.bias[:].fill_(1.0) x = m(torch.zeros(irreps_in.dim)) assert torch.allclose( x, torch.tensor([ 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0 ])) assert_equivariant(m) assert_auto_jitable(m) m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5, weights=[m.weight])
def test_linear(): irreps_in = o3.Irreps("1e + 2e + 3x3o") irreps_out = o3.Irreps("1e + 2e + 3x3o") m = o3.Linear(irreps_in, irreps_out) m(torch.randn(irreps_in.dim)) assert_equivariant(m) assert_auto_jitable(m) assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
def test_gate(): irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps( "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh ], Irreps("16x1e+16x1o") sc = _Sortcut(irreps_scalars, irreps_gates) assert_auto_jitable(sc) g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated) assert_equivariant(g) assert_auto_jitable(g) assert_normalized(g)
def test_activation(irreps_in, acts): irreps_in = o3.Irreps(irreps_in) a = Activation(irreps_in, acts) assert_auto_jitable(a) assert_equivariant(a) inp = irreps_in.randn(13, -1) out = a(inp) for ir_slice, act in zip(irreps_in.slices(), acts): this_out = out[:, ir_slice] true_up_to_factor = act(inp[:, ir_slice]) factors = this_out / true_up_to_factor assert torch.allclose(factors, factors[0]) assert_normalized(a)
def test_normalized(l1, p1, l2, p2, lo, po, mode, weight): if torch.get_default_dtype() != torch.float32: pytest.skip( "No reason to run expensive normalization tests again at float64 expense." ) # Explicit fixed path weights screw with the output normalization, # so don't use them m = make_tp(l1, p1, l2, p2, lo, po, mode, weight, mul=5, path_weights=False) # normalization # n_weight, n_input has to be decently high to ensure statistical convergence # especially for uvuv assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
def test_normalized_ident(): def ident(x1): return x1 ident.irreps_in = random_irreps(clean=True, allow_empty=False) ident.irreps_out = ident.irreps_in assert_normalized(ident)