예제 #1
0
def test_norm(irreps_in, squared):
    m = o3.Norm(irreps_in, squared=squared)
    m(torch.randn(m.irreps_in.dim))
    if m.irreps_in.dim == 0:
        return
    assert_equivariant(m)
    assert_auto_jitable(m)
예제 #2
0
def test_jit(l1, p1, l2, p2, lo, po, mode, weight, special_code, opt_ein):
    """Test the JIT.

    This test is seperate from test_optimizations to ensure that just jitting a model has minimal error if any.
    """
    orig_tp = make_tp(l1,
                      p1,
                      l2,
                      p2,
                      lo,
                      po,
                      mode,
                      weight,
                      _specialized_code=special_code,
                      _optimize_einsums=opt_ein)
    opt_tp = assert_auto_jitable(orig_tp)

    # Confirm equivariance of optimized model
    assert_equivariant(opt_tp,
                       irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2],
                       irreps_out=orig_tp.irreps_out)

    # Confirm that it gives same results
    x1 = orig_tp.irreps_in1.randn(2, -1)
    x2 = orig_tp.irreps_in2.randn(2, -1)
    # TorchScript should casue very little if any numerical error
    assert torch.allclose(
        orig_tp(x1, x2),
        opt_tp(x1, x2),
    )
    assert torch.allclose(
        orig_tp.right(x2),
        opt_tp.right(x2),
    )
예제 #3
0
def test_bias():
    irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o")
    irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o")
    m = o3.Linear(irreps_in,
                  irreps_out,
                  biases=[True, False, False, True, False])
    with torch.no_grad():
        m.bias[:].fill_(1.0)
    x = m(torch.zeros(irreps_in.dim))

    assert torch.allclose(
        x,
        torch.tensor([
            1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 0.0
        ]))

    assert_equivariant(m)
    assert_auto_jitable(m)

    m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True)

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m,
                      n_weight=100,
                      n_input=10_000,
                      atol=0.5,
                      weights=[m.weight])
def test_full():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2x2e + 2x3o")
    m = FullTensorProduct(irreps_in1, irreps_in2)
    print(m)

    assert_equivariant(m)
    assert_auto_jitable(m)
예제 #5
0
def test_equivariance(act, lmax):
    m = SO3Activation(lmax, lmax, act, 6)

    assert_equivariant(m,
                       ntrials=10,
                       tolerance=0.04,
                       irreps_in=so3_irreps(lmax),
                       irreps_out=so3_irreps(lmax))
def test_module(normalization, normalize):
    l = o3.Irreps("0e + 1o + 3o")
    sp = o3.SphericalHarmonics(l, normalize, normalization)
    sp_jit = assert_auto_jitable(sp)
    xyz = torch.randn(11, 3)
    assert torch.allclose(
        sp_jit(xyz), o3.spherical_harmonics(l, xyz, normalize, normalization))
    assert_equivariant(sp)
예제 #7
0
def main():
    data, labels = tetris()
    f = Network()

    print("Built a model:")
    print(f)

    optim = torch.optim.Adam(f.parameters(), lr=1e-3)

    # == Training ==
    for step in range(200):
        pred = f(data)
        loss = (pred - labels).pow(2).sum()

        optim.zero_grad()
        loss.backward()
        optim.step()

        if step % 10 == 0:
            accuracy = pred.round().eq(labels).all(dim=1).double().mean(
                dim=0).item()
            print(
                f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy"
            )

    # == Check equivariance ==
    # Because the model outputs (psuedo)scalars, we can easily directly
    # check its equivariance to the same data with new rotations:
    print("Testing equivariance directly...")
    rotated_data, _ = tetris()
    error = f(rotated_data) - f(data)
    print(f"Equivariance error = {error.abs().max().item():.1e}")

    print("Testing equivariance using `assert_equivariance`...")

    # We can also use the library's `assert_equivariant` helper
    # `assert_equivariant` also tests parity and translation, and
    # can handle non-(psuedo)scalar outputs.
    # To "interpret" between it and torch_geometric, we use a small wrapper:

    def wrapper(pos, batch):
        return f(Data(pos=pos, batch=batch))

    # `assert_equivariant` uses logging to print a summary of the equivariance error,
    # so we enable logging
    logging.basicConfig(level=logging.INFO)
    assert_equivariant(
        wrapper,
        # We provide the original data that `assert_equivariant` will transform...
        args_in=[data.pos, data.batch],
        # ...in accordance with these irreps...
        irreps_in=[
            "cartesian_points",  # pos has vector 1o irreps, but is also translation equivariant
            None,  # `None` indicates invariant, possibly non-floating-point data
        ],
        # ...and confirm that the outputs transform correspondingly for these irreps:
        irreps_out=[f.irreps_out],
    )
