def test_crystal_2(data): cg = CrystalGraphFeaturizer(max_num_nbr=15, radius=10, atom_feature='elements') tmp = cg.node_features(data[0]) assert isinstance(tmp, torch.Tensor) assert tmp.shape == (16, 58) edges, ids = cg.edge_features(data[0]) assert isinstance(edges, torch.Tensor) assert edges.shape == (16, 15, 51) assert isinstance(ids, torch.Tensor) assert ids.shape == (16, 15)
def test_crystal_1(data): cg = CrystalGraphFeaturizer() assert cg.feature_labels == ['atom_feature', 'bond_feature'] tmp = cg.node_features(data[0]) assert isinstance(tmp, torch.Tensor) assert tmp.shape == (16, 92) edges, ids = cg.edge_features(data[0]) assert isinstance(edges, torch.Tensor) assert edges.shape == (16, 12, 41) assert isinstance(ids, torch.Tensor) assert ids.shape == (16, 12)