Esempio n. 1
0
class TestTagPropagation(unittest.TestCase):
    def setUp(self):
        self._tmp_dir = TemporaryDirectory()
        self._config = initialize_config(self._tmp_dir)
        self.analysis_finished_event = Event()
        self.uid_of_key_file = '530bf2f1203b789bfe054d3118ebd29a04013c587efd22235b3b9677cee21c0e_2048'

        self._mongo_server = MongoMgr(config=self._config, auth=False)
        self.backend_interface = BackEndDbInterface(config=self._config)

        self._analysis_scheduler = AnalysisScheduler(
            config=self._config,
            post_analysis=self.count_analysis_finished_event)
        self._tagging_scheduler = TaggingDaemon(
            analysis_scheduler=self._analysis_scheduler)
        self._unpack_scheduler = UnpackingScheduler(
            config=self._config, post_unpack=self._analysis_scheduler.add_task)

    def count_analysis_finished_event(self, fw_object):
        self.backend_interface.add_object(fw_object)
        if fw_object.uid == self.uid_of_key_file:
            self.analysis_finished_event.set()

    def _wait_for_empty_tag_queue(self):
        while not self._analysis_scheduler.tag_queue.empty():
            sleep(0.1)

    def tearDown(self):
        self._unpack_scheduler.shutdown()
        self._tagging_scheduler.shutdown()
        self._analysis_scheduler.shutdown()

        clean_test_database(self._config, get_database_names(self._config))
        self._mongo_server.shutdown()

        self._tmp_dir.cleanup()
        gc.collect()

    def test_run_analysis_with_tag(self):
        test_fw = Firmware(
            file_path='{}/container/with_key.7z'.format(get_test_data_dir()))
        test_fw.release_date = '2017-01-01'
        test_fw.scheduled_analysis = ['crypto_material']

        self._unpack_scheduler.add_task(test_fw)

        assert self.analysis_finished_event.wait(timeout=20)

        processed_fo = self.backend_interface.get_object(
            self.uid_of_key_file, analysis_filter=['crypto_material'])
        assert processed_fo.processed_analysis['crypto_material'][
            'tags'], 'no tags set in analysis'

        self._wait_for_empty_tag_queue()

        processed_fw = self.backend_interface.get_object(
            test_fw.uid, analysis_filter=['crypto_material'])
        assert processed_fw.analysis_tags, 'tags not propagated properly'
        assert processed_fw.analysis_tags['crypto_material'][
            'private_key_inside']
Esempio n. 2
0
class TestAcceptanceAdvancedSearch(TestAcceptanceBase):

    def setUp(self):
        super().setUp()
        self._start_backend()
        self.db_backend_interface = BackEndDbInterface(self.config)

        self.parent_fw = create_test_firmware()
        self.child_fo = create_test_file_object()
        uid = self.parent_fw.uid
        self.child_fo.virtual_file_path = {uid: ['|{}|/folder/{}'.format(uid, self.child_fo.file_name)]}
        self.db_backend_interface.add_object(self.parent_fw)
        self.child_fo.processed_analysis['unpacker'] = {}
        self.child_fo.processed_analysis['unpacker']['plugin_used'] = 'test'
        self.child_fo.processed_analysis['file_type']['mime'] = 'some_type'
        self.db_backend_interface.add_object(self.child_fo)

    def tearDown(self):
        self.db_backend_interface.shutdown()
        self._stop_backend()
        super().tearDown()

    def test_advanced_search_get(self):
        rv = self.test_client.get('/database/advanced_search')
        assert b'<h2>Advanced Search</h2>' in rv.data

    def test_advanced_search(self):
        rv = self.test_client.post('/database/advanced_search', content_type='multipart/form-data',
                                   data={'advanced_search': '{}'}, follow_redirects=True)
        assert b'Please enter a valid search request' not in rv.data
        assert self.parent_fw.uid.encode() in rv.data
        assert self.child_fo.uid.encode() not in rv.data

    def test_advanced_search_file_object(self):
        rv = self.test_client.post('/database/advanced_search', content_type='multipart/form-data',
                                   data={'advanced_search': json.dumps({'_id': self.child_fo.uid})}, follow_redirects=True)
        assert b'Please enter a valid search request' not in rv.data
        assert b'<strong>UID:</strong> ' + self.parent_fw.uid.encode() not in rv.data
        assert b'<strong>UID:</strong> ' + self.child_fo.uid.encode() in rv.data

    def test_advanced_search_only_firmwares(self):
        rv = self.test_client.post('/database/advanced_search', content_type='multipart/form-data',
                                   data={'advanced_search': json.dumps({'_id': self.child_fo.uid}), 'only_firmwares': 'True'}, follow_redirects=True)
        assert b'Please enter a valid search request' not in rv.data
        assert self.child_fo.uid.encode() not in rv.data
        assert self.parent_fw.uid.encode() in rv.data

    def test_rest_recursive_firmware_search(self):
        query = quote(json.dumps({'file_name': self.child_fo.file_name}))
        response = self.test_client.get('/rest/firmware?recursive=true&query={}'.format(query)).data
        assert b'error_message' not in response
        assert self.parent_fw.uid.encode() in response
Esempio n. 3
0
class TestAcceptanceBaseFullStart(TestAcceptanceBase):

    NUMBER_OF_FILES_TO_ANALYZE = 4
    NUMBER_OF_PLUGINS = 2

    def setUp(self):
        super().setUp()
        self.analysis_finished_event = Event()
        self.compare_finished_event = Event()
        self.elements_finished_analyzing = Value('i', 0)
        self.db_backend_service = BackEndDbInterface(config=self.config)
        self._start_backend(post_analysis=self._analysis_callback, compare_callback=self._compare_callback)
        time.sleep(2)  # wait for systems to start

    def tearDown(self):
        self._stop_backend()
        self.db_backend_service.shutdown()
        super().tearDown()

    def _analysis_callback(self, fo):
        self.db_backend_service.add_object(fo)
        self.elements_finished_analyzing.value += 1
        if self.elements_finished_analyzing.value == self.NUMBER_OF_FILES_TO_ANALYZE * self.NUMBER_OF_PLUGINS:
            self.analysis_finished_event.set()

    def _compare_callback(self):
        self.compare_finished_event.set()

    def upload_test_firmware(self, test_fw):
        testfile_path = Path(get_test_data_dir()) / test_fw.path
        with open(str(testfile_path), 'rb') as fp:
            data = {
                'file': (fp, test_fw.file_name),
                'device_name': test_fw.name,
                'device_part': 'test_part',
                'device_class': 'test_class',
                'version': '1.0',
                'vendor': 'test_vendor',
                'release_date': '1970-01-01',
                'tags': '',
                'analysis_systems': []
            }
            rv = self.test_client.post('/upload', content_type='multipart/form-data', data=data, follow_redirects=True)
        self.assertIn(b'Upload Successful', rv.data, 'upload not successful')
        self.assertIn(test_fw.uid.encode(), rv.data, 'uid not found on upload success page')
