예제 #1
0
def test_norm(irreps_in, squared):
    m = o3.Norm(irreps_in, squared=squared)
    m(torch.randn(m.irreps_in.dim))
    if m.irreps_in.dim == 0:
        return
    assert_equivariant(m)
    assert_auto_jitable(m)
예제 #2
0
def test_bias():
    irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o")
    irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o")
    m = o3.Linear(irreps_in,
                  irreps_out,
                  biases=[True, False, False, True, False])
    with torch.no_grad():
        m.bias[:].fill_(1.0)
    x = m(torch.zeros(irreps_in.dim))

    assert torch.allclose(
        x,
        torch.tensor([
            1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 0.0
        ]))

    assert_equivariant(m)
    assert_auto_jitable(m)

    m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True)

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m,
                      n_weight=100,
                      n_input=10_000,
                      atol=0.5,
                      weights=[m.weight])
def test_full():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2x2e + 2x3o")
    m = FullTensorProduct(irreps_in1, irreps_in2)
    print(m)

    assert_equivariant(m)
    assert_auto_jitable(m)
def test_id():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = Identity(irreps_in, irreps_out)
    print(m)
    m(torch.randn(irreps_in.dim))

    assert_equivariant(m)
    assert_auto_jitable(m, strict_shapes=False)
예제 #5
0
def test_input_weights_jit():
    irreps_in1 = Irreps("1e + 2e + 3x3o")
    irreps_in2 = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    # - shared_weights = False -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=False)
    traced = assert_auto_jitable(m)
    x1 = irreps_in1.randn(2, -1)
    x2 = irreps_in2.randn(2, -1)
    w = torch.randn(2, m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, w[0])  # it should reject insufficient weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))

    # Confirm that weird batch dimensions give the same results
    for f in (m, traced):
        x1 = irreps_in1.randn(2, 1, 4, -1)
        x2 = irreps_in2.randn(2, 3, 1, -1)
        w = torch.randn(3, 4, f.weight_numel)
        assert torch.allclose(
            f(x1, x2, w).reshape(24, -1),
            f(
                x1.expand(2, 3, 4, -1).reshape(24, -1),
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)))
        assert torch.allclose(
            f.right(x2, w).reshape(24, -1),
            f.right(
                x2.expand(2, 3, 4, -1).reshape(24, -1),
                w[None].expand(2, 3, 4, -1).reshape(24, -1)).reshape(24, -1))

    # - shared_weights = True -
    m = FullyConnectedTensorProduct(irreps_in1,
                                    irreps_in2,
                                    irreps_out,
                                    internal_weights=False,
                                    shared_weights=True)
    traced = assert_auto_jitable(m)
    w = torch.randn(m.weight_numel)
    with pytest.raises((RuntimeError, torch.jit.Error)):
        m(x1, x2)  # it should require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2)  # it should also require weights
    with pytest.raises((RuntimeError, torch.jit.Error)):
        traced(x1, x2, torch.randn(
            2, m.weight_numel))  # it should reject too many weights
    # Does the trace give right results?
    assert torch.allclose(m(x1, x2, w), traced(x1, x2, w))
예제 #6
0
def test_linear():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = o3.Linear(irreps_in, irreps_out)
    m(torch.randn(irreps_in.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
def test_fully_connected():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    print(m)
    m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
예제 #8
0
def test_extract_single(squeeze):
    c = Extract('1e + 0e + 0e', ['0e'], [(1, )], squeeze_out=squeeze)
    out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]))
    if squeeze:
        assert isinstance(out, torch.Tensor)
    else:
        assert len(out) == 1
        out = out[0]
    assert out == torch.Tensor([1.])
    assert_auto_jitable(c)
    assert_equivariant(c, irreps_out=list(c.irreps_outs))
예제 #9
0
def test_trace_dtypes():
    # FullyConnectedNet is traced
    fc = FullyConnectedNet([8, 16, 8])
    # compile in a dtype other than the default
    target_dtype = {
        torch.float32: torch.float64,
        torch.float64: torch.float32
    }[torch.get_default_dtype()]
    fc = fc.to(dtype=target_dtype)
    for weight in fc.parameters():
        assert weight.dtype == target_dtype
    assert_auto_jitable(fc)
