Ejemplo n.º 1
0
class Predictor(object):
    def __init__(self):
        super().__init__()
        self.labels = [' '] + PHONEME_MAP
        self.decoder = CTCBeamDecoder(labels=self.labels,
                                      beam_width=100,
                                      blank_id=0,
                                      num_processes=32)

    def __call__(self, logits, labels, label_lens):
        return self.forward(logits, labels, label_lens)

    def evaluateError(self, logits, labels, label_lens):

        logits = torch.transpose(logits, 0, 1).cpu()
        probs = F.softmax(logits, dim=2)
        #        print("                begin to decode ctc")
        parse_res = self.decoder.decode(probs=probs)
        output, scores, timesteps, out_seq_len = parse_res

        #        print("                begin to calculate distance")
        pos, ls = 0, 0.
        for i in range(output.size(0)):
            pred = "".join(self.labels[o]
                           for o in output[i, 0, :out_seq_len[i, 0]])
            true = "".join(self.labels[l]
                           for l in labels[pos:pos + label_lens[i]])
            #print("Pred: {}, True: {}".format(pred, true))
            pos += label_lens[i]
            ls += L.distance(pred, true)

#        print("                finished decoding")
        assert pos == labels.size(0)
        return ls

    def predict(self, logits):
        logits = torch.transpose(logits, 0, 1).cpu()
        probs = F.softmax(logits, dim=2)
        parse_res = self.decoder.decode(probs=probs)
        output, scores, timesteps, out_seq_len = parse_res

        result = []
        for i in range(output.size(0)):
            pred = "".join(self.labels[o]
                           for o in output[i, 0, :out_seq_len[i, 0]])
            result.append(pred)

        return result
 def decode_beamsearch(self, preds) :
     texts = []
     preds = preds.softmax(2)
     # preds = torch.Tensor.cpu(preds).detach().numpy()
     # # print(preds.shape)
     # for i in range(preds.shape[0]) :
     #     seq, path = beam_search(preds[i], self.alphabet, beam_size=20, beam_cut_threshold=0.00001)
     #     texts.append(seq)
     decoder = CTCBeamDecoder(
         self.character,
         model_path=None,
         alpha=0,
         beta=0,
         cutoff_top_n=10,
         cutoff_prob=1.0,
         beam_width=4,
         num_processes=16,
         blank_id=0,
         log_probs_input=False
     )
     beam_results, beam_scores, timesteps, out_lens = decoder.decode(preds)
     for i in range(preds.shape[0]) :
         seq = "".join(self.character[n] for n in beam_results[i][0][:out_lens[i][0]])
         texts.append(seq)
     
     # return decoder(preds)
     return texts
Ejemplo n.º 3
0
class BeamDecoder(object):
    def __init__(self,
                 labels,
                 lm_path,
                 alpha=0.8,
                 beta=0.3,
                 cutoff_top_n=40,
                 cutoff_prob=1.0,
                 beam_width=32,
                 num_processes=4,
                 blank_index=0):
        from ctcdecode import CTCBeamDecoder
        self.decoder = CTCBeamDecoder(
            labels,
            lm_path,
            alpha,
            beta,
            cutoff_top_n,
            cutoff_prob,
            beam_width,
            num_processes,
            blank_index,
        )
        self.labels = labels

    def decode(self, probs, sizes=None):
        out, score, offset, outlen = self.decoder.decode(probs, sizes)
        out_, out_len = out[0][0], outlen[0][0]
        return "".join([self.labels[x] for x in out_[0:out_len]]), -1
Ejemplo n.º 4
0
class ER:
    def __init__(self):
        self.label_map = [' '] + DIGITS_MAP
        self.decoder = CTCBeamDecoder(labels=self.label_map, blank_id=0)

    def __call__(self, prediction, target):
        return self.forward(prediction, target)

    def forward(self, prediction, target):
        logits = prediction[0]
        feature_lengths = prediction[1].int()
        labels = target + 1
        logits = torch.transpose(logits, 0, 1)
        logits = logits.cpu()
        probs = F.softmax(logits, dim=2)
        output, scores, timesteps, out_seq_len = self.decoder.decode(
            probs=probs, seq_lens=feature_lengths)

        pos = 0
        ls = 0.
        for i in range(output.size(0)):
            pred = "".join(self.label_map[o]
                           for o in output[i, 0, :out_seq_len[i, 0]])
            true = "".join(self.label_map[l] for l in labels[pos:pos + 10])
            #print("Pred: {}, True: {}".format(pred, true))
            pos += 10
            ls += L.distance(pred, true)
        assert pos == labels.size(0)
        return ls / output.size(0)
Ejemplo n.º 5
0
def validate(model, dev_loader):
    decoder = CTCBeamDecoder(['$'] * 47, beam_width=100, log_probs_input=True)
    with torch.no_grad():
        model.eval()
        model.cuda()
        count = 0
        dist_sum = 0
        for batch_idx, lst in enumerate(dev_loader):
            X, X_lens, Y, Y_lens = process_train_lst(lst)
            out, out_lens = model(X, X_lens)
            val_Y, _, _, val_Y_lens = decoder.decode(out.transpose(0, 1),
                                                     out_lens)
            this_batch_size = val_Y.shape[0]

            predicted_list = [
                val_Y[i, 0, :val_Y_lens[i, 0]] for i in range(this_batch_size)
            ]
            ground_truth_list = [
                Y[i, 0:Y_lens[i]] for i in range(this_batch_size)
            ]
            ground_truth_phoneme_list = convert_to_phoneme(ground_truth_list)
            predicted_phoneme_list = convert_to_phoneme(predicted_list)

            for i in range(len(predicted_list)):
                count += 1
                cur_predicted_str = "".join(predicted_phoneme_list[i])
                cur_label_str = "".join(ground_truth_phoneme_list[i])
                cur_dist = Levenshtein.distance(cur_predicted_str,
                                                cur_label_str)
                dist_sum += cur_dist
            print(f"Batch: {batch_idx} | Avg Distance: {dist_sum / count}")
        print("Dev Avg Distance: {:.4f}".format(dist_sum / count))
Ejemplo n.º 6
0
class CTCDecoder(object):
    def __init__(self, args, tgt_dict):
        self.tgt_dict = tgt_dict
        self.vocab_size = len(tgt_dict)
        self.nbest = args.nbest
        self.beam = args.beam
        self.blank = (tgt_dict.index("<ctc_blank>")
                      if "<ctc_blank>" in tgt_dict.indices else tgt_dict.bos())
        self.decode_fn = CTCBeamDecoder(tgt_dict.symbols,
                                        beam_width=self.beam,
                                        blank_id=self.blank,
                                        num_processes=10)

    def generate(self, models, sample, **unused):
        """Generate a batch of inferences."""
        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v
            for k, v in sample["net_input"].items()
            if k != "prev_output_tokens"
        }
        emissions, seq_lens = self.get_emissions(models, encoder_input)

        return self.decode(emissions, seq_lens)

    def get_emissions(self, models, encoder_input):
        """Run encoder and normalize emissions"""
        # encoder_out = models[0].encoder(**encoder_input)
        encoder_out = models[0](**encoder_input)
        emissions = models[0].get_normalized_probs(encoder_out,
                                                   log_probs=False)
        seq_lens = (~encoder_out['encoder_padding_mask']).sum(-1)

        return emissions.transpose(0, 1), seq_lens

    def get_tokens(self, idxs):
        """Normalize tokens by handling CTC blank, ASG replabels, etc."""
        idxs = (g[0] for g in it.groupby(idxs))
        idxs = filter(lambda x: x != self.blank, idxs)

        return torch.LongTensor(list(idxs))

    def decode(self, emissions, seq_lens):
        hypos = []

        beam_results, beam_scores, timesteps, out_seq_len = self.decode_fn.decode(
            emissions, seq_lens)
        for beam_result, scores, lengthes in zip(beam_results, beam_scores,
                                                 out_seq_len):
            # beam_ids: beam x id; score: beam; length: beam
            top = []
            for result, score, length in zip(beam_result, scores, lengthes):
                top.append({
                    'tokens': self.get_tokens(result[:length]),
                    "score": score
                })
            hypos.append(top)

        return hypos
Ejemplo n.º 7
0
class Levenshtein:
    def __init__(self, charmap):
        self.label_map = [' '] + charmap  # add blank to first entry
        self.decoder = CTCBeamDecoder(
            labels=self.label_map,
            blank_id=0,
            beam_width=100
        )

    def __call__(self, prediction, target):
        return self.forward(prediction, target)

    def forward(self, prediction, target, feature_lengths):
        feature_lengths = torch.Tensor(feature_lengths)
        prediction = torch.transpose(prediction, 0, 1)
        prediction = prediction.cpu()
        probs = F.softmax(prediction, dim=2)
        output, scores, timesteps, out_seq_len = self.decoder.decode(probs=probs, seq_lens=feature_lengths)

        ls = 0.
        for i in range(output.size(0)):
            pred = "".join(self.label_map[o] for o in output[i, 0, :out_seq_len[i, 0]])
            true = "".join(self.label_map[l] for l in target[i].numpy())
            # print("Pred: {}, True: {}".format(pred, true))
            ls += L.distance(pred, true)
        return ls
Ejemplo n.º 8
0
def pred_model(model, test_loader):
    with torch.no_grad():
        model.eval()

        predLabel = []

        for batch_idx, (padinp, xlens) in enumerate(test_loader):
            padinp = padinp.to(device)

            batchlabel = []

            out, out_lens = model(padinp, xlens)

            phonemes = [" "] + PHONEME_MAP

            decoder = CTCBeamDecoder(phonemes,
                                     beam_width=10,
                                     log_probs_input=True)
            out_lens = torch.LongTensor(out_lens)

            pred, _, _, pred_lens = decoder.decode(out.transpose(0, 1),
                                                   out_lens)

            for i in range(len(pred)):
                seq = ""
                for j in range(pred_lens[i, 0]):
                    seq += phonemes[int(pred[i, 0, j])]

                batchlabel.append(seq)

            predLabel = predLabel + batchlabel

    return predLabel
