def test_self_interaction(): coords = torch.tensor([ [0, 0, 0], [1.11646617, 0.7894608, 1.93377613] ]) r_max = 2.5 data_no_si = dh.DataNeighbors( x=torch.zeros(size=(len(coords), 1)), pos=coords, r_max=r_max, self_interaction=False, ) true_no_si = torch.LongTensor([ [0, 1], [1, 0] ]) assert edge_index_set_equiv(data_no_si.edge_index, true_no_si) data_si = dh.DataNeighbors( x=torch.zeros(size=(len(coords), 1)), pos=coords, r_max=r_max, self_interaction=True ) true_si = torch.LongTensor([ [0, 0, 1, 1], [0, 1, 0, 1] ]) assert edge_index_set_equiv(data_si.edge_index, true_si)
def process_molecule(pipeline, max_radius, Rs_in, Rs_out): while True: molecule = pipeline.get_molecule() #print(f"> got molecule {molecule.name}, n_examples={len(molecule.perturbed_geometries)}") assert isinstance(molecule, Molecule), \ f"expected Molecule but got {type(molecule)} instead!" features = torch.tensor(molecule.features, dtype=torch.float64) weights = torch.tensor(molecule.weights, dtype=torch.float64) n_examples = len(molecule.perturbed_geometries) for j in range(n_examples): g = torch.tensor(molecule.perturbed_geometries[j, :, :], dtype=torch.float64) s = torch.tensor(molecule.perturbed_shieldings[j], dtype=torch.float64).unsqueeze(-1) # [1,N] dn = dh.DataNeighbors(x=features, Rs_in=Rs_in, pos=g, r_max=max_radius, self_interaction=True, ID=molecule.ID, weights=weights, y=s, Rs_out=Rs_out) pipeline.put_data_neighbor(dn)
def test_relative_vecs(): coords = torch.tensor([ [0, 0, 0], [1.11646617, 0.7894608, 1.93377613] ]) r_max = 2.5 data = dh.DataNeighbors( x=torch.zeros(size=(len(coords), 1)), pos=coords, r_max=r_max, ) edge_index_true = torch.LongTensor([ [0, 0, 1, 1], [0, 1, 0, 1] ]) assert edge_index_set_equiv(data.edge_index, edge_index_true) assert torch.allclose( coords[1] - coords[0], data.edge_attr[ (data.edge_index[0] == 0) & (data.edge_index[1] == 1) ][0] ) assert torch.allclose( coords[0] - coords[1], data.edge_attr[ (data.edge_index[0] == 1) & (data.edge_index[1] == 0) ][0] )
def test_silicon_neighbors(): lattice = torch.tensor([ [3.34939851, 0, 1.93377613], [1.11646617, 3.1578432, 1.93377613], [0, 0, 3.86755226] ]) coords = torch.tensor([ [0, 0, 0], [1.11646617, 0.7894608, 1.93377613] ]) r_max = 2.5 edge_index, _edge_attr = dh._neighbor_list_and_relative_vec_lattice(coords, lattice, r_max=r_max) edge_index_true = torch.LongTensor([ [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], [0, 1, 1, 1, 1, 1, 0, 0, 0, 0] ]) assert edge_index_set_equiv(edge_index, edge_index_true) data = dh.DataNeighbors( x=torch.zeros(size=(len(coords), 1)), pos=coords, r_max=r_max, cell=lattice, pbc=True, ) assert edge_index_set_equiv(data.edge_index, edge_index_true)
def test_data_helpers(): N = 7 lattice = torch.randn(3, 3) pos = torch.randn(N, 3) Rs_in = [(3, 0), (1, 1)] x = torch.randn(N, rs.dim(Rs_in)) r_max = 1 dh._neighbor_list_and_relative_vec_lattice(pos, lattice, r_max) dh.DataPeriodicNeighbors(x, pos, lattice, r_max) dh._neighbor_list_and_relative_vec(pos, r_max) dh.DataNeighbors(x, pos, r_max)
def test_positions_grad(): N = 7 lattice = torch.randn(3, 3) pos = torch.randn(N, 3) pos.requires_grad_(True) Rs_in = [(3, 0), (1, 1)] x = torch.randn(N, rs.dim(Rs_in)) r_max = 1 data = dh.DataNeighbors(x, pos, r_max) assert pos.requires_grad assert data.edge_attr.requires_grad torch.autograd.grad(data.edge_attr.sum(), pos, create_graph=True) data = dh.DataPeriodicNeighbors(x, pos, lattice, r_max) assert pos.requires_grad assert data.edge_attr.requires_grad torch.autograd.grad(data.edge_attr.sum(), pos, create_graph=True)
def test_data_neighbors(example, Rs_in, Rs_out, max_radius, molecule_dict): dn1 = example molecule = molecule_dict[dn1.ID] features = torch.tensor(molecule.features, dtype=torch.float64) weights = torch.tensor(molecule.weights, dtype=torch.float64) g = torch.tensor(molecule.perturbed_geometries, dtype=torch.float64) if g.ndim == 3: g = g[0, ...] s = torch.tensor(molecule.perturbed_shieldings, dtype=torch.float64) if s.ndim == 3: print("Hello!") s = s[..., 0] if s.ndim == 2: s = s[0, ...] dn2 = dh.DataNeighbors(x=features, Rs_in=Rs_in, pos=g, r_max=max_radius, self_interaction=True, ID=molecule.ID, weights=weights, y=s, Rs_out=Rs_out) compare_data_neighbors(dn1, dn2)
def main(): torch.set_default_dtype(torch.float64) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(device) x = torch.ones(4, 1) Rs_in = [(1, 0, 1)] r_max = 1.1 tetris, labels = get_dataset() tetris_dataset = [ dh.DataNeighbors(x, shape, r_max, y=label) for shape, label in zip(tetris, labels) ] Rs_out = [(1, 0, -1), (6, 0, 1)] lmax = 3 f = MLNetwork(Rs_in, Rs_out, Convolution, partial(make_gated_block, mul=16, lmax=lmax), 2) f = f.to(device) batch = Batch.from_data_list(tetris_dataset) batch = batch.to(device) sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), batch.edge_attr, 'component') optimizer = torch.optim.Adam(f.parameters(), lr=3e-3) wall = time.perf_counter() for step in range(100): out = f(batch.x, batch.edge_index, batch.edge_attr, sh=sh, n_norm=3) out = scatter_add(out, batch.batch, dim=0) out = torch.tanh(out) acc = out.cpu().round().eq(labels).double().mean().item() r_tetris_dataset = [ dh.DataNeighbors(x, shape, r_max, y=label) for shape, label in zip(*get_dataset()) ] r_batch = Batch.from_data_list(r_tetris_dataset) r_batch = r_batch.to(device) r_sh = rsh.spherical_harmonics_xyz(list(range(lmax + 1)), r_batch.edge_attr, 'component') with torch.no_grad(): r_out = f(r_batch.x, r_batch.edge_index, r_batch.edge_attr, sh=r_sh, n_norm=3) r_out = scatter_add(r_out, r_batch.batch, dim=0) r_out = torch.tanh(r_out) loss = (out - labels).pow(2).mean() optimizer.zero_grad() loss.backward() optimizer.step() print( "wall={:.1f} step={} loss={:.2e} accuracy={:.2f} equivariance error={:.1e}" .format(time.perf_counter() - wall, step, loss.item(), acc, (out - r_out).pow(2).mean().sqrt().item())) print(labels.numpy().round(1)) print(out.detach().numpy().round(1))