예제 #10
0
def test_antisymmetric_matrix(float_tolerance):
    tp = o3.ReducedTensorProducts('ij=-ji', i='5x0e + 1e')

    assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out)
    assert_auto_jitable(tp)

    Q = tp.change_of_basis
    x = torch.randn(2, 5 + 3)
    assert (tp(*x) -
            torch.einsum('xij,i,j', Q, *x)).abs().max() < float_tolerance

    assert (Q + torch.einsum("xij->xji", Q)).abs().max() < float_tolerance
예제 #11
0
def test_gate():
    irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps(
        "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh
                                                  ], Irreps("16x1e+16x1o")

    sc = _Sortcut(irreps_scalars, irreps_gates)
    assert_auto_jitable(sc)

    g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates,
             irreps_gated)
    assert_equivariant(g)
    assert_auto_jitable(g)
    assert_normalized(g)
예제 #12
0
def test_reduce_tensor_antisymmetric_L2(float_tolerance):
    tp = o3.ReducedTensorProducts('ijk=-ikj=-jik', i='2e')

    assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out)
    assert_auto_jitable(tp)

    Q = tp.change_of_basis
    x = torch.randn(3, 5)
    assert (tp(*x) -
            torch.einsum('xijk,i,j,k', Q, *x)).abs().max() < float_tolerance

    assert (Q + torch.einsum("xijk->xikj", Q)).abs().max() < float_tolerance
    assert (Q + torch.einsum("xijk->xjik", Q)).abs().max() < float_tolerance
예제 #13
0
def test_norm():
    irreps_in = o3.Irreps("3x0e + 5x1o")
    scalars = torch.randn(3)
    vecs = torch.randn(5, 3)
    norm = Norm(irreps_in=irreps_in)
    out_norms = norm(
        torch.cat((scalars.reshape(1, -1), vecs.reshape(1, -1)), dim=-1))
    true_scalar_norms = torch.abs(scalars)
    true_vec_norms = torch.linalg.norm(vecs, dim=-1)
    assert torch.allclose(out_norms[0, :3], true_scalar_norms)
    assert torch.allclose(out_norms[0, 3:], true_vec_norms)

    assert_equivariant(norm)
    assert_auto_jitable(norm)
예제 #14
0
def test_jit_trace():
    @compile_mode('trace')
    class NotTracable(torch.nn.Module):
        def forward(self, param):
            if param.shape[0] == 7:
                return torch.ones(8)
            else:
                return torch.randn(8, 3)
    not_tracable = NotTracable()
    not_tracable.irreps_in = o3.Irreps("2x0e")
    not_tracable.irreps_out = o3.Irreps("1x1o")
    # TorchScript returns some weird exceptions...
    with pytest.raises(Exception):
        assert_auto_jitable(not_tracable)
예제 #15
0
def test_reduce_tensor_elasticity_tensor(float_tolerance):
    tp = o3.ReducedTensorProducts('ijkl=jikl=klij', i='1e')
    assert tp.irreps_out.dim == 21

    assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out)
    assert_auto_jitable(tp)

    Q = tp.change_of_basis
    x = torch.randn(4, 3)
    assert (tp(*x) -
            torch.einsum('xijkl,i,j,k,l', Q, *x)).abs().max() < float_tolerance

    assert (Q - torch.einsum("xijkl->xjikl", Q)).abs().max() < float_tolerance
    assert (Q - torch.einsum("xijkl->xijlk", Q)).abs().max() < float_tolerance
    assert (Q - torch.einsum("xijkl->xklij", Q)).abs().max() < float_tolerance
예제 #16
0
def test_activation(irreps_in, acts):
    irreps_in = o3.Irreps(irreps_in)
    a = Activation(irreps_in, acts)
    assert_auto_jitable(a)
    assert_equivariant(a)

    inp = irreps_in.randn(13, -1)
    out = a(inp)
    for ir_slice, act in zip(irreps_in.slices(), acts):
        this_out = out[:, ir_slice]
        true_up_to_factor = act(inp[:, ir_slice])
        factors = this_out / true_up_to_factor
        assert torch.allclose(factors, factors[0])

    assert_normalized(a)
예제 #17
0
def test_jit(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein):
    """Test the JIT.

    This test is seperate from test_optimizations to ensure that just jitting a model has minimal error if any.
    """
    orig_tp = make_tp(l1,
                      p1,
                      l2,
                      p2,
                      lo,
                      po,
                      mode,
                      weight,
                      _specialized_code=special_code,
                      _optimize_einsums=opt_ein)
    opt_tp = assert_auto_jitable(orig_tp)

    # Confirm equivariance of optimized model
    assert_equivariant(opt_tp,
                       irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2],
                       irreps_out=orig_tp.irreps_out)

    # Confirm that it gives same results
    x1 = orig_tp.irreps_in1.randn(2, -1)
    x2 = orig_tp.irreps_in2.randn(2, -1)
    # TorchScript should casue very little if any numerical error
    assert torch.allclose(
        orig_tp(x1, x2),
        opt_tp(x1, x2),
    )
    assert torch.allclose(
        orig_tp.right(x2),
        opt_tp.right(x2),
    )
