Beispiel #1
0
def beamsearch(memory, model, device, beam_size=4, candidates=1, max_seq_length=128, bos_token=1, eos_token=2):
    # memory: Tx1xE
    model.eval()

    beam = Beam(beam_size=beam_size, min_length=0, n_top=candidates, ranker=None, start_token_id=bos_token,
                end_token_id=eos_token)

    with torch.no_grad():
        #        memory = memory.repeat(1, beam_size, 1) # TxNxE
        memory = model.SequenceModeling.expand_memory(memory, beam_size)

        for _ in range(max_seq_length):

            tgt_inp = beam.get_current_state().transpose(0, 1).to(device)  # TxN
            decoder_outputs, memory = model.SequenceModeling.forward_decoder(tgt_inp, memory)

            log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0), dim=-1)
            beam.advance(log_prob.cpu())

            if beam.done():
                break

        scores, ks = beam.sort_finished(minimum=1)

        hypothesises = []
        for i, (times, k) in enumerate(ks[:candidates]):
            hypothesis = beam.get_hypothesis(times, k)
            hypothesises.append(hypothesis)

    return [1] + [int(i) for i in hypothesises[0][:-1]]
Beispiel #2
0
    def translate(self, src, trg, beam_size, Lang2):
        ''' beam search decoding. '''
        '''
        :param src:   [src_max_len, batch]    ## batch = 1
        :param trg:   [trg_max_len, batch]    ## batch = 1
        :param sentence:  [sentence_len]
        :return: best translate candidate
        '''
        max_len = trg.size(0)
        encoder_output, hidden = self.encoder(src)
        '''
            ## src: [src_max_len, batch]
            ## encoder_output: [src_max_len, batch, hidden_size]
            ## hidden: (num_layers * num_directions, batch, hidden_size) -> [2, batch, hidden_size]
        '''
        hidden = hidden[:self.decoder.
                        n_layers]  # [n_layers, batch, hidden_size]
        # trg: [trg_max_len, batch]
        output = Variable(trg.data[0, :])  # sos  [batch]

        beam = Beam(beam_size, Lang2.vocab.stoi, True)
        input_feeding = None
        for t in range(1, max_len):
            # output:  [batch] -> [batch, output_size]
            output, hidden, attn_weights = self.decoder(
                output, hidden, encoder_output, input_feeding)

            input_feeding = output
            output = self.decoder.out(output)
            output = F.log_softmax(output, dim=1)

            workd_lk = output
            if output.size(0) == 1:

                output_prob = output.squeeze(0)  ## [output_size]
                workd_lk = output_prob.expand(
                    beam_size,
                    output_prob.size(0))  ## [beam_size, output_size]

                # [n_layers, batch, hidden_size]
                hidden = hidden.squeeze(1)  # [n_layers, hidden_size]
                hidden = hidden.expand(
                    beam_size, hidden.size(0),
                    hidden.size(1))  # [beam_size, n_layers, hidden_size]
                hidden = hidden.transpose(
                    0, 1)  # [n_layers, beam_size, hidden_size]

                # [src_max_len, batch, hidden_size]
                encoder_output = encoder_output.squeeze(
                    1)  ## [src_max_len, hidden_size]
                encoder_output = encoder_output.expand(
                    beam_size, encoder_output.size(0), encoder_output.size(
                        1))  ## [beam_size, src_max_len, hidden_size]
                encoder_output = encoder_output.transpose(
                    0, 1)  ## [src_max_len, beam_size, hidden_size]
                input_feeding = input_feeding.squeeze(0)
                input_feeding = input_feeding.expand(beam_size,
                                                     input_feeding.size(0))

            flag = beam.advance(workd_lk)
            if flag:
                break

            nextInputs = beam.get_current_state()
            # print("[nextInputs]:", nextInputs)
            output = nextInputs
            # output = Variable(nextInputs).cuda()

            originState = beam.get_current_origin()
            ## print("[origin_state]:", originState)
            hidden = hidden[:, originState]
            input_feeding = input_feeding[originState]

        xx, yy = beam.get_best()
        zz = beam.get_final()
        return xx, yy, zz