Beispiel #1
0
 def test_simple_molecule_graph(self):
     mol = Molecule(["C", "H", "O"], [[0, 0, 0], [1, 0, 0], [2, 0, 0]])
     graph = SimpleMolGraph().convert(mol)
     self.assertListEqual(to_list(graph["atom"]), [6, 1, 8])
     self.assertTrue(np.allclose(graph["bond"], [1, 2, 1, 1, 2, 1]))
     self.assertListEqual(to_list(graph["index1"]), [0, 0, 1, 1, 2, 2])
     self.assertListEqual(to_list(graph["index2"]), [1, 2, 0, 2, 0, 1])
Beispiel #2
0
 def test_simple_molecule_graph(self):
     mol = Molecule(['C', 'H', 'O'], [[0, 0, 0], [1, 0, 0], [2, 0, 0]])
     graph = SimpleMolGraph().convert(mol)
     self.assertListEqual(to_list(graph['atom']), [6, 1, 8])
     self.assertTrue(np.allclose(graph['bond'], [1, 2, 1, 1, 2, 1]))
     self.assertListEqual(to_list(graph['index1']), [0, 0, 1, 1, 2, 2])
     self.assertListEqual(to_list(graph['index2']), [1, 2, 0, 2, 0, 1])
Beispiel #3
0
    def get_flat_data(self, graphs, targets=None):
        """
        Expand the graph dictionary to form a list of features and targets tensors.
        This is useful when the model is trained on assembled graphs on the fly.

        Args:
            graphs: (list of dictionary) list of graph dictionary for each structure
            targets: (list of float or list) Optional: corresponding target
                values for each structure

        Returns:
            tuple(node_features, edges_features, global_values, index1, index2, targets)
        """

        output = []  # Will be a list of arrays

        # Convert the graphs to matrices
        for feature in ['atom', 'bond', 'state', 'index1', 'index2']:
            output.append([np.array(x[feature]) for x in graphs])

        # If needed, add the targets
        if targets is not None:
            output.append([to_list(t) for t in targets])

        return tuple(output)
Beispiel #4
0
 def test_to_list(self):
     x = 1
     y = [1]
     z = tuple([1, 2, 3])
     v = np.array([1, 2, 3])
     for k in [x, y, z, v]:
         self.assertTrue(type(to_list(k)), list)
Beispiel #5
0
 def test_crystal_graph_with_bond_types(self):
     graph = {'atom': [11, 8, 8],
              'index1': [0, 0, 1, 1, 2, 2],
              'index2': [0, 1, 2, 2, 1, 1],
              'bond': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
              'state': [[0, 0]]}
     cgbt = CrystalGraphWithBondTypes(nn_strategy='VoronoiNN')
     new_graph = cgbt._get_bond_type(graph)
     self.assertListEqual(to_list(new_graph['bond']), [2, 1, 0, 0, 0, 0])
Beispiel #6
0
 def test_crystal_graph_with_bond_types(self):
     graph = {
         "atom": [11, 8, 8],
         "index1": [0, 0, 1, 1, 2, 2],
         "index2": [0, 1, 2, 2, 1, 1],
         "bond": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
         "state": [[0, 0]],
     }
     cgbt = CrystalGraphWithBondTypes(nn_strategy="VoronoiNN")
     new_graph = cgbt._get_bond_type(graph)
     self.assertListEqual(to_list(new_graph["bond"]), [2, 1, 0, 0, 0, 0])
Beispiel #7
0
 def test_crystalgraph(self):
     cg = CrystalGraph(cutoff=4)
     graph = cg.convert(self.structures[0])
     self.assertEqual(cg.cutoff, 4)
     keys = set(graph.keys())
     self.assertSetEqual({"bond", "atom", "index1", "index2", "state"}, keys)
     cg2 = CrystalGraph(cutoff=6)
     self.assertEqual(cg2.cutoff, 6)
     graph2 = cg2.convert(self.structures[0])
     self.assertListEqual(to_list(graph2["state"][0]), [0, 0])
     graph3 = cg(self.structures[0])
     np.testing.assert_almost_equal(graph["atom"], graph3["atom"])
Beispiel #8
0
def index_rep_from_structure(structure, r=4):
    """
    Take a pymatgen structure and convert it to a index-type graph representation
    The graph will have node, distance, index1, index2, where node is a vector
    of Z number of atoms in the structure, index1 and index2 mark the atom
    indices forming the bond and separated by distance

    Args:
        structure: (pymatgen structure)
        r: (float) distance cutoff

    Returns:
        (dictionary)
    """
    atom_i_segment_id = []  # index list for the center atom i for all bonds (row index)
    atom_i_j_id = []  # index list for atom j
    atom_number = []
    all_neighbors = structure.get_all_neighbors(r, include_index=True)
    distances = []

    for k, n in enumerate(all_neighbors):
        atom_number.append(structure.sites[k].specie.Z)
        if len(n) < 1:
            index = None
        else:
            _, distance, index = list(zip(*n))
            index = np.array(index)
            distance = np.array(distance)

        if index is not None:
            ind = np.argsort(index)
            it = itemgetter(*ind)
            index = it(index)
            index = to_list(index)
            index = [int(i) for i in index]
            distance = distance[ind]
            distances.append(distance)
            atom_i_segment_id.extend([k] * len(index))
            atom_i_j_id.extend(index)
        else:
            pass
    if len(distances) < 1:
        return None
    else:
        return {'distance': np.concatenate(distances),
                'index1': atom_i_segment_id,
                'index2': atom_i_j_id,
                'node': atom_number}
Beispiel #9
0
 def test_convert(self):
     cg = CrystalGraph(cutoff=4)
     graph = cg.convert(self.structures[0])
     self.assertListEqual(to_list(graph["atom"]),
                          [i.specie.Z for i in self.structures[0]])