Exemplo n.º 1
0
def run_transcribe(audio_path: str, spect_parser: SpectrogramParser,
                   model: DeepSpeech, decoder: Decoder, device: torch.device,
                   use_half: bool):
    # audio_path
    # try:
    #     # inTranscript = audio_path.replace("wav", "txt")
    #     # print(inTranscript)
    #     # getTranscript(inTranscript)
    #     pass
    # except Exception as asd:
    #     print(asd)
    #     pass
    spect = spect_parser.parse_audio(audio_path).contiguous()
    spect = spect.view(1, 1, spect.size(0), spect.size(1))
    spect = spect.to(device)
    if use_half:
        spect = spect.half()
    input_sizes = torch.IntTensor([spect.size(3)]).int()
    out, output_sizes = model(spect, input_sizes)
    decoded_output, decoded_offsets = decoder.decode(out, output_sizes)

    #Thêm vào greedy
    decoder2 = GreedyDecoder(labels=model.labels,
                             blank_index=model.labels.index('_'))
    decoded_output2, decoded_offsets2 = decoder2.decode(out, output_sizes)

    return decoded_output, decoded_output2, decoded_offsets, decoded_offsets2
Exemplo n.º 2
0
class DeepSpeech(pl.LightningModule):
    def __init__(self, labels: List, model_cfg: Union[UniDirectionalConfig,
                                                      BiDirectionalConfig],
                 precision: int, optim_cfg: Union[AdamConfig, SGDConfig],
                 spect_cfg: SpectConfig):
        super().__init__()
        self.save_hyperparameters()
        self.model_cfg = model_cfg
        self.precision = precision
        self.optim_cfg = optim_cfg
        self.spect_cfg = spect_cfg
        self.bidirectional = True if OmegaConf.get_type(
            model_cfg) is BiDirectionalConfig else False

        self.labels = labels
        num_classes = len(self.labels)

        self.conv = MaskConv(
            nn.Sequential(
                nn.Conv2d(1,
                          32,
                          kernel_size=(41, 11),
                          stride=(2, 2),
                          padding=(20, 5)), nn.BatchNorm2d(32),
                nn.Hardtanh(0, 20, inplace=True),
                nn.Conv2d(32,
                          32,
                          kernel_size=(21, 11),
                          stride=(2, 1),
                          padding=(10, 5)), nn.BatchNorm2d(32),
                nn.Hardtanh(0, 20, inplace=True)))
        # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
        rnn_input_size = int(
            math.floor((self.spect_cfg.sample_rate *
                        self.spect_cfg.window_size) / 2) + 1)
        rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
        rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
        rnn_input_size *= 32

        self.rnns = nn.Sequential(
            BatchRNN(input_size=rnn_input_size,
                     hidden_size=self.model_cfg.hidden_size,
                     rnn_type=self.model_cfg.rnn_type.value,
                     bidirectional=self.bidirectional,
                     batch_norm=False),
            *(BatchRNN(input_size=self.model_cfg.hidden_size,
                       hidden_size=self.model_cfg.hidden_size,
                       rnn_type=self.model_cfg.rnn_type.value,
                       bidirectional=self.bidirectional)
              for x in range(self.model_cfg.hidden_layers - 1)))

        self.lookahead = nn.Sequential(
            # consider adding batch norm?
            Lookahead(self.model_cfg.hidden_size,
                      context=self.model_cfg.lookahead_context),
            nn.Hardtanh(0, 20,
                        inplace=True)) if not self.bidirectional else None

        fully_connected = nn.Sequential(
            nn.BatchNorm1d(self.model_cfg.hidden_size),
            nn.Linear(self.model_cfg.hidden_size, num_classes, bias=False))
        self.fc = nn.Sequential(SequenceWise(fully_connected), )
        self.inference_softmax = InferenceBatchSoftmax()
        self.criterion = CTCLoss(blank=self.labels.index('_'),
                                 reduction='sum',
                                 zero_infinity=True)
        self.evaluation_decoder = GreedyDecoder(
            self.labels)  # Decoder used for validation
        self.wer = WordErrorRate(decoder=self.evaluation_decoder,
                                 target_decoder=self.evaluation_decoder)
        self.cer = CharErrorRate(decoder=self.evaluation_decoder,
                                 target_decoder=self.evaluation_decoder)

    def forward(self, x, lengths):
        lengths = lengths.cpu().int()
        output_lengths = self.get_seq_lens(lengths)
        x, _ = self.conv(x, output_lengths)

        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2],
                   sizes[3])  # Collapse feature dimension
        x = x.transpose(1, 2).transpose(0, 1).contiguous()  # TxNxH

        for rnn in self.rnns:
            x = rnn(x, output_lengths)

        if not self.bidirectional:  # no need for lookahead layer in bidirectional
            x = self.lookahead(x)

        x = self.fc(x)
        x = x.transpose(0, 1)
        # identity in training mode, softmax in eval mode
        x = self.inference_softmax(x)
        return x, output_lengths

    def training_step(self, batch, batch_idx):
        inputs, targets, input_percentages, target_sizes = batch
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
        out, output_sizes = self(inputs, input_sizes)
        out = out.transpose(0, 1)  # TxNxH
        out = out.log_softmax(-1)

        loss = self.criterion(out, targets, output_sizes, target_sizes)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets, input_percentages, target_sizes = batch
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
        inputs = inputs.to(self.device)
        with autocast(enabled=self.precision == 16):
            out, output_sizes = self(inputs, input_sizes)
        decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes)
        self.wer(preds=out,
                 preds_sizes=output_sizes,
                 targets=targets,
                 target_sizes=target_sizes)
        self.cer(preds=out,
                 preds_sizes=output_sizes,
                 targets=targets,
                 target_sizes=target_sizes)
        self.log('wer', self.wer.compute(), prog_bar=True, on_epoch=True)
        self.log('cer', self.cer.compute(), prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        if OmegaConf.get_type(self.optim_cfg) is SGDConfig:
            optimizer = torch.optim.SGD(
                params=self.parameters(),
                lr=self.optim_cfg.learning_rate,
                momentum=self.optim_cfg.momentum,
                nesterov=True,
                weight_decay=self.optim_cfg.weight_decay)
        elif OmegaConf.get_type(self.optim_cfg) is AdamConfig:
            optimizer = torch.optim.AdamW(
                params=self.parameters(),
                lr=self.optim_cfg.learning_rate,
                betas=self.optim_cfg.betas,
                eps=self.optim_cfg.eps,
                weight_decay=self.optim_cfg.weight_decay)
        else:
            raise ValueError("Optimizer has not been specified correctly.")

        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer=optimizer, gamma=self.optim_cfg.learning_anneal)
        return [optimizer], [scheduler]

    def get_seq_lens(self, input_length):
        """
        Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
        containing the size sequences that will be output by the network.
        :param input_length: 1D Tensor
        :return: 1D Tensor scaled by model
        """
        seq_len = input_length
        for m in self.conv.modules():
            if type(m) == nn.modules.conv.Conv2d:
                seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] *
                            (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1)
        return seq_len.int()
