Esempio n. 1
0
    def make_transform_values(transcript: Any) -> List[str]:
        """
        Make transcripts from a string/json-string.

        :param transcript: A string to search entities within.
        :type transcript: str
        :return: List of transcripts.
        :rtype: List[str]
        """
        try:
            transcript = json.loads(transcript)
            return normalize(transcript)
        except (json.JSONDecodeError, TypeError):
            return normalize(transcript)
Esempio n. 2
0
def merge_asr_output(utterances: Any) -> List[str]:
    """
    .. _merge_asr_output:

    Join ASR output to single string.

    This function provides a merging strategy for n-best ASR transcripts by
    joining each transcript, such that:

    - each sentence end is marked by " </s>" and,
    - sentence start marked by " <s>".

    The key "transcript" is expected in the ASR output, the value of which would be operated on
    by this function.

    The normalization is done by :ref:`normalize<normalize>`

    .. ipython:: python

        from dialogy.plugins.text.merge_asr_output import merge_asr_output

        utterances = ["This is a sentence", "That is a sentence"]
        merge_asr_output(utterances)

    :param utterances: A structure representing ASR output. We support only:

        1. :code:`List[str]`
        2. :code:`List[List[str]]`
        3. :code:`List[List[Dict[str, str]]]`
        4. :code:`List[Dict[str, str]]`

    :type utterances: Any
    :return: Concatenated string, separated by <s> and </s> at respective terminal positions of each sentence.
    :rtype: List[str]
    :raises: TypeError if transcript is missing in cases of :code:`List[List[Dict[str, str]]]` or
        :code:`List[Dict[str, str]]`.
    """
    try:
        transcripts: List[str] = normalize(utterances)
        invalid_transcript = len(transcripts) == 1 and any(
            token.lower() in transcripts for token in const.INVALID_TOKENS
        )
        if invalid_transcript or not transcripts:
            return []
        else:
            return ["<s> " + " </s> <s> ".join(transcripts) + " </s>"]
    except TypeError as type_error:
        raise TypeError("`transcript` is expected in the ASR output.") from type_error
Esempio n. 3
0
    def inference(
        self, transcripts: List[str], utterances: List[Utterance]
    ) -> List[str]:
        transcript_lengths: List[int] = [
            len(transcript.split()) for transcript in transcripts
        ]
        average_word_count: float = (
            sum(transcript_lengths) / len(transcript_lengths) if transcripts else 0.0
        )

        # We want to run this plugin if transcripts have more than WORD_THRESHOLD words
        # below that count, WER is mostly high. We expect this plugin to override
        # a classifier's prediction to a fallback label.
        # If the transcripts have less than WORD_THRESHOLD words, we will always predict the fallback label.
        if average_word_count <= const.WORD_THRESHOLD:
            return transcripts

        return normalize(self.filter_asr_output(utterances))
Esempio n. 4
0
    def transform(self, training_data: pd.DataFrame) -> pd.DataFrame:
        if not self.use_transform:
            return training_data

        logger.debug(f"Transforming dataset via {self.__class__.__name__}")

        for i, row in tqdm(training_data.iterrows(), total=len(training_data)):
            try:
                canonicalized_transcripts = self.mask_transcript(
                    row[self.entity_column],
                    normalize(json.loads(row[self.input_column])),
                )
                training_data.loc[i, self.output_column] = json.dumps(
                    canonicalized_transcripts)
            except Exception as error:  # pylint: disable=broad-except
                logger.error(
                    f"{error} -- {row[self.input_column]}\n{traceback.format_exc()}"
                )
        return training_data
def test_cant_normalize_utterance(utterance: Any) -> None:
    with pytest.raises(TypeError):
        _ = normalize(utterance)
def test_normalize_utterance(utterance: Any, expected: List[str]) -> None:
    output = normalize(utterance)
    assert output == expected
Esempio n. 7
0
 def __attrs_post_init__(self) -> None:
     try:
         object.__setattr__(self, "transcripts", normalize(self.utterances))
     except TypeError:
         ...