def setUp(self):
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

        self.test_firmware = create_test_firmware()

        self.test_yara_match = {
            'rule': 'OpenSSH',
            'tags': [],
            'namespace': 'default',
            'strings': [(0, '$a', b'OpenSSH')],
            'meta': {
                'description': 'SSH library',
                'website': 'http://www.openssh.com',
                'open_source': True,
                'software_name': 'OpenSSH'
            },
            'matches': True
        }

        self.test_fo = create_test_file_object()
Beispiel #2
0
class TestStorageDbInterfaceBackend(unittest.TestCase):
    def setUp(self):
        self._config = get_config_for_testing(TMP_DIR)
        self.mongo_server = MongoMgr(config=self._config)
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

        self.test_firmware = create_test_firmware()

        self.test_yara_match = {
            'rule': 'OpenSSH',
            'tags': [],
            'namespace': 'default',
            'strings': [(0, '$a', b'OpenSSH')],
            'meta': {
                'description': 'SSH library',
                'website': 'http://www.openssh.com',
                'open_source': True,
                'software_name': 'OpenSSH'
            },
            'matches': True
        }

        self.test_fo = create_test_file_object()

    def tearDown(self):
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface_backend.shutdown()
        self.db_interface.shutdown()
        self.mongo_server.shutdown()
        TMP_DIR.cleanup()
        gc.collect()

    def _get_all_firmware_uids(self):
        uid_list = []
        tmp = self.db_interface.firmwares.find()
        for item in tmp:
            uid_list.append(item['_id'])
        return uid_list

    def test_add_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertGreater(len(self._get_all_firmware_uids()), 0,
                           'No entry added to DB')
        recoverd_firmware_entry = self.db_interface_backend.firmwares.find_one(
        )
        self.assertAlmostEqual(recoverd_firmware_entry['submission_date'],
                               time(),
                               msg='submission time not set correctly',
                               delta=5.0)

    def test_add_and_get_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        result_backend = self.db_interface_backend.get_firmware(
            self.test_firmware.get_uid())
        self.assertIsNotNone(result_backend.binary,
                             'binary not set in backend result')
        result_common = self.db_interface.get_firmware(
            self.test_firmware.get_uid())
        self.assertIsNone(result_common.binary, 'binary set in common result')
        self.assertEqual(result_common.size, 787,
                         'file size not correct in common')
        self.assertIsInstance(result_common.tags, dict,
                              'tag field type not correct')

    def test_add_and_get_file_object(self):
        self.db_interface_backend.add_file_object(self.test_fo)
        result_backend = self.db_interface_backend.get_file_object(
            self.test_fo.get_uid())
        self.assertIsNotNone(result_backend.binary,
                             'binary not set in backend result')
        result_common = self.db_interface.get_file_object(
            self.test_fo.get_uid())
        self.assertIsNone(result_common.binary, 'binary set in common result')
        self.assertEqual(result_common.size, 62,
                         'file size not correct in common')

    def test_update_firmware(self):
        first_dict = {
            'stub_plugin': {
                'result': 0
            },
            'other_plugin': {
                'field': 'day'
            }
        }
        second_dict = {'stub_plugin': {'result': 1}}

        self.test_firmware.processed_analysis = first_dict
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertEqual(
            0,
            self.db_interface.get_object(
                self.test_firmware.get_uid()).processed_analysis['stub_plugin']
            ['result'])
        self.test_firmware.processed_analysis = second_dict
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertEqual(
            1,
            self.db_interface.get_object(
                self.test_firmware.get_uid()).processed_analysis['stub_plugin']
            ['result'])
        self.assertIn(
            'other_plugin',
            self.db_interface.get_object(
                self.test_firmware.get_uid()).processed_analysis.keys())

    def test_update_file_object(self):
        first_dict = {'other_plugin': {'result': 0}}
        second_dict = {'stub_plugin': {'result': 1}}

        self.test_fo.processed_analysis = first_dict
        self.test_fo.files_included = {'file a', 'file b'}
        self.db_interface_backend.add_file_object(self.test_fo)
        self.test_fo.processed_analysis = second_dict
        self.test_fo.files_included = {'file b', 'file c'}
        self.db_interface_backend.add_file_object(self.test_fo)
        received_object = self.db_interface.get_object(self.test_fo.get_uid())
        self.assertEqual(
            0, received_object.processed_analysis['other_plugin']['result'])
        self.assertEqual(
            1, received_object.processed_analysis['stub_plugin']['result'])
        self.assertEqual(3, len(received_object.files_included))

    def test_add_and_get_object_including_comment(self):
        comment, author, date, uid = 'this is a test comment!', 'author', '1473431685', self.test_fo.get_uid(
        )
        self.test_fo.comments.append({
            'time': str(date),
            'author': author,
            'comment': comment
        })
        self.db_interface_backend.add_file_object(self.test_fo)

        retrieved_comment = self.db_interface.get_object(uid).comments[0]
        self.assertEqual(author, retrieved_comment['author'])
        self.assertEqual(comment, retrieved_comment['comment'])
        self.assertEqual(date, retrieved_comment['time'])
Beispiel #3
0
 def setUp(self):
     self._config = get_config_for_testing(TMP_DIR)
     self.mongo_server = MongoMgr(config=self._config)
     self.db_interface = MongoInterfaceCommon(config=self._config)
     self.db_interface_backend = BackEndDbInterface(config=self._config)
