Example #1
0
class TestTextToSpeechOperator(unittest.TestCase):
    def setUp(self):
        with patch(
            "airflow.providers.google.cloud.hooks.base.CloudBaseHook.__init__",
            new=mock_base_gcp_hook_default_project_id,
        ):
            self.gcp_speech_to_text_hook = CloudSpeechToTextHook(gcp_conn_id="test")

    @patch(
        "airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook.client_info",
        new_callable=PropertyMock
    )
    @patch("airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook._get_credentials")
    @patch("airflow.providers.google.cloud.hooks.speech_to_text.SpeechClient")
    def test_speech_client_creation(self, mock_client, mock_get_creds, mock_client_info):
        result = self.gcp_speech_to_text_hook.get_conn()
        mock_client.assert_called_once_with(
            credentials=mock_get_creds.return_value,
            client_info=mock_client_info.return_value
        )
        self.assertEqual(mock_client.return_value, result)
        self.assertEqual(self.gcp_speech_to_text_hook._client, result)

    @patch("airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook.get_conn")
    def test_synthesize_speech(self, get_conn):
        recognize_method = get_conn.return_value.recognize
        recognize_method.return_value = None
        self.gcp_speech_to_text_hook.recognize_speech(config=CONFIG, audio=AUDIO)
        recognize_method.assert_called_once_with(config=CONFIG, audio=AUDIO, retry=None, timeout=None)
Example #2
0
 def execute(self, context):
     hook = CloudSpeechToTextHook(gcp_conn_id=self.gcp_conn_id)
     respones = hook.recognize_speech(config=self.config,
                                      audio=self.audio,
                                      retry=self.retry,
                                      timeout=self.timeout)
     return MessageToDict(respones)
Example #3
0
 def execute(self, context):
     hook = CloudSpeechToTextHook(
         gcp_conn_id=self.gcp_conn_id,
         impersonation_chain=self.impersonation_chain,
     )
     response = hook.recognize_speech(
         config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout
     )
     return MessageToDict(response)
    def execute(self, context: 'Context') -> dict:
        speech_to_text_hook = CloudSpeechToTextHook(
            gcp_conn_id=self.gcp_conn_id,
            impersonation_chain=self.impersonation_chain,
        )
        translate_hook = CloudTranslateHook(
            gcp_conn_id=self.gcp_conn_id,
            impersonation_chain=self.impersonation_chain,
        )

        recognize_result = speech_to_text_hook.recognize_speech(
            config=self.config, audio=self.audio)
        recognize_dict = MessageToDict(recognize_result)

        self.log.info("Recognition operation finished")

        if not recognize_dict['results']:
            self.log.info("No recognition results")
            return {}
        self.log.debug("Recognition result: %s", recognize_dict)

        try:
            transcript = recognize_dict['results'][0]['alternatives'][0][
                'transcript']
        except KeyError as key:
            raise AirflowException(
                f"Wrong response '{recognize_dict}' returned - it should contain {key} field"
            )

        try:
            translation = translate_hook.translate(
                values=transcript,
                target_language=self.target_language,
                format_=self.format_,
                source_language=self.source_language,
                model=self.model,
            )
            self.log.info('Translated output: %s', translation)
            FileDetailsLink.persist(
                context=context,
                task_instance=self,
                uri=self.audio["uri"][5:],
                project_id=self.project_id or translate_hook.project_id,
            )
            return translation
        except ValueError as e:
            self.log.error(
                'An error has been thrown from translate speech method:')
            self.log.error(e)
            raise AirflowException(e)
Example #5
0
    def execute(self, context: 'Context'):
        hook = CloudSpeechToTextHook(
            gcp_conn_id=self.gcp_conn_id,
            impersonation_chain=self.impersonation_chain,
        )

        FileDetailsLink.persist(
            context=context,
            task_instance=self,
            # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}"
            uri=self.audio["uri"][5:],
            project_id=self.project_id or hook.project_id,
        )

        response = hook.recognize_speech(config=self.config,
                                         audio=self.audio,
                                         retry=self.retry,
                                         timeout=self.timeout)
        return MessageToDict(response)
Example #6
0
    def execute(self, context):
        speech_to_text_hook = CloudSpeechToTextHook(
            gcp_conn_id=self.gcp_conn_id)
        translate_hook = CloudTranslateHook(gcp_conn_id=self.gcp_conn_id)

        recognize_result = speech_to_text_hook.recognize_speech(
            config=self.config, audio=self.audio)
        recognize_dict = MessageToDict(recognize_result)

        self.log.info("Recognition operation finished")

        if not recognize_dict['results']:
            self.log.info("No recognition results")
            return {}
        self.log.debug("Recognition result: %s", recognize_dict)

        try:
            transcript = recognize_dict['results'][0]['alternatives'][0][
                'transcript']
        except KeyError as key:
            raise AirflowException(
                "Wrong response '{}' returned - it should contain {} field".
                format(recognize_dict, key))

        try:
            translation = translate_hook.translate(
                values=transcript,
                target_language=self.target_language,
                format_=self.format_,
                source_language=self.source_language,
                model=self.model)
            self.log.info('Translated output: %s', translation)
            return translation
        except ValueError as e:
            self.log.error(
                'An error has been thrown from translate speech method:')
            self.log.error(e)
            raise AirflowException(e)