Example #1
0
    def test_save_reload_then_reconcile_bnn_graph(self):
        m1, _ = self.make_bnn_model(self.make_net())
        FactorGraph.save(self.TESTFILE, m1.as_json())
        m1_loaded = Model()
        from mxfusion.util.serialization import ModelComponentDecoder, load_json_file
        FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded])
        self.assertTrue(set(m1.components) == set(m1_loaded.components))

        m2, _ = self.make_bnn_model(self.make_net())
        component_map = mf.models.FactorGraph.reconcile_graphs([m2], m1_loaded)
        self.assertTrue(len(component_map.values()) == len(set(component_map.values())), "Assert there are only 1:1 mappings.")
        self.assertTrue(len(component_map) == len(m1.components))
        sort_m1 = list(set(map(lambda x: x.uuid, m1.components.values())))
        sort_m1.sort()

        sort_m2 = list(set(map(lambda x: x.uuid, m2.components.values())))
        sort_m2.sort()

        sort_component_map_values = list(set(component_map.values()))
        sort_component_map_values.sort()

        sort_component_map_keys = list(set(component_map.keys()))
        sort_component_map_keys.sort()

        zippy_values = zip(sort_m2, sort_component_map_values)
        zippy_keys = zip(sort_m1, sort_component_map_keys)
        self.assertTrue(all([m1_item == component_map_item for m1_item, component_map_item in zippy_values]))
        self.assertTrue(all([m2_item == component_map_item for m2_item, component_map_item in zippy_keys]))
        import os
        os.remove(self.TESTFILE)
Example #2
0
    def test_save_reload_then_reconcile_gp_module(self):
        m1 = self.make_gpregr_model()
        FactorGraph.save(self.TESTFILE, m1.as_json())
        m1_loaded = Model()
        FactorGraph.load_graphs(self.TESTFILE, [m1_loaded])
        self.assertTrue(set(m1.components) == set(m1_loaded.components))
        self.assertTrue(
            len(set(m1.Y.factor._module_graph.components)) == len(
                set(m1_loaded[m1.Y.factor.uuid]._module_graph.components)))
        self.assertTrue(
            len(set(m1.Y.factor._extra_graphs[0].components)) == len(
                set(m1_loaded[m1.Y.factor.uuid]._extra_graphs[0].components)))

        m2 = self.make_gpregr_model()
        component_map = mf.models.FactorGraph.reconcile_graphs([m2], m1_loaded)
        self.assertTrue(
            len(component_map.values()) == len(set(component_map.values())),
            "Assert there are only 1:1 mappings.")
        sort_m1 = list(
            set(
                map(
                    lambda x: x.uuid,
                    set(m1.components.values()).union(
                        set(m1.Y.factor._module_graph.components.values())).
                    union(set(
                        m1.Y.factor._extra_graphs[0].components.values())))))
        sort_m1.sort()

        sort_m2 = list(
            set(
                map(
                    lambda x: x.uuid,
                    set(m2.components.values()).union(
                        set(m2.Y.factor._module_graph.components.values())).
                    union(set(
                        m2.Y.factor._extra_graphs[0].components.values())))))
        sort_m2.sort()

        sort_component_map_values = list(set(component_map.values()))
        sort_component_map_values.sort()

        sort_component_map_keys = list(set(component_map.keys()))
        sort_component_map_keys.sort()

        zippy_values = zip(sort_m2, sort_component_map_values)
        zippy_keys = zip(sort_m1, sort_component_map_keys)
        self.assertTrue(
            all([
                m1_item == component_map_item
                for m1_item, component_map_item in zippy_values
            ]))
        self.assertTrue(
            all([
                m2_item == component_map_item
                for m2_item, component_map_item in zippy_keys
            ]))
        import os
        os.remove(self.TESTFILE)
Example #3
0
    def test_save_reload_bnn_graph(self):
        m1, _ = self.make_bnn_model(self.make_net())
        FactorGraph.save(self.TESTFILE, m1.as_json())
        m1_loaded = Model()
        from mxfusion.util.serialization import ModelComponentDecoder, load_json_file
        FactorGraph.load_graphs(load_json_file(self.TESTFILE, ModelComponentDecoder), [m1_loaded])
        m1_loaded_edges = set(m1_loaded.components_graph.edges())
        m1_edges = set(m1.components_graph.edges())

        self.assertTrue(set(m1.components) == set(m1_loaded.components))
        self.assertTrue(set(m1.components_graph.edges()) == set(m1_loaded.components_graph.edges()), m1_edges.symmetric_difference(m1_loaded_edges))
        self.assertTrue(len(m1_loaded.components.values()) == len(set(m1_loaded.components.values())))
        import os
        os.remove(self.TESTFILE)
Example #4
0
    def test_save_reload_bnn_graph(self):
        m1, _ = self.make_bnn_model(self.make_net())
        FactorGraph.save(self.TESTFILE, m1.as_json())
        m1_loaded = Model()
        FactorGraph.load_graphs(self.TESTFILE, [m1_loaded])
        m1_loaded_edges = set(m1_loaded.components_graph.edges())
        m1_edges = set(m1.components_graph.edges())

        self.assertTrue(set(m1.components) == set(m1_loaded.components))
        self.assertTrue(
            set(m1.components_graph.edges()) == set(
                m1_loaded.components_graph.edges()),
            m1_edges.symmetric_difference(m1_loaded_edges))
        self.assertTrue(
            len(m1_loaded.components.values()) == len(
                set(m1_loaded.components.values())))
        import os
        os.remove(self.TESTFILE)
Example #5
0
    def test_save_reload_then_reconcile_simple_graph(self):
        m1 = self.make_simple_model()
        FactorGraph.save(self.TESTFILE, m1.as_json())
        m1_loaded = Model()
        FactorGraph.load_graphs(self.TESTFILE, [m1_loaded])
        self.assertTrue(set(m1.components) == set(m1_loaded.components))

        m2 = self.make_simple_model()
        component_map = mf.models.FactorGraph.reconcile_graphs([m2], m1_loaded)
        self.assertTrue(len(component_map) == len(m1.components))
        sort_m1 = list(set(map(lambda x: x.uuid, m1.components.values())))
        sort_m1.sort()

        sort_m2 = list(set(map(lambda x: x.uuid, m2.components.values())))
        sort_m2.sort()

        sort_component_map_values = list(set(component_map.values()))
        sort_component_map_values.sort()

        sort_component_map_keys = list(set(component_map.keys()))
        sort_component_map_keys.sort()

        zippy_values = zip(sort_m2, sort_component_map_values)
        zippy_keys = zip(sort_m1, sort_component_map_keys)
        self.assertTrue(
            all([
                m1_item == component_map_item
                for m1_item, component_map_item in zippy_values
            ]))
        self.assertTrue(
            all([
                m2_item == component_map_item
                for m2_item, component_map_item in zippy_keys
            ]))

        import os
        os.remove(self.TESTFILE)