Exemplo n.º 3
0
def run_evaluation(test_loader,
                   device,
                   model,
                   decoder,
                   target_decoder,
                   save_output=None,
                   verbose=False,
                   use_half=False):
    model.eval()
    total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
    total_cer2, total_wer2, num_tokens2, num_chars2 = 0, 0, 0, 0

    output_data = []
    for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)):
        inputs, targets, input_percentages, target_sizes = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int(
        )  #độ dài 1 dòng trong spect của mẫu, input_sizes là 32 mẫu
        inputs = inputs.to(device)
        if use_half:
            inputs = inputs.half()  #không thay đổi nhiều
        # unflatten targets
        split_targets = []
        offset = 0
        for size in target_sizes:
            split_targets.append(targets[offset:offset + size])
            offset += size

        out, output_sizes = model(inputs, input_sizes)

        decoded_output, _ = decoder.decode(out, output_sizes)
        target_strings = target_decoder.convert_to_strings(split_targets)

        if save_output is not None:
            # add output to data array, and continue
            output_data.append((out.cpu(), output_sizes, target_strings))
    #     for x in range(len(target_strings)):
    #         transcript, reference = decoded_output[x][0], target_strings[x][0]
    #         wer_inst = decoder.wer(transcript, reference)
    #         cer_inst = decoder.cer(transcript, reference)
    #         total_wer += wer_inst
    #         total_cer += cer_inst
    #         num_tokens += len(reference.split())
    #         num_chars += len(reference.replace(' ', ''))
    #         if verbose:
    #             print("Ref:", reference.lower())
    #             print("Hyp:", transcript.lower())
    #             print("WER:", float(wer_inst) / len(reference.split()),
    #                   "CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
    # wer = float(total_wer) / num_tokens
    # cer = float(total_cer) / num_chars

    ############
        decoder2 = GreedyDecoder(labels=model.labels,
                                 blank_index=model.labels.index('_'))
        old_out, out_offsets = decoder2.decode(out, output_sizes)
        for x in range(len(target_strings)):
            transcript, reference = decoded_output[x][0], target_strings[x][0]
            wer_inst = decoder.wer(transcript, reference)
            cer_inst = decoder.cer(transcript, reference)
            total_wer += wer_inst
            total_cer += cer_inst
            num_tokens += len(reference.split())
            num_chars += len(reference.replace(' ', ''))
            if verbose:
                print("TRUTH :", reference.lower())
                print("Beam  :", transcript.lower())
                print("WER:",
                      float(wer_inst) / len(reference.split()), "CER:",
                      float(cer_inst) / len(reference.replace(' ', '')))

            transcript2 = old_out[x][0]
            wer_inst2 = decoder2.wer(transcript2, reference)
            cer_inst2 = decoder2.cer(transcript2, reference)
            total_wer2 += wer_inst2
            total_cer2 += cer_inst2
            num_tokens2 += len(reference.split())
            num_chars2 += len(reference.replace(' ', ''))
            if verbose:
                print("Greedy:", transcript2.lower())
                print("WER2:",
                      float(wer_inst2) / len(reference.split()), "CER2:",
                      float(cer_inst2) / len(reference.replace(' ', '')), "\n")
            # if(total_wer!=total_wer2):
            #     print("BUG HERE")
    wer = float(total_wer) / num_tokens
    cer = float(total_cer) / num_chars
    wer2 = float(total_wer2) / num_tokens2
    cer2 = float(total_cer2) / num_chars2
    ##########

    # for x in range(len(target_strings)):
    #     transcript2=old_out[x][0]
    #     wer_inst2 = decoder2.wer(transcript2, reference)
    #     cer_inst2 = decoder2.cer(transcript2, reference)
    #     total_wer2 += wer_inst2
    #     total_cer2 += cer_inst2
    #     num_tokens2 += len(reference.split())
    #     num_chars2 += len(reference.replace(' ', ''))
    #     if verbose:
    #         print("Old:",transcript2.lower())
    #         print("WER2:", float(wer_inst2) / len(reference.split()),
    #             "CER2:", float(cer_inst2) / len(reference.replace(' ', '')), "\n")

    ################
    return wer * 100, cer * 100, output_data, wer2 * 100, cer * 100