Beispiel #4
0
class TestSummary(unittest.TestCase):
    def setUp(self):
        self._config = get_config_for_testing(TMP_DIR)
        self.mongo_server = MongoMgr(config=self._config)
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

    def tearDown(self):
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface.shutdown()
        self.db_interface_backend.shutdown()
        self.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def create_and_add_test_fimrware_and_file_object(self):
        self.test_fw = create_test_firmware()
        self.test_fo = create_test_file_object()
        self.test_fw.add_included_file(self.test_fo)
        self.db_interface_backend.add_object(self.test_fw)
        self.db_interface_backend.add_object(self.test_fo)

    def test_get_set_of_all_included_files(self):
        self.create_and_add_test_fimrware_and_file_object()
        result_set_fo = self.db_interface.get_set_of_all_included_files(
            self.test_fo)
        self.assertIsInstance(result_set_fo, set, 'result is not a set')
        self.assertEqual(len(result_set_fo), 1, 'number of files not correct')
        self.assertIn(self.test_fo.uid, result_set_fo,
                      'object not in its own result set')
        result_set_fw = self.db_interface.get_set_of_all_included_files(
            self.test_fw)
        self.assertEqual(len(result_set_fw), 2, 'number of files not correct')
        self.assertIn(self.test_fo.uid, result_set_fw,
                      'test file not in result set firmware')
        self.assertIn(self.test_fw.uid, result_set_fw,
                      'fw not in result set firmware')

    def test_get_uids_of_all_included_files(self):
        def add_test_file_to_db(uid, parent_uids: Set[str]):
            test_fo = create_test_file_object()
            test_fo.parent_firmware_uids = parent_uids
            test_fo.uid = uid
            self.db_interface_backend.add_object(test_fo)

        add_test_file_to_db('uid1', {'foo'})
        add_test_file_to_db('uid2', {'foo', 'bar'})
        add_test_file_to_db('uid3', {'bar'})
        result = self.db_interface.get_uids_of_all_included_files('foo')
        assert result == {'uid1', 'uid2'}

        assert self.db_interface.get_uids_of_all_included_files(
            'uid not in db') == set()

    def test_get_summary(self):
        self.create_and_add_test_fimrware_and_file_object()
        result_sum = self.db_interface.get_summary(self.test_fw, 'dummy')
        self.assertIsInstance(result_sum, dict, 'summary is not a dict')
        self.assertIn('sum a', result_sum, 'summary entry of parent missing')
        self.assertIn(self.test_fw.uid, result_sum['sum a'],
                      'origin (parent) missing in parent summary entry')
        self.assertIn(self.test_fo.uid, result_sum['sum a'],
                      'origin (child) missing in parent summary entry')
        self.assertNotIn(self.test_fo.uid, result_sum['fw exclusive sum a'],
                         'child as origin but should not be')
        self.assertIn('file exclusive sum b', result_sum,
                      'file exclusive summary missing')
        self.assertIn(self.test_fo.uid, result_sum['file exclusive sum b'],
                      'origin of file exclusive missing')
        self.assertNotIn(self.test_fw.uid, result_sum['file exclusive sum b'],
                         'parent as origin but should not be')

    def test_collect_summary(self):
        self.create_and_add_test_fimrware_and_file_object()
        fo_list = [self.test_fo.uid]
        result_sum = self.db_interface._collect_summary(fo_list, 'dummy')
        assert all(
            item in result_sum
            for item in self.test_fo.processed_analysis['dummy']['summary'])
        assert all(value == [self.test_fo.uid]
                   for value in result_sum.values())

    def test_get_summary_of_one_error_handling(self):
        result_sum = self.db_interface._get_summary_of_one(None, 'foo')
        self.assertEqual(result_sum, {},
                         'None object should result in empty dict')
        self.create_and_add_test_fimrware_and_file_object()
        result_sum = self.db_interface._get_summary_of_one(
            self.test_fw, 'none_existing_analysis')
        self.assertEqual(result_sum, {},
                         'analysis not existend should lead to empty dict')

    def test_update_summary(self):
        orig = {'a': ['a']}
        update = {'a': ['aa'], 'b': ['aa']}
        result = self.db_interface._update_summary(orig, update)
        self.assertIn('a', result)
        self.assertIn('b', result)
        self.assertIn('a', result['a'])
        self.assertIn('aa', result['a'])
        self.assertIn('aa', result['b'])
