Esempio n. 1
0
def test_bad_normalize():
    def not_normal(x1):
        return 870.0 * x1.square().relu()
    not_normal.irreps_in = random_irreps(clean=True, allow_empty=False)
    not_normal.irreps_out = not_normal.irreps_in
    with pytest.raises(AssertionError):
        assert_normalized(not_normal)
Esempio n. 2
0
def test_bias():
    irreps_in = o3.Irreps("2x0e + 1e + 2x0e + 0o")
    irreps_out = o3.Irreps("3x0e + 1e + 3x0e + 5x0e + 0o")
    m = o3.Linear(irreps_in,
                  irreps_out,
                  biases=[True, False, False, True, False])
    with torch.no_grad():
        m.bias[:].fill_(1.0)
    x = m(torch.zeros(irreps_in.dim))

    assert torch.allclose(
        x,
        torch.tensor([
            1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 0.0
        ]))

    assert_equivariant(m)
    assert_auto_jitable(m)

    m = o3.Linear("0e + 0o + 1e + 1o", "10x0e + 0o + 1e + 1o", biases=True)

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m,
                      n_weight=100,
                      n_input=10_000,
                      atol=0.5,
                      weights=[m.weight])
Esempio n. 3
0
def test_linear():
    irreps_in = o3.Irreps("1e + 2e + 3x3o")
    irreps_out = o3.Irreps("1e + 2e + 3x3o")
    m = o3.Linear(irreps_in, irreps_out)
    m(torch.randn(irreps_in.dim))

    assert_equivariant(m)
    assert_auto_jitable(m)
    assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
Esempio n. 4
0
def test_gate():
    irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated = Irreps(
        "16x0o"), [torch.tanh], Irreps("32x0o"), [torch.tanh
                                                  ], Irreps("16x1e+16x1o")

    sc = _Sortcut(irreps_scalars, irreps_gates)
    assert_auto_jitable(sc)

    g = Gate(irreps_scalars, act_scalars, irreps_gates, act_gates,
             irreps_gated)
    assert_equivariant(g)
    assert_auto_jitable(g)
    assert_normalized(g)
Esempio n. 5
0
def test_activation(irreps_in, acts):
    irreps_in = o3.Irreps(irreps_in)
    a = Activation(irreps_in, acts)
    assert_auto_jitable(a)
    assert_equivariant(a)

    inp = irreps_in.randn(13, -1)
    out = a(inp)
    for ir_slice, act in zip(irreps_in.slices(), acts):
        this_out = out[:, ir_slice]
        true_up_to_factor = act(inp[:, ir_slice])
        factors = this_out / true_up_to_factor
        assert torch.allclose(factors, factors[0])

    assert_normalized(a)
Esempio n. 6
0
def test_normalized(l1, p1, l2, p2, lo, po, mode, weight):
    if torch.get_default_dtype() != torch.float32:
        pytest.skip(
            "No reason to run expensive normalization tests again at float64 expense."
        )
    # Explicit fixed path weights screw with the output normalization,
    # so don't use them
    m = make_tp(l1,
                p1,
                l2,
                p2,
                lo,
                po,
                mode,
                weight,
                mul=5,
                path_weights=False)
    # normalization
    # n_weight, n_input has to be decently high to ensure statistical convergence
    # especially for uvuv
    assert_normalized(m, n_weight=100, n_input=10_000, atol=0.5)
Esempio n. 7
0
def test_normalized_ident():
    def ident(x1):
        return x1
    ident.irreps_in = random_irreps(clean=True, allow_empty=False)
    ident.irreps_out = ident.irreps_in
    assert_normalized(ident)