Пример #1
0
class TestSpineDBFetcher(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        if not QApplication.instance():
            QApplication()

    def setUp(self):
        app_settings = MagicMock()
        self._logger = MagicMock(
        )  # Collects error messages therefore handy for debugging.
        self._db_mngr = SpineDBManager(app_settings, None)
        self._db_map = self._db_mngr.get_db_map("sqlite://",
                                                self._logger,
                                                codename="test_db",
                                                create=True)
        self._listener = MagicMock()
        self._fetcher = self._db_mngr.get_fetcher()

    def tearDown(self):
        self._db_mngr.close_all_sessions()
        self._db_mngr.clean_up()

    def _fetch(self):
        waiter = SignalWaiter()
        self._fetcher.finished.connect(waiter.trigger)
        self._fetcher.fetch(self._listener, [self._db_map])
        waiter.wait()

    def test_fetch_empty_database(self):
        self._fetch()
        self.assertTrue(self._listener.silenced)
        self._listener.receive_alternatives_added.assert_called_once_with({
            self._db_map: [{
                "id": 1,
                "name": "Base",
                "description": "Base alternative",
                "commit_id": 1
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "alternative", 1),
            {
                'commit_id': 1,
                'description': 'Base alternative',
                'id': 1,
                'name': 'Base'
            },
        )
        self._listener.receive_scenarios_added.assert_not_called()
        self._listener.receive_scenario_alternatives_added.assert_not_called()
        self._listener.receive_object_classes_added.assert_not_called()
        self._listener.receive_objects_added.assert_not_called()
        self._listener.receive_relationship_classes_added.assert_not_called()
        self._listener.receive_relationships_added.assert_not_called()
        self._listener.receive_entity_groups_added.assert_not_called()
        self._listener.receive_parameter_definitions_added.assert_not_called()
        self._listener.receive_parameter_definition_tags_added.assert_not_called(
        )
        self._listener.receive_parameter_values_added.assert_not_called()
        self._listener.receive_parameter_value_lists_added.assert_not_called()
        self._listener.receive_parameter_tags_added.assert_not_called()
        self._listener.receive_features_added.assert_not_called()
        self._listener.receive_tools_added.assert_not_called()
        self._listener.receive_tool_features_added.assert_not_called()
        self._listener.receive_tool_feature_methods_added.assert_not_called()

    def _import_data(self, **data):
        waiter = SignalWaiter()
        self._db_mngr.data_imported.connect(waiter.trigger)
        self._db_mngr.import_data({self._db_map: data})
        waiter.wait()
        self._db_mngr.data_imported.disconnect(waiter.trigger)
        self._db_mngr._get_commit_msg = lambda *args, **kwargs: "Add test data."
        self._db_mngr.session_committed.connect(waiter.trigger)
        self._db_mngr.commit_session(self._db_map)
        waiter.wait()
        self._db_mngr.session_committed.disconnect(waiter.trigger)

    def test_fetch_alternatives(self):
        self._import_data(alternatives=("alt", ))
        self._fetch()
        self._listener.receive_alternatives_added.assert_called_once_with({
            self._db_map: [
                {
                    'id': 1,
                    'name': 'Base',
                    'description': 'Base alternative',
                    'commit_id': 1
                },
                {
                    'id': 2,
                    'name': 'alt',
                    'description': None,
                    'commit_id': 2
                },
            ]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "alternative", 2),
            {
                'commit_id': 2,
                'description': None,
                'id': 2,
                'name': 'alt'
            },
        )

    def test_fetch_scenarios(self):
        self._import_data(scenarios=("scenario", ))
        self._fetch()
        self._listener.receive_scenarios_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'name': 'scenario',
                'description': None,
                'active': False,
                'alternative_id_list': None,
                'alternative_name_list': None,
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "scenario", 1),
            {
                'active': False,
                'alternative_id_list': None,
                'alternative_name_list': None,
                'description': None,
                'id': 1,
                'name': 'scenario',
            },
        )

    def test_fetch_scenario_alternatives(self):
        self._import_data(alternatives=("alt", ),
                          scenarios=("scenario", ),
                          scenario_alternatives=(("scenario", "alt"), ))
        self._fetch()
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "scenario_alternative", 1),
            {
                'alternative_id': 2,
                'commit_id': 2,
                'id': 1,
                'rank': 1,
                'scenario_id': 1
            },
        )

    def test_fetch_object_classes(self):
        self._import_data(object_classes=("oc", ))
        self._fetch()
        self._listener.receive_object_classes_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'name': 'oc',
                'description': None,
                'display_order': 99,
                'display_icon': None,
                'hidden': 0,
                'commit_id': 2,
            }]
        })
        self.assertIsInstance(
            self._db_mngr.entity_class_icon(self._db_map, "object_class", 1),
            QIcon)
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "object_class", 1),
            {
                'commit_id': 2,
                'description': None,
                'display_icon': None,
                'display_order': 99,
                'hidden': 0,
                'id': 1,
                'name': 'oc',
            },
        )

    def test_fetch_objects(self):
        self._import_data(object_classes=("oc", ), objects=(("oc", "obj"), ))
        self._fetch()
        self._listener.receive_objects_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'class_id': 1,
                'class_name': 'oc',
                'name': 'obj',
                'description': None
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "object", 1),
            {
                'class_id': 1,
                'class_name': 'oc',
                'description': None,
                'id': 1,
                'name': 'obj'
            },
        )

    def test_fetch_relationship_classes(self):
        self._import_data(object_classes=("oc", ),
                          relationship_classes=(("rc", ("oc", )), ))
        self._fetch()
        self._listener.receive_relationship_classes_added.assert_called_once_with(
            {
                self._db_map: [{
                    'id': 2,
                    'name': 'rc',
                    'description': None,
                    'object_class_id_list': '1',
                    'object_class_name_list': 'oc',
                }]
            })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "relationship_class", 2),
            {
                'description': None,
                'id': 2,
                'name': 'rc',
                'object_class_id_list': '1',
                'object_class_name_list': 'oc'
            },
        )

    def test_fetch_relationships(self):
        self._import_data(
            object_classes=("oc", ),
            objects=(("oc", "obj"), ),
            relationship_classes=(("rc", ("oc", )), ),
            relationships=(("rc", ("obj", )), ),
        )
        self._fetch()
        self._listener.receive_relationships_added.assert_called_once_with({
            self._db_map: [{
                'id': 2,
                'name': 'rc_obj',
                'class_id': 2,
                'class_name': 'rc',
                'object_id_list': '1',
                'object_name_list': 'obj',
                'object_class_id_list': '1',
                'object_class_name_list': 'oc',
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "relationship", 2),
            {
                'class_id': 2,
                'class_name': 'rc',
                'id': 2,
                'name': 'rc_obj',
                'object_class_id_list': '1',
                'object_class_name_list': 'oc',
                'object_id_list': '1',
                'object_name_list': 'obj',
            },
        )

    def test_fetch_object_groups(self):
        self._import_data(object_classes=("oc", ),
                          objects=(("oc", "obj"), ("oc", "group")),
                          object_groups=(("oc", "group", "obj"), ))
        self._fetch()
        self._listener.receive_entity_groups_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'class_id': 1,
                'group_id': 2,
                'member_id': 1,
                'class_name': 'oc',
                'group_name': 'group',
                'member_name': 'obj',
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "entity_group", 1),
            {
                'id': 1,
                'class_id': 1,
                'group_id': 2,
                'member_id': 1,
                'class_name': 'oc',
                'group_name': 'group',
                'member_name': 'obj',
            },
        )

    def test_fetch_parameter_definitions(self):
        self._import_data(object_classes=("oc", ),
                          object_parameters=(("oc", "param"), ))
        self._fetch()
        self._listener.receive_parameter_definitions_added.assert_called_once_with(
            {
                self._db_map: [{
                    'id': 1,
                    'entity_class_id': 1,
                    'object_class_id': 1,
                    'object_class_name': 'oc',
                    'parameter_name': 'param',
                    'value_list_id': None,
                    'value_list_name': None,
                    'parameter_tag_id_list': None,
                    'parameter_tag_list': None,
                    'default_value': None,
                    'description': None,
                }]
            })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "parameter_definition", 1),
            {
                'default_value': None,
                'description': None,
                'entity_class_id': 1,
                'id': 1,
                'object_class_id': 1,
                'object_class_name': 'oc',
                'parameter_name': 'param',
                'parameter_tag_id_list': None,
                'parameter_tag_list': None,
                'value_list_id': None,
                'value_list_name': None,
            },
        )

    def test_fetch_parameter_values(self):
        self._import_data(
            object_classes=("oc", ),
            objects=(("oc", "obj"), ),
            object_parameters=(("oc", "param"), ),
            object_parameter_values=(("oc", "obj", "param", 2.3), ),
        )
        self._fetch()
        self._listener.receive_parameter_values_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'entity_class_id': 1,
                'object_class_id': 1,
                'object_class_name': 'oc',
                'entity_id': 1,
                'object_id': 1,
                'object_name': 'obj',
                'parameter_id': 1,
                'parameter_name': 'param',
                'alternative_id': 1,
                'alternative_name': 'Base',
                'value': '2.3',
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "parameter_value", 1),
            {
                'alternative_id': 1,
                'alternative_name': 'Base',
                'entity_class_id': 1,
                'entity_id': 1,
                'id': 1,
                'object_class_id': 1,
                'object_class_name': 'oc',
                'object_id': 1,
                'object_name': 'obj',
                'parameter_id': 1,
                'parameter_name': 'param',
                'value': '2.3',
            },
        )

    def test_fetch_parameter_value_lists(self):
        self._import_data(parameter_value_lists=(("value_list", (2.3, )), ))
        self._fetch()
        self._listener.receive_parameter_value_lists_added.assert_called_once_with(
            {
                self._db_map: [{
                    'id': 1,
                    'name': 'value_list',
                    'value_index_list': '0',
                    'value_list': '[2.3]'
                }]
            })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "parameter_value_list", 1),
            {
                'id': 1,
                'name': 'value_list',
                'value_index_list': '0',
                'value_list': '[2.3]'
            },
        )

    def test_fetch_features(self):
        self._import_data(
            object_classes=("oc", ),
            parameter_value_lists=(("value_list", (2.3, )), ),
            object_parameters=(("oc", "param", 2.3, "value_list"), ),
            features=(("oc", "param"), ),
        )
        self._fetch()
        self._listener.receive_features_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'entity_class_id': 1,
                'entity_class_name': 'oc',
                'parameter_definition_id': 1,
                'parameter_definition_name': 'param',
                'parameter_value_list_id': 1,
                'parameter_value_list_name': 'value_list',
                'description': None,
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "feature", 1),
            {
                'description': None,
                'entity_class_id': 1,
                'entity_class_name': 'oc',
                'id': 1,
                'parameter_definition_id': 1,
                'parameter_definition_name': 'param',
                'parameter_value_list_id': 1,
                'parameter_value_list_name': 'value_list',
            },
        )

    def test_fetch_tools(self):
        self._import_data(tools=("tool", ))
        self._fetch()
        self._listener.receive_tools_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'name': 'tool',
                'description': None,
                'commit_id': 2
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "tool", 1),
            {
                'commit_id': 2,
                'description': None,
                'id': 1,
                'name': 'tool'
            },
        )

    def test_fetch_tool_features(self):
        self._import_data(
            object_classes=("oc", ),
            parameter_value_lists=(("value_list", (2.3, )), ),
            object_parameters=(("oc", "param", 2.3, "value_list"), ),
            features=(("oc", "param"), ),
            tools=("tool", ),
            tool_features=(("tool", "oc", "param"), ),
        )
        self._fetch()
        self._listener.receive_tool_features_added.assert_called_once_with({
            self._db_map: [{
                'id': 1,
                'tool_id': 1,
                'feature_id': 1,
                'parameter_value_list_id': 1,
                'required': False,
                'commit_id': 2,
            }]
        })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "tool_feature", 1),
            {
                'commit_id': 2,
                'feature_id': 1,
                'id': 1,
                'parameter_value_list_id': 1,
                'required': False,
                'tool_id': 1
            },
        )

    def test_fetch_tool_feature_methods(self):
        self._import_data(
            object_classes=("oc", ),
            parameter_value_lists=(("value_list", "m"), ),
            object_parameters=(("oc", "param", "m", "value_list"), ),
            features=(("oc", "param"), ),
            tools=("tool", ),
            tool_features=(("tool", "oc", "param"), ),
            tool_feature_methods=(("tool", "oc", "param", "m"), ),
        )
        self._fetch()
        self._listener.receive_tool_feature_methods_added.assert_called_once_with(
            {
                self._db_map: [{
                    'id': 1,
                    'tool_feature_id': 1,
                    'parameter_value_list_id': 1,
                    'method_index': 0,
                    'commit_id': 2
                }]
            })
        self.assertEqual(
            self._db_mngr.get_item(self._db_map, "tool_feature_method", 1),
            {
                'commit_id': 2,
                'id': 1,
                'method_index': 0,
                'parameter_value_list_id': 1,
                'tool_feature_id': 1
            },
        )
