Exemplo n.º 1
0
    def append(
        self,
        ids,
        predict,
        target,
        predict_len=None,
        target_len=None,
        ind2lab=None,
    ):
        """Add stats to the relevant containers.

        * See MetricStats.append()

        Arguments
        ---------
        ids : list
            List of ids corresponding to utterances.
        predict : torch.tensor
            A predicted output, for comparison with the target output
        target : torch.tensor
            The correct reference output, for comparison with the prediction.
        predict_len : torch.tensor
            The predictions relative lengths, used to undo padding if
            there is padding present in the predictions.
        target_len : torch.tensor
            The target outputs' relative lengths, used to undo padding if
            there is padding present in the target.
        ind2lab : callable
            Callable that maps from indices to labels, operating on batches,
            for writing alignments.
        """
        self.ids.extend(ids)

        if predict_len is not None:
            predict = undo_padding(predict, predict_len)

        if target_len is not None:
            target = undo_padding(target, target_len)

        if ind2lab is not None:
            predict = ind2lab(predict)
            target = ind2lab(target)

        if self.merge_tokens:
            predict = merge_char(predict)
            target = merge_char(target)

        if self.split_tokens:
            predict = split_word(predict)
            target = split_word(target)

        scores = edit_distance.wer_details_for_batch(ids, target, predict,
                                                     True)

        self.scores.extend(scores)
Exemplo n.º 2
0
    def _json2text(self):
        """Read JSON file and convert specific data entries into text file.
        """
        if not os.path.isfile(os.path.abspath(self.annotation_train)):
            raise ValueError(
                self.annotation_train +
                " is not a file. please provide annotation file for training.")
        logger.info("Extract " + self.annotation_read + " sequences from:" +
                    self.annotation_train)

        # Read JSON
        with open(self.annotation_train, "r") as f:
            out_json = json.load(f)

        # Save text file
        text_file = open(self.text_file, "w+")
        row_idx = 0

        for snt_id in out_json.keys():
            if self.num_sequences is not None and row_idx > self.num_sequences:
                print("Using %d sequences to train the tokenizer." %
                      self.num_sequences)
                break
            row_idx += 1
            sent = out_json[snt_id][self.annotation_read]
            if self.char_format_input:
                (sent, ) = merge_char([sent.split()])
                sent = " ".join(sent)

            text_file.write(sent + "\n")
        text_file.close()

        logger.info("Text file created at: " + self.text_file)
Exemplo n.º 3
0
 def _csv2text(self):
     """Read CSV file and convert specific data entries into text file.
     """
     if not os.path.isfile(os.path.abspath(self.annotation_train)):
         raise ValueError(
             self.annotation_train +
             " is not a file. please provide annotation file for training.")
     logger.info("Extract " + self.annotation_read + " sequences from:" +
                 self.annotation_train)
     annotation_file = open(self.annotation_train, "r")
     reader = csv.reader(annotation_file)
     headers = next(reader, None)
     if self.annotation_read not in headers:
         raise ValueError(self.annotation_read + " must exist in:" +
                          self.annotation_train)
     index_label = headers.index(self.annotation_read)
     text_file = open(self.text_file, "w+")
     row_idx = 0
     for row in reader:
         if self.num_sequences is not None and row_idx > self.num_sequences:
             print("Using %d sequences to train the tokenizer." %
                   self.num_sequences)
             break
         row_idx += 1
         sent = row[index_label]
         if self.char_format_input:
             (sent, ) = merge_char([sent.split()])
             sent = " ".join(sent)
         text_file.write(sent + "\n")
     text_file.close()
     annotation_file.close()
     logger.info("Text file created at: " + self.text_file)
