Example #1
0
    def test_SO3Vec_init_channels(self, batch, maxl, channels):

        tau_list = [channels]*(maxl+1)

        test_vec = SO3Vec.rand(batch, tau_list)

        assert test_vec.tau == tau_list
Example #2
0
    def test_SO3Vec_init_arb_tau(self, batch, maxl, channels):

        tau_list = torch.randint(1, channels+1, [maxl+1])

        test_vec = SO3Vec.rand(batch, tau_list)

        assert test_vec.tau == tau_list
Example #3
0
def test_mix_SO3Vec(batch, maxl, channels1, channels2):

    tau_in = [channels1] * (maxl + 1)
    tau_out = [channels2] * (maxl + 1)

    test_vec = SO3Vec.rand(batch, tau_in)
    test_weight = SO3Weight.rand(tau_in, tau_out)

    print(test_vec.shapes, test_weight.shapes)
    mix(test_weight, test_vec)
Example #4
0
    def test_cg_product_dict_maxl(self, maxl_dict, maxl_prod, maxl1, maxl2,
                                  chan, batch):
        cg_dict = CGDict(maxl=maxl_dict, dtype=torch.double)

        tau1, tau2 = [chan] * (maxl1 + 1), [chan] * (maxl2 + 1)

        rep1 = SO3Vec.rand(batch, tau1, dtype=torch.double)
        rep2 = SO3Vec.rand(batch, tau2, dtype=torch.double)

        if all(maxl_dict >= maxl for maxl in [maxl_prod, maxl1, maxl2]):
            cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod)

        else:
            with pytest.raises(ValueError) as e_info:
                cg_prod = cg_product(cg_dict, rep1, rep2, maxl=maxl_prod)

                tau_out = cg_prod.tau
                tau_pred = cg_product_tau(tau1, tau2)

                # Test to make sure the output type matches the expected output type
                assert list(tau_out) == list(tau_pred)

            assert str(e_info.value).startswith('CG Dictionary maxl')
Example #5
0
 def test_apply_euler(self, batch, channels, maxl):
     tau = SO3Tau([channels] * (maxl + 1))
     vec = SO3Vec.rand(batch, tau, dtype=torch.double)
     wigner = SO3WignerD.euler(maxl, dtype=torch.double)
     so3_torch.apply_wigner(wigner, vec)