예제 #8
0
def test_equivariant():
    irreps = o3.Irreps("3x0e + 3x0o + 4x1e")
    m = BatchNorm(irreps)
    m(irreps.randn(16, -1))
    m(irreps.randn(16, -1))
    m.train()
    assert_equivariant(m, irreps_in=[irreps], irreps_out=[irreps])
    m.eval()
    assert_equivariant(m, irreps_in=[irreps], irreps_out=[irreps])
예제 #9
0
def test_linear():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = o3.Linear(irreps_in, irreps_out)
    m(torch.randn(irreps_in.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
예제 #10
0
def test_id():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = Identity(irreps_in, irreps_out)
    print(m)
    m(torch.randn(irreps_in.dim))

    assert_equivariant(m)
    assert_auto_jitable(m, strict_shapes=False)
예제 #11
0
def test_fully_connected():
    irreps_in1 = o3.Irreps("1e + 2e + 3x3o")
    irreps_in2 = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
    print(m)
    m(torch.randn(irreps_in1.dim), torch.randn(irreps_in2.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
예제 #12
0
def test_assert_equivariant():
    def not_equivariant(x1, x2):
        return x1*x2
    not_equivariant.irreps_in1 = o3.Irreps("2x0e + 1x1e + 3x2o + 1x4e")
    not_equivariant.irreps_in2 = o3.Irreps("2x0o + 3x0o + 3x2e + 1x4o")
    not_equivariant.irreps_out = o3.Irreps("1x1e + 2x0o + 3x2e + 1x4o")
    assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_in2.dim
    assert not_equivariant.irreps_in1.dim == not_equivariant.irreps_out.dim
    with pytest.raises(AssertionError):
        assert_equivariant(not_equivariant)
예제 #13
0
def test_optimizations(l1, p1, l2, p2, lo, po, mode, weight, special_code,
                       opt_ein, jit, float_tolerance):
    orig_tp = make_tp(l1,
                      p1,
                      l2,
                      p2,
                      lo,
                      po,
                      mode,
                      weight,
                      _specialized_code=False,
                      _optimize_einsums=False)
    opt_tp = make_tp(l1,
                     p1,
                     l2,
                     p2,
                     lo,
                     po,
                     mode,
                     weight,
                     _specialized_code=special_code,
                     _optimize_einsums=opt_ein)
    # We don't use state_dict here since that contains things like wigners that can differ between optimized and unoptimized TPs
    with torch.no_grad():
        opt_tp.weight[:] = orig_tp.weight
    assert opt_tp._specialized_code == special_code
    assert opt_tp._optimize_einsums == opt_ein

    if jit:
        opt_tp = assert_auto_jitable(opt_tp)

    # Confirm equivariance of optimized model
    assert_equivariant(opt_tp,
                       irreps_in=[orig_tp.irreps_in1, orig_tp.irreps_in2],
                       irreps_out=orig_tp.irreps_out)

    # Confirm that it gives same results
    x1 = orig_tp.irreps_in1.randn(2, -1)
    x2 = orig_tp.irreps_in2.randn(2, -1)
    assert torch.allclose(
        orig_tp(x1, x2),
        opt_tp(x1, x2),
        atol=
        float_tolerance  # numerical optimizations can cause meaningful numerical error by changing operations
    )
    assert torch.allclose(orig_tp.right(x2),
                          opt_tp.right(x2),
                          atol=float_tolerance)

    # We also test .to(), even if only with a dtype, to ensure that various optimizations still always store constants in correct ways
    other_dtype = next(d for d in [torch.float32, torch.float64]
                       if d != torch.get_default_dtype())
    x1, x2 = x1.to(other_dtype), x2.to(other_dtype)
    opt_tp = opt_tp.to(other_dtype)
    assert opt_tp(x1, x2).dtype == other_dtype
예제 #14
0
def test_extract_single(squeeze):
    c = Extract('1e + 0e + 0e', ['0e'], [(1, )], squeeze_out=squeeze)
    out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]))
    if squeeze:
        assert isinstance(out, torch.Tensor)
    else:
        assert len(out) == 1
        out = out[0]
    assert out == torch.Tensor([1.])
    assert_auto_jitable(c)
    assert_equivariant(c, irreps_out=list(c.irreps_outs))
예제 #15
0
def test_equivariance(float_tolerance, act, normalization, p_val, p_arg):
    irreps = io.SphericalTensor(3, p_val, p_arg)

    m = S2Activation(irreps,
                     act,
                     120,
                     normalization=normalization,
                     lmax_out=6,
                     random_rot=True)

    assert_equivariant(m, ntrials=10, tolerance=torch.sqrt(float_tolerance))
예제 #16
0
def test_equivariance(lmax, res_b, res_a):
    m = FromS2Grid((res_b, res_a), lmax)
    k = ToS2Grid(lmax, (res_b, res_a))

    def f(x):
        y = k(x)
        y = y.exp()
        return m(y)

    f.irreps_in = f.irreps_out = Irreps.spherical_harmonics(lmax)

    assert_equivariant(f)
예제 #17
0
def test_antisymmetric_matrix(float_tolerance):
    tp = o3.ReducedTensorProducts('ij=-ji', i='5x0e + 1e')

    assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out)
    assert_auto_jitable(tp)

    Q = tp.change_of_basis
    x = torch.randn(2, 5 + 3)
    assert (tp(*x) -
            torch.einsum('xij,i,j', Q, *x)).abs().max() < float_tolerance

    assert (Q + torch.einsum("xij->xji", Q)).abs().max() < float_tolerance
