Ejemplo n.º 1
0
def test_self_interaction():
    coords = torch.tensor([
        [0, 0, 0],
        [1.11646617, 0.7894608, 1.93377613]
    ])
    r_max = 2.5
    data_no_si = dh.DataNeighbors(
        x=torch.zeros(size=(len(coords), 1)),
        pos=coords,
        r_max=r_max,
        self_interaction=False,
    )
    true_no_si = torch.LongTensor([
        [0, 1],
        [1, 0]
    ])
    assert edge_index_set_equiv(data_no_si.edge_index, true_no_si)
    data_si = dh.DataNeighbors(
        x=torch.zeros(size=(len(coords), 1)),
        pos=coords,
        r_max=r_max,
        self_interaction=True
    )
    true_si = torch.LongTensor([
        [0, 0, 1, 1],
        [0, 1, 0, 1]
    ])
    assert edge_index_set_equiv(data_si.edge_index, true_si)
Ejemplo n.º 2
0
def process_molecule(pipeline, max_radius, Rs_in, Rs_out):
    while True:
        molecule = pipeline.get_molecule()
        #print(f"> got molecule {molecule.name}, n_examples={len(molecule.perturbed_geometries)}")
        assert isinstance(molecule, Molecule), \
            f"expected Molecule but got {type(molecule)} instead!"

        features = torch.tensor(molecule.features, dtype=torch.float64)
        weights = torch.tensor(molecule.weights, dtype=torch.float64)
        n_examples = len(molecule.perturbed_geometries)
        for j in range(n_examples):
            g = torch.tensor(molecule.perturbed_geometries[j, :, :],
                             dtype=torch.float64)
            s = torch.tensor(molecule.perturbed_shieldings[j],
                             dtype=torch.float64).unsqueeze(-1)  # [1,N]
            dn = dh.DataNeighbors(x=features,
                                  Rs_in=Rs_in,
                                  pos=g,
                                  r_max=max_radius,
                                  self_interaction=True,
                                  ID=molecule.ID,
                                  weights=weights,
                                  y=s,
                                  Rs_out=Rs_out)
            pipeline.put_data_neighbor(dn)
Ejemplo n.º 3
0
def test_relative_vecs():
    coords = torch.tensor([
        [0, 0, 0],
        [1.11646617, 0.7894608, 1.93377613]
    ])
    r_max = 2.5
    data = dh.DataNeighbors(
        x=torch.zeros(size=(len(coords), 1)),
        pos=coords,
        r_max=r_max,
    )
    edge_index_true = torch.LongTensor([
        [0, 0, 1, 1],
        [0, 1, 0, 1]
    ])
    assert edge_index_set_equiv(data.edge_index, edge_index_true)
    assert torch.allclose(
        coords[1] - coords[0],
        data.edge_attr[
            (data.edge_index[0] == 0) & (data.edge_index[1] == 1)
        ][0]
    )
    assert torch.allclose(
        coords[0] - coords[1],
        data.edge_attr[
            (data.edge_index[0] == 1) & (data.edge_index[1] == 0)
        ][0]
    )
Ejemplo n.º 4
0
def test_silicon_neighbors():
    lattice = torch.tensor([
        [3.34939851, 0, 1.93377613],
        [1.11646617, 3.1578432, 1.93377613],
        [0, 0, 3.86755226]
    ])
    coords = torch.tensor([
        [0, 0, 0],
        [1.11646617, 0.7894608, 1.93377613]
    ])
    r_max = 2.5
    edge_index, _edge_attr = dh._neighbor_list_and_relative_vec_lattice(coords, lattice, r_max=r_max)
    edge_index_true = torch.LongTensor([
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 0, 0, 0, 0]
    ])
    assert edge_index_set_equiv(edge_index, edge_index_true)
    data = dh.DataNeighbors(
        x=torch.zeros(size=(len(coords), 1)),
        pos=coords,
        r_max=r_max,
        cell=lattice,
        pbc=True,
    )
    assert edge_index_set_equiv(data.edge_index, edge_index_true)