Ejemplo n.º 9
0
class BeamCTCDecoder(Decoder):
    def __init__(self,
                 alphabet,
                 lm_path=None,
                 alpha=0,
                 beta=0,
                 cutoff_top_n=40,
                 cutoff_prob=1.0,
                 beam_width=100,
                 num_processes=4):
        super().__init__(alphabet)

        try:
            from ctcdecode import CTCBeamDecoder
        except ImportError:
            raise ImportError("BeamCTCDecoder requires ctcdecode package.")

        self._decoder = CTCBeamDecoder(alphabet.tokens,
                                       lm_path,
                                       alpha,
                                       beta,
                                       cutoff_top_n,
                                       cutoff_prob,
                                       beam_width,
                                       num_processes,
                                       alphabet.blank_index,
                                       log_probs_input=True)

    def decode(self, log_probs, sizes=None):
        """
        Given a matrix of character probabilities, returns the decoder's
        best guess of the transcription

        Arguments:
            log_probs (tensor): Tensor of log probabilities with shape (B, T, L), 
                where `log_probs[b, t, l]` is the log probability of character `c` at time `t` 
                in batch `b`
            sizes (optional): Size of each sequence in the batch
        Returns:
            decoded (list of string): sequence of the model's best guess for the transcription
            scores (tensor): tensor of size B the negative log probability 
            offsets (tensor): time-step per character predicted
        """
        log_probs = log_probs.cpu()

        out, scores, offsets, seq_lens = self._decoder.decode(log_probs, sizes)

        strings = self.tensor2str(out[:, 0, :], seq_lens[:, 0])

        scores = scores[:, 0]
        offsets = offsets[:, 0]

        return strings, scores, offsets

    def reset_params(self, alpha, beta):
        self._decoder.reset_params(alpha, beta)
Ejemplo n.º 10
0
def recognize(image_path, model, label_dict, device):
    img = Image.open(image_path).convert("RGB")
    tgt_height = 64

    width, height = img.size
    reshape_width = tgt_height * (width / height)
    img = img.resize([int(reshape_width), int(tgt_height)])
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    img = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img)

    _, ind2ch = get_label_dict(label_dict)

    # output = output.squeeze(1).cpu().numpy()
    # results, score = ctcdecoder.decode(output, 20, 98)

    labels = list(ind2ch.values())
    replace_label = {
        'UNK': '_',
        'SOS': '_',
        'EOS': '_',
        'SPACE': ' ',
        'BLANK': '_'
    }
    labels = ''.join(
        [replace_label[l] if l in replace_label.keys() else l for l in labels])
    decoder = CTCBeamDecoder(labels,
                             model_path=None,
                             alpha=0,
                             beta=0,
                             cutoff_top_n=40,
                             cutoff_prob=1.0,
                             beam_width=20,
                             num_processes=8,
                             blank_id=98,
                             log_probs_input=True)
    output = output.permute(1, 0, 2)
    beam_results, beam_scores, timesteps, out_lens = decoder.decode(output)
    results = beam_results[0][0][:out_lens[0][0]].cpu().tolist()
    # print(results)
    # print(1/torch.exp(beam_scores))

    pred = ''
    for ch in results:
        ch = ind2ch[ch]
        if ch in ['UNK', 'SOS', 'EOS', 'BLANK']:
            continue
        elif ch == 'SPACE':
            pred += ' '
        else:
            pred += ch
    return pred
Ejemplo n.º 11
0
class BeamCTCDecoder(Decoder):
    def __init__(self, labels, lm_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100,
                 num_processes=4, blank_index=0):
        super(BeamCTCDecoder, self).__init__(labels)
        try:
            from ctcdecode import CTCBeamDecoder
        except ImportError:
            raise ImportError("BeamCTCDecoder requires paddledecoder package.")
        self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width,
                                       num_processes, blank_index)

    def convert_to_strings(self, out, seq_len):
        results = []
        for b, batch in enumerate(out):
            utterances = []
            for p, utt in enumerate(batch):
                size = seq_len[b][p]
                if size > 0:
                    transcript = ''.join(map(lambda x: self.int_to_char[x], utt[0:size]))
                else:
                    transcript = ''
                utterances.append(transcript)
            results.append(utterances)
        return results

    def convert_tensor(self, offsets, sizes):
        results = []
        for b, batch in enumerate(offsets):
            utterances = []
            for p, utt in enumerate(batch):
                size = sizes[b][p]
                if sizes[b][p] > 0:
                    utterances.append(utt[0:size])
                else:
                    utterances.append(torch.IntTensor())
            results.append(utterances)
        return results

    def decode(self, probs, sizes=None):
        """
        Decodes probability output using ctcdecode package.
        Arguments:
            probs: Tensor of character probabilities, where probs[c,t]
                            is the probability of character c at time t
            sizes: Size of each sequence in the mini-batch
        Returns:
            string: sequences of the model's best guess for the transcription
        """
        probs = probs.cpu().transpose(0, 1).contiguous()
        out, scores, offsets, seq_lens = self._decoder.decode(probs)

        strings = self.convert_to_strings(out, seq_lens)
        offsets = self.convert_tensor(offsets, seq_lens)
        return strings, offsets
Ejemplo n.º 12
0
def fast_beam_search_decode(logprobs,
                            logprobs_lens,
                            vocab,
                            beam_size,
                            cutoff_top_n,
                            cutoff_prob,
                            ext_scoring_func,
                            alpha,
                            beta,
                            num_processes,
                            rescorer=None):
    blank_index = vocab['<blank>']

    labels = ''.join(vocab.indices2tokens()).replace('<blank>',
                                                     '_').replace('<unk>', '')
    decoder = CTCBeamDecoder(labels=labels,
                             blank_id=blank_index,
                             cutoff_top_n=cutoff_top_n,
                             cutoff_prob=cutoff_prob,
                             beam_width=beam_size,
                             model_path=ext_scoring_func,
                             alpha=alpha,
                             beta=beta,
                             num_processes=num_processes,
                             log_probs_input=True)
    beam_results, beam_scores, timesteps, out_lens = decoder.decode(
        torch.transpose(logprobs, 0, 1), logprobs_lens)

    predictions = []
    for idx in range(beam_results.shape[0]):
        beam = []
        for jdx in range(beam_results.shape[1]):
            hypo = ''.join(
                vocab.lookup_tokens(
                    beam_results[idx, jdx, :out_lens[idx, jdx]].tolist()))
            hypo_score = -beam_scores[idx, jdx]
            beam.append((hypo, hypo_score))
        predictions.append(beam)

    if rescorer is not None:
        all_hypos = [hypo for beam in predictions for hypo, _ in beam]
        scoring_results = rescorer.score(all_hypos)
        all_lm_scores = [
            scoring_result['positional_scores'].mean().item()
            for scoring_result in scoring_results
        ]
        all_lm_scores = torch.tensor(all_lm_scores).reshape(beam_scores.shape)
        all_lm_scores = torch.softmax(all_lm_scores, dim=1)
        predictions = [[(predictions[idx][jdx][0], all_lm_scores[idx, jdx])
                        for jdx in range(beam_results.shape[1])]
                       for idx in range(beam_results.shape[0])]

    return predictions
Ejemplo n.º 13
0
def val():
  model.eval()
  distances = []
  for batch_idx, (data, target, in_lens, target_lens) in enumerate(test_loader):
    data, in_lens = data.to(device), in_lens.to(device)
    out, out_lens = model(data, in_lens)
    decoder = CTCBeamDecoder(PHONEME_LIST, beam_width=3)
    decoded_out, _, _, decoded_lens = decoder.decode(out.transpose(0, 1).cpu(), out_lens.cpu())
    decoded_strings = [label_to_short_phoneme(decoded_out[i, 0, :decoded_lens[i]]) for i in range(decoded_out.shape[0])]
    decoded_labels = [label_to_short_phoneme(label_pad[i, : target_lens[i]]) for i in range(label_pad.shape[0])]
    batch_distances = [distance(o, l) for o, l in zip(decoded_strings, decoded_labels)]
    distances.extend(batch_distances)
    print('Distance = ', np.mean(distances))
Ejemplo n.º 14
0
 def cpp_beam_search(predictions, labels, beam_width=5, beam_cut_threshold=0.1):
     """
     C++ Beam search CTC decoder https://github.com/parlance/ctcdecode
     """
     # add batch dimension expected by CTCBeamDecoder
     predictions = np.expand_dims(predictions, 0)
     predictions = torch.FloatTensor(predictions)
     decoder = CTCBeamDecoder(
         labels, beam_width=beam_width, cutoff_prob=beam_cut_threshold
     )
     beam_result, _, _, out_seq_len = decoder.decode(predictions)
     beam_result = beam_result[0][0][0:out_seq_len[0][0]]
     return ''.join(labels[x] for x in beam_result)
