예제 #1
0
class ExtractionRequestModelTest(unittest.TestCase):
    def setUp(self):
        self.extraction_request_1 = ExtractionRequest(
            "some_file_1.mp3", "some_vendor:some_plugin:some_output", {}, {})
        self.extraction_request_2 = ExtractionRequest(
            "some_file_2.mp3", "some_vendor:some_plugin:some_output",
            {"block_size": 1024}, {})
        self.extraction_request_3 = ExtractionRequest(
            "some_file_1.mp3", "some_vendor:some_plugin2:some_output",
            {"step_size": 2048}, {})
        self.extraction_request_4 = ExtractionRequest(
            "some_file_1.mp3", "some_vendor:some_plugin:some_output2", {
                "step_size": 2048,
                "block_size": 4096
            }, {
                "metric_1": {
                    "transformation": {
                        "name": "select_row",
                        "args": [5],
                        "kwargs": {}
                    }
                }
            })

    def test_should_serialize_and_deserialize_model(self):
        serializable_form = self.extraction_request_1.to_serializable()
        actual_object = ExtractionRequest.from_serializable(serializable_form)
        assert_that(actual_object).is_not_none().is_equal_to(
            self.extraction_request_1)

    def test_should_serialize_to_json_and_back(self):
        json_form = json.dumps(self.extraction_request_2.to_serializable())
        assert_that(json_form).is_not_none().is_type_of(str)

        actual_object = ExtractionRequest.from_serializable(
            json.loads(json_form))
        assert_that(actual_object).is_equal_to(self.extraction_request_2)

    def test_same_request_should_generate_same_uuid(self):
        task_id_1 = self.extraction_request_1.task_id
        time.sleep(
            0.001)  # make sure time does not take part in uuid computation
        task_id_2 = self.extraction_request_1.task_id
        assert_that(task_id_1).is_not_none().is_not_empty().is_equal_to(
            task_id_2)

    def test_different_requests_should_generate_different_uuids(self):
        task_id_1, task_id_2 = self.extraction_request_1.task_id, self.extraction_request_2.task_id
        task_id_3, task_id_4 = self.extraction_request_3.task_id, self.extraction_request_4.task_id
        assert_that(task_id_1).is_not_none().is_not_equal_to(
            task_id_2).is_not_equal_to(task_id_3).is_not_equal_to(task_id_4)
        assert_that(task_id_2).is_not_equal_to(task_id_3).is_not_equal_to(
            task_id_4)
        assert_that(task_id_3).is_not_equal_to(task_id_4)
예제 #2
0
    def test_should_serialize_to_json_and_back(self):
        json_form = json.dumps(self.extraction_request_2.to_serializable())
        assert_that(json_form).is_not_none().is_type_of(str)

        actual_object = ExtractionRequest.from_serializable(
            json.loads(json_form))
        assert_that(actual_object).is_equal_to(self.extraction_request_2)
예제 #3
0
 def _parse_request(self, the_request: ApiRequest) -> ExtractionRequest:
     try:
         request_json = the_request.payload
         if request_json["plugin_config"] is None:
             request_json[
                 "plugin_config"] = self.plugin_config_provider.get_for_plugin(
                     request_json["plugin_full_key"])
         if request_json["metric_config"] is None:
             request_json[
                 "metric_config"] = self.metric_config_provider.get_for_plugin(
                     request_json["plugin_full_key"])
         execution_request = ExtractionRequest.from_serializable(
             request_json)
         return execution_request
     except Exception as e:
         raise ClientError("Could not parse request body: {}".format(e))
예제 #4
0
 def _generate_extraction_requests(
         self, audio_file_names: List[str], plugins: List[VampyPlugin],
         plugin_configs: Dict[str, Dict[str,
                                        Any]]) -> List[ExtractionRequest]:
     extraction_requests = []
     for audio_file_name in audio_file_names:
         for plugin in plugins:
             plugin_metric_config = self.metric_config_provider.get_for_plugin(
                 plugin_full_key=plugin.full_key)
             extraction_requests.append(
                 ExtractionRequest(audio_file_name=audio_file_name,
                                   plugin_full_key=plugin.full_key,
                                   plugin_config=plugin_configs.get(
                                       plugin.full_key, None),
                                   metric_config=plugin_metric_config
                                   or None))
     return extraction_requests
예제 #5
0
def extract_feature(extraction_request: Dict[str, Any]) -> Dict[str, Any]:
    logger = get_logger()
    request = ExtractionRequest.from_serializable(extraction_request)

    plugin_provider = VampyPluginProvider(plugin_black_list=[], logger=logger)
    mp3_file_store = Mp3FileStore(AUDIO_FILES_DIR)

    db_session_provider = SessionProvider()
    audio_tag_repo = AudioTagRepository(db_session_provider)
    audio_meta_repo = AudioFileRepository(db_session_provider)
    plugin_repo = VampyPluginRepository(db_session_provider)
    plugin_config_repo = PluginConfigRepository(db_session_provider)
    feature_data_repo = FeatureDataRepository(db_session_provider)
    feature_meta_repo = FeatureMetaRepository(db_session_provider)
    metric_definition_repo = MetricDefinitionRepository(
        db_session_provider, plugin_repo)
    metric_value_repo = MetricValueRepository(db_session_provider,
                                              metric_definition_repo)
    result_repo = RequestRepository(db_session_provider, audio_meta_repo,
                                    audio_tag_repo, plugin_repo,
                                    plugin_config_repo)
    result_stats_repo = ResultStatsRepository(db_session_provider)

    extraction_service = FeatureExtractionService(
        plugin_provider=plugin_provider,
        audio_file_store=mp3_file_store,
        audio_tag_repo=audio_tag_repo,
        audio_meta_repo=audio_meta_repo,
        plugin_repo=plugin_repo,
        plugin_config_repo=plugin_config_repo,
        metric_definition_repo=metric_definition_repo,
        metric_value_repo=metric_value_repo,
        feature_data_repo=feature_data_repo,
        feature_meta_repo=feature_meta_repo,
        request_repo=result_repo,
        result_stats_repo=result_stats_repo,
        logger=logger)
    try:
        extraction_service.extract_feature_and_store(request)
        return extraction_request
    except SoftTimeLimitExceeded as e:
        logger.exception(e)
        raise e
예제 #6
0
 def setUp(self):
     self.extraction_request_1 = ExtractionRequest(
         "some_file_1.mp3", "some_vendor:some_plugin:some_output", {}, {})
     self.extraction_request_2 = ExtractionRequest(
         "some_file_2.mp3", "some_vendor:some_plugin:some_output",
         {"block_size": 1024}, {})
     self.extraction_request_3 = ExtractionRequest(
         "some_file_1.mp3", "some_vendor:some_plugin2:some_output",
         {"step_size": 2048}, {})
     self.extraction_request_4 = ExtractionRequest(
         "some_file_1.mp3", "some_vendor:some_plugin:some_output2", {
             "step_size": 2048,
             "block_size": 4096
         }, {
             "metric_1": {
                 "transformation": {
                     "name": "select_row",
                     "args": [5],
                     "kwargs": {}
                 }
             }
         })
예제 #7
0
 def test_should_serialize_and_deserialize_model(self):
     serializable_form = self.extraction_request_1.to_serializable()
     actual_object = ExtractionRequest.from_serializable(serializable_form)
     assert_that(actual_object).is_not_none().is_equal_to(
         self.extraction_request_1)