class TestEmptyParameterModel(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        if not QApplication.instance():
            QApplication()

    def setUp(self):
        """Overridden method. Runs before each test."""
        app_settings = mock.MagicMock()
        logger = mock.MagicMock()
        with mock.patch("spinetoolbox.spine_db_manager.SpineDBManager.thread",
                        new_callable=mock.PropertyMock) as mock_thread:
            mock_thread.return_value = QApplication.instance().thread()
            self._db_mngr = SpineDBManager(app_settings, None)
            fetcher = self._db_mngr.get_fetcher()
        self._db_map = self._db_mngr.get_db_map("sqlite://",
                                                logger,
                                                codename="mock_db",
                                                create=True)
        import_object_classes(self._db_map, ("dog", "fish"))
        import_object_parameters(self._db_map, (("dog", "breed"), ))
        import_objects(self._db_map, (("dog", "pluto"), ("fish", "nemo")))
        import_relationship_classes(self._db_map,
                                    (("dog__fish", ("dog", "fish")), ))
        import_relationship_parameters(self._db_map,
                                       (("dog__fish", "relative_speed"), ))
        import_relationships(self._db_map, (("dog_fish", ("pluto", "nemo")), ))
        self._db_map.commit_session("Add test data")
        fetcher.fetch([self._db_map])
        self.object_table_header = [
            "object_class_name",
            "object_name",
            "parameter_name",
            "alternative_id",
            "value",
            "database",
        ]
        self.relationship_table_header = [
            "relationship_class_name",
            "object_name_list",
            "parameter_name",
            "alternative_id",
            "value",
            "database",
        ]

    def tearDown(self):
        self._db_mngr.close_all_sessions()
        self._db_mngr.clean_up()
        self._db_mngr.deleteLater()

    def test_add_object_parameter_values_to_db(self):
        """Test that object parameter values are added to the db when editing the table."""
        header = self.object_table_header
        model = EmptyObjectParameterValueModel(None, header, self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(
                _empty_indexes(model),
                ["dog", "pluto", "breed", 1, "bloodhound", "mock_db"]))
        values = next(self._db_mngr.get_object_parameter_values(self._db_map),
                      [])
        self.assertEqual(len(values), 1)
        self.assertEqual(values[0]["object_class_name"], "dog")
        self.assertEqual(values[0]["object_name"], "pluto")
        self.assertEqual(values[0]["parameter_name"], "breed")
        self.assertEqual(values[0]["value"], "bloodhound")

    def test_do_not_add_invalid_object_parameter_values(self):
        """Test that object parameter values aren't added to the db if data is incomplete."""
        header = self.object_table_header
        model = EmptyObjectParameterValueModel(None, header, self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(
                _empty_indexes(model),
                ["fish", "nemo", "water", "salty", "mock_db"]))
        values = next(self._db_mngr.get_object_parameter_values(self._db_map),
                      [])
        self.assertEqual(values, [])

    def test_infer_class_from_object_and_parameter(self):
        """Test that object classes are inferred from the object and parameter if possible."""
        header = self.object_table_header
        model = EmptyObjectParameterValueModel(None, header, self._db_mngr)
        model.fetchMore()
        indexes = _empty_indexes(model)
        self.assertTrue(
            model.batch_set_data(
                indexes,
                ["cat", "pluto", "breed", 1, "bloodhound", "mock_db"]))
        self.assertEqual(indexes[0].data(), "dog")
        values = next(self._db_mngr.get_object_parameter_values(self._db_map),
                      [])
        self.assertEqual(len(values), 1)
        self.assertEqual(values[0]["object_class_name"], "dog")
        self.assertEqual(values[0]["object_name"], "pluto")
        self.assertEqual(values[0]["parameter_name"], "breed")
        self.assertEqual(values[0]["value"], "bloodhound")

    def test_add_relationship_parameter_values_to_db(self):
        """Test that relationship parameter values are added to the db when editing the table."""
        header = self.relationship_table_header
        model = EmptyRelationshipParameterValueModel(None, header,
                                                     self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(_empty_indexes(model), [
                "dog__fish", "pluto,nemo", "relative_speed", 1, -1, "mock_db"
            ]))
        values = next(
            self._db_mngr.get_relationship_parameter_values(self._db_map), [])
        self.assertEqual(len(values), 1)
        self.assertEqual(values[0]["relationship_class_name"], "dog__fish")
        self.assertEqual(values[0]["object_name_list"], "pluto,nemo")
        self.assertEqual(values[0]["parameter_name"], "relative_speed")
        self.assertEqual(values[0]["value"], "-1")

    def test_do_not_add_invalid_relationship_parameter_values(self):
        """Test that relationship parameter values aren't added to the db if data is incomplete."""
        header = self.relationship_table_header
        model = EmptyRelationshipParameterValueModel(None, header,
                                                     self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(
                _empty_indexes(model),
                ["dog__fish", "pluto,nemo", "combined_mojo", 100, "mock_db"]))
        values = next(
            self._db_mngr.get_relationship_parameter_values(self._db_map), [])
        self.assertEqual(values, [])

    def test_add_object_parameter_definitions_to_db(self):
        """Test that object parameter definitions are added to the db when editing the table."""
        header = [
            "object_class_name", "parameter_name", "value_list_name",
            "parameter_tag_list", "database"
        ]
        model = EmptyObjectParameterDefinitionModel(None, header,
                                                    self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(_empty_indexes(model),
                                 ["dog", "color", None, None, "mock_db"]))
        definitions = next(
            self._db_mngr.get_object_parameter_definitions(self._db_map), [])
        self.assertEqual(len(definitions), 2)
        names = {d["parameter_name"] for d in definitions}
        self.assertEqual(names, {"breed", "color"})

    def test_do_not_add_invalid_object_parameter_definitions(self):
        """Test that object parameter definitions aren't added to the db if data is incomplete."""
        header = self.object_table_header
        model = EmptyObjectParameterDefinitionModel(None, header,
                                                    self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(_empty_indexes(model),
                                 ["cat", "color", None, None, "mock_db"]))
        definitions = next(
            self._db_mngr.get_object_parameter_definitions(self._db_map), [])
        self.assertEqual(len(definitions), 1)
        self.assertEqual(definitions[0]["parameter_name"], "breed")

    def test_add_relationship_parameter_definitions_to_db(self):
        """Test that relationship parameter definitions are added to the db when editing the table."""
        header = [
            "relationship_class_name", "parameter_name", "value_list_name",
            "parameter_tag_list", "database"
        ]
        model = EmptyRelationshipParameterDefinitionModel(
            None, header, self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(
                _empty_indexes(model),
                ["dog__fish", "combined_mojo", None, None, "mock_db"]))
        definitions = next(
            self._db_mngr.get_relationship_parameter_definitions(self._db_map),
            [])
        self.assertEqual(len(definitions), 2)
        names = {d["parameter_name"] for d in definitions}
        self.assertEqual(names, {"relative_speed", "combined_mojo"})

    def test_do_not_add_invalid_relationship_parameter_definitions(self):
        """Test that relationship parameter definitions aren't added to the db if data is incomplete."""
        header = self.relationship_table_header
        model = EmptyRelationshipParameterDefinitionModel(
            None, header, self._db_mngr)
        model.fetchMore()
        self.assertTrue(
            model.batch_set_data(
                _empty_indexes(model),
                ["fish__dog", "each_others_opinion", None, None, "mock_db"]))
        definitions = next(
            self._db_mngr.get_relationship_parameter_definitions(self._db_map),
            [])
        self.assertEqual(len(definitions), 1)
        self.assertEqual(definitions[0]["parameter_name"], "relative_speed")