class TestStorageDbInterfaceFrontend(unittest.TestCase):
    def setUp(self):
        self._config = get_config_for_testing(TMP_DIR)
        self.mongo_server = MongoMgr(config=self._config)
        self.db_frontend_interface = FrontEndDbInterface(config=self._config)
        self.db_backend_interface = BackEndDbInterface(config=self._config)
        self.test_firmware = create_test_firmware()

    def tearDown(self):
        self.db_frontend_interface.shutdown()
        self.db_backend_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_backend_interface.shutdown()
        self.mongo_server.shutdown()
        TMP_DIR.cleanup()
        gc.collect()

    def test_get_meta_list(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        list_of_firmwares = self.db_frontend_interface.get_meta_list()
        test_output = list_of_firmwares.pop()
        self.assertEqual(test_output[1],
                         'test_vendor test_router - 0.1 (Router)',
                         'Firmware not successfully received')

    def test_get_hid_firmware(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        result = self.db_frontend_interface.get_hid(
            self.test_firmware.get_uid())
        self.assertEqual(result, 'test_vendor test_router - 0.1 (Router)',
                         'fw hid not correct')

    def test_get_hid_fo(self):
        test_fo = create_test_file_object(bin_path='get_files_test/testfile2')
        test_fo.virtual_file_path = {
            'a': ['|a|/test_file'],
            'b': ['|b|/get_files_test/testfile2']
        }
        self.db_backend_interface.add_file_object(test_fo)
        result = self.db_frontend_interface.get_hid(test_fo.get_uid(),
                                                    root_uid='b')
        self.assertEqual(result, '/get_files_test/testfile2',
                         'fo hid not correct')
        result = self.db_frontend_interface.get_hid(test_fo.get_uid())
        self.assertIsInstance(result, str, 'result is not a string')
        self.assertEqual(result[0], '/',
                         'first character not correct if no root_uid set')
        result = self.db_frontend_interface.get_hid(test_fo.get_uid(),
                                                    root_uid='c')
        self.assertEqual(
            result[0], '/',
            'first character not correct if invalid root_uid set')

    def test_get_file_name(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        result = self.db_frontend_interface.get_file_name(
            self.test_firmware.get_uid())
        self.assertEqual(result, 'test.zip', 'name not correct')

    def test_get_hid_invalid_uid(self):
        result = self.db_frontend_interface.get_hid('foo')
        self.assertEqual(result, '',
                         'invalid uid should result in empty string')

    def test_get_firmware_attribute_list(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        self.assertEqual(self.db_frontend_interface.get_device_class_list(),
                         ['Router'])
        self.assertEqual(self.db_frontend_interface.get_vendor_list(),
                         ['test_vendor'])
        self.assertEqual(
            self.db_frontend_interface.get_firmware_attribute_list(
                'device_name', {
                    'vendor': 'test_vendor',
                    'device_class': 'Router'
                }), ['test_router'])
        self.assertEqual(
            self.db_frontend_interface.get_firmware_attribute_list('version'),
            ['0.1'])
        self.assertEqual(self.db_frontend_interface.get_device_name_dict(),
                         {'Router': {
                             'test_vendor': ['test_router']
                         }})

    def test_get_data_for_nice_list(self):
        uid_list = [self.test_firmware.get_uid()]
        self.db_backend_interface.add_firmware(self.test_firmware)
        nice_list_data = self.db_frontend_interface.get_data_for_nice_list(
            uid_list, uid_list[0])
        self.assertEquals(
            sorted([
                'size', 'virtual_file_paths', 'uid', 'mime-type',
                'files_included'
            ]), sorted(nice_list_data[0].keys()))
        self.assertEqual(nice_list_data[0]['uid'],
                         self.test_firmware.get_uid())

    def test_generic_search(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        # str input
        result = self.db_frontend_interface.generic_search(
            '{"file_name": "test.zip"}')
        self.assertEqual(result, [self.test_firmware.get_uid()],
                         'Firmware not successfully received')
        # dict input
        result = self.db_frontend_interface.generic_search(
            {'file_name': 'test.zip'})
        self.assertEqual(result, [self.test_firmware.get_uid()],
                         'Firmware not successfully received')

    def test_all_uids_found_in_database(self):
        self.db_backend_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        uid_list = [self.test_firmware.get_uid()]
        self.assertFalse(
            self.db_frontend_interface.all_uids_found_in_database(uid_list))
        self.db_backend_interface.add_firmware(self.test_firmware)
        self.assertTrue(
            self.db_frontend_interface.all_uids_found_in_database(
                [self.test_firmware.get_uid()]))

    def test_get_number_of_firmwares_in_db(self):
        self.assertEqual(
            self.db_frontend_interface.get_number_of_firmwares_in_db(), 0)
        self.db_backend_interface.add_firmware(self.test_firmware)
        self.assertEqual(
            self.db_frontend_interface.get_number_of_firmwares_in_db(), 1)

    def test_get_x_last_added_firmwares(self):
        self.assertEqual(self.db_frontend_interface.get_last_added_firmwares(),
                         [], 'empty db should result in empty list')
        test_fw_one = create_test_firmware(device_name='fw_one')
        self.db_backend_interface.add_firmware(test_fw_one)
        test_fw_two = create_test_firmware(device_name='fw_two',
                                           bin_path='container/test.7z')
        self.db_backend_interface.add_firmware(test_fw_two)
        test_fw_three = create_test_firmware(device_name='fw_three',
                                             bin_path='container/test.cab')
        self.db_backend_interface.add_firmware(test_fw_three)
        result = self.db_frontend_interface.get_last_added_firmwares(limit_x=2)
        self.assertEqual(len(result), 2, 'Number of results should be 2')
        self.assertEqual(result[0]['device_name'], 'fw_three',
                         'last firmware is not first entry')
        self.assertEqual(result[1]['device_name'], 'fw_two',
                         'second last firmware is not the second entry')

    def test_generate_file_tree_node(self):
        parent_fw = create_test_firmware()
        child_fo = create_test_file_object()
        child_fo.processed_analysis['file_type'] = {'mime': 'sometype'}
        uid = parent_fw.get_uid()
        child_fo.virtual_file_path = {
            uid: ['|{}|/folder/{}'.format(uid, child_fo.file_name)]
        }
        parent_fw.files_included = {child_fo.get_uid()}
        self.db_backend_interface.add_object(parent_fw)
        self.db_backend_interface.add_object(child_fo)
        for node in self.db_frontend_interface.generate_file_tree_node(
                uid, uid):
            self.assertIsInstance(node, FileTreeNode)
            self.assertEqual(node.name, parent_fw.file_name)
            self.assertTrue(node.has_children)
        for node in self.db_frontend_interface.generate_file_tree_node(
                child_fo.get_uid(), uid):
            self.assertIsInstance(node, FileTreeNode)
            self.assertEqual(node.name, 'folder')
            self.assertTrue(node.has_children)
            virtual_grand_child = node.get_list_of_child_nodes()[0]
            self.assertEqual(virtual_grand_child.type, 'sometype')
            self.assertFalse(virtual_grand_child.has_children)
            self.assertEqual(virtual_grand_child.name, child_fo.file_name)

    def test_get_number_of_total_matches(self):
        parent_fw = create_test_firmware()
        child_fo = create_test_file_object()
        uid = parent_fw.get_uid()
        child_fo.virtual_file_path = {
            uid: ['|{}|/folder/{}'.format(uid, child_fo.file_name)]
        }
        self.db_backend_interface.add_object(parent_fw)
        self.db_backend_interface.add_object(child_fo)
        query = '{{"$or": [{{"_id": "{}"}}, {{"_id": "{}"}}]}}'.format(
            uid, child_fo.get_uid())
        self.assertEqual(
            self.db_frontend_interface.get_number_of_total_matches(
                query, only_parent_firmwares=False), 2)
        self.assertEqual(
            self.db_frontend_interface.get_number_of_total_matches(
                query, only_parent_firmwares=True), 1)

    def test_get_other_versions_of_firmware(self):
        parent_fw1 = create_test_firmware(version='1')
        self.db_backend_interface.add_object(parent_fw1)
        parent_fw2 = create_test_firmware(version='2',
                                          bin_path='container/test.7z')
        self.db_backend_interface.add_object(parent_fw2)
        parent_fw3 = create_test_firmware(version='3',
                                          bin_path='container/test.cab')
        self.db_backend_interface.add_object(parent_fw3)

        other_versions = self.db_frontend_interface.get_other_versions_of_firmware(
            parent_fw1)
        self.assertEqual(len(other_versions), 2,
                         'wrong number of other versions')
        self.assertIn({
            '_id': parent_fw2.get_uid(),
            'version': '2'
        }, other_versions)
        self.assertIn({
            '_id': parent_fw3.get_uid(),
            'version': '3'
        }, other_versions)

        other_versions = self.db_frontend_interface.get_other_versions_of_firmware(
            parent_fw2)
        self.assertIn({
            '_id': parent_fw3.get_uid(),
            'version': '3'
        }, other_versions)
Esempio n. 5
0
class TestAcceptanceCompareFirmwares(TestAcceptanceBase):
    def setUp(self):
        super().setUp()
        self.analysis_finished_event = Event()
        self.compare_finished_event = Event()
        self.elements_finished_analyzing = Value('i', 0)
        self.db_backend_service = BackEndDbInterface(config=self.config)
        self._start_backend(post_analysis=self._analysis_callback,
                            compare_callback=self._compare_callback)
        time.sleep(2)  # wait for systems to start

    def tearDown(self):
        self._stop_backend()
        self.db_backend_service.shutdown()
        super().tearDown()

    def _analysis_callback(self, fo):
        self.db_backend_service.add_object(fo)
        self.elements_finished_analyzing.value += 1
        if self.elements_finished_analyzing.value == 4 * 2 * 2:  # two firmware container with 3 included files each times two plugins
            self.analysis_finished_event.set()

    def _compare_callback(self):
        self.compare_finished_event.set()

    def _upload_firmware_get(self):
        rv = self.test_client.get('/upload')
        self.assertIn(b'<h2>Upload Firmware</h2>', rv.data,
                      'upload page not displayed correctly')

    def _upload_firmware_put(self, path, device_name, uid):
        testfile_path = os.path.join(get_test_data_dir(), path)
        with open(testfile_path, 'rb') as fp:
            data = {
                'file': fp,
                'device_name': device_name,
                'device_part': 'full',
                'device_class': 'test_class',
                'version': '1.0',
                'vendor': 'test_vendor',
                'release_date': '01.01.1970',
                'tags': '',
                'analysis_systems': []
            }
            rv = self.test_client.post('/upload',
                                       content_type='multipart/form-data',
                                       data=data,
                                       follow_redirects=True)
        self.assertIn(b'Upload Successful', rv.data, 'upload not successful')
        self.assertIn(uid.encode(), rv.data,
                      'uid not found on upload success page')

    def _add_firmwares_to_compare(self):
        rv = self.test_client.get('/analysis/{}'.format(self.test_fw_a.uid))
        self.assertIn(self.test_fw_a.uid, rv.data.decode(), '')
        rv = self.test_client.get('/comparison/add/{}'.format(
            self.test_fw_a.uid),
                                  follow_redirects=True)
        self.assertIn('Firmwares Selected for Comparison', rv.data.decode())

        rv = self.test_client.get('/analysis/{}'.format(self.test_fw_c.uid))
        self.assertIn(self.test_fw_c.uid, rv.data.decode())
        self.assertIn(self.test_fw_c.name, rv.data.decode())
        rv = self.test_client.get('/comparison/add/{}'.format(
            self.test_fw_c.uid),
                                  follow_redirects=True)
        self.assertIn('Remove All', rv.data.decode())

    def _start_compare(self):
        rv = self.test_client.get('/compare', follow_redirects=True)
        self.assertIn(b'Your compare task is in progress.', rv.data,
                      'compare wait page not displayed correctly')

    def _show_comparison_results(self):
        rv = self.test_client.get('/compare/{};{}'.format(
            self.test_fw_a.uid, self.test_fw_c.uid))
        self.assertIn(self.test_fw_a.name.encode(), rv.data,
                      'test firmware a comparison not displayed correctly')
        self.assertIn(self.test_fw_c.name.encode(), rv.data,
                      'test firmware b comparison not displayed correctly')
        self.assertIn(b'File Coverage', rv.data,
                      'comparison page not displayed correctly')

    def _show_home_page(self):
        rv = self.test_client.get('/')
        self.assertIn(b'Latest Comparisons', rv.data,
                      'latest comparisons not displayed on "home"')

    def _show_compare_browse(self):
        rv = self.test_client.get('/database/browse_compare')
        self.assertIn(self.test_fw_a.name.encode(), rv.data,
                      'no compare result shown in browse')

    def _show_analysis_without_compare_list(self):
        rv = self.test_client.get('/analysis/{}'.format(self.test_fw_a.uid))
        assert b'Show List of Known Comparisons' not in rv.data

    def _show_analysis_with_compare_list(self):
        rv = self.test_client.get('/analysis/{}'.format(self.test_fw_a.uid))
        assert b'Show List of Known Comparisons' in rv.data

    def test_compare_firmwares(self):
        self._upload_firmware_get()
        for fw in [self.test_fw_a, self.test_fw_c]:
            self._upload_firmware_put(fw.path, fw.name, fw.uid)
        self.analysis_finished_event.wait(timeout=20)
        self._show_analysis_without_compare_list()
        self._add_firmwares_to_compare()
        self._start_compare()
        self.compare_finished_event.wait(timeout=20)
        self._show_comparison_results()
        self._show_home_page()
        self._show_compare_browse()
        self._show_analysis_with_compare_list()
Esempio n. 6
0
class TestSummary(unittest.TestCase):
    def setUp(self):
        self._config = get_config_for_testing(TMP_DIR)
        self.mongo_server = MongoMgr(config=self._config)
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

    def tearDown(self):
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface.shutdown()
        self.db_interface_backend.shutdown()
        self.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def create_and_add_test_fimrware_and_file_object(self):
        self.test_fw = create_test_firmware()
        self.test_fo = create_test_file_object()
        self.test_fw.add_included_file(self.test_fo)
        self.db_interface_backend.add_object(self.test_fw)
        self.db_interface_backend.add_object(self.test_fo)

    def test_get_set_of_all_included_files(self):
        self.create_and_add_test_fimrware_and_file_object()
        result_set_fo = self.db_interface.get_set_of_all_included_files(
            self.test_fo)
        self.assertIsInstance(result_set_fo, set, 'result is not a set')
        self.assertEqual(len(result_set_fo), 1, 'number of files not correct')
        self.assertIn(self.test_fo.uid, result_set_fo,
                      'object not in its own result set')
        result_set_fw = self.db_interface.get_set_of_all_included_files(
            self.test_fw)
        self.assertEqual(len(result_set_fw), 2, 'number of files not correct')
        self.assertIn(self.test_fo.uid, result_set_fw,
                      'test file not in result set firmware')
        self.assertIn(self.test_fw.uid, result_set_fw,
                      'fw not in result set firmware')

    def test_get_uids_of_all_included_files(self):
        def add_test_file_to_db(uid, parent_uids: Set[str]):
            test_fo = create_test_file_object()
            test_fo.parent_firmware_uids = parent_uids
            test_fo.uid = uid
            self.db_interface_backend.add_object(test_fo)

        add_test_file_to_db('uid1', {'foo'})
        add_test_file_to_db('uid2', {'foo', 'bar'})
        add_test_file_to_db('uid3', {'bar'})
        result = self.db_interface.get_uids_of_all_included_files('foo')
        assert result == {'uid1', 'uid2'}

        assert self.db_interface.get_uids_of_all_included_files(
            'uid not in db') == set()

    def test_get_summary(self):
        self.create_and_add_test_fimrware_and_file_object()
        result_sum = self.db_interface.get_summary(self.test_fw, 'dummy')
        self.assertIsInstance(result_sum, dict, 'summary is not a dict')
        self.assertIn('sum a', result_sum, 'summary entry of parent missing')
        self.assertIn(self.test_fw.uid, result_sum['sum a'],
                      'origin (parent) missing in parent summary entry')
        self.assertIn(self.test_fo.uid, result_sum['sum a'],
                      'origin (child) missing in parent summary entry')
        self.assertNotIn(self.test_fo.uid, result_sum['fw exclusive sum a'],
                         'child as origin but should not be')
        self.assertIn('file exclusive sum b', result_sum,
                      'file exclusive summary missing')
        self.assertIn(self.test_fo.uid, result_sum['file exclusive sum b'],
                      'origin of file exclusive missing')
        self.assertNotIn(self.test_fw.uid, result_sum['file exclusive sum b'],
                         'parent as origin but should not be')

    def test_collect_summary(self):
        self.create_and_add_test_fimrware_and_file_object()
        fo_list = [self.test_fo.uid]
        result_sum = self.db_interface._collect_summary(fo_list, 'dummy')
        assert all(
            item in result_sum
            for item in self.test_fo.processed_analysis['dummy']['summary'])
        assert all(value == [self.test_fo.uid]
                   for value in result_sum.values())

    def test_get_summary_of_one_error_handling(self):
        result_sum = self.db_interface._get_summary_of_one(None, 'foo')
        self.assertEqual(result_sum, {},
                         'None object should result in empty dict')
        self.create_and_add_test_fimrware_and_file_object()
        result_sum = self.db_interface._get_summary_of_one(
            self.test_fw, 'none_existing_analysis')
        self.assertEqual(result_sum, {},
                         'analysis not existend should lead to empty dict')

    def test_update_summary(self):
        orig = {'a': ['a']}
        update = {'a': ['aa'], 'b': ['aa']}
        result = self.db_interface._update_summary(orig, update)
        self.assertIn('a', result)
        self.assertIn('b', result)
        self.assertIn('a', result['a'])
        self.assertIn('aa', result['a'])
        self.assertIn('aa', result['b'])
Esempio n. 7
0
class TestMongoInterface(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._config = get_config_for_testing(TMP_DIR)
        cls._config.set('data_storage', 'report_threshold', '32')
        cls._config.set('data_storage', 'sanitize_database', 'tmp_sanitize')
        cls.mongo_server = MongoMgr(config=cls._config)

    def setUp(self):
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

        self.test_firmware = create_test_firmware()

        self.test_yara_match = {
            'rule': 'OpenSSH',
            'tags': [],
            'namespace': 'default',
            'strings': [(0, '$a', b'OpenSSH')],
            'meta': {
                'description': 'SSH library',
                'website': 'http://www.openssh.com',
                'open_source': True,
                'software_name': 'OpenSSH'
            },
            'matches': True
        }

        self.test_fo = create_test_file_object()

    def tearDown(self):
        self.db_interface_backend.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface_backend.shutdown()
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'sanitize_database'))
        self.db_interface.shutdown()
        gc.collect()

    @classmethod
    def tearDownClass(cls):
        cls.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def _get_all_firmware_uids(self):
        uid_list = []
        tmp = self.db_interface.firmwares.find()
        for item in tmp:
            uid_list.append(item['_id'])
        return uid_list

    def test_existence_quick_check(self):
        self.assertFalse(
            self.db_interface.existence_quick_check('none_existing'),
            'none existing firmware found')
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertTrue(
            self.db_interface.existence_quick_check(self.test_firmware.uid),
            'existing firmware not found')
        self.db_interface_backend.add_file_object(self.test_fo)
        self.assertTrue(
            self.db_interface.existence_quick_check(self.test_fo.uid),
            'existing file not found')

    def test_get_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        fobject = self.db_interface.get_firmware(self.test_firmware.uid)
        self.assertEqual(fobject.vendor, 'test_vendor')
        self.assertEqual(fobject.device_name, 'test_router')
        self.assertEqual(fobject.part, '')

    def test_get_object(self):
        fo = self.db_interface.get_object(self.test_firmware.uid)
        self.assertIsNone(
            fo, 'found something but there is nothing in the database')
        self.db_interface_backend.add_firmware(self.test_firmware)
        fo = self.db_interface.get_object(self.test_firmware.uid)
        self.assertIsInstance(fo, Firmware, 'firmware has wrong type')
        self.assertEqual(fo.device_name, 'test_router',
                         'Device name in Firmware not correct')
        test_file = FileObject(file_path=path.join(get_test_data_dir(),
                                                   'get_files_test/testfile2'))
        self.db_interface_backend.add_file_object(test_file)
        fo = self.db_interface.get_object(test_file.uid)
        self.assertIsInstance(fo, FileObject, 'file object has wrong type')

    def test_get_complete_object_including_all_summaries(self):
        self.db_interface_backend.report_threshold = 1024
        test_file = create_test_file_object()
        self.test_firmware.add_included_file(test_file)
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.db_interface_backend.add_file_object(test_file)
        tmp = self.db_interface.get_complete_object_including_all_summaries(
            self.test_firmware.uid)
        self.assertIsInstance(tmp, Firmware, 'wrong type')
        self.assertIn('summary', tmp.processed_analysis['dummy'].keys(),
                      'summary not found in processed analysis')
        self.assertIn('sum a', tmp.processed_analysis['dummy']['summary'],
                      'summary of original file not included')
        self.assertIn('file exclusive sum b',
                      tmp.processed_analysis['dummy']['summary'],
                      'summary of included file not found')

    def test_sanitize_analysis(self):
        short_dict = {'stub_plugin': {'result': 0}}
        long_dict = {
            'stub_plugin': {
                'result': 10000000000,
                'misc': 'Bananarama',
                'summary': []
            }
        }

        self.test_firmware.processed_analysis = short_dict
        sanitized_dict = self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        self.assertIn('file_system_flag', sanitized_dict['stub_plugin'].keys())
        self.assertFalse(sanitized_dict['stub_plugin']['file_system_flag'])
        self.assertEqual(self.db_interface.sanitize_fs.list(), [],
                         'file stored in db but should not')

        self.test_firmware.processed_analysis = long_dict
        sanitized_dict = self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        self.assertIn('stub_plugin_result_{}'.format(self.test_firmware.uid),
                      self.db_interface.sanitize_fs.list(),
                      'sanitized file not stored')
        self.assertNotIn('summary_result_{}'.format(self.test_firmware.uid),
                         self.db_interface.sanitize_fs.list(),
                         'summary is erroneously stored')
        self.assertIn('file_system_flag', sanitized_dict['stub_plugin'].keys())
        self.assertTrue(sanitized_dict['stub_plugin']['file_system_flag'])
        self.assertEqual(type(sanitized_dict['stub_plugin']['summary']), list)

    def test_sanitize_db_duplicates(self):
        long_dict = {
            'stub_plugin': {
                'result': 10000000000,
                'misc': 'Bananarama',
                'summary': []
            }
        }
        gridfs_file_name = 'stub_plugin_result_{}'.format(
            self.test_firmware.uid)

        self.test_firmware.processed_analysis = long_dict
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 0
        self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 1
        self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 1, 'duplicate entry was created'
        md5 = self.db_interface.sanitize_fs.find_one({
            'filename':
            gridfs_file_name
        }).md5

        long_dict['stub_plugin']['result'] += 1  # new analysis result
        self.db_interface.sanitize_analysis(
            self.test_firmware.processed_analysis, self.test_firmware.uid)
        assert self.db_interface.sanitize_fs.find({
            'filename': gridfs_file_name
        }).count() == 1, 'duplicate entry was created'
        assert self.db_interface.sanitize_fs.find_one({
            'filename':
            gridfs_file_name
        }).md5 != md5, 'hash of new file did not change'

    def test_retrieve_analysis(self):
        self.db_interface.sanitize_fs.put(pickle.dumps('This is a test!'),
                                          filename='test_file_path')

        sanitized_dict = {
            'stub_plugin': {
                'result': 'test_file_path',
                'file_system_flag': True
            }
        }
        sanitized_dict['inbound_result'] = {
            'result': 'inbound result',
            'file_system_flag': False
        }
        retrieved_dict = self.db_interface.retrieve_analysis(sanitized_dict)

        self.assertNotIn('file_system_flag',
                         retrieved_dict['stub_plugin'].keys())
        self.assertIn('result', retrieved_dict['stub_plugin'].keys())
        self.assertEqual(retrieved_dict['stub_plugin']['result'],
                         'This is a test!')
        self.assertNotIn('file_system_flag',
                         retrieved_dict['inbound_result'].keys())
        self.assertEqual(retrieved_dict['inbound_result']['result'],
                         'inbound result')

    def test_retrieve_analysis_filter(self):
        self.db_interface.sanitize_fs.put(pickle.dumps('This is a test!'),
                                          filename='test_file_path')
        sanitized_dict = {
            'selected_plugin': {
                'result': 'test_file_path',
                'file_system_flag': True
            }
        }
        sanitized_dict['other_plugin'] = {
            'result': 'test_file_path',
            'file_system_flag': True
        }
        retrieved_dict = self.db_interface.retrieve_analysis(
            sanitized_dict, analysis_filter=['selected_plugin'])
        self.assertEqual(retrieved_dict['selected_plugin']['result'],
                         'This is a test!')
        self.assertIn('file_system_flag', retrieved_dict['other_plugin'])

    def test_get_objects_by_uid_list(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        fo_list = self.db_interface.get_objects_by_uid_list(
            [self.test_firmware.uid])
        self.assertIsInstance(fo_list[0], Firmware, 'firmware has wrong type')
        self.assertEqual(fo_list[0].device_name, 'test_router',
                         'Device name in Firmware not correct')
        test_file = FileObject(file_path=path.join(get_test_data_dir(),
                                                   'get_files_test/testfile2'))
        self.db_interface_backend.add_file_object(test_file)
        fo_list = self.db_interface.get_objects_by_uid_list([test_file.uid])
        self.assertIsInstance(fo_list[0], FileObject,
                              'file object has wrong type')

    def test_sanitize_extract_and_retrieve_binary(self):
        test_data = {'dummy': {'test_key': 'test_value'}}
        test_data['dummy'] = self.db_interface._extract_binaries(
            test_data, 'dummy', 'uid')
        self.assertEqual(self.db_interface.sanitize_fs.list(),
                         ['dummy_test_key_uid'], 'file not written')
        self.assertEqual(test_data['dummy']['test_key'], 'dummy_test_key_uid',
                         'new file path not set')
        test_data['dummy'] = self.db_interface._retrieve_binaries(
            test_data, 'dummy')
        self.assertEqual(test_data['dummy']['test_key'], 'test_value',
                         'value not recoverd')

    def test_get_firmware_number(self):
        result = self.db_interface.get_firmware_number()
        self.assertEqual(result, 0)

        self.db_interface_backend.add_firmware(self.test_firmware)
        result = self.db_interface.get_firmware_number(query={})
        self.assertEqual(result, 1)
        result = self.db_interface.get_firmware_number(
            query={'_id': self.test_firmware.uid})
        self.assertEqual(result, 1)

        test_fw_2 = create_test_firmware(bin_path='container/test.7z')
        self.db_interface_backend.add_firmware(test_fw_2)
        result = self.db_interface.get_firmware_number(query='{}')
        self.assertEqual(result, 2)
        result = self.db_interface.get_firmware_number(
            query={'_id': self.test_firmware.uid})
        self.assertEqual(result, 1)

    def test_get_file_object_number(self):
        result = self.db_interface.get_file_object_number()
        self.assertEqual(result, 0)

        self.db_interface_backend.add_file_object(self.test_fo)
        result = self.db_interface.get_file_object_number(
            query={}, zero_on_empty_query=False)
        self.assertEqual(result, 1)
        result = self.db_interface.get_file_object_number(
            query={'_id': self.test_fo.uid})
        self.assertEqual(result, 1)
        result = self.db_interface.get_file_object_number(
            query=json.dumps({'_id': self.test_fo.uid}))
        self.assertEqual(result, 1)
        result = self.db_interface.get_file_object_number(
            query={}, zero_on_empty_query=True)
        self.assertEqual(result, 0)
        result = self.db_interface.get_file_object_number(
            query='{}', zero_on_empty_query=True)
        self.assertEqual(result, 0)

        test_fo_2 = create_test_file_object(
            bin_path='get_files_test/testfile2')
        self.db_interface_backend.add_file_object(test_fo_2)
        result = self.db_interface.get_file_object_number(
            query={}, zero_on_empty_query=False)
        self.assertEqual(result, 2)
        result = self.db_interface.get_file_object_number(
            query={'_id': self.test_fo.uid})
        self.assertEqual(result, 1)

    def test_unpacking_lock(self):
        first_uid, second_uid = 'id1', 'id2'
        assert not self.db_interface.check_unpacking_lock(
            first_uid) and not self.db_interface.check_unpacking_lock(
                second_uid), 'locks should not be set at start'

        self.db_interface.set_unpacking_lock(first_uid)
        assert self.db_interface.check_unpacking_lock(
            first_uid), 'locks should have been set'

        self.db_interface.set_unpacking_lock(second_uid)
        assert self.db_interface.check_unpacking_lock(
            first_uid) and self.db_interface.check_unpacking_lock(
                second_uid), 'both locks should be set'

        self.db_interface.release_unpacking_lock(first_uid)
        assert not self.db_interface.check_unpacking_lock(
            first_uid) and self.db_interface.check_unpacking_lock(
                second_uid), 'lock 1 should be released, lock 2 not'

        self.db_interface.drop_unpacking_locks()
        assert not self.db_interface.check_unpacking_lock(
            second_uid), 'all locks should be dropped'

    def test_lock_is_released(self):
        self.db_interface.set_unpacking_lock(self.test_fo.uid)
        assert self.db_interface.check_unpacking_lock(
            self.test_fo.uid), 'setting lock did not work'

        self.db_interface_backend.add_object(self.test_fo)
        assert not self.db_interface.check_unpacking_lock(
            self.test_fo.uid), 'add_object should release lock'

    def test_is_firmware(self):
        assert self.db_interface.is_firmware(self.test_firmware.uid) is False

        self.db_interface_backend.add_firmware(self.test_firmware)
        assert self.db_interface.is_firmware(self.test_firmware.uid) is True

    def test_is_file_object(self):
        assert self.db_interface.is_file_object(self.test_fo.uid) is False

        self.db_interface_backend.add_file_object(self.test_fo)
        assert self.db_interface.is_file_object(self.test_fo.uid) is True
Esempio n. 8
0
class TestStorageDbInterfaceFrontendEditing(unittest.TestCase):
    def setUp(self):
        self._config = get_config_for_testing(TMP_DIR)
        self.mongo_server = MongoMgr(config=self._config)
        self.db_frontend_editing = FrontendEditingDbInterface(
            config=self._config)
        self.db_frontend_interface = FrontEndDbInterface(config=self._config)
        self.db_backend_interface = BackEndDbInterface(config=self._config)

    def tearDown(self):
        self.db_frontend_editing.shutdown()
        self.db_frontend_interface.shutdown()
        self.db_backend_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_backend_interface.shutdown()
        self.mongo_server.shutdown()
        TMP_DIR.cleanup()
        gc.collect()

    def test_add_comment(self):
        test_fw = create_test_firmware()
        self.db_backend_interface.add_object(test_fw)
        comment, author, uid, time = 'this is a test comment!', 'author', test_fw.get_uid(
        ), 1234567890
        self.db_frontend_editing.add_comment_to_object(uid, comment, author,
                                                       time)
        test_fw = self.db_backend_interface.get_object(uid)
        self.assertEqual(test_fw.comments[0], {
            'time': str(time),
            'author': author,
            'comment': comment
        })

    def test_get_latest_comments(self):
        comments = [{
            'time': '1234567890',
            'author': 'author1',
            'comment': 'test comment'
        }, {
            'time': '1234567899',
            'author': 'author2',
            'comment': 'test comment2'
        }]
        test_fw = self._add_test_fw_with_comments_to_db()
        latest_comments = self.db_frontend_interface.get_latest_comments()
        comments.sort(key=lambda x: x['time'], reverse=True)
        for i in range(len(comments)):
            time, author, comment, uid = comments[i]['time'], comments[i][
                'author'], comments[i]['comment'], test_fw.get_uid()
            self.assertEqual(latest_comments[i]['time'], time)
            self.assertEqual(latest_comments[i]['author'], author)
            self.assertEqual(latest_comments[i]['comment'], comment)
            self.assertEqual(latest_comments[i]['uid'], uid)

    def test_remove_element_from_array_in_field(self):
        test_fw = self._add_test_fw_with_comments_to_db()
        retrieved_fw = self.db_backend_interface.get_object(test_fw.get_uid())
        self.assertEqual(len(retrieved_fw.comments), 2,
                         'comments were not saved correctly')

        self.db_frontend_editing.remove_element_from_array_in_field(
            test_fw.get_uid(), 'comments', {'time': '1234567899'})
        retrieved_fw = self.db_backend_interface.get_object(test_fw.get_uid())
        self.assertEqual(len(retrieved_fw.comments), 1,
                         'comment was not deleted')

    def test_delete_comment(self):
        test_fw = self._add_test_fw_with_comments_to_db()
        retrieved_fw = self.db_backend_interface.get_object(test_fw.get_uid())
        self.assertEqual(len(retrieved_fw.comments), 2,
                         'comments were not saved correctly')

        self.db_frontend_editing.delete_comment(test_fw.get_uid(),
                                                '1234567899')
        retrieved_fw = self.db_backend_interface.get_object(test_fw.get_uid())
        self.assertEqual(len(retrieved_fw.comments), 1,
                         'comment was not deleted')

    def _add_test_fw_with_comments_to_db(self):
        test_fw = create_test_firmware()
        comments = [{
            'time': '1234567890',
            'author': 'author1',
            'comment': 'test comment'
        }, {
            'time': '1234567899',
            'author': 'author2',
            'comment': 'test comment2'
        }]
        test_fw.comments.extend(comments)
        self.db_backend_interface.add_object(test_fw)
        return test_fw
Esempio n. 9
0
class TestStorageDbInterfaceBackend(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._config = get_config_for_testing(TMP_DIR)
        cls.mongo_server = MongoMgr(config=cls._config)

    def setUp(self):
        self.db_interface = MongoInterfaceCommon(config=self._config)
        self.db_interface_backend = BackEndDbInterface(config=self._config)

        self.test_firmware = create_test_firmware()

        self.test_yara_match = {
            'rule': 'OpenSSH',
            'tags': [],
            'namespace': 'default',
            'strings': [(0, '$a', b'OpenSSH')],
            'meta': {
                'description': 'SSH library',
                'website': 'http://www.openssh.com',
                'open_source': True,
                'software_name': 'OpenSSH'
            },
            'matches': True
        }

        self.test_fo = create_test_file_object()

    def tearDown(self):
        self.db_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_interface_backend.shutdown()
        self.db_interface.shutdown()
        gc.collect()

    @classmethod
    def tearDownClass(cls):
        cls.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def _get_all_firmware_uids(self):
        uid_list = []
        tmp = self.db_interface.firmwares.find()
        for item in tmp:
            uid_list.append(item['_id'])
        return uid_list

    def test_add_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertGreater(len(self._get_all_firmware_uids()), 0,
                           'No entry added to DB')
        recoverd_firmware_entry = self.db_interface_backend.firmwares.find_one(
        )
        self.assertAlmostEqual(recoverd_firmware_entry['submission_date'],
                               time(),
                               msg='submission time not set correctly',
                               delta=5.0)

    def test_add_and_get_firmware(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        result_backend = self.db_interface_backend.get_firmware(
            self.test_firmware.uid)
        self.assertIsNotNone(result_backend.binary,
                             'binary not set in backend result')
        result_common = self.db_interface.get_firmware(self.test_firmware.uid)
        self.assertIsNone(result_common.binary, 'binary set in common result')
        self.assertEqual(result_common.size, 787,
                         'file size not correct in common')
        self.assertIsInstance(result_common.tags, dict,
                              'tag field type not correct')

    def test_add_and_get_file_object(self):
        self.db_interface_backend.add_file_object(self.test_fo)
        result_backend = self.db_interface_backend.get_file_object(
            self.test_fo.uid)
        self.assertIsNotNone(result_backend.binary,
                             'binary not set in backend result')
        result_common = self.db_interface.get_file_object(self.test_fo.uid)
        self.assertIsNone(result_common.binary, 'binary set in common result')
        self.assertEqual(result_common.size, 62,
                         'file size not correct in common')

    def test_update_firmware(self):
        first_dict = {
            'stub_plugin': {
                'result': 0
            },
            'other_plugin': {
                'field': 'day'
            }
        }
        second_dict = {'stub_plugin': {'result': 1}}

        self.test_firmware.processed_analysis = first_dict
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertEqual(
            0,
            self.db_interface.get_object(
                self.test_firmware.uid).processed_analysis['stub_plugin']
            ['result'])
        self.test_firmware.processed_analysis = second_dict
        self.db_interface_backend.add_firmware(self.test_firmware)
        self.assertEqual(
            1,
            self.db_interface.get_object(
                self.test_firmware.uid).processed_analysis['stub_plugin']
            ['result'])
        self.assertIn(
            'other_plugin',
            self.db_interface.get_object(
                self.test_firmware.uid).processed_analysis.keys())

    def test_update_file_object(self):
        first_dict = {'other_plugin': {'result': 0}}
        second_dict = {'stub_plugin': {'result': 1}}

        self.test_fo.processed_analysis = first_dict
        self.test_fo.files_included = {'file a', 'file b'}
        self.db_interface_backend.add_file_object(self.test_fo)
        self.test_fo.processed_analysis = second_dict
        self.test_fo.files_included = {'file b', 'file c'}
        self.db_interface_backend.add_file_object(self.test_fo)
        received_object = self.db_interface.get_object(self.test_fo.uid)
        self.assertEqual(
            0, received_object.processed_analysis['other_plugin']['result'])
        self.assertEqual(
            1, received_object.processed_analysis['stub_plugin']['result'])
        self.assertEqual(3, len(received_object.files_included))

    def test_add_and_get_object_including_comment(self):
        comment, author, date, uid = 'this is a test comment!', 'author', '1473431685', self.test_fo.uid
        self.test_fo.comments.append({
            'time': str(date),
            'author': author,
            'comment': comment
        })
        self.db_interface_backend.add_file_object(self.test_fo)

        retrieved_comment = self.db_interface.get_object(uid).comments[0]
        self.assertEqual(author, retrieved_comment['author'])
        self.assertEqual(comment, retrieved_comment['comment'])
        self.assertEqual(date, retrieved_comment['time'])

    def test_update_analysis_tag_no_firmware(self):
        self.db_interface_backend.add_file_object(self.test_fo)
        tag = {'value': 'yay', 'color': 'default', 'propagate': True}

        self.db_interface_backend.update_analysis_tags(self.test_fo.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag=tag)
        processed_fo = self.db_interface_backend.get_object(self.test_fo.uid)

        assert not processed_fo.analysis_tags

    def test_update_analysis_tag_uid_not_found(self):
        self.db_interface_backend.update_analysis_tags(self.test_fo.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag='should not matter')
        assert not self.db_interface_backend.get_object(self.test_fo.uid)

    def test_update_analysis_tag_bad_tag(self):
        self.db_interface_backend.add_firmware(self.test_firmware)

        self.db_interface_backend.update_analysis_tags(self.test_firmware.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag='bad_tag')
        processed_firmware = self.db_interface_backend.get_object(
            self.test_firmware.uid)

        assert not processed_firmware.analysis_tags

    def test_update_analysis_tag_success(self):
        self.db_interface_backend.add_firmware(self.test_firmware)
        tag = {'value': 'yay', 'color': 'default', 'propagate': True}

        self.db_interface_backend.update_analysis_tags(self.test_firmware.uid,
                                                       plugin_name='dummy',
                                                       tag_name='some_tag',
                                                       tag=tag)
        processed_firmware = self.db_interface_backend.get_object(
            self.test_firmware.uid)

        assert processed_firmware.analysis_tags
        assert processed_firmware.analysis_tags['dummy']['some_tag'] == tag

    def test_add_analysis_firmware(self):
        self.db_interface_backend.add_object(self.test_firmware)
        before = self.db_interface_backend.get_object(
            self.test_firmware.uid).processed_analysis

        self.test_firmware.processed_analysis['foo'] = {'bar': 5}
        self.db_interface_backend.add_analysis(self.test_firmware)
        after = self.db_interface_backend.get_object(
            self.test_firmware.uid).processed_analysis

        assert before != after
        assert 'foo' not in before
        assert 'foo' in after
        assert after['foo'] == {'bar': 5}

    def test_add_analysis_file_object(self):
        self.db_interface_backend.add_object(self.test_fo)

        self.test_fo.processed_analysis['foo'] = {'bar': 5}
        self.db_interface_backend.add_analysis(self.test_fo)
        analysis = self.db_interface_backend.get_object(
            self.test_fo.uid).processed_analysis

        assert 'foo' in analysis
        assert analysis['foo'] == {'bar': 5}

    def test_crash_add_analysis(self):
        with self.assertRaises(RuntimeError):
            self.db_interface_backend.add_analysis(dict())

        with self.assertRaises(AttributeError):
            self.db_interface_backend._update_analysis(dict(), 'dummy', dict())
Esempio n. 10
0
class TestAcceptanceAnalyzeFirmware(TestAcceptanceBase):
    def setUp(self):
        super().setUp()
        self.analysis_finished_event = Event()
        self.elements_finished_analyzing = Value('i', 0)
        self.db_backend_service = BackEndDbInterface(config=self.config)
        self._start_backend(post_analysis=self._analysis_callback)
        time.sleep(2)  # wait for systems to start

    def _analysis_callback(self, fo):
        self.db_backend_service.add_object(fo)
        self.elements_finished_analyzing.value += 1
        if self.elements_finished_analyzing.value > 3:
            self.analysis_finished_event.set()

    def tearDown(self):
        self._stop_backend()
        self.db_backend_service.shutdown()
        super().tearDown()

    def _upload_firmware_get(self):
        rv = self.test_client.get('/upload')
        self.assertIn(b'<h2>Upload Firmware</h2>', rv.data,
                      'upload page not displayed correctly')

        with ConnectTo(InterComFrontEndBinding, self.config) as connection:
            plugins = connection.get_available_analysis_plugins()

        mandatory_plugins = [p for p in plugins if plugins[p][1]]
        default_plugins = [p for p in plugins if plugins[p][2]]
        optional_plugins = [
            p for p in plugins if not (plugins[p][1] or plugins[p][2])
        ]
        for mandatory_plugin in mandatory_plugins:
            self.assertNotIn(
                mandatory_plugin.encode(), rv.data,
                'mandatory plugin {} found erroneously'.format(
                    mandatory_plugin))
        for default_plugin in default_plugins:
            self.assertIn(
                'value="{}" checked'.format(default_plugin).encode(), rv.data,
                'default plugin {} erroneously unchecked or not found'.format(
                    default_plugin))
        for optional_plugin in optional_plugins:
            self.assertIn(
                'value="{}" unchecked'.format(optional_plugin).encode(),
                rv.data,
                'optional plugin {} erroneously checked or not found'.format(
                    optional_plugin))

    def _upload_firmware_post(self):
        testfile_path = os.path.join(get_test_data_dir(), self.test_fw_a.path)
        with open(testfile_path, 'rb') as fp:
            data = {
                'file': (fp, self.test_fw_a.file_name),
                'device_name': 'test_device',
                'device_class': 'test_class',
                'firmware_version': '1.0',
                'vendor': 'test_vendor',
                'release_date': '1970-01-01',
                'tags': '',
                'analysis_systems': []
            }
            rv = self.test_client.post('/upload',
                                       content_type='multipart/form-data',
                                       data=data,
                                       follow_redirects=True)
        self.assertIn(b'Upload Successful', rv.data, 'upload not successful')
        self.assertIn(self.test_fw_a.uid.encode(), rv.data,
                      'uid not found on upload success page')

    def _show_analysis_page(self):
        with ConnectTo(FrontEndDbInterface, self.config) as connection:
            self.assertIsNotNone(
                connection.firmwares.find_one({'_id': self.test_fw_a.uid}),
                'Error: Test firmware not found in DB!')
        rv = self.test_client.get('/analysis/{}'.format(self.test_fw_a.uid))
        self.assertIn(self.test_fw_a.uid.encode(), rv.data)
        self.assertIn(b'test_device', rv.data)
        self.assertIn(b'test_class', rv.data)
        self.assertIn(b'test_vendor', rv.data)
        self.assertIn(b'unknown', rv.data)
        self.assertIn(self.test_fw_a.file_name.encode(), rv.data,
                      'file name not found')
        self.assertIn(b'admin options:', rv.data,
                      'admin options not shown with disabled auth')

    def _check_ajax_file_tree_routes(self):
        rv = self.test_client.get('/ajax_tree/{}/{}'.format(
            self.test_fw_a.uid, self.test_fw_a.uid))
        self.assertIn(b'"children":', rv.data)
        rv = self.test_client.get('/ajax_root/{}'.format(self.test_fw_a.uid))
        self.assertIn(b'"children":', rv.data)

    def _check_ajax_on_demand_binary_load(self):
        rv = self.test_client.get(
            '/ajax_get_binary/text_plain/d558c9339cb967341d701e3184f863d3928973fccdc1d96042583730b5c7b76a_62'
        )
        self.assertIn(b'test file', rv.data)

    def _show_analysis_details_file_type(self):
        rv = self.test_client.get('/analysis/{}/file_type'.format(
            self.test_fw_a.uid))
        self.assertIn(b'application/zip', rv.data)
        self.assertIn(b'Zip archive data', rv.data)
        self.assertNotIn(
            b'<pre><code>', rv.data,
            'generic template used instead of specific template -> sync view error!'
        )

    def _show_home_page(self):
        rv = self.test_client.get('/')
        self.assertIn(
            self.test_fw_a.uid.encode(), rv.data,
            'test firmware not found under recent analysis on home page')

    def _re_do_analysis_get(self):
        rv = self.test_client.get('/admin/re-do_analysis/{}'.format(
            self.test_fw_a.uid))
        self.assertIn(
            b'<input type="hidden" name="file_name" id="file_name" value="' +
            self.test_fw_a.file_name.encode() + b'">', rv.data,
            'file name not set in re-do page')

    def test_run_from_upload_to_show_analysis(self):
        self._upload_firmware_get()
        self._upload_firmware_post()
        self.analysis_finished_event.wait(timeout=15)
        self._show_analysis_page()
        self._show_analysis_details_file_type()
        self._check_ajax_file_tree_routes()
        self._check_ajax_on_demand_binary_load()
        self._show_home_page()
        self._re_do_analysis_get()
class TestFileAddition(unittest.TestCase):

    @patch('unpacker.unpack.FS_Organizer', MockFSOrganizer)
    def setUp(self):
        self._tmp_dir = TemporaryDirectory()
        self._config = initialize_config(self._tmp_dir)
        self.elements_finished_analyzing = Value('i', 0)
        self.analysis_finished_event = Event()
        self.compare_finished_event = Event()

        self._mongo_server = MongoMgr(config=self._config, auth=False)
        self.backend_interface = BackEndDbInterface(config=self._config)

        self._analysis_scheduler = AnalysisScheduler(config=self._config, post_analysis=self.count_analysis_finished_event)
        self._unpack_scheduler = UnpackingScheduler(config=self._config, post_unpack=self._analysis_scheduler.add_task)
        self._compare_scheduler = CompareScheduler(config=self._config, callback=self.trigger_compare_finished_event)

    def count_analysis_finished_event(self, fw_object):
        self.backend_interface.add_object(fw_object)
        self.elements_finished_analyzing.value += 1
        if self.elements_finished_analyzing.value > 7:
            self.analysis_finished_event.set()

    def trigger_compare_finished_event(self):
        self.compare_finished_event.set()

    def tearDown(self):
        self._compare_scheduler.shutdown()
        self._unpack_scheduler.shutdown()
        self._analysis_scheduler.shutdown()

        clean_test_database(self._config, get_database_names(self._config))
        self._mongo_server.shutdown()

        self._tmp_dir.cleanup()
        gc.collect()

    def test_unpack_analyse_and_compare(self):
        test_fw_1 = Firmware(file_path='{}/container/test.zip'.format(get_test_data_dir()))
        test_fw_1.release_date = '2017-01-01'
        test_fw_2 = Firmware(file_path='{}/container/test.7z'.format(get_test_data_dir()))
        test_fw_2.release_date = '2017-01-01'

        self._unpack_scheduler.add_task(test_fw_1)
        self._unpack_scheduler.add_task(test_fw_2)

        self.analysis_finished_event.wait(timeout=10)

        compare_id = unify_string_list(';'.join([fw.uid for fw in [test_fw_1, test_fw_2]]))

        self.assertIsNone(self._compare_scheduler.add_task((compare_id, False)), 'adding compare task creates error')

        self.compare_finished_event.wait(timeout=10)

        with ConnectTo(CompareDbInterface, self._config) as sc:
            result = sc.get_compare_result(compare_id)

        self.assertFalse(isinstance(result, str), 'compare result should exist')
        self.assertEqual(result['plugins']['Software'], self._expected_result()['Software'])
        self.assertCountEqual(result['plugins']['File_Coverage']['exclusive_files'], self._expected_result()['File_Coverage']['exclusive_files'])

    @staticmethod
    def _expected_result():
        return {
            'File_Coverage': {
                'exclusive_files': {
                    '418a54d78550e8584291c96e5d6168133621f352bfc1d43cf84e81187fef4962_787': [],
                    'd38970f8c5153d1041810d0908292bc8df21e7fd88aab211a8fb96c54afe6b01_319': [],
                    'collapse': False
                },
                'files_in_common': {
                    'all': [
                        'faa11db49f32a90b51dfc3f0254f9fd7a7b46d0b570abd47e1943b86d554447a_28',
                        '289b5a050a83837f192d7129e4c4e02570b94b4924e50159fad5ed1067cfbfeb_20',
                        'd558c9339cb967341d701e3184f863d3928973fccdc1d96042583730b5c7b76a_62'
                    ],
                    'collapse': False
                },
                'similar_files': {}
            },
            'Software': {
                'Compare Skipped': {
                    'all': 'Required analysis not present: [\'software_components\', \'software_components\']'
                }
            }
        }
Esempio n. 12
0
class TestStorageDbInterfaceFrontend(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._config = get_config_for_testing(TMP_DIR)
        cls.mongo_server = MongoMgr(config=cls._config)

    def setUp(self):
        self.db_frontend_interface = FrontEndDbInterface(config=self._config)
        self.db_backend_interface = BackEndDbInterface(config=self._config)
        self.test_firmware = create_test_firmware()

    def tearDown(self):
        self.db_frontend_interface.shutdown()
        self.db_backend_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_backend_interface.shutdown()
        gc.collect()

    @classmethod
    def tearDownClass(cls):
        cls.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def test_regression_meta_list(self):
        assert self.test_firmware.processed_analysis.pop('unpacker')
        self.db_backend_interface.add_firmware(self.test_firmware)
        list_of_firmwares = self.db_frontend_interface.get_meta_list()
        assert 'NOP' in list_of_firmwares.pop()[2]

    def test_get_meta_list(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        list_of_firmwares = self.db_frontend_interface.get_meta_list()
        test_output = list_of_firmwares.pop()
        self.assertEqual(test_output[1],
                         'test_vendor test_router - 0.1 (Router)',
                         'Firmware not successfully received')
        self.assertIsInstance(test_output[2], dict, 'tag field is not a dict')

    def test_get_meta_list_of_fo(self):
        test_fo = create_test_file_object()
        self.db_backend_interface.add_file_object(test_fo)
        files = self.db_frontend_interface.file_objects.find()
        meta_list = self.db_frontend_interface.get_meta_list(files)
        self.assertEqual(meta_list[0][0], test_fo.uid,
                         'uid of object not correct')
        self.assertEqual(meta_list[0][3], 0,
                         'non existing submission date should lead to 0')

    def test_get_hid_firmware(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        result = self.db_frontend_interface.get_hid(self.test_firmware.uid)
        self.assertEqual(result, 'test_vendor test_router - 0.1 (Router)',
                         'fw hid not correct')

    def test_get_hid_fo(self):
        test_fo = create_test_file_object(bin_path='get_files_test/testfile2')
        test_fo.virtual_file_path = {
            'a': ['|a|/test_file'],
            'b': ['|b|/get_files_test/testfile2']
        }
        self.db_backend_interface.add_file_object(test_fo)
        result = self.db_frontend_interface.get_hid(test_fo.uid, root_uid='b')
        self.assertEqual(result, '/get_files_test/testfile2',
                         'fo hid not correct')
        result = self.db_frontend_interface.get_hid(test_fo.uid)
        self.assertIsInstance(result, str, 'result is not a string')
        self.assertEqual(result[0], '/',
                         'first character not correct if no root_uid set')
        result = self.db_frontend_interface.get_hid(test_fo.uid, root_uid='c')
        self.assertEqual(
            result[0], '/',
            'first character not correct if invalid root_uid set')

    def test_get_file_name(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        result = self.db_frontend_interface.get_file_name(
            self.test_firmware.uid)
        self.assertEqual(result, 'test.zip', 'name not correct')

    def test_get_hid_invalid_uid(self):
        result = self.db_frontend_interface.get_hid('foo')
        self.assertEqual(result, '',
                         'invalid uid should result in empty string')

    def test_get_firmware_attribute_list(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        self.assertEqual(self.db_frontend_interface.get_device_class_list(),
                         ['Router'])
        self.assertEqual(self.db_frontend_interface.get_vendor_list(),
                         ['test_vendor'])
        self.assertEqual(
            self.db_frontend_interface.get_firmware_attribute_list(
                'device_name', {
                    'vendor': 'test_vendor',
                    'device_class': 'Router'
                }), ['test_router'])
        self.assertEqual(
            self.db_frontend_interface.get_firmware_attribute_list('version'),
            ['0.1'])
        self.assertEqual(self.db_frontend_interface.get_device_name_dict(),
                         {'Router': {
                             'test_vendor': ['test_router']
                         }})

    def test_get_data_for_nice_list(self):
        uid_list = [self.test_firmware.uid]
        self.db_backend_interface.add_firmware(self.test_firmware)
        nice_list_data = self.db_frontend_interface.get_data_for_nice_list(
            uid_list, uid_list[0])
        self.assertEqual(
            sorted([
                'size', 'virtual_file_paths', 'uid', 'mime-type',
                'files_included'
            ]), sorted(nice_list_data[0].keys()))
        self.assertEqual(nice_list_data[0]['uid'], self.test_firmware.uid)

    def test_generic_search(self):
        self.db_backend_interface.add_firmware(self.test_firmware)
        # str input
        result = self.db_frontend_interface.generic_search(
            '{"file_name": "test.zip"}')
        self.assertEqual(result, [self.test_firmware.uid],
                         'Firmware not successfully received')
        # dict input
        result = self.db_frontend_interface.generic_search(
            {'file_name': 'test.zip'})
        self.assertEqual(result, [self.test_firmware.uid],
                         'Firmware not successfully received')

    def test_all_uids_found_in_database(self):
        self.db_backend_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        uid_list = [self.test_firmware.uid]
        self.assertFalse(
            self.db_frontend_interface.all_uids_found_in_database(uid_list))
        self.db_backend_interface.add_firmware(self.test_firmware)
        self.assertTrue(
            self.db_frontend_interface.all_uids_found_in_database(
                [self.test_firmware.uid]))

    def test_get_x_last_added_firmwares(self):
        self.assertEqual(self.db_frontend_interface.get_last_added_firmwares(),
                         [], 'empty db should result in empty list')
        test_fw_one = create_test_firmware(device_name='fw_one')
        self.db_backend_interface.add_firmware(test_fw_one)
        test_fw_two = create_test_firmware(device_name='fw_two',
                                           bin_path='container/test.7z')
        self.db_backend_interface.add_firmware(test_fw_two)
        test_fw_three = create_test_firmware(device_name='fw_three',
                                             bin_path='container/test.cab')
        self.db_backend_interface.add_firmware(test_fw_three)
        result = self.db_frontend_interface.get_last_added_firmwares(limit_x=2)
        self.assertEqual(len(result), 2, 'Number of results should be 2')
        self.assertEqual(result[0][0], test_fw_three.uid,
                         'last firmware is not first entry')
        self.assertEqual(result[1][0], test_fw_two.uid,
                         'second last firmware is not the second entry')

    def test_generate_file_tree_level(self):
        parent_fw = create_test_firmware()
        child_fo = create_test_file_object()
        child_fo.processed_analysis['file_type'] = {'mime': 'sometype'}
        uid = parent_fw.uid
        child_fo.virtual_file_path = {
            uid: ['|{}|/folder/{}'.format(uid, child_fo.file_name)]
        }
        parent_fw.files_included = {child_fo.uid}
        self.db_backend_interface.add_object(parent_fw)
        self.db_backend_interface.add_object(child_fo)
        for node in self.db_frontend_interface.generate_file_tree_level(
                uid, uid):
            self.assertIsInstance(node, FileTreeNode)
            self.assertEqual(node.name, parent_fw.file_name)
            self.assertTrue(node.has_children)
        for node in self.db_frontend_interface.generate_file_tree_level(
                child_fo.uid, uid):
            self.assertIsInstance(node, FileTreeNode)
            self.assertEqual(node.name, 'folder')
            self.assertTrue(node.has_children)
            virtual_grand_child = node.get_list_of_child_nodes()[0]
            self.assertEqual(virtual_grand_child.type, 'sometype')
            self.assertFalse(virtual_grand_child.has_children)
            self.assertEqual(virtual_grand_child.name, child_fo.file_name)

    def test_get_number_of_total_matches(self):
        parent_fw = create_test_firmware()
        child_fo = create_test_file_object()
        uid = parent_fw.uid
        child_fo.virtual_file_path = {
            uid: ['|{}|/folder/{}'.format(uid, child_fo.file_name)]
        }
        self.db_backend_interface.add_object(parent_fw)
        self.db_backend_interface.add_object(child_fo)
        query = '{{"$or": [{{"_id": "{}"}}, {{"_id": "{}"}}]}}'.format(
            uid, child_fo.uid)
        self.assertEqual(
            self.db_frontend_interface.get_number_of_total_matches(
                query, only_parent_firmwares=False), 2)
        self.assertEqual(
            self.db_frontend_interface.get_number_of_total_matches(
                query, only_parent_firmwares=True), 1)

    def test_get_other_versions_of_firmware(self):
        parent_fw1 = create_test_firmware(version='1')
        self.db_backend_interface.add_object(parent_fw1)
        parent_fw2 = create_test_firmware(version='2',
                                          bin_path='container/test.7z')
        self.db_backend_interface.add_object(parent_fw2)
        parent_fw3 = create_test_firmware(version='3',
                                          bin_path='container/test.cab')
        self.db_backend_interface.add_object(parent_fw3)

        other_versions = self.db_frontend_interface.get_other_versions_of_firmware(
            parent_fw1)
        self.assertEqual(len(other_versions), 2,
                         'wrong number of other versions')
        self.assertIn({'_id': parent_fw2.uid, 'version': '2'}, other_versions)
        self.assertIn({'_id': parent_fw3.uid, 'version': '3'}, other_versions)

        other_versions = self.db_frontend_interface.get_other_versions_of_firmware(
            parent_fw2)
        self.assertIn({'_id': parent_fw3.uid, 'version': '3'}, other_versions)

    def test_get_specific_fields_for_multiple_entries(self):
        test_fw_1 = create_test_firmware(device_name='fw_one',
                                         vendor='test_vendor_one')
        self.db_backend_interface.add_firmware(test_fw_1)
        test_fw_2 = create_test_firmware(device_name='fw_two',
                                         vendor='test_vendor_two',
                                         bin_path='container/test.7z')
        self.db_backend_interface.add_firmware(test_fw_2)
        test_fo = create_test_file_object()
        self.db_backend_interface.add_file_object(test_fo)

        test_uid_list = [test_fw_1.uid, test_fw_2.uid]
        result = list(
            self.db_frontend_interface.
            get_specific_fields_for_multiple_entries(uid_list=test_uid_list,
                                                     field_dict={
                                                         'vendor': 1,
                                                         'device_name': 1
                                                     }))
        assert len(result) == 2
        assert all(
            set(entry.keys()) == {'_id', 'vendor', 'device_name'}
            for entry in result)
        result_uids = [entry['_id'] for entry in result]
        assert all(uid in result_uids for uid in test_uid_list)

        test_uid_list = [test_fw_1.uid, test_fo.uid]
        result = list(
            self.db_frontend_interface.
            get_specific_fields_for_multiple_entries(
                uid_list=test_uid_list, field_dict={'virtual_file_path': 1}))
        assert len(result) == 2
        assert all(
            set(entry.keys()) == {'_id', 'virtual_file_path'}
            for entry in result)
        result_uids = [entry['_id'] for entry in result]
        assert all(uid in result_uids for uid in test_uid_list)

    def test_find_missing_files(self):
        test_fw_1 = create_test_firmware()
        test_fw_1.files_included.add('uid1234')
        self.db_backend_interface.add_firmware(test_fw_1)
        missing_files = self.db_frontend_interface.find_missing_files()
        assert test_fw_1.uid in missing_files
        assert missing_files[test_fw_1.uid] == {'uid1234'}

        test_fo = create_test_file_object()
        test_fo.uid = 'uid1234'
        self.db_backend_interface.add_file_object(test_fo)
        missing_files = self.db_frontend_interface.find_missing_files()
        assert missing_files == {}

    def test_find_missing_analyses(self):
        test_fw_1 = create_test_firmware()
        test_fo = create_test_file_object()
        test_fw_1.files_included.add(test_fo.uid)
        test_fo.virtual_file_path = {test_fw_1.uid: ['|foo|bar|']}
        self.db_backend_interface.add_firmware(test_fw_1)
        self.db_backend_interface.add_file_object(test_fo)

        missing_analyses = self.db_frontend_interface.find_missing_analyses()
        assert missing_analyses == {}

        test_fw_1.processed_analysis['foobar'] = {'foo': 'bar'}
        self.db_backend_interface.add_analysis(test_fw_1)
        missing_analyses = self.db_frontend_interface.find_missing_analyses()
        assert test_fw_1.uid in missing_analyses
        assert missing_analyses[test_fw_1.uid] == {test_fo.uid}
Esempio n. 13
0
class TestRestFirmware(TestAcceptanceBase):
    def setUp(self):
        super().setUp()
        self.analysis_finished_event = Event()
        self.elements_finished_analyzing = Value('i', 0)
        self.db_backend_service = BackEndDbInterface(config=self.config)
        self._start_backend(post_analysis=self._analysis_callback)
        self.test_container_uid = '418a54d78550e8584291c96e5d6168133621f352bfc1d43cf84e81187fef4962_787'
        time.sleep(2)  # wait for systems to start

    def tearDown(self):
        self._stop_backend()
        self.db_backend_service.shutdown()
        super().tearDown()

    def _analysis_callback(self, fo):
        self.db_backend_service.add_object(fo)
        self.elements_finished_analyzing.value += 1
        if self.elements_finished_analyzing.value > 3:
            self.analysis_finished_event.set()

    def _rest_upload_firmware(self):
        testfile_path = os.path.join(get_test_data_dir(), 'container/test.zip')
        with open(testfile_path, 'rb') as fp:
            file_content = fp.read()
        data = {
            'binary': standard_b64encode(file_content).decode(),
            'file_name': 'test.zip',
            'device_name': 'test_device',
            'device_class': 'test_class',
            'firmware_version': '1.0',
            'vendor': 'test_vendor',
            'release_date': '01.01.1970',
            'tags': '',
            'requested_analysis_systems': ['software_components']
        }
        rv = self.test_client.put('/rest/firmware',
                                  data=json.dumps(data),
                                  follow_redirects=True)
        self.assertIn(b'"status": 0', rv.data, 'rest upload not successful')
        self.assertIn(self.test_container_uid.encode(), rv.data,
                      'uid not found in rest upload reply')

    def _rest_get_analysis_result(self):
        rv = self.test_client.get('/rest/firmware/{}'.format(
            self.test_container_uid),
                                  follow_redirects=True)
        self.assertIn(b'analysis_date', rv.data,
                      'rest analysis download not successful')
        self.assertIn(b'software_components', rv.data,
                      'rest analysis not successful')

    def _rest_search(self):
        rv = self.test_client.get('/rest/firmware?query={}'.format(
            urllib.parse.quote('{"device_class": "test_class"}')),
                                  follow_redirects=True)
        self.assertIn(self.test_container_uid.encode(), rv.data,
                      'test firmware not found in rest search')

    def _rest_search_fw_only(self):
        query = json.dumps({'sha256': self.test_container_uid.split('_')[0]})
        rv = self.test_client.get('/rest/firmware?query={}'.format(
            urllib.parse.quote(query)),
                                  follow_redirects=True)
        self.assertIn(self.test_container_uid.encode(), rv.data,
                      'test firmware not found in rest search')

    def _rest_update_analysis_bad_analysis(self):
        rv = self.test_client.put('/rest/firmware/{}?update={}'.format(
            self.test_container_uid, urllib.parse.quote('["unknown_system"]')),
                                  follow_redirects=True)
        self.assertIn(
            'Unknown analysis system'.encode(), rv.data,
            "rest analysis update should break on request of non existing system"
        )

    def _rest_update_analysis_success(self):
        rv = self.test_client.put('/rest/firmware/{}?update={}'.format(
            self.test_container_uid,
            urllib.parse.quote(json.dumps(['crypto_material']))),
                                  follow_redirects=True)
        self.assertNotIn(b'error_message', rv.data, 'Error on update request')

    def _rest_check_new_analysis_exists(self):
        rv = self.test_client.get('/rest/firmware/{}'.format(
            self.test_container_uid),
                                  follow_redirects=True)
        response_data = json.loads(rv.data.decode())
        assert response_data['firmware']['analysis']['crypto_material']
        assert response_data['firmware']['analysis']['crypto_material'][
            'analysis_date'] > response_data['firmware']['analysis'][
                'software_components']['analysis_date']

    def test_run_from_upload_to_show_analysis_and_search(self):
        self._rest_upload_firmware()
        self.analysis_finished_event.wait(timeout=15)
        self.elements_finished_analyzing.value = 0
        self.analysis_finished_event.clear()
        self._rest_get_analysis_result()
        self._rest_search()
        self._rest_search_fw_only()
        self._rest_update_analysis_bad_analysis()
        self._rest_update_analysis_success()

        self.analysis_finished_event.wait(timeout=10)

        self._rest_check_new_analysis_exists()
class TestStorageDbInterfaceFrontendEditing(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._config = get_config_for_testing(TMP_DIR)
        cls.mongo_server = MongoMgr(config=cls._config)

    def setUp(self):
        self.db_frontend_editing = FrontendEditingDbInterface(
            config=self._config)
        self.db_frontend_interface = FrontEndDbInterface(config=self._config)
        self.db_backend_interface = BackEndDbInterface(config=self._config)

    def tearDown(self):
        self.db_frontend_editing.shutdown()
        self.db_frontend_interface.shutdown()
        self.db_backend_interface.client.drop_database(
            self._config.get('data_storage', 'main_database'))
        self.db_backend_interface.shutdown()
        gc.collect()

    @classmethod
    def tearDownClass(cls):
        cls.mongo_server.shutdown()
        TMP_DIR.cleanup()

    def test_add_comment(self):
        test_fw = create_test_firmware()
        self.db_backend_interface.add_object(test_fw)
        comment, author, uid, time = 'this is a test comment!', 'author', test_fw.uid, 1234567890
        self.db_frontend_editing.add_comment_to_object(uid, comment, author,
                                                       time)
        test_fw = self.db_backend_interface.get_object(uid)
        self.assertEqual(test_fw.comments[0], {
            'time': str(time),
            'author': author,
            'comment': comment
        })

    def test_get_latest_comments(self):
        comments = [{
            'time': '1234567890',
            'author': 'author1',
            'comment': 'test comment'
        }, {
            'time': '1234567899',
            'author': 'author2',
            'comment': 'test comment2'
        }]
        test_fw = self._add_test_fw_with_comments_to_db()
        latest_comments = self.db_frontend_interface.get_latest_comments()
        comments.sort(key=lambda x: x['time'], reverse=True)
        for i, comment in enumerate(comments):
            assert latest_comments[i]['time'] == comment['time']
            assert latest_comments[i]['author'] == comment['author']
            assert latest_comments[i]['comment'] == comment['comment']
            assert latest_comments[i]['uid'] == test_fw.uid

    def test_remove_element_from_array_in_field(self):
        test_fw = self._add_test_fw_with_comments_to_db()
        retrieved_fw = self.db_backend_interface.get_object(test_fw.uid)
        self.assertEqual(len(retrieved_fw.comments), 2,
                         'comments were not saved correctly')

        self.db_frontend_editing.remove_element_from_array_in_field(
            test_fw.uid, 'comments', {'time': '1234567899'})
        retrieved_fw = self.db_backend_interface.get_object(test_fw.uid)
        self.assertEqual(len(retrieved_fw.comments), 1,
                         'comment was not deleted')

    def test_delete_comment(self):
        test_fw = self._add_test_fw_with_comments_to_db()
        retrieved_fw = self.db_backend_interface.get_object(test_fw.uid)
        self.assertEqual(len(retrieved_fw.comments), 2,
                         'comments were not saved correctly')

        self.db_frontend_editing.delete_comment(test_fw.uid, '1234567899')
        retrieved_fw = self.db_backend_interface.get_object(test_fw.uid)
        self.assertEqual(len(retrieved_fw.comments), 1,
                         'comment was not deleted')

    def _add_test_fw_with_comments_to_db(self):
        test_fw = create_test_firmware()
        comments = [{
            'time': '1234567890',
            'author': 'author1',
            'comment': 'test comment'
        }, {
            'time': '1234567899',
            'author': 'author2',
            'comment': 'test comment2'
        }]
        test_fw.comments.extend(comments)
        self.db_backend_interface.add_object(test_fw)
        return test_fw

    def test_update_object_field(self):
        test_fw = create_test_firmware(vendor='foo')
        self.db_backend_interface.add_object(test_fw)

        result = self.db_frontend_editing.get_object(test_fw.uid)
        assert result.vendor == 'foo'

        self.db_frontend_editing.update_object_field(test_fw.uid, 'vendor',
                                                     'bar')
        result = self.db_frontend_editing.get_object(test_fw.uid)
        assert result.vendor == 'bar'

    def test_add_to_search_query_cache(self):
        query = '{"device_class": "Router"}'
        uid = create_uid(query)
        assert self.db_frontend_editing.add_to_search_query_cache(query) == uid
        assert self.db_frontend_editing.search_query_cache.find_one(
            {'_id': uid})['search_query'] == query
        # check what happens if search is added again
        assert self.db_frontend_editing.add_to_search_query_cache(query) == uid
        assert self.db_frontend_editing.search_query_cache.count_documents(
            {'_id': uid}) == 1