def test_save_reload_then_reconcile_bnn_graph(self): m1, _ = self.make_bnn_model(self.make_net()) FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() from mxfusion.util.serialization import ModelComponentDecoder, load_json_file FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded]) self.assertTrue(set(m1.components) == set(m1_loaded.components)) m2, _ = self.make_bnn_model(self.make_net()) component_map = mf.models.FactorGraph.reconcile_graphs([m2], m1_loaded) self.assertTrue(len(component_map.values()) == len(set(component_map.values())), "Assert there are only 1:1 mappings.") self.assertTrue(len(component_map) == len(m1.components)) sort_m1 = list(set(map(lambda x: x.uuid, m1.components.values()))) sort_m1.sort() sort_m2 = list(set(map(lambda x: x.uuid, m2.components.values()))) sort_m2.sort() sort_component_map_values = list(set(component_map.values())) sort_component_map_values.sort() sort_component_map_keys = list(set(component_map.keys())) sort_component_map_keys.sort() zippy_values = zip(sort_m2, sort_component_map_values) zippy_keys = zip(sort_m1, sort_component_map_keys) self.assertTrue(all([m1_item == component_map_item for m1_item, component_map_item in zippy_values])) self.assertTrue(all([m2_item == component_map_item for m2_item, component_map_item in zippy_keys])) import os os.remove(self.TESTFILE)
def test_save_reload_then_reconcile_gp_module(self): m1 = self.make_gpregr_model() FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() FactorGraph.load_graphs(self.TESTFILE, [m1_loaded]) self.assertTrue(set(m1.components) == set(m1_loaded.components)) self.assertTrue( len(set(m1.Y.factor._module_graph.components)) == len( set(m1_loaded[m1.Y.factor.uuid]._module_graph.components))) self.assertTrue( len(set(m1.Y.factor._extra_graphs[0].components)) == len( set(m1_loaded[m1.Y.factor.uuid]._extra_graphs[0].components))) m2 = self.make_gpregr_model() component_map = mf.models.FactorGraph.reconcile_graphs([m2], m1_loaded) self.assertTrue( len(component_map.values()) == len(set(component_map.values())), "Assert there are only 1:1 mappings.") sort_m1 = list( set( map( lambda x: x.uuid, set(m1.components.values()).union( set(m1.Y.factor._module_graph.components.values())). union(set( m1.Y.factor._extra_graphs[0].components.values()))))) sort_m1.sort() sort_m2 = list( set( map( lambda x: x.uuid, set(m2.components.values()).union( set(m2.Y.factor._module_graph.components.values())). union(set( m2.Y.factor._extra_graphs[0].components.values()))))) sort_m2.sort() sort_component_map_values = list(set(component_map.values())) sort_component_map_values.sort() sort_component_map_keys = list(set(component_map.keys())) sort_component_map_keys.sort() zippy_values = zip(sort_m2, sort_component_map_values) zippy_keys = zip(sort_m1, sort_component_map_keys) self.assertTrue( all([ m1_item == component_map_item for m1_item, component_map_item in zippy_values ])) self.assertTrue( all([ m2_item == component_map_item for m2_item, component_map_item in zippy_keys ])) import os os.remove(self.TESTFILE)
def test_save_reload_bnn_graph(self): m1, _ = self.make_bnn_model(self.make_net()) FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() from mxfusion.util.serialization import ModelComponentDecoder, load_json_file FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded]) m1_loaded_edges = set(m1_loaded.components_graph.edges()) m1_edges = set(m1.components_graph.edges()) self.assertTrue(set(m1.components) == set(m1_loaded.components)) self.assertTrue(set(m1.components_graph.edges()) == set(m1_loaded.components_graph.edges()), m1_edges.symmetric_difference(m1_loaded_edges)) self.assertTrue(len(m1_loaded.components.values()) == len(set(m1_loaded.components.values()))) import os os.remove(self.TESTFILE)
def test_save_reload_bnn_graph(self): m1, _ = self.make_bnn_model(self.make_net()) FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() FactorGraph.load_graphs(self.TESTFILE, [m1_loaded]) m1_loaded_edges = set(m1_loaded.components_graph.edges()) m1_edges = set(m1.components_graph.edges()) self.assertTrue(set(m1.components) == set(m1_loaded.components)) self.assertTrue( set(m1.components_graph.edges()) == set( m1_loaded.components_graph.edges()), m1_edges.symmetric_difference(m1_loaded_edges)) self.assertTrue( len(m1_loaded.components.values()) == len( set(m1_loaded.components.values()))) import os os.remove(self.TESTFILE)
def test_save_reload_then_reconcile_simple_graph(self): m1 = self.make_simple_model() FactorGraph.save(self.TESTFILE, m1.as_json()) m1_loaded = Model() FactorGraph.load_graphs(self.TESTFILE, [m1_loaded]) self.assertTrue(set(m1.components) == set(m1_loaded.components)) m2 = self.make_simple_model() component_map = mf.models.FactorGraph.reconcile_graphs([m2], m1_loaded) self.assertTrue(len(component_map) == len(m1.components)) sort_m1 = list(set(map(lambda x: x.uuid, m1.components.values()))) sort_m1.sort() sort_m2 = list(set(map(lambda x: x.uuid, m2.components.values()))) sort_m2.sort() sort_component_map_values = list(set(component_map.values())) sort_component_map_values.sort() sort_component_map_keys = list(set(component_map.keys())) sort_component_map_keys.sort() zippy_values = zip(sort_m2, sort_component_map_values) zippy_keys = zip(sort_m1, sort_component_map_keys) self.assertTrue( all([ m1_item == component_map_item for m1_item, component_map_item in zippy_values ])) self.assertTrue( all([ m2_item == component_map_item for m2_item, component_map_item in zippy_keys ])) import os os.remove(self.TESTFILE)