예제 #18
0
def test_jit(float_tolerance):
    sh = o3.SphericalHarmonicsAlphaBeta([0, 1, 2])
    jited = assert_auto_jitable(sh)

    a = torch.randn(5, 4)
    b = torch.randn(5, 4)
    return (sh(a, b) - jited(a, b)).abs().max() < float_tolerance
def test_module(normalization, normalize):
    l = o3.Irreps("0e + 1o + 3o")
    sp = o3.SphericalHarmonics(l, normalize, normalization)
    sp_jit = assert_auto_jitable(sp)
    xyz = torch.randn(11, 3)
    assert torch.allclose(
        sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization))
    assert_equivariant(sp)
예제 #20
0
def test():
    from torch_cluster import radius_graph
    from e3nn.util.test import assert_equivariant, assert_auto_jitable

    mp = MessagePassing(
        irreps_node_input="0e",
        irreps_node_hidden="0e + 1e",
        irreps_node_output="1e",
        irreps_node_attr="0e + 1e",
        irreps_edge_attr="1e",
        layers=3,
        fc_neurons=[2, 100],
        num_neighbors=3.0,
    )

    num_nodes = 4
    node_pos = torch.randn(num_nodes, 3)
    edge_index = radius_graph(node_pos, 3.0)
    edge_src, edge_dst = edge_index
    num_edges = edge_index.shape[1]
    edge_attr = node_pos[edge_index[0]] - node_pos[edge_index[1]]

    node_features = torch.randn(num_nodes, 1)
    node_attr = torch.randn(num_nodes, 4)
    edge_scalars = torch.randn(num_edges, 2)

    assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr,
              edge_scalars).shape == (num_nodes, 3)

    assert_equivariant(
        mp,
        irreps_in=[
            mp.irreps_node_input, mp.irreps_node_attr, None, None,
            mp.irreps_edge_attr, None
        ],
        args_in=[
            node_features, node_attr, edge_src, edge_dst, edge_attr,
            edge_scalars
        ],
        irreps_out=[mp.irreps_node_output],
    )

    assert_auto_jitable(mp.layers[0].first)
예제 #21
0
def test_optimizations(l1, p1, l2, p2, lo, po, mode, weight, special_code,
                       opt_ein, jit, float_tolerance):
    orig_tp = make_tp(l1,
                      p1,
                      l2,
                      p2,
                      lo,
                      po,
                      mode,
                      weight,
                      _specialized_code=False,
                      _optimize_einsums=False)
    opt_tp = make_tp(l1,
                     p1,
                     l2,
                     p2,
                     lo,
                     po,
                     mode,
                     weight,
                     _specialized_code=special_code,
                     _optimize_einsums=opt_ein)
    # We don't use state_dict here since that contains things like wigners that can differ between optimized and unoptimized TPs
    with torch.no_grad():
        opt_tp.weight[:] = orig_tp.weight
    assert opt_tp._specialized_code == special_code
    assert opt_tp._optimize_einsums == opt_ein

    if jit:
        opt_tp = assert_auto_jitable(opt_tp)

    # Confirm equivariance of optimized model
    assert_equivariant(opt_tp,
                       irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2],
                       irreps_out=orig_tp.irreps_out)

    # Confirm that it gives same results
    x1 = orig_tp.irreps_in1.randn(2, -1)
    x2 = orig_tp.irreps_in2.randn(2, -1)
    assert torch.allclose(
        orig_tp(x1, x2),
        opt_tp(x1, x2),
        atol=
        float_tolerance  # numerical optimizations can cause meaningful numerical error by changing operations
    )
    assert torch.allclose(orig_tp.right(x2),
                          opt_tp.right(x2),
                          atol=float_tolerance)

    # We also test .to(), even if only with a dtype, to ensure that various optimizations still always store constants in correct ways
    other_dtype = next(d for d in [torch.float32, torch.float64]
                       if d != torch.get_default_dtype())
    x1, x2 = x1.to(other_dtype), x2.to(other_dtype)
    opt_tp = opt_tp.to(other_dtype)
    assert opt_tp(x1, x2).dtype == other_dtype
