Exemplo n.º 1
0
Arquivo: ocr.py Projeto: hedinang/ocr
 def translate_beam_search(self, img):
     with torch.no_grad():
         memory = self.transformer(img)
         beam = Beam(beam_size=2,
                     min_length=0,
                     n_top=1,
                     ranker=None,
                     start_token_id=1,
                     end_token_id=2)
         for _ in range(128):
             tgt_inp = beam.get_current_state().transpose(0, 1).to(
                 self.device)  # TxN
             decoder_outputs = self.transformer.transformer.forward_decoder(
                 tgt_inp, memory)
             log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0),
                                    dim=-1)
             beam.advance(log_prob.cpu())
             if beam.done():
                 break
         scores, ks = beam.sort_finished(minimum=1)
         hypothesises = []
         for times, k in ks:
             hypothesis = beam.get_hypothesis(times, k)
             hypothesises.append(hypothesis)
         encode = [1] + [int(i) for i in hypothesises[0][:-1]]
         return self.vocab.decode(encode)
Exemplo n.º 2
0
    def predict_one(self, source, num_candidates=5):
        source_preprocessed = self.preprocess(source)
        source_tensor = torch.tensor(source_preprocessed).unsqueeze(
            0)  # why unsqueeze?
        length_tensor = torch.tensor(len(source_preprocessed)).unsqueeze(0)

        sources_mask = pad_masking(source_tensor, source_tensor.size(1))
        memory_mask = pad_masking(source_tensor, 1)
        memory = self.model.encoder(source_tensor, sources_mask)

        decoder_state = self.model.decoder.init_decoder_state()
        # print('decoder_state src', decoder_state.src.shape)
        # print('previous_input previous_input', decoder_state.previous_input)
        # print('previous_input previous_layer_inputs ', decoder_state.previous_layer_inputs)

        # Repeat beam_size times
        memory_beam = memory.detach().repeat(
            self.beam_size, 1, 1)  # (beam_size, seq_len, hidden_size)

        beam = Beam(beam_size=self.beam_size,
                    min_length=0,
                    n_top=num_candidates,
                    ranker=None)

        for _ in range(self.max_length):

            new_inputs = beam.get_current_state().unsqueeze(
                1)  # (beam_size, seq_len=1)
            decoder_outputs, decoder_state = self.model.decoder(
                new_inputs, memory_beam, memory_mask, state=decoder_state)
            # decoder_outputs: (beam_size, target_seq_len=1, vocabulary_size)
            # attentions['std']: (target_seq_len=1, beam_size, source_seq_len)

            attention = self.model.decoder.decoder_layers[
                -1].memory_attention_layer.sublayer.attention
            beam.advance(decoder_outputs.squeeze(1), attention)

            beam_current_origin = beam.get_current_origin()  # (beam_size, )
            decoder_state.beam_update(beam_current_origin)

            if beam.done():
                break

        scores, ks = beam.sort_finished(minimum=num_candidates)
        hypothesises, attentions = [], []
        for i, (times, k) in enumerate(ks[:num_candidates]):
            hypothesis, attention = beam.get_hypothesis(times, k)
            hypothesises.append(hypothesis)
            attentions.append(attention)

        self.attentions = attentions
        self.hypothesises = [[token.item() for token in h]
                             for h in hypothesises]
        hs = [self.postprocess(h) for h in self.hypothesises]
        return list(reversed(hs))