예제 #18
0
def test_equivariant():
    # Confirm that a compiled tensorproduct is still equivariant
    irreps_in = Irreps("1e + 2e + 3x3o")
    irreps_out = Irreps("1e + 2e + 3x3o")
    mod = Linear(irreps_in, irreps_out)
    mod_script = compile(mod)
    assert_equivariant(
        mod_script,
        # we provide explicit irreps because infering on a script module is not reliable
        irreps_in=irreps_in,
        irreps_out=irreps_out
    )
예제 #19
0
def test_f():
    m = o3.Linear("0e + 1e + 2e",
                  "0e + 2x1e + 2e",
                  f_in=44,
                  f_out=5,
                  _optimize_einsums=False)
    assert_equivariant(m, args_in=[torch.randn(10, 44, 9)])
    m = assert_auto_jitable(m)
    y = m(torch.randn(10, 44, 9))
    assert m.weight_numel == 4
    assert m.weight.numel() == 44 * 5 * 4
    assert 0.7 < y.pow(2).mean() < 1.4
예제 #20
0
def test_gate():
    irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps(
        "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh
                                                  ], Irreps("16x1e+16x1o")

    sc = _Sortcut(irreps_scalars, irreps_gates)
    assert_auto_jitable(sc)

    g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates,
             irreps_gated)
    assert_equivariant(g)
    assert_auto_jitable(g)
    assert_normalized(g)
