def test_equal(self): s = self.get_structure("SiO2") en = EnvironmentNode(central_site=s[0], i_central_site=0, ce_symbol="T:4") en1 = EnvironmentNode(central_site=s[2], i_central_site=0, ce_symbol="T:4") assert en == en1 assert not en.everything_equal(en1) en2 = EnvironmentNode(central_site=s[0], i_central_site=3, ce_symbol="T:4") assert en != en2 assert not en.everything_equal(en2) en3 = EnvironmentNode(central_site=s[0], i_central_site=0, ce_symbol="O:6") assert en == en3 assert not en.everything_equal(en3) en4 = EnvironmentNode(central_site=s[0], i_central_site=0, ce_symbol="T:4") assert en == en4 assert en.everything_equal(en4)
def test_cycle(self): e1 = EnvironmentNode(central_site="Si", i_central_site=0, ce_symbol="T:4") e2 = EnvironmentNode(central_site="Si", i_central_site=3, ce_symbol="T:4") e3 = EnvironmentNode(central_site="Si", i_central_site=2, ce_symbol="T:4") e4 = EnvironmentNode(central_site="Si", i_central_site=5, ce_symbol="T:4") e5 = EnvironmentNode(central_site="Si", i_central_site=1, ce_symbol="T:4") # Tests of SimpleGraphCycle with EnvironmentNodes c1 = SimpleGraphCycle([e2]) c2 = SimpleGraphCycle([e2]) self.assertEqual(c1, c2) c1 = SimpleGraphCycle([e1]) c2 = SimpleGraphCycle([e2]) self.assertNotEqual(c1, c2) c1 = SimpleGraphCycle([e1, e2, e3]) c2 = SimpleGraphCycle([e2, e1, e3]) self.assertEqual(c1, c2) c2 = SimpleGraphCycle([e2, e3, e1]) self.assertEqual(c1, c2) c1 = SimpleGraphCycle([e3, e2, e4, e1, e5]) c2 = SimpleGraphCycle([e1, e4, e2, e3, e5]) self.assertEqual(c1, c2) c2 = SimpleGraphCycle([e2, e3, e5, e1, e4]) self.assertEqual(c1, c2) c1 = SimpleGraphCycle([e2, e3, e4, e1, e5]) c2 = SimpleGraphCycle([e2, e3, e5, e1, e4]) self.assertNotEqual(c1, c2) # Tests of MultiGraphCycle with EnvironmentNodes c1 = MultiGraphCycle([e1], [2]) c2 = MultiGraphCycle([e1], [2]) self.assertEqual(c1, c2) c2 = MultiGraphCycle([e1], [1]) self.assertNotEqual(c1, c2) c2 = MultiGraphCycle([e2], [2]) self.assertNotEqual(c1, c2) c1 = MultiGraphCycle([e1, e2], [0, 1]) c2 = MultiGraphCycle([e1, e2], [1, 0]) self.assertEqual(c1, c2) c2 = MultiGraphCycle([e2, e1], [1, 0]) self.assertEqual(c1, c2) c2 = MultiGraphCycle([e2, e1], [0, 1]) self.assertEqual(c1, c2) c2 = MultiGraphCycle([e2, e1], [2, 1]) self.assertNotEqual(c1, c2) c1 = MultiGraphCycle([e1, e2, e3], [0, 1, 2]) c2 = MultiGraphCycle([e2, e1, e3], [0, 2, 1]) self.assertEqual(c1, c2) c2 = MultiGraphCycle([e2, e3, e1], [1, 2, 0]) self.assertEqual(c1, c2)
def from_dict(cls, d): """ Reconstructs the ConnectedComponent object from a dict representation of the ConnectedComponent object created using the as_dict method. Args: d (dict): dict representation of the ConnectedComponent object Returns: ConnectedComponent: The connected component representing the links of a given set of environments. """ nodes_map = { inode_str: EnvironmentNode.from_dict(nodedict) for inode_str, (nodedict, nodedata) in d["nodes"].items() } nodes_data = {inode_str: nodedata for inode_str, (nodedict, nodedata) in d["nodes"].items()} dod = {} for e1, e1dict in d["graph"].items(): dod[e1] = {} for e2, e2dict in e1dict.items(): dod[e1][e2] = { cls._edgedictkey_to_edgekey(ied): cls._retuplify_edgedata(edata) for ied, edata in e2dict.items() } graph = nx.from_dict_of_dicts(dod, create_using=nx.MultiGraph, multigraph_input=True) nx.set_node_attributes(graph, nodes_data) nx.relabel_nodes(graph, nodes_map, copy=False) return cls(graph=graph)
def test_as_dict(self): s = self.get_structure("SiO2") en = EnvironmentNode(central_site=s[2], i_central_site=2, ce_symbol="T:4") en_from_dict = EnvironmentNode.from_dict(en.as_dict()) assert en.everything_equal(en_from_dict) if bson is not None: bson_data = bson.BSON.encode(en.as_dict()) en_from_bson = EnvironmentNode.from_dict(bson_data.decode()) assert en.everything_equal(en_from_bson)
def test_periodicity(self): en1 = EnvironmentNode(central_site="Si", i_central_site=3, ce_symbol="T:4") en2 = EnvironmentNode(central_site="Ag", i_central_site=5, ce_symbol="T:4") en3 = EnvironmentNode(central_site="Ag", i_central_site=8, ce_symbol="O:6") en4 = EnvironmentNode(central_site="Fe", i_central_site=23, ce_symbol="C:8") graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3]) graph.add_edge( en1, en2, start=en1.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en3, start=en1.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(10, (0, 0, 1), (0, 0, 1)), (11, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) assert cc.is_0d assert not cc.is_1d assert not cc.is_2d assert not cc.is_3d assert not cc.is_periodic assert cc.periodicity == "0D" graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3]) graph.add_edge( en1, en2, start=en1.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en3, start=en1.isite, end=en3.isite, delta=(0, 0, 0), ligands=[(10, (0, 0, 1), (0, 0, 1)), (11, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en2, en3, start=en2.isite, end=en3.isite, delta=(0, 0, 1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) assert not cc.is_0d assert cc.is_1d assert not cc.is_2d assert not cc.is_3d assert cc.is_periodic assert cc.periodicity == "1D" graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3]) graph.add_edge( en1, en2, start=en1.isite, end=en2.isite, delta=(0, 0, 1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en3, start=en1.isite, end=en3.isite, delta=(0, 0, 0), ligands=[(10, (0, 0, 1), (0, 0, 1)), (11, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en2, en3, start=en2.isite, end=en3.isite, delta=(0, 0, -1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) assert cc.periodicity == "0D" # Test errors when computing periodicity graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3]) graph.add_edge( en1, en1, start=en1.isite, end=en1.isite, delta=(0, 0, 1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en1, start=en1.isite, end=en1.isite, delta=(0, 0, 1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) with pytest.raises( ValueError, match=r"There should not be self loops with the same " r"\x28or opposite\x29 delta image\x2E", ): cc.compute_periodicity_all_simple_paths_algorithm() graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3]) graph.add_edge( en1, en1, start=en1.isite, end=en1.isite, delta=(3, 2, -1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en1, start=en1.isite, end=en1.isite, delta=(-3, -2, 1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) with pytest.raises( ValueError, match=r"There should not be self loops with the same " r"\x28or opposite\x29 delta image\x2E", ): cc.compute_periodicity_all_simple_paths_algorithm() graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3]) graph.add_edge( en1, en1, start=en1.isite, end=en1.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) with pytest.raises( ValueError, match=r"There should not be self loops with delta image = " r"\x280, 0, 0\x29\x2E", ): cc.compute_periodicity_all_simple_paths_algorithm() # Test a 2d periodicity graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3, en4]) graph.add_edge( en1, en2, start=en1.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en3, start=en1.isite, end=en3.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en4, en2, start=en4.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en3, en4, start=en4.isite, end=en3.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en3, en4, start=en4.isite, end=en3.isite, delta=(0, -1, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en3, en2, start=en2.isite, end=en3.isite, delta=(-1, -1, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) assert not cc.is_0d assert not cc.is_1d assert cc.is_2d assert not cc.is_3d assert cc.is_periodic assert cc.periodicity == "2D" assert np.allclose( cc.periodicity_vectors, [np.array([0, 1, 0]), np.array([1, 1, 0])]) assert type(cc.periodicity_vectors) is list assert cc.periodicity_vectors[0].dtype is np.dtype(int) # Test a 3d periodicity graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3, en4]) graph.add_edge( en1, en2, start=en1.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en3, start=en1.isite, end=en3.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en4, en2, start=en4.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en3, en4, start=en4.isite, end=en3.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en3, en4, start=en4.isite, end=en3.isite, delta=(0, -1, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en3, en2, start=en2.isite, end=en3.isite, delta=(-1, -1, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en3, en3, start=en3.isite, end=en3.isite, delta=(-1, -1, -1), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) assert not cc.is_0d assert not cc.is_1d assert not cc.is_2d assert cc.is_3d assert cc.is_periodic assert cc.periodicity == "3D" assert np.allclose( cc.periodicity_vectors, [np.array([0, 1, 0]), np.array([1, 1, 0]), np.array([1, 1, 1])], ) assert type(cc.periodicity_vectors) is list assert cc.periodicity_vectors[0].dtype is np.dtype(int)
def test_serialization(self): lat = Lattice.hexagonal(a=2.0, c=2.5) en1 = EnvironmentNode( central_site=PeriodicSite("Si", coords=np.array([0.0, 0.0, 0.0]), lattice=lat), i_central_site=3, ce_symbol="T:4", ) en2 = EnvironmentNode( central_site=PeriodicSite("Ag", coords=np.array([0.0, 0.0, 0.5]), lattice=lat), i_central_site=5, ce_symbol="T:4", ) en3 = EnvironmentNode( central_site=PeriodicSite("Ag", coords=np.array([0.0, 0.5, 0.5]), lattice=lat), i_central_site=8, ce_symbol="O:6", ) graph = nx.MultiGraph() graph.add_nodes_from([en1, en2, en3]) graph.add_edge( en1, en2, start=en1.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(2, (0, 0, 1), (0, 0, 1)), (1, (0, 0, 1), (0, 0, 1))], ) graph.add_edge( en1, en3, start=en1.isite, end=en2.isite, delta=(0, 0, 0), ligands=[(10, (0, 0, 1), (0, 0, 1)), (11, (0, 0, 1), (0, 0, 1))], ) cc = ConnectedComponent(graph=graph) ref_sorted_edges = [[en1, en2], [en1, en3]] sorted_edges = sorted([sorted(e) for e in cc.graph.edges()]) assert sorted_edges == ref_sorted_edges ccfromdict = ConnectedComponent.from_dict(cc.as_dict()) ccfromjson = ConnectedComponent.from_dict( json.loads(json.dumps(cc.as_dict()))) loaded_cc_list = [ccfromdict, ccfromjson] if bson is not None: bson_data = bson.BSON.encode(cc.as_dict()) ccfrombson = ConnectedComponent.from_dict(bson_data.decode()) loaded_cc_list.append(ccfrombson) for loaded_cc in loaded_cc_list: assert loaded_cc.graph.number_of_nodes() == 3 assert loaded_cc.graph.number_of_edges() == 2 assert set(list(cc.graph.nodes())) == set( list(loaded_cc.graph.nodes())) assert sorted_edges == sorted( [sorted(e) for e in loaded_cc.graph.edges()]) for ii, e in enumerate(sorted_edges): assert cc.graph[e[0]][e[1]] == loaded_cc.graph[e[0]][e[1]] for node in loaded_cc.graph.nodes(): assert isinstance(node.central_site, PeriodicSite)
def test_str(self): s = self.get_structure("SiO2") en = EnvironmentNode(central_site=s[2], i_central_site=2, ce_symbol="T:4") assert str(en) == "Node #2 Si (T:4)"
def test_str(self): s = self.get_structure('SiO2') en = EnvironmentNode(central_site=s[2], i_central_site=2, ce_symbol='T:4') assert str(en) == 'Node #2 Si (T:4)'