Ejemplo n.º 15
0
def test(model, test_loader, ocr_dataset):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ratios = []
    lv_ratios = []

    BLANK = ocr_dataset.get_num_classes()-1

    with torch.no_grad():
        for ((x, input_lengths),(y,target_lengths)) in test_loader:
            print("Run eval")
            x = x.to(device)
            
            outputs = model.forward(x)
            outputs = outputs.permute(1, 0, 2)
            
            decoder = CTCBeamDecoder(ocr_dataset.char_vec,
                                    blank_id=BLANK,
                                    log_probs_input=True)

            output, scores, ts, out_seq_len = decoder.decode(outputs.data, 
                                                    torch.IntTensor(input_lengths))

            results = []
        
            for b, batch in enumerate(output):
                size = out_seq_len[b][0]
                dec = batch[0]

                text = ''
                if size > 0:
                    text = ocr_dataset.get_decoded_label(dec[0:size])
                
                results.append(text)
            
            
            ptr = 0
            for i, p in enumerate(target_lengths):
                yi = y[ptr:ptr+p]
                
                s1 = results[i]
                s2 = ocr_dataset.get_decoded_label(yi)

                ratios.append(SequenceMatcher(None, s1, s2).quick_ratio())
                
                lv_ratios.append(char_err_rate(s1, s2))

                ptr += p   

    print("SequenceMatcher acc:", np.mean(ratios), np.std(ratios))
    print("Levenshtein acc:", np.mean(lv_ratios), np.std(lv_ratios))
def run(config):
    batch_size = config["batch_size"]
    seq_len = config["seg_len"]
    n_iter = config["epoch"]
    input_size = config["input_size"]
    device = config["device"]
    vocab_size = config["vocab_size"]
    # num_processes = config["num_processes"]

    beam_width = config["beam_width"]
    # print("num_processes_cpu: ", os.cpu_count())
    num_threads = config["num_threads"]
    if device == "cpu":
        torch.set_num_threads(num_threads)
        print("num_threads: ", torch.get_num_threads())
    model = DeepSpeech(config)
    decoder = CTCBeamDecoder(['$'] * (vocab_size + 1),
                             beam_width=beam_width,
                             blank_id=0,
                             num_processes=num_threads,
                             log_probs_input=True)

    # inp = torch.ones((batch_size, seq_len, input_size+2*input_size*n_context))

    model = model.to(device)

    forward_time = 0
    decode_time = 0
    overall_time = 0
    for i in range(n_iter):
        start_time = time.perf_counter()
        inp = torch.rand(
            (batch_size, seq_len, input_size + 2 * input_size * n_context))
        inp = inp.to(device)
        out = model(inp)
        end_time1 = time.perf_counter()
        start_time1 = time.perf_counter()
        out = out.transpose(0, 1)
        out_lens = torch.tensor([seq_len for _ in range(batch_size)])
        output, scores, timesteps, out_seq_len = decoder.decode(
            out,
            out_lens)  # [b, seq_len, vocab_size] -> [b, beam_width, seq_len]
        end_time2 = time.perf_counter()

        forward_time += end_time1 - start_time
        decode_time += end_time2 - start_time1
        overall_time += end_time2 - start_time

    print("Forward: %f s" % (forward_time / n_iter))
    print("CTC Decode %f s" % (decode_time / n_iter))
    print("Overall %f s" % (overall_time / n_iter))
Ejemplo n.º 17
0
def run_decoder(model, inputs):
    inputlen = torch.IntTensor([len(seq) for seq in inputs]).to(net.DEVICE)
    phonemes = [' '] + PL.PHONEME_MAP
    decoder = CTCBeamDecoder(['$'] * (len(phonemes)),
                             beam_width=200,
                             log_probs_input=True)
    with torch.no_grad():
        out, out_lens = model(inputs, inputlen)
    test_Y, _, _, test_Y_lens = decoder.decode(out.transpose(0, 1), out_lens)
    for i in range(len(inputs)):
        # For the i-th sample in the batch, get the best output
        best_seq = test_Y[i, 0, :test_Y_lens[i, 0]]
        best_pron = ''.join(phonemes[i + 1] for i in best_seq)
        return best_pron
Ejemplo n.º 18
0
class Decoder:
    def __init__(
            self,
            labels: list = LABELS,
            beam_width: int = 100,
            model_path: str = None,
            alpha: float = 0.0,
            beta: float = 0.0,
            cutoff_top_n: int = 40,
            cutoff_prob: float = 1.0,
            blank_id: int = LABELS.index('_'),
            log_probs_input: bool = False,
    ):
        self.labels = labels
        self.beam_width = beam_width
        self.model_path = model_path
        self.alpha = alpha
        self.beta = beta
        self.cutoff_top_n = cutoff_top_n
        self.cutoff_prob = cutoff_prob
        self.blank_id = blank_id
        self.log_probs_input = log_probs_input

        self.decoder = CTCBeamDecoder(labels=labels,
                                      beam_width=beam_width,
                                      model_path=model_path,
                                      alpha=alpha,
                                      beta=beta,
                                      cutoff_top_n=cutoff_top_n,
                                      cutoff_prob=cutoff_prob,
                                      num_processes=max(os.cpu_count(), 1),
                                      blank_id=blank_id,
                                      log_probs_input=log_probs_input)

    def __call__(self, token_probs: torch.Tensor) -> str:
        """Generate a decoded string from token probabilities.

    Args:
        probs (torch.Tensor): The output from an acoustic model.

    Returns:
        str: The output string.
    """
        token_probs = torch.Tensor(token_probs[None, ...])
        beam_results, beam_scores, timesteps, out_lens = self.decoder.decode(
            token_probs)
        tokens = beam_results[0][0]
        seq_len = out_lens[0][0]
        return ''.join([LABELS[x] for x in tokens[0:seq_len]])
Ejemplo n.º 19
0
class BeamCTCDecoder():
    def __init__(self, PHONEME_MAP, blank_index=0, beam_width=100):
        # Add the blank to the phoneme_map as the first element
        if PHONEME_MAP[blank_index] != ' ':
            PHONEME_MAP.insert(0, ' ')
        # Define the int_to_char dictionary
        self.int_to_char = dict([(i, c) for (i, c) in enumerate(PHONEME_MAP)])
        self._decoder = CTCBeamDecoder(PHONEME_MAP,
                                       blank_id=blank_index,
                                       beam_width=beam_width,
                                       log_probs_input=True)

    def decode(self, probs, sizes=None):
        probs, sizes = probs.cpu(), sizes.cpu()
        out, _, _, seq_lens = self._decoder.decode(probs, sizes)
        # out: shape (batch_size, beam_width, seq_len)
        # seq_lens: shape (batch_size, beam_width)
        # The best sequences are indexed 0 in the beam_width dimension.
        strings = self.convert_to_strings(out[:, 0, :], seq_lens[:, 0])
        return strings

    def convert_to_strings(self, out, seq_len):
        """
        :param out: (batch_size, sequence_length)
        :param seq_len: (batch_size)
        :return:
        """
        out = out.cpu()
        results = []
        for b, utt in enumerate(out):
            size = seq_len[b]
            if size > 0:
                # Map each integer to the char using the int_to_char dictionary
                # Only get the original len and remove all the padding elements
                transcript = ''.join(
                    map(lambda x: self.int_to_char[x.item()], utt[:size]))
            else:
                transcript = ''
            transcript = transcript.replace(' ', '')
            results.append(transcript)
        return results

    def Lev_dist(self, s1, s2):
        s1, s2 = s1.replace(' ', ''), s2.replace(' ', '')
        return Lev.distance(s1, s2)
def decode(output, seq_sizes, beam_width=40):
    decoder = CTCBeamDecoder(labels=PHONEME_MAP, blank_id=0, beam_width=beam_width)
    output = torch.transpose(output, 0, 1)  # batch, seq_len, probs
    probs = F.softmax(output, dim=2).data.cpu()

    output, scores, timesteps, out_seq_len = decoder.decode(probs=probs,
                                                            seq_lens=torch.IntTensor(seq_sizes))
    #     print("output", output)
    #     print("scores", scores)
    #     print("timesteps", timesteps)
    #     print("out_seq_len", out_seq_len)
    decoded = []
    for i in range(output.size(0)):
        chrs = ""
        if out_seq_len[i, 0] != 0:
            chrs = "".join(PHONEME_MAP[o] for o in output[i, 0, :out_seq_len[i, 0]])
        decoded.append(chrs)
    return decoded
Ejemplo n.º 21
0
class CTCDecoder:
    def __init__(self,
                 blank_id: int,
                 alphabet: List[str],
                 count_prediction=10):
        self.decoder = CTCBeamDecoder(alphabet,
                                      beam_width=count_prediction,
                                      blank_id=blank_id)

    def __call__(self, output):
        result, _, _, sec_len = self.decoder.decode(output)
        len_best_result = sec_len[:, 0]
        best_results = []
        for i, res in enumerate(result):
            best_results.append(res[0, :len_best_result[i]])

        tensor_res = nn.utils.rnn.pad_sequence(sequences=best_results,
                                               batch_first=True)
        return tensor_res, len_best_result
Ejemplo n.º 22
0
class CTCDecoder:

    def __init__(self, 
            labels, 
            lm_path=None, 
            alpha=1.5, 
            beta=0.8,
            cutoff_top_n=15,
            cutoff_prob=1.0,
            beam_width=256,
            num_processes=4,
            blank_id=31,
            log_probs_input=False):

        print("Initializing Decoder")
        self.decoder = CTCBeamDecoder(
            labels,
            model_path = lm_path,
            alpha=alpha,
            beta=beta,
            cutoff_top_n=cutoff_top_n,
            cutoff_prob=cutoff_prob,
            beam_width=beam_width,
            num_processes=num_processes,
            blank_id=blank_id,
            log_probs_input=log_probs_input
        )

        self.decode_dict = self._dict_from_labels(labels)
        print("Decoder ready")

    def _dict_from_labels(self, labels):
        d = {}
        for i in range(len(labels)):
            d[i] = labels[i]
        return d

    def map_to_chars(self, ids):
        return "".join([self.decode_dict[i] for i in ids])

    def decode(self, probs):
        beam_results, beam_scores, timesteps, out_lens = self.decoder.decode(probs)
        return self.map_to_chars(beam_results[0][0][:out_lens[0][0]].numpy())