예제 #21
0
def test_reduce_tensor_antisymmetric_L2(float_tolerance):
    tp = o3.ReducedTensorProducts('ijk=-ikj=-jik', i='2e')

    assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out)
    assert_auto_jitable(tp)

    Q = tp.change_of_basis
    x = torch.randn(3, 5)
    assert (tp(*x) -
            torch.einsum('xijk,i,j,k', Q, *x)).abs().max() < float_tolerance

    assert (Q + torch.einsum("xijk->xikj", Q)).abs().max() < float_tolerance
    assert (Q + torch.einsum("xijk->xjik", Q)).abs().max() < float_tolerance
예제 #22
0
def test_gate_points_2101_equivariant(network):
    f, random_graph = network

    # -- Test equivariance: --
    def wrapper(pos, x, z):
        data = Data(pos=pos, x=x, z=z, batch=torch.zeros(pos.shape[0], dtype=torch.long))
        return f(data)

    assert_equivariant(
        wrapper,
        irreps_in=['cartesian_points', f.irreps_in, f.irreps_node_attr],
        irreps_out=[f.irreps_out],
    )
예제 #23
0
def test_norm():
    irreps_in = o3.Irreps("3x0e + 5x1o")
    scalars = torch.randn(3)
    vecs = torch.randn(5, 3)
    norm = Norm(irreps_in=irreps_in)
    out_norms = norm(
        torch.cat((scalars.reshape(1, -1), vecs.reshape(1, -1)), dim=-1))
    true_scalar_norms = torch.abs(scalars)
    true_vec_norms = torch.linalg.norm(vecs, dim=-1)
    assert torch.allclose(out_norms[0, :3], true_scalar_norms)
    assert torch.allclose(out_norms[0, 3:], true_vec_norms)

    assert_equivariant(norm)
    assert_auto_jitable(norm)
예제 #24
0
def test_reduce_tensor_elasticity_tensor(float_tolerance):
    tp = o3.ReducedTensorProducts('ijkl=jikl=klij', i='1e')
    assert tp.irreps_out.dim == 21

    assert_equivariant(tp, irreps_in=tp.irreps_in, irreps_out=tp.irreps_out)
    assert_auto_jitable(tp)

    Q = tp.change_of_basis
    x = torch.randn(4, 3)
    assert (tp(*x) -
            torch.einsum('xijkl,i,j,k,l', Q, *x)).abs().max() < float_tolerance

    assert (Q - torch.einsum("xijkl->xjikl", Q)).abs().max() < float_tolerance
    assert (Q - torch.einsum("xijkl->xijlk", Q)).abs().max() < float_tolerance
    assert (Q - torch.einsum("xijkl->xklij", Q)).abs().max() < float_tolerance
예제 #25
0
def test_activation(irreps_in, acts):
    irreps_in = o3.Irreps(irreps_in)
    a = Activation(irreps_in, acts)
    assert_auto_jitable(a)
    assert_equivariant(a)

    inp = irreps_in.randn(13, -1)
    out = a(inp)
    for ir_slice, act in zip(irreps_in.slices(), acts):
        this_out = out[:, ir_slice]
        true_up_to_factor = act(inp[:, ir_slice])
        factors = this_out / true_up_to_factor
        assert torch.allclose(factors, factors[0])

    assert_normalized(a)
예제 #26
0
def test_dropout():
    c = Dropout(irreps='10x1e + 10x0e', p=0.75)
    x = c.irreps.randn(5, 2, -1)

    for c in [c, assert_auto_jitable(c)]:
        c.eval()
        assert c(x).eq(x).all()

        c.train()
        y = c(x)
        assert (y.eq(x / 0.25) | y.eq(0)).all()

        def wrap(x):
            torch.manual_seed(0)
            return c(x)

        assert_equivariant(wrap, args_in=[x], irreps_in=[c.irreps], irreps_out=[c.irreps])