Exemplo n.º 4
0
    def _check_coverage_from_bpe(self, list_csv_files=[]):
        """Logging the accuracy of the BPE model to recover words from the training text.

        Arguments
        ---------
        csv_list_to_check : list,
            List of the csv file which is used for checking the accuracy of recovering words from the tokenizer.
        """
        for csv_file in list_csv_files:
            if os.path.isfile(os.path.abspath(csv_file)):
                logger.info(
                    "==== Accuracy checking for recovering text from tokenizer ==="
                )
                fcsv_file = open(csv_file, "r")
                reader = csv.reader(fcsv_file)
                headers = next(reader, None)
                if self.csv_read not in headers:
                    raise ValueError(self.csv_read + " must exist in:" +
                                     csv_file)
                index_label = headers.index(self.csv_read)
                wrong_recover_list = []
                for row in reader:
                    row = row[index_label]
                    if self.char_format_input:
                        (row, ) = merge_char([row.split()])
                        row = " ".join(row)
                    row = row.split("\n")[0]
                    encoded_id = self.sp.encode_as_ids(row)
                    decode_text = self.sp.decode_ids(encoded_id)
                    (details, ) = edit_distance.wer_details_for_batch(
                        ["utt1"],
                        [row.split(" ")],
                        [decode_text.split(" ")],
                        compute_alignments=True,
                    )
                    if details["WER"] > 0:
                        for align in details["alignment"]:
                            if align[0] != "=" and align[1] is not None:
                                if align[1] not in wrong_recover_list:
                                    wrong_recover_list.append(align[1])
                fcsv_file.close()
                logger.info("recover words from: " + csv_file)
                if len(wrong_recover_list) > 0:
                    logger.warn("Wrong recover words: " +
                                str(len(wrong_recover_list)))
                    logger.warn("Tokenizer vocab size: " +
                                str(self.sp.vocab_size()))
                    logger.warn("accuracy recovering words: " +
                                str(1 - float(len(wrong_recover_list)) /
                                    self.sp.vocab_size()))
                else:
                    logger.info("Wrong recover words: 0")
                    logger.warning("accuracy recovering words: " + str(1.0))
            else:
                logger.info("No accuracy recover checking for" + csv_file)
Exemplo n.º 5
0
    def __call__(
        self,
        batch,
        batch_lens=None,
        ind2lab=None,
        task="encode",
    ):
        """This __call__ function implements the tokenizer encoder and decoder
        (restoring the string of word) for BPE, Regularized BPE (with unigram),
        and char (speechbrain/nnet/RNN.py).
        Arguments
        ----------
        batch : tensor.IntTensor or list
            List if ( batch_lens = None and task = "decode_from_list")
            Contains the original labels. Shape: [batch_size, max_length]
        batch_lens : tensor.LongTensor
            Containing the relative length of each label sequences. Must be 1D
            tensor of shape: [batch_size]. (default: None)
        ind2lab : dict
            Dictionary which maps the index from label sequences
            (batch tensor) to string label.
        task : str
            ("encode", "decode", "decode_from_list)
            "encode": convert the batch tensor into sequence of tokens.
                the output contain a list of (tokens_seq, tokens_lens)
            "decode": convert a tensor of tokens to a list of word sequences.
            "decode_from_list": convert a list of token sequences to a list
                of word sequences.
        """
        if task == "encode" and ind2lab is None:
            raise ValueError(
                "Tokenizer encoder must have the ind2lab function")

        if task == "encode":
            # Convert list of words/chars to bpe ids
            bpe = []
            max_bpe_len = 0
            batch_lens = (batch_lens * batch.shape[1]).int()
            for i, utt_seq in enumerate(batch):
                tokens = [
                    ind2lab[int(index)] for index in utt_seq[:batch_lens[i]]
                ]
                if self.char_format_input:
                    (words_list, ) = merge_char([tokens])
                    sent = " ".join(words_list)
                else:
                    sent = " ".join(tokens)
                bpe_encode = self.sp.encode_as_ids(sent)
                bpe.append(bpe_encode)
                # save the longest bpe sequence
                # it help to compute the relative length of each utterance
                if len(bpe_encode) > max_bpe_len:
                    max_bpe_len = len(bpe_encode)
            # Create bpe tensor
            bpe_tensor = torch.zeros((batch.shape[0], max_bpe_len),
                                     device=batch.device)
            bpe_lens = torch.zeros((batch.shape[0]), device=batch.device)
            for i, bpe_utt in enumerate(bpe):
                bpe_tensor[i, :len(bpe_utt)] = torch.Tensor(bpe_utt)
                bpe_lens[i] = len(bpe_utt) / max_bpe_len
            return bpe_tensor, bpe_lens
        elif task == "decode_from_list":
            # From list of hyps (not padded outputs)
            # do decoding
            return [
                self.sp.decode_ids(utt_seq).split(" ") for utt_seq in batch
            ]
        elif task == "decode":
            # From a batch tensor and a length tensor
            # find the absolute batch lengths and do decoding
            batch_lens = (batch_lens * batch.shape[1]).int()
            return [
                self.sp.decode_ids(
                    utt_seq[:batch_lens[i]].int().tolist()).split(" ")
                for i, utt_seq in enumerate(batch)
            ]