Ejemplo n.º 23
0
class BeamDecoder(Decoder):
    def __init__(self,
                 vocab,
                 lm_path=None,
                 alpha=1,
                 beta=1.5,
                 cutoff_top_n=40,
                 cutoff_prob=0.99,
                 beam_width=100,
                 num_processes=4):
        super(BeamDecoder, self).__init__(vocab)

        self._decoder = CTCBeamDecoder(vocab,
                                       lm_path,
                                       alpha,
                                       beta,
                                       cutoff_top_n,
                                       cutoff_prob,
                                       beam_width,
                                       num_processes,
                                       blank_id=0)
        self.int2char = dict([(i, c) for (i, c) in enumerate(vocab)])

    def decode(self, logits, seq_lens):
        tlogits = logits.transpose(0, 1)
        results, scores, _, out_lens = self._decoder.decode(tlogits, seq_lens)
        return self.convert_to_strings(results, out_lens)

    def convert_to_strings(self, out, seq_len):
        results = []
        for b, batch in enumerate(out):
            utterances = []
            for p, utt in enumerate(batch):
                size = seq_len[b][p]
                if size > 0:
                    transcript = ''.join(
                        map(lambda x: self.int2char[x.item()], utt[0:size]))
                else:
                    transcript = ''
                utterances.append(transcript)
            results.append(utterances)
        return results
Ejemplo n.º 24
0
class ER:
    def __init__(self):
        self.label_map = PHONEME_MAP + [' ']
        self.phoneme_list = PHONEME_LIST + [' ']
        self.decoder = CTCBeamDecoder(labels=self.label_map,
                                      blank_id=phonemes_len,
                                      log_probs_input=True,
                                      beam_width=200)
        self.greedy_decoder = GreedyDecoder(labels=self.label_map,
                                            blank_index=phonemes_len)

    def __call__(self, prediction, target=None, test=False):
        return self.forward(prediction, target)

    def forward(self, prediction, target):
        logits = prediction[0]  # (logits, len)
        feature_lengths = prediction[1].int()
        labels = target
        logits = torch.transpose(logits, 0, 1)
        logits = logits.cpu()
        # beam decoder
        output, scores, timesteps, out_seq_len = self.decoder.decode(
            probs=logits, seq_lens=feature_lengths)

        ############# GREEDY DECODE ##########################
        _, max_probs = torch.max(logits, 2)
        strings, offsets = self.greedy_decoder.decode(probs=logits)
        predictions = []
        time_stamps = []
        ls = 0
        for i in range(len(strings)):
            pred = strings[i][0]
            phone_pred = []
            for j in pred:
                phone_pred.append(self.phoneme_list[self.label_map.index(j)])
            predictions.append(phone_pred)
            time_stamps.append(offsets[i][0].float() / 100)
            if target != None:
                true = "".join(self.label_map[l] for l in labels[i])
                ls += stringdist.levenshtein(strings[i][0], true)
        return predictions, time_stamps, ls / len(strings)
Ejemplo n.º 25
0
def train_wakeword_model(audio_train_loader,
                         vocab_list,
                         label_model,
                         beam_size=3,
                         num_hypotheses=5,
                         query_by_string=False):
    wakeword_model = {}

    if query_by_string:
        # load ww model produced by MFA from config
        keywords = config["wakeword_model"]
        # load blick
        b = BlickLoader()

        for i, _, y_hat in enumerate(keywords.items()):
            w = b.assessWord(y_hat)
            # for each keyword, append the tuple(hypotheses + weights) to the list
            # only one hypothesis if using MFA
            wakeword_model[i] = (y_hat, w)

    else:
        # train ww model from scratch
        for i in audio_train_loader:
            posteriors_i = label_model(i)
            # decode using CTC, vocab_list is A (labels)
            decoder = CTCBeamDecoder(self.vocab_list,
                                     beam_width=self.beam_size,
                                     blank_id=self.vocab_list.index('_'))

            beam, beam_scores, _, _ = decoder.decode(posteriors_i)

            for j in range(num_hypotheses):
                y_hat = beam[j]  # hypothesis
                log_prob_post = beam_scores[j]
                w = log_prob_post**-1

                # for each keyword, append the tuple(hypotheses + weights) to the list
                wakeword_model[i].append((y_hat, w))

    return wakeword_model
Ejemplo n.º 26
0
def predict(model, test_loader, result_path):
    decoder = CTCBeamDecoder(['$'] * 47, beam_width=100, log_probs_input=True)
    with torch.no_grad():
        model.eval()
        model.cuda()
        predicted_list = []
        for batch_idx, lst in enumerate(test_loader):
            X, X_lens = process_test_lst(lst)
            out, out_lens = model(X, X_lens)
            test_Y, _, _, test_Y_lens = decoder.decode(out.transpose(0, 1),
                                                       out_lens)
            this_batch_size = test_Y.shape[0]

            predicted_list += [
                test_Y[i, 0, :test_Y_lens[i, 0]]
                for i in range(this_batch_size)
            ]
            print(CONFIG.model_name, f" Predicting... | Batch: {batch_idx}")
        predicted_phoneme_list = convert_to_phoneme(predicted_list)
        labels = [i for i in range(len(predicted_list))]
        export_to_csv(labels, "id", predicted_phoneme_list, "Predicted",
                      result_path)
        print(CONFIG.model_name, " Validation Finished")
Ejemplo n.º 27
0
def decode(output_probs, dataLens, beamWidth):
    decoder = CTCBeamDecoder(labels=PHONEME_MAP,
                             beam_width=beamWidth,
                             num_processes=os.cpu_count(),
                             log_probs_input=True)
    output_probs = torch.transpose(output_probs, 0,
                                   1)  # post transpose: (B, T, C=47)
    output, _, _, out_seq_len = decoder.decode(
        output_probs, dataLens
    )  # output dim: (BatchSize, Beamwith, T), Out_seq_len dim (batchsize, bewmwidth)
    decodedListShort = []
    decodedListLong = []
    for b in range(output_probs.size(0)):
        currDecode = ""
        if out_seq_len[b][0] != 0:
            currDecodeShort = "".join(
                [PHONEME_MAP[i] for i in output[b, 0, :out_seq_len[b][0]]])
            currDecodeLong = "".join(
                [PHONEME_LIST[i] for i in output[b, 0, :out_seq_len[b][0]]])
        decodedListShort.append(currDecodeShort)
        decodedListLong.append(currDecodeLong)

    return decodedListShort, decodedListLong
class ER:
    def __init__(self):
        self.label_map = [' '] + PHONEME_MAP
        self.decoder = CTCBeamDecoder(labels=self.label_map,
                                      beam_width=50,
                                      blank_id=0)

    def __call__(self, prediction, target):
        return self.forward(prediction, target)

    def forward(self, prediction, target):
        logits = prediction[0]
        feature_lengths = prediction[1].int()
        logits = torch.transpose(logits, 0, 1)
        logits = logits.cpu()
        probs = F.softmax(logits, dim=2)
        output, scores, timesteps, out_seq_len = self.decoder.decode(
            probs=probs, seq_lens=feature_lengths)
        predictions = []
        if target == None:
            for i in range(output.size(0)):
                pred = "".join(self.label_map[o]
                               for o in output[i, 0, :out_seq_len[i, 0]])
                predictions.append(pred)
            return predictions
        pos = 0
        ls = 0.
        for i in range(output.size(0)):
            pred = "".join(self.label_map[o]
                           for o in output[i, 0, :out_seq_len[i, 0]])
            true = "".join(self.label_map[l] for l in target[i])
            # print("Pred: {}, True: {}".format(pred, true))
            # pos = pos + target_lengths[i]
            ls += L.distance(pred, true)
        # assert pos == labels.size(0)
        return ls / output.size(0)