예제 #27
0
def test():
    from torch_cluster import radius_graph
    from e3nn.util.test import assert_equivariant, assert_auto_jitable

    mp = MessagePassing(
        irreps_node_input="0e",
        irreps_node_hidden="0e + 1e",
        irreps_node_output="1e",
        irreps_node_attr="0e + 1e",
        irreps_edge_attr="1e",
        layers=3,
        fc_neurons=[2, 100],
        num_neighbors=3.0,
    )

    num_nodes = 4
    node_pos = torch.randn(num_nodes, 3)
    edge_index = radius_graph(node_pos, 3.0)
    edge_src, edge_dst = edge_index
    num_edges = edge_index.shape[1]
    edge_attr = node_pos[edge_index[0]] - node_pos[edge_index[1]]

    node_features = torch.randn(num_nodes, 1)
    node_attr = torch.randn(num_nodes, 4)
    edge_scalars = torch.randn(num_edges, 2)

    assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr,
              edge_scalars).shape == (num_nodes, 3)

    assert_equivariant(
        mp,
        irreps_in=[
            mp.irreps_node_input, mp.irreps_node_attr, None, None,
            mp.irreps_edge_attr, None
        ],
        args_in=[
            node_features, node_attr, edge_src, edge_dst, edge_attr,
            edge_scalars
        ],
        irreps_out=[mp.irreps_node_output],
    )

    assert_auto_jitable(mp.layers[0].first)
예제 #28
0
def test_bilinear_right_variance_equivariance(float_tolerance, l1, p1, l2, p2,
                                              lo, po, mode, weight):
    eps = float_tolerance
    n = 1_500
    tol = 3.0

    m = make_tp(l1, p1, l2, p2, lo, po, mode, weight)

    # bilinear
    x1 = torch.randn(2, m.irreps_in1.dim)
    x2 = torch.randn(2, m.irreps_in1.dim)
    y1 = torch.randn(2, m.irreps_in2.dim)
    y2 = torch.randn(2, m.irreps_in2.dim)

    z1 = m(x1 + 1.7 * x2, y1 - y2)
    z2 = m(x1, y1 - y2) + 1.7 * m(x2, y1 - y2)
    z3 = m(x1 + 1.7 * x2, y1) - m(x1 + 1.7 * x2, y2)
    assert (z1 - z2).abs().max() < eps
    assert (z1 - z3).abs().max() < eps

    # right
    z1 = m(x1, y1)
    z2 = torch.einsum('zi,zij->zj', x1, m.right(y1))
    assert (z1 - z2).abs().max() < eps

    # variance
    x1 = torch.randn(n, m.irreps_in1.dim)
    y1 = torch.randn(n, m.irreps_in2.dim)
    z1 = m(x1, y1).var(0)
    assert z1.mean().log10().abs() < torch.tensor(tol).log10()

    # equivariance
    assert_equivariant(m,
                       irreps_in=[m.irreps_in1, m.irreps_in2],
                       irreps_out=m.irreps_out)

    if weight:
        # linear in weights
        w1 = m.weight.clone().normal_()
        w2 = m.weight.clone().normal_()
        z1 = m(x1, y1, weight=w1) + 1.5 * m(x1, y1, weight=w2)
        z2 = m(x1, y1, weight=w1 + 1.5 * w2)
        assert (z1 - z2).abs().max() < eps
예제 #29
0
def test_norm_activation_equivariant(do_bias, nonlin):
    irreps_in = e3nn.o3.Irreps(
        # test lots of different irreps
        "2x0e + 3x0o + 5x1o + 1x1e + 2x2e + 1x2o + 1x3e + 1x3o + 1x5e + 1x6o"
    )

    norm_act = NormActivation(
        irreps_in=irreps_in,
        scalar_nonlinearity=nonlin,
        bias=do_bias
    )

    if do_bias:
        # Set up some nonzero biases
        assert len(list(norm_act.parameters())) == 1
        with torch.no_grad():
            norm_act.biases[:] = torch.randn(norm_act.biases.shape)

    assert_equivariant(norm_act)
    assert_auto_jitable(norm_act)
예제 #30
0
def test_extract_ir():
    c = ExtractIr('1e + 0e + 0e', '0e')
    out = c(torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]))
    assert torch.all(out == torch.Tensor([1., 2.]))
    assert_auto_jitable(c)
    assert_equivariant(c)