Beispiel #5
0
class TestMongoInterface(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._config = get_config_for_testing(TMP_DIR)
        cls._config.set('data_storage', 'report_threshold', '32')
        cls._config.set('data_storage', 'sanitize_database', 'tmp_sanitize')
        cls.mongo_server = MongoMgr(config=cls._config)

    def setUp(self):
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

        self.test_firmware = create_test_firmware()

        self.test_yara_match = {
            'rule': 'OpenSSH',
            'tags': [],
            'namespace': 'default',
            'strings': [(0, '$a', b'OpenSSH')],
            'meta': {
                'description': 'SSH library',
                'website': 'http://www.openssh.com',
                'open_source': True,
                'software_name': 'OpenSSH'
            },
            'matches': True
        }

        self.test_fo = create_test_file_object()

    def tearDown(self):
        self.db_interface_backend.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface_backend.shutdown()
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'sanitize_database'))
        self.db_interface.shutdown()
        gc.collect()

    @classmethod
    def tearDownClass(cls):
        cls.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def _get_all_firmware_uids(self):
        uid_list = []
        tmp = self.db_interface.firmwares.find()
        for item in tmp:
            uid_list.append(item['_id'])
        return uid_list

    def test_existence_quick_check(self):
        self.assertFalse(
            self.db_interface.existence_quick_check('none_existing'),
            'none existing firmware found')
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertTrue(
            self.db_interface.existence_quick_check(self.test_firmware.uid),
            'existing firmware not found')
        self.db_interface_backend.add_file_object(self.test_fo)
        self.assertTrue(
            self.db_interface.existence_quick_check(self.test_fo.uid),
            'existing file not found')

    def test_get_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        fobject = self.db_interface.get_firmware(self.test_firmware.uid)
        self.assertEqual(fobject.vendor, 'test_vendor')
        self.assertEqual(fobject.device_name, 'test_router')
        self.assertEqual(fobject.part, '')

    def test_get_object(self):
        fo = self.db_interface.get_object(self.test_firmware.uid)
        self.assertIsNone(
            fo, 'found something but there is nothing in the database')
        self.db_interface_backend.add_firmware(self.test_firmware)
        fo = self.db_interface.get_object(self.test_firmware.uid)
        self.assertIsInstance(fo, Firmware, 'firmware has wrong type')
        self.assertEqual(fo.device_name, 'test_router',
                         'Device name in Firmware not correct')
        test_file = FileObject(file_path=path.join(get_test_data_dir(),
                                                   'get_files_test/testfile2'))
        self.db_interface_backend.add_file_object(test_file)
        fo = self.db_interface.get_object(test_file.uid)
        self.assertIsInstance(fo, FileObject, 'file object has wrong type')

    def test_get_complete_object_including_all_summaries(self):
        self.db_interface_backend.report_threshold = 1024
        test_file = create_test_file_object()
        self.test_firmware.add_included_file(test_file)
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.db_interface_backend.add_file_object(test_file)
        tmp = self.db_interface.get_complete_object_including_all_summaries(
            self.test_firmware.uid)
        self.assertIsInstance(tmp, Firmware, 'wrong type')
        self.assertIn('summary', tmp.processed_analysis['dummy'].keys(),
                      'summary not found in processed analysis')
        self.assertIn('sum a', tmp.processed_analysis['dummy']['summary'],
                      'summary of original file not included')
        self.assertIn('file exclusive sum b',
                      tmp.processed_analysis['dummy']['summary'],
                      'summary of included file not found')

    def test_sanitize_analysis(self):
        short_dict = {'stub_plugin': {'result': 0}}
        long_dict = {
            'stub_plugin': {
                'result': 10000000000,
                'misc': 'Bananarama',
                'summary': []
            }
        }

        self.test_firmware.processed_analysis = short_dict
        sanitized_dict = self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        self.assertIn('file_system_flag', sanitized_dict['stub_plugin'].keys())
        self.assertFalse(sanitized_dict['stub_plugin']['file_system_flag'])
        self.assertEqual(self.db_interface.sanitize_fs.list(), [],
                         'file stored in db but should not')

        self.test_firmware.processed_analysis = long_dict
        sanitized_dict = self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        self.assertIn('stub_plugin_result_{}'.format(self.test_firmware.uid),
                      self.db_interface.sanitize_fs.list(),
                      'sanitized file not stored')
        self.assertNotIn('summary_result_{}'.format(self.test_firmware.uid),
                         self.db_interface.sanitize_fs.list(),
                         'summary is erroneously stored')
        self.assertIn('file_system_flag', sanitized_dict['stub_plugin'].keys())
        self.assertTrue(sanitized_dict['stub_plugin']['file_system_flag'])
        self.assertEqual(type(sanitized_dict['stub_plugin']['summary']), list)

    def test_sanitize_db_duplicates(self):
        long_dict = {
            'stub_plugin': {
                'result': 10000000000,
                'misc': 'Bananarama',
                'summary': []
            }
        }
        gridfs_file_name = 'stub_plugin_result_{}'.format(
            self.test_firmware.uid)

        self.test_firmware.processed_analysis = long_dict
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 0
        self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 1
        self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 1, 'duplicate entry was created'
        md5 = self.db_interface.sanitize_fs.find_one({
            'filename':
            gridfs_file_name
        }).md5

        long_dict['stub_plugin']['result'] += 1  # new analysis result
        self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 1, 'duplicate entry was created'
        assert self.db_interface.sanitize_fs.find_one({
            'filename':
            gridfs_file_name
        }).md5 != md5, 'hash of new file did not change'

    def test_retrieve_analysis(self):
        self.db_interface.sanitize_fs.put(pickle.dumps('This is a test!'),
                                          filename='test_file_path')

        sanitized_dict = {
            'stub_plugin': {
                'result': 'test_file_path',
                'file_system_flag': True
            }
        }
        sanitized_dict['inbound_result'] = {
            'result': 'inbound result',
            'file_system_flag': False
        }
        retrieved_dict = self.db_interface.retrieve_analysis(sanitized_dict)

        self.assertNotIn('file_system_flag',
                         retrieved_dict['stub_plugin'].keys())
        self.assertIn('result', retrieved_dict['stub_plugin'].keys())
        self.assertEqual(retrieved_dict['stub_plugin']['result'],
                         'This is a test!')
        self.assertNotIn('file_system_flag',
                         retrieved_dict['inbound_result'].keys())
        self.assertEqual(retrieved_dict['inbound_result']['result'],
                         'inbound result')

    def test_retrieve_analysis_filter(self):
        self.db_interface.sanitize_fs.put(pickle.dumps('This is a test!'),
                                          filename='test_file_path')
        sanitized_dict = {
            'selected_plugin': {
                'result': 'test_file_path',
                'file_system_flag': True
            }
        }
        sanitized_dict['other_plugin'] = {
            'result': 'test_file_path',
            'file_system_flag': True
        }
        retrieved_dict = self.db_interface.retrieve_analysis(
            sanitized_dict, analysis_filter=['selected_plugin'])
        self.assertEqual(retrieved_dict['selected_plugin']['result'],
                         'This is a test!')
        self.assertIn('file_system_flag', retrieved_dict['other_plugin'])

    def test_get_objects_by_uid_list(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        fo_list = self.db_interface.get_objects_by_uid_list(
            [self.test_firmware.uid])
        self.assertIsInstance(fo_list[0], Firmware, 'firmware has wrong type')
        self.assertEqual(fo_list[0].device_name, 'test_router',
                         'Device name in Firmware not correct')
        test_file = FileObject(file_path=path.join(get_test_data_dir(),
                                                   'get_files_test/testfile2'))
        self.db_interface_backend.add_file_object(test_file)
        fo_list = self.db_interface.get_objects_by_uid_list([test_file.uid])
        self.assertIsInstance(fo_list[0], FileObject,
                              'file object has wrong type')

    def test_sanitize_extract_and_retrieve_binary(self):
        test_data = {'dummy': {'test_key': 'test_value'}}
        test_data['dummy'] = self.db_interface._extract_binaries(
            test_data, 'dummy', 'uid')
        self.assertEqual(self.db_interface.sanitize_fs.list(),
                         ['dummy_test_key_uid'], 'file not written')
        self.assertEqual(test_data['dummy']['test_key'], 'dummy_test_key_uid',
                         'new file path not set')
        test_data['dummy'] = self.db_interface._retrieve_binaries(
            test_data, 'dummy')
        self.assertEqual(test_data['dummy']['test_key'], 'test_value',
                         'value not recoverd')

    def test_get_firmware_number(self):
        result = self.db_interface.get_firmware_number()
        self.assertEqual(result, 0)

        self.db_interface_backend.add_firmware(self.test_firmware)
        result = self.db_interface.get_firmware_number(query={})
        self.assertEqual(result, 1)
        result = self.db_interface.get_firmware_number(
            query={'_id': self.test_firmware.uid})
        self.assertEqual(result, 1)

        test_fw_2 = create_test_firmware(bin_path='container/test.7z')
        self.db_interface_backend.add_firmware(test_fw_2)
        result = self.db_interface.get_firmware_number(query='{}')
        self.assertEqual(result, 2)
        result = self.db_interface.get_firmware_number(
            query={'_id': self.test_firmware.uid})
        self.assertEqual(result, 1)

    def test_get_file_object_number(self):
        result = self.db_interface.get_file_object_number()
        self.assertEqual(result, 0)

        self.db_interface_backend.add_file_object(self.test_fo)
        result = self.db_interface.get_file_object_number(
            query={}, zero_on_empty_query=False)
        self.assertEqual(result, 1)
        result = self.db_interface.get_file_object_number(
            query={'_id': self.test_fo.uid})
        self.assertEqual(result, 1)
        result = self.db_interface.get_file_object_number(
            query=json.dumps({'_id': self.test_fo.uid}))
        self.assertEqual(result, 1)
        result = self.db_interface.get_file_object_number(
            query={}, zero_on_empty_query=True)
        self.assertEqual(result, 0)
        result = self.db_interface.get_file_object_number(
            query='{}', zero_on_empty_query=True)
        self.assertEqual(result, 0)

        test_fo_2 = create_test_file_object(
            bin_path='get_files_test/testfile2')
        self.db_interface_backend.add_file_object(test_fo_2)
        result = self.db_interface.get_file_object_number(
            query={}, zero_on_empty_query=False)
        self.assertEqual(result, 2)
        result = self.db_interface.get_file_object_number(
            query={'_id': self.test_fo.uid})
        self.assertEqual(result, 1)

    def test_unpacking_lock(self):
        first_uid, second_uid = 'id1', 'id2'
        assert not self.db_interface.check_unpacking_lock(
            first_uid) and not self.db_interface.check_unpacking_lock(
                second_uid), 'locks should not be set at start'

        self.db_interface.set_unpacking_lock(first_uid)
        assert self.db_interface.check_unpacking_lock(
            first_uid), 'locks should have been set'

        self.db_interface.set_unpacking_lock(second_uid)
        assert self.db_interface.check_unpacking_lock(
            first_uid) and self.db_interface.check_unpacking_lock(
                second_uid), 'both locks should be set'

        self.db_interface.release_unpacking_lock(first_uid)
        assert not self.db_interface.check_unpacking_lock(
            first_uid) and self.db_interface.check_unpacking_lock(
                second_uid), 'lock 1 should be released, lock 2 not'

        self.db_interface.drop_unpacking_locks()
        assert not self.db_interface.check_unpacking_lock(
            second_uid), 'all locks should be dropped'

    def test_lock_is_released(self):
        self.db_interface.set_unpacking_lock(self.test_fo.uid)
        assert self.db_interface.check_unpacking_lock(
            self.test_fo.uid), 'setting lock did not work'

        self.db_interface_backend.add_object(self.test_fo)
        assert not self.db_interface.check_unpacking_lock(
            self.test_fo.uid), 'add_object should release lock'

    def test_is_firmware(self):
        assert self.db_interface.is_firmware(self.test_firmware.uid) is False

        self.db_interface_backend.add_firmware(self.test_firmware)
        assert self.db_interface.is_firmware(self.test_firmware.uid) is True

    def test_is_file_object(self):
        assert self.db_interface.is_file_object(self.test_fo.uid) is False

        self.db_interface_backend.add_file_object(self.test_fo)
        assert self.db_interface.is_file_object(self.test_fo.uid) is True
Beispiel #6
0
class TestCompare(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._config = get_config_for_testing()
        cls.mongo_server = MongoMgr(config=cls._config)

    def setUp(self):
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)
        self.db_interface_compare = CompareDbInterface(config=self._config)
        self.db_interface_admin = AdminDbInterface(config=self._config)

        self.fw_one = create_test_firmware()
        self.fw_two = create_test_firmware()
        self.fw_two.set_binary(b'another firmware')
        self.compare_dict = self._create_compare_dict()

    def tearDown(self):
        self.db_interface_compare.shutdown()
        self.db_interface_admin.shutdown()
        self.db_interface_backend.shutdown()
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface.shutdown()
        gc.collect()

    @classmethod
    def tearDownClass(cls):
        cls.mongo_server.shutdown()

    def _create_compare_dict(self):
        comp_dict = {
            'general': {
                'hid': {
                    self.fw_one.uid: 'foo',
                    self.fw_two.uid: 'bar'
                }
            },
            'plugins': {}
        }
        comp_dict['general']['virtual_file_path'] = {
            self.fw_one.uid: 'dev_one_name',
            self.fw_two.uid: 'dev_two_name'
        }
        return comp_dict

    def test_add_and_get_compare_result(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)
        retrieved = self.db_interface_compare.get_compare_result(
            '{};{}'.format(self.fw_one.uid, self.fw_two.uid))
        self.assertEqual(
            retrieved['general']['virtual_file_path'][self.fw_one.uid],
            'dev_one_name', 'content of retrieval not correct')

    def test_get_not_existing_compare_result(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        result = self.db_interface_compare.get_compare_result('{};{}'.format(
            self.fw_one.uid, self.fw_two.uid))
        self.assertIsNone(result, 'result not none')

    def test_calculate_compare_result_id(self):
        comp_id = self.db_interface_compare._calculate_compare_result_id(
            self.compare_dict)
        self.assertEqual(comp_id, '{};{}'.format(self.fw_one.uid,
                                                 self.fw_two.uid))

    def test_calculate_compare_result_id__incomplete_entries(self):
        compare_dict = {
            'general': {
                'stat_1': {
                    'a': None
                },
                'stat_2': {
                    'b': None
                }
            }
        }
        comp_id = self.db_interface_compare._calculate_compare_result_id(
            compare_dict)
        self.assertEqual('a;b', comp_id)

    def test_check_objects_exist(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        assert self.db_interface_compare.check_objects_exist(
            self.fw_one.uid) is None, 'existing_object not found'
        with pytest.raises(FactCompareException):
            self.db_interface_compare.check_objects_exist(
                '{};none_existing_object'.format(self.fw_one.uid))

    def test_get_compare_result_of_nonexistent_uid(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        try:
            self.db_interface_compare.check_objects_exist(
                '{};none_existing_object'.format(self.fw_one.uid))
        except FactCompareException as exception:
            assert exception.get_message(
            ) == 'none_existing_object not found in database', 'error message not correct'

    def test_get_latest_comparisons(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        before = time()
        self.db_interface_compare.add_compare_result(self.compare_dict)
        result = self.db_interface_compare.page_compare_results(limit=10)
        for id_, hids, submission_date in result:
            self.assertIn(self.fw_one.uid, hids)
            self.assertIn(self.fw_two.uid, hids)
            self.assertIn(self.fw_one.uid, id_)
            self.assertIn(self.fw_two.uid, id_)
            self.assertTrue(before <= submission_date <= time())

    def test_get_latest_comparisons_removed_firmware(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)

        result = self.db_interface_compare.page_compare_results(limit=10)
        self.assertNotEqual(result, [], 'A compare result should be available')

        self.db_interface_admin.delete_firmware(self.fw_two.uid)

        result = self.db_interface_compare.page_compare_results(limit=10)

        self.assertEqual(result, [], 'No compare result should be available')

    def test_get_total_number_of_results(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)

        number = self.db_interface_compare.get_total_number_of_results()
        self.assertEqual(number, 1, 'no compare result found in database')
Beispiel #7
0
class TestStorageDbInterfaceBackend(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._config = get_config_for_testing(TMP_DIR)
        cls.mongo_server = MongoMgr(config=cls._config)

    def setUp(self):
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

        self.test_firmware = create_test_firmware()

        self.test_yara_match = {
            'rule': 'OpenSSH',
            'tags': [],
            'namespace': 'default',
            'strings': [(0, '$a', b'OpenSSH')],
            'meta': {
                'description': 'SSH library',
                'website': 'http://www.openssh.com',
                'open_source': True,
                'software_name': 'OpenSSH'
            },
            'matches': True
        }

        self.test_fo = create_test_file_object()

    def tearDown(self):
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface_backend.shutdown()
        self.db_interface.shutdown()
        gc.collect()

    @classmethod
    def tearDownClass(cls):
        cls.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def _get_all_firmware_uids(self):
        uid_list = []
        tmp = self.db_interface.firmwares.find()
        for item in tmp:
            uid_list.append(item['_id'])
        return uid_list

    def test_add_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertGreater(len(self._get_all_firmware_uids()), 0,
                           'No entry added to DB')
        recoverd_firmware_entry = self.db_interface_backend.firmwares.find_one(
        )
        self.assertAlmostEqual(recoverd_firmware_entry['submission_date'],
                               time(),
                               msg='submission time not set correctly',
                               delta=5.0)

    def test_add_and_get_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        result_backend = self.db_interface_backend.get_firmware(
            self.test_firmware.uid)
        self.assertIsNotNone(result_backend.binary,
                             'binary not set in backend result')
        result_common = self.db_interface.get_firmware(self.test_firmware.uid)
        self.assertIsNone(result_common.binary, 'binary set in common result')
        self.assertEqual(result_common.size, 787,
                         'file size not correct in common')
        self.assertIsInstance(result_common.tags, dict,
                              'tag field type not correct')

    def test_add_and_get_file_object(self):
        self.db_interface_backend.add_file_object(self.test_fo)
        result_backend = self.db_interface_backend.get_file_object(
            self.test_fo.uid)
        self.assertIsNotNone(result_backend.binary,
                             'binary not set in backend result')
        result_common = self.db_interface.get_file_object(self.test_fo.uid)
        self.assertIsNone(result_common.binary, 'binary set in common result')
        self.assertEqual(result_common.size, 62,
                         'file size not correct in common')

    def test_update_firmware(self):
        first_dict = {
            'stub_plugin': {
                'result': 0
            },
            'other_plugin': {
                'field': 'day'
            }
        }
        second_dict = {'stub_plugin': {'result': 1}}

        self.test_firmware.processed_analysis = first_dict
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertEqual(
            0,
            self.db_interface.get_object(
                self.test_firmware.uid).processed_analysis['stub_plugin']
            ['result'])
        self.test_firmware.processed_analysis = second_dict
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertEqual(
            1,
            self.db_interface.get_object(
                self.test_firmware.uid).processed_analysis['stub_plugin']
            ['result'])
        self.assertIn(
            'other_plugin',
            self.db_interface.get_object(
                self.test_firmware.uid).processed_analysis.keys())

    def test_update_file_object(self):
        first_dict = {'other_plugin': {'result': 0}}
        second_dict = {'stub_plugin': {'result': 1}}

        self.test_fo.processed_analysis = first_dict
        self.test_fo.files_included = {'file a', 'file b'}
        self.db_interface_backend.add_file_object(self.test_fo)
        self.test_fo.processed_analysis = second_dict
        self.test_fo.files_included = {'file b', 'file c'}
        self.db_interface_backend.add_file_object(self.test_fo)
        received_object = self.db_interface.get_object(self.test_fo.uid)
        self.assertEqual(
            0, received_object.processed_analysis['other_plugin']['result'])
        self.assertEqual(
            1, received_object.processed_analysis['stub_plugin']['result'])
        self.assertEqual(3, len(received_object.files_included))

    def test_add_and_get_object_including_comment(self):
        comment, author, date, uid = 'this is a test comment!', 'author', '1473431685', self.test_fo.uid
        self.test_fo.comments.append({
            'time': str(date),
            'author': author,
            'comment': comment
        })
        self.db_interface_backend.add_file_object(self.test_fo)

        retrieved_comment = self.db_interface.get_object(uid).comments[0]
        self.assertEqual(author, retrieved_comment['author'])
        self.assertEqual(comment, retrieved_comment['comment'])
        self.assertEqual(date, retrieved_comment['time'])

    def test_update_analysis_tag_no_firmware(self):
        self.db_interface_backend.add_file_object(self.test_fo)
        tag = {'value': 'yay', 'color': 'default', 'propagate': True}

        self.db_interface_backend.update_analysis_tags(self.test_fo.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag=tag)
        processed_fo = self.db_interface_backend.get_object(self.test_fo.uid)

        assert not processed_fo.analysis_tags

    def test_update_analysis_tag_uid_not_found(self):
        self.db_interface_backend.update_analysis_tags(self.test_fo.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag='should not matter')
        assert not self.db_interface_backend.get_object(self.test_fo.uid)

    def test_update_analysis_tag_bad_tag(self):
        self.db_interface_backend.add_firmware(self.test_firmware)

        self.db_interface_backend.update_analysis_tags(self.test_firmware.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag='bad_tag')
        processed_firmware = self.db_interface_backend.get_object(
            self.test_firmware.uid)

        assert not processed_firmware.analysis_tags

    def test_update_analysis_tag_success(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        tag = {'value': 'yay', 'color': 'default', 'propagate': True}

        self.db_interface_backend.update_analysis_tags(self.test_firmware.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag=tag)
        processed_firmware = self.db_interface_backend.get_object(
            self.test_firmware.uid)

        assert processed_firmware.analysis_tags
        assert processed_firmware.analysis_tags['dummy']['some_tag'] == tag

    def test_add_analysis_firmware(self):
        self.db_interface_backend.add_object(self.test_firmware)
        before = self.db_interface_backend.get_object(
            self.test_firmware.uid).processed_analysis

        self.test_firmware.processed_analysis['foo'] = {'bar': 5}
        self.db_interface_backend.add_analysis(self.test_firmware)
        after = self.db_interface_backend.get_object(
            self.test_firmware.uid).processed_analysis

        assert before != after
        assert 'foo' not in before
        assert 'foo' in after
        assert after['foo'] == {'bar': 5}

    def test_add_analysis_file_object(self):
        self.db_interface_backend.add_object(self.test_fo)

        self.test_fo.processed_analysis['foo'] = {'bar': 5}
        self.db_interface_backend.add_analysis(self.test_fo)
        analysis = self.db_interface_backend.get_object(
            self.test_fo.uid).processed_analysis

        assert 'foo' in analysis
        assert analysis['foo'] == {'bar': 5}

    def test_crash_add_analysis(self):
        with self.assertRaises(RuntimeError):
            self.db_interface_backend.add_analysis(dict())

        with self.assertRaises(AttributeError):
            self.db_interface_backend._update_analysis(dict(), 'dummy', dict())
Beispiel #8
0
class UnpackingScheduler(object):
    '''
    This scheduler performs unpacking on firmware objects
    '''
    def __init__(self,
                 config=None,
                 post_unpack=None,
                 analysis_workload=None,
                 db_interface=None):
        self.config = config
        self.stop_condition = Value('i', 0)
        self.throttle_condition = Value('i', 0)
        self.get_analysis_workload = analysis_workload
        self.in_queue = Queue()
        self.work_load_counter = 25
        self.workers = []
        self.post_unpack = post_unpack
        self.db_interface = MongoInterfaceCommon(
            config) if not db_interface else db_interface
        self.drop_cached_locks()
        self.start_unpack_workers()
        self.start_work_load_monitor()
        logging.info('Unpacker Module online')

    def drop_cached_locks(self):
        self.db_interface.drop_unpacking_locks()

    def add_task(self, fo):
        '''
        schedule a firmware_object for unpacking
        '''
        self.in_queue.put(fo)

    def get_scheduled_workload(self):
        return {'unpacking_queue': self.in_queue.qsize()}

    def shutdown(self):
        '''
        shutdown the scheduler
        '''
        logging.debug('Shutting down...')
        self.stop_condition.value = 1
        for worker in self.workers:
            worker.join()
        self.in_queue.close()
        logging.info('Unpacker Module offline')

# ---- internal functions ----

    def start_unpack_workers(self):
        logging.debug('Starting {} working threads'.format(
            int(self.config['unpack']['threads'])))
        for process_index in range(int(self.config['unpack']['threads'])):
            self._start_single_worker(process_index)

    def unpack_worker(self, worker_id):
        unpacker = Unpacker(self.config,
                            worker_id=worker_id,
                            db_interface=self.db_interface)
        while self.stop_condition.value == 0:
            with suppress(Empty):
                fo = self.in_queue.get(timeout=float(
                    self.config['ExpertSettings']['block_delay']))
                extracted_objects = unpacker.unpack(fo)
                logging.debug(
                    '[worker {}] unpacking of {} complete: {} files extracted'.
                    format(worker_id, fo.get_uid(), len(extracted_objects)))
                self.post_unpack(fo)
                self.schedule_extracted_files(extracted_objects)

    def schedule_extracted_files(self, object_list):
        for item in object_list:
            self._add_object_to_unpack_queue(item)

    def _add_object_to_unpack_queue(self, item):
        while self.stop_condition.value == 0:
            if self.throttle_condition.value == 0:
                self.in_queue.put(item)
                break
            else:
                logging.debug(
                    'throttle down unpacking to reduce memory consumption...')
                sleep(5)

    def start_work_load_monitor(self):
        logging.debug('Start work load monitor...')
        process = ExceptionSafeProcess(target=self._work_load_monitor)
        process.start()
        self.workers.append(process)

    def _work_load_monitor(self):
        while self.stop_condition.value == 0:
            workload = self._get_combined_analysis_workload()
            unpack_queue_size = self.in_queue.qsize()

            if self.work_load_counter >= 25:
                self.work_load_counter = 0
                log_function = logging.info
            else:
                self.work_load_counter += 1
                log_function = logging.debug
            log_function('{}Queue Length (Analysis/Unpack): {} / {}{}'.format(
                bcolors.WARNING, workload, unpack_queue_size, bcolors.ENDC))

            if workload < int(
                    self.config['ExpertSettings']['unpack_throttle_limit']):
                self.throttle_condition.value = 0
            else:
                self.throttle_condition.value = 1
            sleep(2)

    def _get_combined_analysis_workload(self):
        if self.get_analysis_workload is not None:
            current_analysis_workload = self.get_analysis_workload()
            return sum(current_analysis_workload.values())
        return 0

    def check_exceptions(self):
        return_value = False
        for worker in self.workers:
            if worker.exception:
                logging.error("{}Worker Exception Found!!{}".format(
                    bcolors.FAIL, bcolors.ENDC))
                logging.error(worker.exception[1])
                terminate_process_and_childs(worker)
                self.workers.remove(worker)

                if self.config.getboolean('ExpertSettings',
                                          'throw_exceptions'):
                    return_value = True
                else:
                    process_index = worker.name.split('-')[2]
                    self._start_single_worker(process_index)
        return return_value

    def _start_single_worker(self, process_index):
        process = ExceptionSafeProcess(
            target=self.unpack_worker,
            name='Unpacking-Worker-{}'.format(process_index),
            args=(process_index, ))
        process.start()
        self.workers.append(process)
Beispiel #9
0
class UnpackingScheduler:
    '''
    This scheduler performs unpacking on firmware objects
    '''
    def __init__(self,
                 config=None,
                 post_unpack=None,
                 analysis_workload=None,
                 db_interface=None):
        self.config = config
        self.stop_condition = Value('i', 0)
        self.throttle_condition = Value('i', 0)
        self.get_analysis_workload = analysis_workload
        self.in_queue = Queue()
        self.work_load_counter = 25
        self.workers = []
        self.post_unpack = post_unpack
        self.db_interface = MongoInterfaceCommon(
            config) if not db_interface else db_interface
        self.drop_cached_locks()
        self.start_unpack_workers()
        self.work_load_process = self.start_work_load_monitor()
        logging.info('Unpacker Module online')

    def drop_cached_locks(self):
        self.db_interface.drop_unpacking_locks()

    def add_task(self, fo):
        '''
        schedule a firmware_object for unpacking
        '''
        self.in_queue.put(fo)

    def get_scheduled_workload(self):
        return {'unpacking_queue': self.in_queue.qsize()}

    def shutdown(self):
        '''
        shutdown the scheduler
        '''
        logging.debug('Shutting down...')
        self.stop_condition.value = 1
        for worker in self.workers:
            worker.join()
        self.work_load_process.join()
        self.in_queue.close()
        logging.info('Unpacker Module offline')

# ---- internal functions ----

    def start_unpack_workers(self):
        threads = int(self.config['unpack']['threads'])
        logging.debug(f'Starting {threads} working threads')
        for process_index in range(threads):
            self.workers.append(
                start_single_worker(process_index, 'Unpacking',
                                    self.unpack_worker))

    def unpack_worker(self, worker_id):
        unpacker = Unpacker(self.config,
                            worker_id=worker_id,
                            db_interface=self.db_interface)
        while self.stop_condition.value == 0:
            with suppress(Empty):
                fo = self.in_queue.get(timeout=float(
                    self.config['ExpertSettings']['block_delay']))
                extracted_objects = unpacker.unpack(fo)
                logging.debug(
                    f'[worker {worker_id}] unpacking of {fo.uid} complete: {len(extracted_objects)} files extracted'
                )
                self.post_unpack(fo)
                self.schedule_extracted_files(extracted_objects)

    def schedule_extracted_files(self, object_list):
        for item in object_list:
            self._add_object_to_unpack_queue(item)

    def _add_object_to_unpack_queue(self, item):
        while self.stop_condition.value == 0:
            if self.throttle_condition.value == 0:
                self.in_queue.put(item)
                break
            logging.debug(
                'throttle down unpacking to reduce memory consumption...')
            sleep(5)

    def start_work_load_monitor(self):
        logging.debug('Start work load monitor...')
        return start_single_worker(None, 'unpack-load',
                                   self._work_load_monitor)

    def _work_load_monitor(self):
        while self.stop_condition.value == 0:
            workload = self._get_combined_analysis_workload()
            unpack_queue_size = self.in_queue.qsize()

            if self.work_load_counter >= 25:
                self.work_load_counter = 0
                log_function = logging.info
            else:
                self.work_load_counter += 1
                log_function = logging.debug
            log_function(
                color_string(
                    f'Queue Length (Analysis/Unpack): {workload} / {unpack_queue_size}',
                    TerminalColors.WARNING))

            if workload < int(
                    self.config['ExpertSettings']['unpack_throttle_limit']):
                self.throttle_condition.value = 0
            else:
                self.throttle_condition.value = 1
            sleep(2)

    def _get_combined_analysis_workload(self):
        if self.get_analysis_workload is not None:
            return self.get_analysis_workload()
        return 0

    def check_exceptions(self):
        shutdown = check_worker_exceptions(self.workers, 'Unpacking',
                                           self.config, self.unpack_worker)

        list_with_load_process = [
            self.work_load_process,
        ]
        shutdown |= check_worker_exceptions(list_with_load_process,
                                            'unpack-load', self.config,
                                            self._work_load_monitor)
        if new_worker_was_started(new_process=list_with_load_process[0],
                                  old_process=self.work_load_process):
            self.work_load_process = list_with_load_process.pop()

        return shutdown
Beispiel #10
0
class TestCompare:
    @classmethod
    def setup_class(cls):
        cls._config = get_config_for_testing()
        cls.mongo_server = MongoMgr(config=cls._config)

    def setup(self):
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)
        self.db_interface_compare = CompareDbInterface(config=self._config)
        self.db_interface_admin = AdminDbInterface(config=self._config)

        self.fw_one = create_test_firmware()
        self.fw_two = create_test_firmware()
        self.fw_two.set_binary(b'another firmware')
        self.compare_dict = self._create_compare_dict()
        self.compare_id = '{};{}'.format(self.fw_one.uid, self.fw_two.uid)

    def teardown(self):
        self.db_interface_compare.shutdown()
        self.db_interface_admin.shutdown()
        self.db_interface_backend.shutdown()
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface.shutdown()
        gc.collect()

    @classmethod
    def teardown_class(cls):
        cls.mongo_server.shutdown()

    def _create_compare_dict(self):
        return {
            'general': {
                'hid': {
                    self.fw_one.uid: 'foo',
                    self.fw_two.uid: 'bar'
                },
                'virtual_file_path': {
                    self.fw_one.uid: 'dev_one_name',
                    self.fw_two.uid: 'dev_two_name'
                }
            },
            'plugins': {},
        }

    def test_add_and_get_compare_result(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)
        retrieved = self.db_interface_compare.get_compare_result(
            self.compare_id)
        assert retrieved['general']['virtual_file_path'][self.fw_one.uid] == 'dev_one_name',\
            'content of retrieval not correct'

    def test_get_not_existing_compare_result(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        result = self.db_interface_compare.get_compare_result(self.compare_id)
        assert result is None, 'result not none'

    def test_calculate_compare_result_id(self):
        comp_id = self.db_interface_compare._calculate_compare_result_id(
            self.compare_dict)
        assert comp_id == self.compare_id

    def test_calculate_compare_result_id__incomplete_entries(self):
        compare_dict = {
            'general': {
                'stat_1': {
                    'a': None
                },
                'stat_2': {
                    'b': None
                }
            }
        }
        comp_id = self.db_interface_compare._calculate_compare_result_id(
            compare_dict)
        assert comp_id == 'a;b'

    def test_check_objects_exist(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        assert self.db_interface_compare.check_objects_exist(
            self.fw_one.uid) is None, 'existing_object not found'
        with pytest.raises(FactCompareException):
            self.db_interface_compare.check_objects_exist(
                '{};none_existing_object'.format(self.fw_one.uid))

    def test_get_compare_result_of_nonexistent_uid(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        try:
            self.db_interface_compare.check_objects_exist(
                '{};none_existing_object'.format(self.fw_one.uid))
        except FactCompareException as exception:
            assert exception.get_message(
            ) == 'none_existing_object not found in database', 'error message not correct'

    def test_get_latest_comparisons(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        before = time()
        self.db_interface_compare.add_compare_result(self.compare_dict)
        result = self.db_interface_compare.page_compare_results(limit=10)
        for id_, hids, submission_date in result:
            assert self.fw_one.uid in hids
            assert self.fw_two.uid in hids
            assert self.fw_one.uid in id_
            assert self.fw_two.uid in id_
            assert before <= submission_date <= time()

    def test_get_latest_comparisons_removed_firmware(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)

        result = self.db_interface_compare.page_compare_results(limit=10)
        assert result != [], 'A compare result should be available'

        self.db_interface_admin.delete_firmware(self.fw_two.uid)

        result = self.db_interface_compare.page_compare_results(limit=10)

        assert result == [], 'No compare result should be available'

    def test_get_total_number_of_results(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)

        number = self.db_interface_compare.get_total_number_of_results()
        assert number == 1, 'no compare result found in database'

    @pytest.mark.parametrize('root_uid, expected_result', [
        ('the_root_uid', ['uid1', 'uid2']),
        ('some_other_uid', []),
        (None, []),
    ])
    def test_get_exclusive_files(self, root_uid, expected_result):
        compare_dict = self._create_compare_dict()
        compare_dict['plugins'] = {
            'File_Coverage': {
                'exclusive_files': {
                    'the_root_uid': ['uid1', 'uid2']
                }
            }
        }

        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(compare_dict)
        exclusive_files = self.db_interface_compare.get_exclusive_files(
            self.compare_id, root_uid)
        assert exclusive_files == expected_result
class TestCompare(unittest.TestCase):
    def setUp(self):
        self._config = get_config_for_testing()
        self.mongo_server = MongoMgr(config=self._config)
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)
        self.db_interface_compare = CompareDbInterface(config=self._config)
        self.db_interface_admin = AdminDbInterface(config=self._config)

        self.fw_one = create_test_firmware()
        self.fw_two = create_test_firmware()
        self.fw_two.set_binary(b'another firmware')
        self.compare_dict = self._create_compare_dict()

    def tearDown(self):
        self.db_interface_compare.shutdown()
        self.db_interface_admin.shutdown()
        self.db_interface_backend.shutdown()
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface.shutdown()
        self.mongo_server.shutdown()
        gc.collect()

    def _create_compare_dict(self):
        comp_dict = {
            'general': {
                'hid': {
                    self.fw_one.get_uid(): 'foo',
                    self.fw_two.get_uid(): 'bar'
                }
            },
            'plugins': {}
        }
        comp_dict['general']['virtual_file_path'] = {
            self.fw_one.get_uid(): 'dev_one_name',
            self.fw_two.get_uid(): 'dev_two_name'
        }
        return comp_dict

    def test_add_and_get_compare_result(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)
        retrieved = self.db_interface_compare.get_compare_result(
            '{};{}'.format(self.fw_one.get_uid(), self.fw_two.get_uid()))
        self.assertEqual(
            retrieved['general']['virtual_file_path'][self.fw_one.get_uid()],
            'dev_one_name', 'content of retrieval not correct')

    def test_get_not_existing_compare_result(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        result = self.db_interface_compare.get_compare_result('{};{}'.format(
            self.fw_one.get_uid(), self.fw_two.get_uid()))
        self.assertIsNone(result, 'result not none')

    def test_calculate_compare_result_id(self):
        comp_id = self.db_interface_compare._calculate_compare_result_id(
            self.compare_dict)
        self.assertEqual(
            comp_id, '{};{}'.format(self.fw_one.get_uid(),
                                    self.fw_two.get_uid()))

    def test_object_existence_quick_check(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.assertIsNone(
            self.db_interface_compare.object_existence_quick_check(
                self.fw_one.get_uid()), 'existing_object not found')
        self.assertEqual(
            self.db_interface_compare.object_existence_quick_check(
                '{};none_existing_object'.format(self.fw_one.get_uid())),
            'none_existing_object not found in database',
            'error message not correct')

    def test_get_compare_result_of_none_existing_uid(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        result = self.db_interface_compare.get_compare_result(
            '{};none_existing_uid'.format(self.fw_one.get_uid()))
        self.assertEqual(result, 'none_existing_uid not found in database',
                         'no result not found error')

    def test_get_latest_comparisons(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        before = time()
        self.db_interface_compare.add_compare_result(self.compare_dict)
        result = self.db_interface_compare.page_compare_results(limit=10)
        for id, hids, submission_date in result:
            self.assertIn(self.fw_one.get_uid(), hids)
            self.assertIn(self.fw_two.get_uid(), hids)
            self.assertIn(self.fw_one.get_uid(), id)
            self.assertIn(self.fw_two.get_uid(), id)
            self.assertTrue(before <= submission_date <= time())

    def test_get_latest_comparisons_removed_firmware(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)

        result = self.db_interface_compare.page_compare_results(limit=10)
        self.assertNotEqual(result, [], 'A compare result should be available')

        self.db_interface_admin.delete_firmware(self.fw_two.uid)

        result = self.db_interface_compare.page_compare_results(limit=10)

        self.assertEqual(result, [], 'No compare result should be available')

    def test_get_total_number_of_results(self):
        self.db_interface_backend.add_firmware(self.fw_one)
        self.db_interface_backend.add_firmware(self.fw_two)
        self.db_interface_compare.add_compare_result(self.compare_dict)

        number = self.db_interface_compare.get_total_number_of_results()
        self.assertEqual(number, 1, 'no compare result found in database')