Ejemplo n.º 29
0
class TextRecognizer(object):
    def __init__(self, args):
        if args.use_pdserving is False:
            self.predictor, self.input_tensor, self.output_tensors =\
                utility.create_predictor(args, mode="rec")
            self.use_zero_copy_run = args.use_zero_copy_run
        self.rec_image_shape = [
            int(v) for v in args.rec_image_shape.split(",")
        ]
        self.rec_render = args.rec_render
        self.character_type = args.rec_char_type
        self.rec_batch_num = args.rec_batch_num
        self.rec_algorithm = args.rec_algorithm
        self.rec_whitelist = args.rec_whitelist
        self.rec_blacklist = args.rec_blacklist
        self.text_len = args.max_text_length

        char_ops_params = {
            "character_type": args.rec_char_type,
            "character_dict_path": args.rec_char_dict_path,
            "use_space_char": args.use_space_char,
            "max_text_length": args.max_text_length
        }
        if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
            char_ops_params['loss_type'] = 'ctc'
            self.loss_type = 'ctc'
        elif self.rec_algorithm == "RARE":
            char_ops_params['loss_type'] = 'attention'
            self.loss_type = 'attention'
        elif self.rec_algorithm == "SRN":
            char_ops_params['loss_type'] = 'srn'
            self.loss_type = 'srn'
        self.char_ops = CharacterOps(char_ops_params)
        self.len_chars = self.char_ops.len_characters()

        # If both blacklist and whitelist are provided, whitelist is only used
        if self.rec_whitelist == '' and self.rec_blacklist != '':
            self.mod_chars = np.arange(start=0,
                                       stop=self.len_chars + 1,
                                       step=1)
            black_list = self.char_ops.encode(self.rec_blacklist)
            self.mod_chars = np.setdiff1d(self.mod_chars, black_list)
        elif self.rec_whitelist != '':
            white_list = self.char_ops.encode(self.rec_whitelist)
            self.mod_chars = np.append(white_list, [self.len_chars])
        elif self.rec_whitelist == '' and self.rec_blacklist == '':
            self.mod_chars = []

        self.use_beam_search = args.use_beam_search
        if self.use_beam_search:
            self.beam_width = args.beam_width
            self.beam_lm_dir = args.beam_lm_dir if args.beam_lm_dir != '' else None
            self.beam_alpha = args.beam_alpha
            self.beam_beta = args.beam_beta
            self.beam_cutoff_top = args.beam_cutoff_top
            self.beam_cutoff_prob = args.beam_cutoff_prob
            # self.labels = self.char_ops.decode(np.arange(0,self.len_chars,1)) + '_'
            self.labels = list(self.char_ops.decode(self.mod_chars) + '_')
            self.blank_id = len(self.mod_chars)

            self.decoder = CTCBeamDecoder(labels=self.labels,
                                          model_path=self.beam_lm_dir,
                                          alpha=self.beam_alpha,
                                          beta=self.beam_beta,
                                          cutoff_top_n=self.beam_cutoff_top,
                                          cutoff_prob=self.beam_cutoff_prob,
                                          beam_width=self.beam_width,
                                          num_processes=os.cpu_count(),
                                          blank_id=self.blank_id,
                                          log_probs_input=False)

        self.use_spell_check = args.use_spell_check
        if self.use_spell_check:
            self.spell_case_sensitive = args.spell_case_sensitive
            self.spell_language = "" if self.spell_case_sensitive else args.spell_language
            self.spell_tokenizer = word_tokenize if args.spell_tokenizer == 'NLTK' else None
            self.spell_word_freq = args.spell_word_freq if args.spell_word_freq != '' else None
            self.spell_text_corpus = args.spell_text_corpus

            self.spell = SpellChecker(
                language=self.spell_language,
                local_dictionary=self.spell_word_freq,
                tokenizer=self.spell_tokenizer,
                case_sensitive=self.spell_case_sensitive,
            )

            if self.spell_text_corpus != '':
                self.spell.word_frequency.load_text_file(
                    self.spell_text_corpus)

    def resize_norm_img(self, img, max_wh_ratio):
        imgC, imgH, imgW = self.rec_image_shape
        assert imgC == img.shape[2]
        if self.character_type == "ch":
            imgW = int((32 * max_wh_ratio))
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def resize_norm_img_srn(self, img, image_shape):
        imgC, imgH, imgW = image_shape

        img_black = np.zeros((imgH, imgW))
        im_hei = img.shape[0]
        im_wid = img.shape[1]

        if im_wid <= im_hei * 1:
            img_new = cv2.resize(img, (imgH * 1, imgH))
        elif im_wid <= im_hei * 2:
            img_new = cv2.resize(img, (imgH * 2, imgH))
        elif im_wid <= im_hei * 3:
            img_new = cv2.resize(img, (imgH * 3, imgH))
        else:
            img_new = cv2.resize(img, (imgW, imgH))

        img_np = np.asarray(img_new)
        img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
        img_black[:, 0:img_np.shape[1]] = img_np
        img_black = img_black[:, :, np.newaxis]

        row, col, c = img_black.shape
        c = 1

        return np.reshape(img_black, (c, row, col)).astype(np.float32)

    def srn_other_inputs(self, image_shape, num_heads, max_text_length,
                         char_num):

        imgC, imgH, imgW = image_shape
        feature_dim = int((imgH / 8) * (imgW / 8))

        encoder_word_pos = np.array(range(0, feature_dim)).reshape(
            (feature_dim, 1)).astype('int64')
        gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
            (max_text_length, 1)).astype('int64')

        gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
        gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias1 = np.tile(
            gsrm_slf_attn_bias1,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias2 = np.tile(
            gsrm_slf_attn_bias2,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        encoder_word_pos = encoder_word_pos[np.newaxis, :]
        gsrm_word_pos = gsrm_word_pos[np.newaxis, :]

        return [
            encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
            gsrm_slf_attn_bias2
        ]

    def process_image_srn(self,
                          img,
                          image_shape,
                          num_heads,
                          max_text_length,
                          char_ops=None):
        norm_img = self.resize_norm_img_srn(img, image_shape)
        norm_img = norm_img[np.newaxis, :]
        char_num = char_ops.get_char_num()

        [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
            self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)

        gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
        gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)

        return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
                gsrm_slf_attn_bias2)

    def __call__(self, img_list):
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the recognition process
        indices = np.argsort(np.array(width_list))

        #rec_res = []
        rec_res = [['', 0.0]] * img_num
        batch_num = self.rec_batch_num
        predict_time = 0
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
            max_wh_ratio = 0
            for ino in range(beg_img_no, end_img_no):
                # h, w = img_list[ino].shape[0:2]
                h, w = img_list[indices[ino]].shape[0:2]
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
                if self.loss_type != "srn":
                    norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                    max_wh_ratio)
                    norm_img = norm_img[np.newaxis, :]
                    norm_img_batch.append(norm_img)
                else:
                    norm_img = self.process_image_srn(img_list[indices[ino]],
                                                      self.rec_image_shape, 8,
                                                      25, self.char_ops)
                    encoder_word_pos_list = []
                    gsrm_word_pos_list = []
                    gsrm_slf_attn_bias1_list = []
                    gsrm_slf_attn_bias2_list = []
                    encoder_word_pos_list.append(norm_img[1])
                    gsrm_word_pos_list.append(norm_img[2])
                    gsrm_slf_attn_bias1_list.append(norm_img[3])
                    gsrm_slf_attn_bias2_list.append(norm_img[4])
                    norm_img_batch.append(norm_img[0])

            norm_img_batch = np.concatenate(norm_img_batch, axis=0)
            norm_img_batch = norm_img_batch.copy()

            if self.loss_type == "srn":
                starttime = time.time()
                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
                gsrm_slf_attn_bias1_list = np.concatenate(
                    gsrm_slf_attn_bias1_list)
                gsrm_slf_attn_bias2_list = np.concatenate(
                    gsrm_slf_attn_bias2_list)
                starttime = time.time()

                norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
                encoder_word_pos_list = fluid.core.PaddleTensor(
                    encoder_word_pos_list)
                gsrm_word_pos_list = fluid.core.PaddleTensor(
                    gsrm_word_pos_list)
                gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
                    gsrm_slf_attn_bias1_list)
                gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
                    gsrm_slf_attn_bias2_list)

                inputs = [
                    norm_img_batch, encoder_word_pos_list,
                    gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
                    gsrm_word_pos_list
                ]

                self.predictor.run(inputs)
            else:
                starttime = time.time()
                if self.use_zero_copy_run:
                    self.input_tensor.copy_from_cpu(norm_img_batch)
                    self.predictor.zero_copy_run()
                else:
                    norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
                    self.predictor.run([norm_img_batch])

            if len(self.mod_chars) != 0:
                mod_onehot = np.zeros((self.len_chars + 1))
                mod_onehot[self.mod_chars] = 1

            if self.loss_type == "ctc":
                rec_idx_batch = self.output_tensors[0].copy_to_cpu()
                rec_idx_lod = self.output_tensors[0].lod()[0]
                predict_batch = self.output_tensors[1].copy_to_cpu()
                predict_lod = self.output_tensors[1].lod()[0]

                if len(self.mod_chars) != 0:
                    predict_batch = np.multiply(
                        predict_batch, mod_onehot
                    )  #* Implemented blacklist and whitelist here!

                for rno in range(len(rec_idx_lod) - 1):

                    beg = predict_lod[rno]
                    end = predict_lod[rno + 1]
                    probs = predict_batch[beg:end, :]
                    ind = np.argmax(probs, axis=1)
                    valid_ind = range(ind.shape[0])
                    preds_text = self.char_ops.decode(ind[valid_ind],
                                                      is_remove_duplicate=True)
                    if len(valid_ind) == 0:
                        continue
                    score = np.mean(probs[valid_ind, ind[valid_ind]])

                    # use_spell_check results are the final results if both beam search and spell check is true!
                    if self.use_beam_search:
                        mod_probs = probs[:, self.mod_chars]
                        mod_probs = torch.Tensor(mod_probs).unsqueeze(0)
                        beams, scores, _, out_lens = self.decoder.decode(
                            mod_probs)
                        res_beam = beams[0][0][:out_lens[0][0]]
                        res_list = [self.mod_chars[i] for i in res_beam]
                        res_text = self.char_ops.decode(res_list)
                        score_beam = 1 / np.exp(scores[0][0])
                        if preds_text != res_text:
                            print(
                                f'original: {preds_text} || beam_search_corrected: {res_text}'
                            )
                        rec_res[indices[beg_img_no +
                                        rno]] = [res_text, score_beam]

                        if self.use_spell_check:
                            corrected = self.spell.correction(res_text)
                            if preds_text != corrected:
                                print(
                                    f'original: {preds_text} || spell_check_corrected: {corrected}'
                                )
                            rec_res[indices[beg_img_no +
                                            rno]] = [corrected, score_beam]

                    elif self.use_spell_check:
                        corrected = self.spell.correction(preds_text)
                        if preds_text != corrected:
                            print(
                                f'original: {preds_text} || spell_check_corrected: {corrected}'
                            )
                        rec_res[indices[beg_img_no + rno]] = [corrected, score]

                    else:
                        rec_res[indices[beg_img_no +
                                        rno]] = [preds_text, score]

            elif self.loss_type == 'srn':
                rec_idx_batch = self.output_tensors[0].copy_to_cpu()
                probs = self.output_tensors[1].copy_to_cpu()

                # TODO: implement whitelist and blacklist for srn loss

                char_num = self.char_ops.get_char_num()
                preds = rec_idx_batch.reshape(-1)
                elapse = time.time() - starttime
                predict_time += elapse
                total_preds = preds.copy()
                for ino in range(int(len(rec_idx_batch) / self.text_len)):
                    preds = total_preds[ino * self.text_len:(ino + 1) *
                                        self.text_len]
                    ind = np.argmax(probs, axis=1)
                    valid_ind = np.where(preds != int(char_num - 1))[0]
                    if len(valid_ind) == 0:
                        continue
                    score = np.mean(probs[valid_ind, ind[valid_ind]])
                    preds = preds[:valid_ind[-1] + 1]
                    preds_text = self.char_ops.decode(preds)

                    rec_res[indices[beg_img_no + ino]] = [preds_text, score]
            else:
                rec_idx_batch = self.output_tensors[0].copy_to_cpu()
                predict_batch = self.output_tensors[1].copy_to_cpu()

                # TODO: implement whitelist and blacklist for srn loss

                elapse = time.time() - starttime
                predict_time += elapse
                for rno in range(len(rec_idx_batch)):
                    end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
                    if len(end_pos) <= 1:
                        preds = rec_idx_batch[rno, 1:]
                        score = np.mean(predict_batch[rno, 1:])
                    else:
                        preds = rec_idx_batch[rno, 1:end_pos[1]]
                        score = np.mean(predict_batch[rno, 1:end_pos[1]])
                    preds_text = self.char_ops.decode(preds)
                    # rec_res.append([preds_text, score])
                    rec_res[indices[beg_img_no + rno]] = [preds_text, score]

        # *TODO: COMPLETE (if self.rec_render:)

        return rec_res, predict_time