Ejemplo n.º 5
0
def test_data_helpers():
    N = 7
    lattice = torch.randn(3, 3)
    pos = torch.randn(N, 3)
    Rs_in = [(3, 0), (1, 1)]
    x = torch.randn(N, rs.dim(Rs_in))
    r_max = 1
    dh._neighbor_list_and_relative_vec_lattice(pos, lattice, r_max)
    dh.DataPeriodicNeighbors(x, pos, lattice, r_max)
    dh._neighbor_list_and_relative_vec(pos, r_max)
    dh.DataNeighbors(x, pos, r_max)
Ejemplo n.º 6
0
def test_positions_grad():
    N = 7
    lattice = torch.randn(3, 3)
    pos = torch.randn(N, 3)
    pos.requires_grad_(True)
    Rs_in = [(3, 0), (1, 1)]
    x = torch.randn(N, rs.dim(Rs_in))
    r_max = 1
    data = dh.DataNeighbors(x, pos, r_max)
    assert pos.requires_grad
    assert data.edge_attr.requires_grad
    torch.autograd.grad(data.edge_attr.sum(), pos, create_graph=True)
    data = dh.DataPeriodicNeighbors(x, pos, lattice, r_max)
    assert pos.requires_grad
    assert data.edge_attr.requires_grad
    torch.autograd.grad(data.edge_attr.sum(), pos, create_graph=True)
Ejemplo n.º 7
0
def test_data_neighbors(example, Rs_in, Rs_out, max_radius, molecule_dict):
    dn1 = example
    molecule = molecule_dict[dn1.ID]
    features = torch.tensor(molecule.features, dtype=torch.float64)
    weights = torch.tensor(molecule.weights, dtype=torch.float64)
    g = torch.tensor(molecule.perturbed_geometries, dtype=torch.float64)
    if g.ndim == 3:
        g = g[0, ...]
    s = torch.tensor(molecule.perturbed_shieldings, dtype=torch.float64)
    if s.ndim == 3:
        print("Hello!")
        s = s[..., 0]
    if s.ndim == 2:
        s = s[0, ...]
    dn2 = dh.DataNeighbors(x=features,
                           Rs_in=Rs_in,
                           pos=g,
                           r_max=max_radius,
                           self_interaction=True,
                           ID=molecule.ID,
                           weights=weights,
                           y=s,
                           Rs_out=Rs_out)
    compare_data_neighbors(dn1, dn2)
Ejemplo n.º 8
0
def main():
    torch.set_default_dtype(torch.float64)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)

    x = torch.ones(4, 1)
    Rs_in = [(1, 0, 1)]
    r_max = 1.1

    tetris, labels = get_dataset()
    tetris_dataset = [
        dh.DataNeighbors(x, shape, r_max, y=label)
        for shape, label in zip(tetris, labels)
    ]

    Rs_out = [(1, 0, -1), (6, 0, 1)]
    lmax = 3

    f = MLNetwork(Rs_in, Rs_out, Convolution,
                  partial(make_gated_block, mul=16, lmax=lmax), 2)
    f = f.to(device)

    batch = Batch.from_data_list(tetris_dataset)
    batch = batch.to(device)
    sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr,
                                     'component')

    optimizer = torch.optim.Adam(f.parameters(), lr=3e-3)

    wall = time.perf_counter()
    for step in range(100):
        out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3)
        out = scatter_add(out, batch.batch, dim=0)
        out = torch.tanh(out)

        acc = out.cpu().round().eq(labels).double().mean().item()

        r_tetris_dataset = [
            dh.DataNeighbors(x, shape, r_max, y=label)
            for shape, label in zip(*get_dataset())
        ]
        r_batch = Batch.from_data_list(r_tetris_dataset)
        r_batch = r_batch.to(device)
        r_sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)),
                                           r_batch.edge_attr, 'component')

        with torch.no_grad():
            r_out = f(r_batch.x,
                      r_batch.edge_index,
                      r_batch.edge_attr,
                      sh=r_sh,
                      n_norm=3)
            r_out = scatter_add(r_out, r_batch.batch, dim=0)
            r_out = torch.tanh(r_out)

        loss = (out - labels).pow(2).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(
            "wall={:.1f} step={} loss={:.2e} accuracy={:.2f} equivariance error={:.1e}"
            .format(time.perf_counter() - wall, step, loss.item(), acc,
                    (out - r_out).pow(2).mean().sqrt().item()))

    print(labels.numpy().round(1))
    print(out.detach().numpy().round(1))