def translate_beam_search(self, img): with torch.no_grad(): memory = self.transformer(img) beam = Beam(beam_size=2, min_length=0, n_top=1, ranker=None, start_token_id=1, end_token_id=2) for _ in range(128): tgt_inp = beam.get_current_state().transpose(0, 1).to( self.device) # TxN decoder_outputs = self.transformer.transformer.forward_decoder( tgt_inp, memory) log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0), dim=-1) beam.advance(log_prob.cpu()) if beam.done(): break scores, ks = beam.sort_finished(minimum=1) hypothesises = [] for times, k in ks: hypothesis = beam.get_hypothesis(times, k) hypothesises.append(hypothesis) encode = [1] + [int(i) for i in hypothesises[0][:-1]] return self.vocab.decode(encode)
def predict_one(self, source, num_candidates=5): source_preprocessed = self.preprocess(source) source_tensor = torch.tensor(source_preprocessed).unsqueeze( 0) # why unsqueeze? length_tensor = torch.tensor(len(source_preprocessed)).unsqueeze(0) sources_mask = pad_masking(source_tensor, source_tensor.size(1)) memory_mask = pad_masking(source_tensor, 1) memory = self.model.encoder(source_tensor, sources_mask) decoder_state = self.model.decoder.init_decoder_state() # print('decoder_state src', decoder_state.src.shape) # print('previous_input previous_input', decoder_state.previous_input) # print('previous_input previous_layer_inputs ', decoder_state.previous_layer_inputs) # Repeat beam_size times memory_beam = memory.detach().repeat( self.beam_size, 1, 1) # (beam_size, seq_len, hidden_size) beam = Beam(beam_size=self.beam_size, min_length=0, n_top=num_candidates, ranker=None) for _ in range(self.max_length): new_inputs = beam.get_current_state().unsqueeze( 1) # (beam_size, seq_len=1) decoder_outputs, decoder_state = self.model.decoder( new_inputs, memory_beam, memory_mask, state=decoder_state) # decoder_outputs: (beam_size, target_seq_len=1, vocabulary_size) # attentions['std']: (target_seq_len=1, beam_size, source_seq_len) attention = self.model.decoder.decoder_layers[ -1].memory_attention_layer.sublayer.attention beam.advance(decoder_outputs.squeeze(1), attention) beam_current_origin = beam.get_current_origin() # (beam_size, ) decoder_state.beam_update(beam_current_origin) if beam.done(): break scores, ks = beam.sort_finished(minimum=num_candidates) hypothesises, attentions = [], [] for i, (times, k) in enumerate(ks[:num_candidates]): hypothesis, attention = beam.get_hypothesis(times, k) hypothesises.append(hypothesis) attentions.append(attention) self.attentions = attentions self.hypothesises = [[token.item() for token in h] for h in hypothesises] hs = [self.postprocess(h) for h in self.hypothesises] return list(reversed(hs))