예제 #22
0
def test_norm_activation_equivariant(do_bias, nonlin):
    irreps_in = e3nn.o3.Irreps(
        # test lots of different irreps
        "2x0e + 3x0o + 5x1o + 1x1e + 2x2e + 1x2o + 1x3e + 1x3o + 1x5e + 1x6o"
    )

    norm_act = NormActivation(
        irreps_in=irreps_in,
        scalar_nonlinearity=nonlin,
        bias=do_bias
    )

    if do_bias:
        # Set up some nonzero biases
        assert len(list(norm_act.parameters())) == 1
        with torch.no_grad():
            norm_act.biases[:] = torch.randn(norm_act.biases.shape)

    assert_equivariant(norm_act)
    assert_auto_jitable(norm_act)
예제 #23
0
def test_f():
    m = o3.Linear("0e + 1e + 2e",
                  "0e + 2x1e + 2e",
                  f_in=44,
                  f_out=5,
                  _optimize_einsums=False)
    assert_equivariant(m, args_in=[torch.randn(10, 44, 9)])
    m = assert_auto_jitable(m)
    y = m(torch.randn(10, 44, 9))
    assert m.weight_numel == 4
    assert m.weight.numel() == 44 * 5 * 4
    assert 0.7 < y.pow(2).mean() < 1.4
예제 #24
0
def test_variance(act, var_in, var_out, out_act):
    hs = (1000, 500, 1500, 4)

    f = FullyConnectedNet(hs, act, var_in, var_out, out_act)

    x = torch.randn(2000, hs[0]) * var_in**0.5
    y = f(x) / var_out**0.5

    if not out_act:
        assert y.mean().abs() < 0.5
    assert y.pow(2).mean().log10().abs() < torch.tensor(1.5).log10()

    f = assert_auto_jitable(f)
    f(x)
예제 #25
0
def test_dropout():
    c = Dropout(irreps='10x1e + 10x0e', p=0.75)
    x = c.irreps.randn(5, 2, -1)

    for c in [c, assert_auto_jitable(c)]:
        c.eval()
        assert c(x).eq(x).all()

        c.train()
        y = c(x)
        assert (y.eq(x / 0.25) | y.eq(0)).all()

        def wrap(x):
            torch.manual_seed(0)
            return c(x)

        assert_equivariant(wrap, args_in=[x], irreps_in=[c.irreps], irreps_out=[c.irreps])
예제 #26
0
def test_save(l1, p1, l2, p2, lo, po, mode, weight):
    tp = make_tp(l1, p1, l2, p2, lo, po, mode, weight)
    # Saved TP
    with tempfile.NamedTemporaryFile(suffix=".pth") as tmp:
        torch.save(tp, tmp.name)
        tp2 = torch.load(tmp.name)
    # JITed, saved TP
    with tempfile.NamedTemporaryFile(suffix=".pth") as tmp:
        tp_jit = assert_auto_jitable(tp)
        tp_jit.save(tmp.name)
        tp3 = torch.jit.load(tmp.name)
    # Double-saved TP
    with tempfile.NamedTemporaryFile(suffix=".pth") as tmp:
        torch.save(tp2, tmp.name)
        tp4 = torch.load(tmp.name)
    x1 = torch.randn(2, tp.irreps_in1.dim)
    x2 = torch.randn(2, tp.irreps_in2.dim)
    res1 = tp(x1, x2)
    res2 = tp2(x1, x2)
    res3 = tp3(x1, x2)
    res4 = tp4(x1, x2)
    assert torch.allclose(res1, res2)
    assert torch.allclose(res1, res3)
    assert torch.allclose(res1, res4)
예제 #27
0
def test_extract_ir():
    c = ExtractIr('1e + 0e + 0e', '0e')
    out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]))
    assert torch.all(out == torch.Tensor([1., 2.]))
    assert_auto_jitable(c)
    assert_equivariant(c)
예제 #28
0
def test_convolution_jit(network):
    f, _ = network
    # Get a convolution from the network
    assert_auto_jitable(f.layers[0].first)
예제 #29
0
def test_extract():
    c = Extract('1e + 0e + 0e', ['0e', '0e'], [(1, ), (2, )])
    out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]))
    assert out == (torch.Tensor([1.]), torch.Tensor([2.]))
    assert_auto_jitable(c)
    assert_equivariant(c, irreps_out=list(c.irreps_outs))