Ejemplo n.º 30
0
class Translator(object):
    """
    Uses a model to translate a batch of sentences.


    Args:
       model (:obj:`onmt.modules.NMTModel`):
          NMT model to use for translation
       lm: language model
       fields (dict of Fields): data fields
       beam_size (int): size of beam to use
       n_best (int): number of translations produced
       max_length (int): maximum length output to produce
       global_scores (:obj:`GlobalScorer`):
         object to rescore final translations
       copy_attn (bool): use copy attention during translation
       cuda (bool): use cuda
       beam_trace (bool): trace beam search for debugging
       logger(logging.Logger): logger.
    """

    def __init__(self,
                 model,
                 lm,
                 fields,
                 opt,
                 model_opt,
                 global_scorer=None,
                 out_file=None,
                 report_score=True,
                 logger=None):

        self.model = model
        self.lm = lm
        self.fields = fields
        self.gpu = opt.gpu
        self.cuda = opt.gpu > -1

        self.n_best = opt.n_best
        self.max_length = opt.max_length
        self.beam_size = opt.beam_size
        self.min_length = opt.min_length
        self.stepwise_penalty = opt.stepwise_penalty
        self.dump_beam = opt.dump_beam
        self.block_ngram_repeat = opt.block_ngram_repeat
        self.ignore_when_blocking = set(opt.ignore_when_blocking)
        self.sample_rate = opt.sample_rate
        self.window_size = opt.window_size
        self.window_stride = opt.window_stride
        self.window = opt.window
        self.image_channel_size = opt.image_channel_size
        self.replace_unk = opt.replace_unk
        self.data_type = opt.data_type
        self.verbose = opt.verbose
        self.report_bleu = opt.report_bleu
        self.report_rouge = opt.report_rouge
        self.fast = opt.fast
        self.lbda = opt.lbda

        self.copy_attn = model_opt.copy_attn
        self.ctc_ratio = model_opt.ctc_ratio

        self.global_scorer = global_scorer
        self.out_file = out_file
        self.report_score = report_score
        self.logger = logger

        self.use_filter_pred = False

        # for debugging
        self.beam_trace = self.dump_beam != ""
        self.beam_accum = None
        if self.beam_trace:
            self.beam_accum = {
                "predicted_ids": [],
                "beam_parent_ids": [],
                "scores": [],
                "log_probs": []}

        if self.ctc_ratio > 0:
            from ctcdecode import CTCBeamDecoder
            ctc_vocab_field = "tgt_feat_0" if "tgt_feat_0" in fields else "tgt"
            self.ctc_vocab =  fields[ctc_vocab_field].vocab
            dummy_vocab = self.ctc_vocab.itos
            blank_id = 0 if "tgt_feat_0" in fields else self.ctc_vocab.stoi[inputters.BOS_WORD] 
            self.ctc_loss = nn.CTCLoss(blank=blank_id, reduction='none')
            self.ctc_c2v = {chr(65+i): v for i, v in enumerate(dummy_vocab)}
            dummy_vocab = [chr(65+i) for i, _ in enumerate(dummy_vocab)]
            self.ctc_dec = CTCBeamDecoder(dummy_vocab,
                                          beam_width=opt.beam_size,
                                          blank_id=blank_id,
                                          log_probs_input=True,
                                          num_processes=4)


    def translate(self,
                  src_path=None,
                  src_data_iter=None,
                  tgt_path=None,
                  tgt_data_iter=None,
                  src_dir=None,
                  batch_size=None,
                  attn_debug=False):
        """
        Translate content of `src_data_iter` (if not None) or `src_path`
        and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.

        Note: batch_size must not be None
        Note: one of ('src_path', 'src_data_iter') must not be None

        Args:
            src_path (str): filepath of source data
            src_data_iter (iterator): an interator generating source data
                e.g. it may be a list or an openned file
            tgt_path (str): filepath of target data
            tgt_data_iter (iterator): an interator generating target data
            src_dir (str): source directory path
                (used for Audio and Image datasets)
            batch_size (int): size of examples per mini-batch
            attn_debug (bool): enables the attention logging

        Returns:
            (`list`, `list`)

            * all_scores is a list of `batch_size` lists of `n_best` scores
            * all_predictions is a list of `batch_size` lists
                of `n_best` predictions
        """
        assert src_data_iter is not None or src_path is not None

        if batch_size is None:
            raise ValueError("batch_size must be set")
        data = inputters. \
            build_dataset(self.fields,
                          self.data_type,
                          src_path=src_path,
                          src_data_iter=src_data_iter,
                          tgt_path=tgt_path,
                          tgt_data_iter=tgt_data_iter,
                          src_dir=src_dir,
                          sample_rate=self.sample_rate,
                          window_size=self.window_size,
                          window_stride=self.window_stride,
                          window=self.window,
                          use_filter_pred=self.use_filter_pred,
                          image_channel_size=self.image_channel_size)

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = inputters.OrderedIterator(
            dataset=data, device=cur_device,
            batch_size=batch_size, train=False, sort=False,
            sort_within_batch=True, shuffle=False)

        builder = onmt.translate.TranslationBuilder(
            data, self.fields,
            self.n_best, self.replace_unk, tgt_path)

        # Statistics
        counter = count(1)
        pred_score_total, pred_words_total = 0, 0
        gold_score_total, gold_words_total = 0, 0

        all_scores = []
        all_predictions = []

        for batch in data_iter:
            batch_data = self.translate_batch(batch, data, attn_debug,
                                              fast=self.fast)
            translations = builder.from_batch(batch_data)

            for trans in translations:
                all_scores += [trans.pred_scores[:self.n_best]]
                pred_score_total += trans.pred_scores[0]
                pred_words_total += len(trans.pred_sents[0])
                if tgt_path is not None:
                    gold_score_total += trans.gold_score
                    gold_words_total += len(trans.gold_sent) + 1

                n_best_preds = [" ".join(pred)
                                for pred in trans.pred_sents[:self.n_best]]
                all_predictions += [n_best_preds]
                self.out_file.write('\n'.join(n_best_preds) + '\n')
                self.out_file.flush()

                if self.verbose:
                    sent_number = next(counter)
                    output = trans.log(sent_number)
                    if self.logger:
                        self.logger.info(output)
                    else:
                        os.write(1, output.encode('utf-8'))

                # Debug attention.
                if attn_debug:
                    preds = trans.pred_sents[0]
                    preds.append('</s>')
                    attns = trans.attns[0].tolist()
                    if self.data_type == 'text':
                        srcs = trans.src_raw
                    else:
                        srcs = [str(item) for item in range(len(attns[0]))]
                    header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
                    row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    output = header_format.format("", *srcs) + '\n'
                    for word, row in zip(preds, attns):
                        max_index = row.index(max(row))
                        row_format = row_format.replace(
                            "{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
                        row_format = row_format.replace(
                            "{:*>10.7f} ", "{:>10.7f} ", max_index)
                        output += row_format.format(word, *row) + '\n'
                        row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
                    os.write(1, output.encode('utf-8'))

        if self.report_score:
            msg = self._report_score('PRED', pred_score_total,
                                     pred_words_total)
            if self.logger:
                self.logger.info(msg)
            else:
                print(msg)
            if tgt_path is not None:
                msg = self._report_score('GOLD', gold_score_total,
                                         gold_words_total)
                if self.logger:
                    self.logger.info(msg)
                else:
                    print(msg)
                if self.report_bleu:
                    msg = self._report_bleu(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)
                if self.report_rouge:
                    msg = self._report_rouge(tgt_path)
                    if self.logger:
                        self.logger.info(msg)
                    else:
                        print(msg)

        if self.dump_beam:
            import json
            json.dump(self.translator.beam_accum,
                      codecs.open(self.dump_beam, 'w', 'utf-8'))
        return all_scores, all_predictions

    def translate_batch(self, batch, data, attn_debug, fast=False):
        """
        Translate a batch of sentences.

        Mostly a wrapper around :obj:`Beam`.

        Args:
           batch (:obj:`Batch`): a batch from a dataset object
           data (:obj:`Dataset`): the dataset object
           fast (bool): enables fast beam search (may not support all features)

        Todo:
           Shouldn't need the original dataset.
        """
        with torch.no_grad():
            if self.ctc_ratio == 1:
                return self._ctc_translate_batch(
                    batch,
                    data,
                    self.max_length,
                    min_length=self.min_length,
                    n_best=self.n_best,
                    return_attention=attn_debug or self.replace_unk)
            elif fast:
                return self._fast_translate_batch(
                    batch,
                    data,
                    self.max_length,
                    min_length=self.min_length,
                    n_best=self.n_best,
                    return_attention=attn_debug or self.replace_unk)
            else:
                return self._translate_batch(batch, data)

    def _run_encoder(self, batch, data_type):
        src = inputters.make_features(batch, 'src', data_type)
        src_lengths = None
        if data_type == 'text':
            _, src_lengths = batch.src
        elif data_type == 'audio':
            src_lengths = batch.src_lengths
        enc_states, memory_bank, src_lengths, enc_out = self.model.encoder(
            src, src_lengths)

        ctc_scores = None
        if self.ctc_ratio > 0:
            batch_size = enc_out.size(1)
            bottled_enc_out = self._bottle(enc_out)
            ctc_scores = self.model.encoder.ctc_gen(bottled_enc_out)
            ctc_scores = ctc_scores.view(-1, batch_size, ctc_scores.size(-1))

        if src_lengths is None:
            assert not isinstance(memory_bank, tuple), \
                'Ensemble decoding only supported for text data'
            src_lengths = torch.Tensor(batch.batch_size) \
                               .type_as(memory_bank) \
                               .long() \
                               .fill_(memory_bank.size(0))
        return src, enc_states, memory_bank, src_lengths, ctc_scores

    def _decode_and_generate(self, decoder_input, memory_bank, batch, data,
                             memory_lengths, src_map=None,
                             step=None, batch_offset=None):

        if self.copy_attn:
            # Turn any copied words to UNKs (index 0).
            decoder_input = decoder_input.masked_fill(
                decoder_input.gt(len(self.fields["tgt"].vocab) - 1), 0)

        # Decoder forward, takes [tgt_len, batch, nfeats] as input
        # and [src_len, batch, hidden] as memory_bank
        # in case of inference tgt_len = 1, batch = beam times batch_size
        # in case of Gold Scoring tgt_len = actual length, batch = 1 batch
        dec_out, dec_attn = self.model.decoder(
            decoder_input,
            memory_bank,
            memory_lengths=memory_lengths,
            step=step)

        # Generator forward.
        if not self.copy_attn:
            attn = dec_attn["std"]
            log_probs = self.model.generator(dec_out.squeeze(0))
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence
        else:
            attn = dec_attn["copy"]
            scores = self.model.generator(dec_out.view(-1, dec_out.size(2)),
                                          attn.view(-1, attn.size(2)),
                                          src_map)
            # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
            if batch_offset is None:
                scores = scores.view(batch.batch_size, -1, scores.size(-1))
            else:
                scores = scores.view(-1, self.beam_size, scores.size(-1))
            scores = data.collapse_copy_scores(
                scores,
                batch,
                self.fields["tgt"].vocab,
                data.src_vocabs,
                batch_dim=0,
                batch_offset=batch_offset)
            scores = scores.view(decoder_input.size(0), -1, scores.size(-1))
            log_probs = scores.squeeze(0).log()
            # returns [(batch_size x beam_size) , vocab ] when 1 step
            # or [ tgt_len, batch_size, vocab ] when full sentence

        return log_probs, attn

    def _bottle(self, _v):
        return _v.view(-1, _v.size(2))

    def _ctc_decode(self, scores):
        scores = scores.transpose(0,1)
        seq_len =  scores.new_full((scores.size(1),), scores.size(0))
        beam_preds, beam_scores, _, out_seq_len = \
                self.ctc_dec.decode(scores)
        #blank_id = self.ctc_dec._blank_id
        beam_preds = [[[i for i in beam[0:l]]
                      for beam, l in zip(beams, lens)] for beams, lens in
                   zip(beam_preds.tolist(), out_seq_len.tolist())]
        return beam_scores, beam_preds

    def _ctc_translate_batch(self,
                             batch,
                             data,
                             max_length,
                             min_length=0,
                             n_best=1,
                             return_attention=False):
        batch_size = batch.batch_size
        beam_size = self.beam_size
        results = {}
        #results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        #results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["attention"] = [[[] for _ in range(beam_size)]
                for _ in range(batch_size)]  # noqa: F812
        results["batch"] = batch
        results["gold_score"] = [0] * batch_size
        src, enc_states, memory_bank, src_lengths, ctc_scores = self._run_encoder(
            batch, data.data_type)
        scores, predictions = self._ctc_decode(ctc_scores)
        results["scores"] = scores
        results["predictions"] = predictions

        return results

    def _fast_translate_batch(self,
                              batch,
                              data,
                              max_length,
                              min_length=0,
                              n_best=1,
                              return_attention=False):
        # TODO: faster code path for beam_size == 1.

        # TODO: support these blacklisted features.
        assert not self.dump_beam
        assert not self.use_filter_pred
        assert self.block_ngram_repeat == 0
        assert self.global_scorer.beta == 0

        beam_size = self.beam_size
        batch_size = batch.batch_size
        vocab = self.fields["tgt"].vocab
        start_token = vocab.stoi[inputters.BOS_WORD]
        end_token = vocab.stoi[inputters.EOS_WORD]

        # Encoder forward.
        src, enc_states, memory_bank, src_lengths, ctc_scores = self._run_encoder(
            batch, data.data_type)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {}
        results["predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["attention"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch, memory_bank, src_lengths, data, batch.src_map
                if data.data_type == 'text' and self.copy_attn else None)
            self.model.decoder.init_state(src, memory_bank, enc_states)
        else:
            results["gold_score"] = [0] * batch_size

        if "tgt_feat_0" in batch.__dict__:
            results["ctc_gold_score"] = self._ctc_score_target(batch, ctc_scores)
        else:
            results["ctc_gold_score"] = [0] * batch_size

        results["ctc_scores"] = [[] for _ in range(batch_size)]  # noqa: F812
        results["ctc_predictions"] = [[] for _ in range(batch_size)]  # noqa: F812
        if self.ctc_ratio > 0:
            ctc_scores, ctc_predictions = self._ctc_decode(ctc_scores)
            results["ctc_scores"] = ctc_scores
            results["ctc_predictions"] = ctc_predictions

        # Tile states and memory beam_size times.
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))
        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
            mb_device = memory_bank[0].device
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
            mb_device = memory_bank.device

        memory_lengths = tile(src_lengths, beam_size)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if data.data_type == 'text' and self.copy_attn else None)

        top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
        batch_offset = torch.arange(batch_size, dtype=torch.long)
        beam_offset = torch.arange(
            0,
            batch_size * beam_size,
            step=beam_size,
            dtype=torch.long,
            device=mb_device)
        alive_seq = torch.full(
            [batch_size * beam_size, 1],
            start_token,
            dtype=torch.long,
            device=mb_device)
        alive_attn = None

        # Give full probability to the first beam on the first step.
        topk_log_probs = (
            torch.tensor([0.0] + [float("-inf")] * (beam_size - 1),
                         device=mb_device).repeat(batch_size))

        # Structure that holds finished hypotheses.
        hypotheses = [[] for _ in range(batch_size)]  # noqa: F812
        if self.lm:
            hidden = self.lm.init_hidden(batch_size)
            hidden = self.lm.map_state(hidden,
                     lambda state, dim: tile(state, beam_size, dim=dim))

        for step in range(max_length):
            decoder_input = alive_seq[:, -1].view(1, -1, 1)

            log_probs, attn = \
                self._decode_and_generate(decoder_input, memory_bank,
                                          batch, data,
                                          memory_lengths=memory_lengths,
                                          src_map=src_map,
                                          step=step,
                                          batch_offset=batch_offset)
            if self.lm:
                lm_outputs, hidden, _, _ = self.lm(decoder_input.squeeze(-1), hidden, return_h=True)
                lm_logits = self.lm.decoder(lm_outputs)
                lm_log_probs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
                log_probs += lm_log_probs * self.lbda

            vocab_size = log_probs.size(-1)

            if step < min_length:
                log_probs[:, end_token] = -1e20

            # Multiply probs by the beam probability.
            log_probs += topk_log_probs.view(-1).unsqueeze(1)

            alpha = self.global_scorer.alpha
            length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha

            # Flatten probs into a list of possibilities.
            curr_scores = log_probs / length_penalty
            curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
            topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)

            # Recover log probs.
            topk_log_probs = topk_scores * length_penalty

            # Resolve beam origin and true word ids.
            topk_beam_index = topk_ids.div(vocab_size)
            topk_ids = topk_ids.fmod(vocab_size)

            # Map beam_index to batch_index in the flat representation.
            batch_index = (
                    topk_beam_index
                    + beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
            select_indices = batch_index.view(-1)

            # Append last prediction.
            alive_seq = torch.cat(
                [alive_seq.index_select(0, select_indices),
                 topk_ids.view(-1, 1)], -1)
            if return_attention:
                current_attn = attn.index_select(1, select_indices)
                if alive_attn is None:
                    alive_attn = current_attn
                else:
                    alive_attn = alive_attn.index_select(1, select_indices)
                    alive_attn = torch.cat([alive_attn, current_attn], 0)

            is_finished = topk_ids.eq(end_token)
            if step + 1 == max_length:
                is_finished.fill_(1)

            # Save finished hypotheses.
            if is_finished.any():
                # Penalize beams that finished.
                topk_log_probs.masked_fill_(is_finished, -1e10)
                is_finished = is_finished.to('cpu')
                top_beam_finished |= is_finished[:, 0].eq(1)
                predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
                attention = (
                    alive_attn.view(
                        alive_attn.size(0), -1, beam_size, alive_attn.size(-1))
                    if alive_attn is not None else None)
                non_finished_batch = []
                for i in range(is_finished.size(0)):
                    b = batch_offset[i]
                    finished_hyp = is_finished[i].nonzero().view(-1)
                    # Store finished hypotheses for this batch.
                    for j in finished_hyp:
                        hypotheses[b].append((
                            topk_scores[i, j],
                            predictions[i, j, 1:],  # Ignore start_token.
                            attention[:, i, j, :memory_lengths[i]]
                            if attention is not None else None))
                    # End condition is the top beam finished and we can return
                    # n_best hypotheses.
                    if top_beam_finished[i] and len(hypotheses[b]) >= n_best:
                        best_hyp = sorted(
                            hypotheses[b], key=lambda x: x[0], reverse=True)
                        for n, (score, pred, attn) in enumerate(best_hyp):
                            if n >= n_best:
                                break
                            results["scores"][b].append(score)
                            results["predictions"][b].append(pred)
                            results["attention"][b].append(
                                attn if attn is not None else [])
                    else:
                        non_finished_batch.append(i)
                non_finished = torch.tensor(non_finished_batch)
                # If all sentences are translated, no need to go further.
                if len(non_finished) == 0:
                    break
                # Remove finished batches for the next step.
                top_beam_finished = top_beam_finished.index_select(
                    0, non_finished)
                batch_offset = batch_offset.index_select(0, non_finished)
                non_finished = non_finished.to(topk_ids.device)
                topk_log_probs = topk_log_probs.index_select(0, non_finished)
                batch_index = batch_index.index_select(0, non_finished)
                select_indices = batch_index.view(-1)
                alive_seq = predictions.index_select(0, non_finished) \
                    .view(-1, alive_seq.size(-1))
                if alive_attn is not None:
                    alive_attn = attention.index_select(1, non_finished) \
                        .view(alive_attn.size(0),
                              -1, alive_attn.size(-1))

            # Reorder states.
            if isinstance(memory_bank, tuple):
                memory_bank = tuple(x.index_select(1, select_indices)
                                    for x in memory_bank)
            else:
                memory_bank = memory_bank.index_select(1, select_indices)

            memory_lengths = memory_lengths.index_select(0, select_indices)
            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))
            if self.lm:
                hidden = self.lm.map_state(hidden,
                    lambda state, dim: state.index_select(dim, select_indices))
            if src_map is not None:
                src_map = src_map.index_select(1, select_indices)

        return results

    def _translate_batch(self, batch, data):
        # (0) Prep each of the components of the search.
        # And helper method for reducing verbosity.
        beam_size = self.beam_size
        batch_size = batch.batch_size
        data_type = data.data_type
        vocab = self.fields["tgt"].vocab

        # Define a list of tokens to exclude from ngram-blocking
        # exclusion_list = ["<t>", "</t>", "."]
        exclusion_tokens = set([vocab.stoi[t]
                                for t in self.ignore_when_blocking])

        beam = [onmt.translate.Beam(beam_size, n_best=self.n_best,
                                    cuda=self.cuda,
                                    global_scorer=self.global_scorer,
                                    pad=vocab.stoi[inputters.PAD_WORD],
                                    eos=vocab.stoi[inputters.EOS_WORD],
                                    bos=vocab.stoi[inputters.BOS_WORD],
                                    min_length=self.min_length,
                                    stepwise_penalty=self.stepwise_penalty,
                                    block_ngram_repeat=self.block_ngram_repeat,
                                    exclusion_tokens=exclusion_tokens)
                for __ in range(batch_size)]

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths, _ = self._run_encoder(
            batch, data_type)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        results = {}
        results["predictions"] = []
        results["scores"] = []
        results["attention"] = []
        results["batch"] = batch
        if "tgt" in batch.__dict__:
            results["gold_score"] = self._score_target(
                batch, memory_bank, src_lengths, data, batch.src_map
                if data_type == 'text' and self.copy_attn else None)
            self.model.decoder.init_state(src, memory_bank, enc_states)
        else:
            results["gold_score"] = [0] * batch_size

        # (2) Repeat src objects `beam_size` times.
        # We use now  batch_size x beam_size (same as fast mode)
        src_map = (tile(batch.src_map, beam_size, dim=1)
                   if data.data_type == 'text' and self.copy_attn else None)
        self.model.decoder.map_state(
            lambda state, dim: tile(state, beam_size, dim=dim))

        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
        else:
            memory_bank = tile(memory_bank, beam_size, dim=1)
        memory_lengths = tile(src_lengths, beam_size)

        # (3) run the decoder to generate sentences, using beam search.
        for i in range(self.max_length):
            if all((b.done() for b in beam)):
                break

            # (a) Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.

            inp = torch.stack([b.get_current_state() for b in beam])
            inp = inp.view(1, -1, 1)

            # (b) Decode and forward
            out, beam_attn = \
                self._decode_and_generate(inp, memory_bank, batch, data,
                                          memory_lengths=memory_lengths,
                                          src_map=src_map, step=i)

            out = out.view(batch_size, beam_size, -1)
            beam_attn = beam_attn.view(batch_size, beam_size, -1)

            # (c) Advance each beam.
            select_indices_array = []
            # Loop over the batch_size number of beam
            for j, b in enumerate(beam):
                b.advance(out[j, :],
                          beam_attn.data[j, :, :memory_lengths[j]])
                select_indices_array.append(
                    b.get_current_origin() + j * beam_size)
            select_indices = torch.cat(select_indices_array)

            self.model.decoder.map_state(
                lambda state, dim: state.index_select(dim, select_indices))

        # (4) Extract sentences from beam.
        for b in beam:
            n_best = self.n_best
            scores, ks = b.sort_finished(minimum=n_best)
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.get_hyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            results["predictions"].append(hyps)
            results["scores"].append(scores)
            results["attention"].append(attn)

        return results

    def _lm_rescore(self, results):
        vocab = self.fields["tgt"].vocab
        start_token = vocab.stoi[inputters.BOS_WORD]
        def _rescore(beam, scores):
            def modify_score(sent, score):
                sent = torch.cat([sent.new([start_token]), sent])
                hidden = self.lm.init_hidden(1)
                lm_outputs, hidden, _, _ = self.lm(sent[:-1].unsqueeze(1), hidden, return_h=True)
                lm_logits = self.lm.decoder(lm_outputs)
                lm_log_probs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
                lm_score = lm_log_probs.gather(1, sent[1:].unsqueeze(1)).sum()
                return sent[1:], score + self.lbda * lm_score
            beam, scores = zip(*sorted(list(map(modify_score, beam, scores)), key=lambda x: x[1]))
            return beam, scores
        return zip(*list(map(_rescore, results['predictions'], results['scores'])))

    def _score_target(self, batch, memory_bank, src_lengths, data, src_map):
        tgt_in = inputters.make_features(batch, 'tgt')[:-1]

        log_probs, attn = \
            self._decode_and_generate(tgt_in, memory_bank, batch, data,
                                      memory_lengths=src_lengths,
                                      src_map=src_map)
        tgt_pad = self.fields["tgt"].vocab.stoi[inputters.PAD_WORD]

        log_probs[:, :, tgt_pad] = 0
        gold = batch.tgt[1:].unsqueeze(2)
        gold_scores = log_probs.gather(2, gold)
        gold_scores = gold_scores.sum(dim=0).view(-1)

        return gold_scores

    def _ctc_score_target(self, batch, ctc_scores):
        ctc_target = batch.tgt_feat_0[1:] 
        in_len = ctc_scores.size(0)
        batch_size = ctc_scores.size(1)
        ctc_target = ctc_target.transpose(0,1)
        input_lengths =  torch.full((batch_size,), in_len, dtype=torch.int32)

        padding_idx = self.ctc_vocab.stoi[inputters.PAD_WORD]
        eos_idx = self.ctc_vocab.stoi[inputters.EOS_WORD]
        sil_idx = self.ctc_vocab.stoi['$']
        valid_indices = ctc_target.ne(padding_idx) * \
                        ctc_target.ne(eos_idx) * \
                        ctc_target.ne(sil_idx)
        target_lengths = valid_indices.sum(dim=1).cpu().int()
        ctc_target = ctc_target.masked_select(valid_indices)
        ctc_gold_scores = self.ctc_loss(ctc_scores, ctc_target, input_lengths, target_lengths)
        return ctc_gold_scores

    def _report_score(self, name, score_total, words_total):
        if words_total == 0:
            msg = "%s No words predicted" % (name,)
        else:
            msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
                name, score_total / words_total,
                name, math.exp(-score_total / words_total)))
        return msg

    def _report_bleu(self, tgt_path):
        import subprocess
        base_dir = os.path.abspath(__file__ + "/../../..")
        # Rollback pointer to the beginning.
        self.out_file.seek(0)
        print()

        res = subprocess.check_output("perl %s/tools/multi-bleu.perl %s"
                                      % (base_dir, tgt_path),
                                      stdin=self.out_file,
                                      shell=True).decode("utf-8")

        msg = ">> " + res.strip()
        return msg

    def _report_rouge(self, tgt_path):
        import subprocess
        path = os.path.split(os.path.realpath(__file__))[0]
        res = subprocess.check_output(
            "python %s/tools/test_rouge.py -r %s -c STDIN"
            % (path, tgt_path),
            shell=True,
            stdin=self.out_file).decode("utf-8")
        msg = res.strip()
        return msg