Ejemplo n.º 1
0
  def translate_batch(self, batch):
    def get_inst_idx_to_tensor_position_map(inst_idx_list):
      ''' Indicate the position of an instance in a tensor. '''
      return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
    
    def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
      ''' Collect tensor parts associated to active instances. '''

      _, *d_hs = beamed_tensor.size()
      n_curr_active_inst = len(curr_active_inst_idx)
      new_shape = (n_curr_active_inst * n_bm, *d_hs)

      beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
      beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
      beamed_tensor = beamed_tensor.view(*new_shape)

      return beamed_tensor
    
    def beam_decode_step(
      inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm):
      ''' Decode and update beam status, and then return active beam idx '''
      # len_dec_seq: i (starting from 0)

      def prepare_beam_dec_seq(inst_dec_beams):
        dec_seq = [b.get_last_target_word() for b in inst_dec_beams if not b.done]
        # dec_seq: [(beam_size)] * batch_size
        dec_seq = torch.stack(dec_seq).to(self.device)
        # dec_seq: (batch_size, beam_size)
        dec_seq = dec_seq.view(1, -1)
        # dec_seq: (1, batch_size * beam_size)
        return dec_seq

      def predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq):
        # dec_seq: (1, batch_size * beam_size)
        dec_output, *_ = self.model.decoder(dec_seq, step=len_dec_seq)
        # dec_output: (1, batch_size * beam_size, hid_size)
        word_prob = self.model.generator(dec_output.squeeze(0))
        # word_prob: (batch_size * beam_size, vocab_size)
        
        word_prob = word_prob.view(n_active_inst, n_bm, -1)
        # word_prob: (batch_size, beam_size, vocab_size)

        return word_prob

      def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
        active_inst_idx_list = []
        select_indices_array = []
        for inst_idx, inst_position in inst_idx_to_position_map.items():
          is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
          if not is_inst_complete:
            active_inst_idx_list += [inst_idx]
            select_indices_array.append(inst_beams[inst_idx].get_current_origin() + inst_position * n_bm)
        if len(select_indices_array) > 0:
          select_indices = torch.cat(select_indices_array)
        else:
          select_indices = None
        return active_inst_idx_list, select_indices

      n_active_inst = len(inst_idx_to_position_map)

      dec_seq = prepare_beam_dec_seq(inst_dec_beams)
      # dec_seq: (1, batch_size * beam_size)
      word_prob = predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq)

      # Update the beam with predicted word prob information and collect incomplete instances
      active_inst_idx_list, select_indices = collect_active_inst_idx_list(
        inst_dec_beams, word_prob, inst_idx_to_position_map)
      
      if select_indices is not None:
        assert len(active_inst_idx_list) > 0
        self.model.decoder.map_state(
            lambda state, dim: state.index_select(dim, select_indices))
      return active_inst_idx_list
    
    def collate_active_info(
        src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
      # Sentences which are still active are collected,
      # so the decoder will not run on completed sentences.
      n_prev_active_inst = len(inst_idx_to_position_map)
      active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
      active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

      active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
      active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
      active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

      return active_src_seq, active_src_enc, active_inst_idx_to_position_map

    def collect_best_hypothesis_and_score(inst_dec_beams):
      hyps, scores = [], []
      for inst_idx in range(len(inst_dec_beams)):
        hyp, score = inst_dec_beams[inst_idx].get_best_hypothesis()
        hyps.append(hyp)
        scores.append(score)
        
      return hyps, scores
    
    with torch.no_grad():
      #-- Encode
      src_seq = make_features(batch, 'src')
      # src: (seq_len_src, batch_size)
      src_emb, src_enc, _ = self.model.encoder(src_seq)
      # src_emb: (seq_len_src, batch_size, emb_size)
      # src_end: (seq_len_src, batch_size, hid_size)

      self.model.decoder.init_state(src_seq, src_enc)
      src_len = src_seq.size(0)
      
      #-- Repeat data for beam search
      n_bm = self.beam_size
      n_inst = src_seq.size(1)
      self.model.decoder.map_state(lambda state, dim: tile(state, n_bm, dim=dim))
      # src_enc: (seq_len_src, batch_size * beam_size, hid_size)
      
      #-- Prepare beams
      decode_length = src_len + self.decode_extra_length
      decode_min_length = 0
      if self.decode_min_length >= 0:
        decode_min_length = src_len - self.decode_min_length
      if self.task_type == 'task':
        inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt_bos_id, eos_id=self.tgt_eos_id, device=self.device) for _ in range(n_inst)]
      else:
        inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt2_bos_id, eos_id=self.tgt2_eos_id, device=self.device) for _ in range(n_inst)]

      #-- Bookkeeping for active or not
      active_inst_idx_list = list(range(n_inst))
      inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
      
      #-- Decode
      for len_dec_seq in range(0, decode_length):
        active_inst_idx_list = beam_decode_step(
          inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm)
        
        if not active_inst_idx_list:
          break  # all instances have finished their path to <EOS>

        inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
        
    batch_hyps, batch_scores = collect_best_hypothesis_and_score(inst_dec_beams)
    return batch_hyps, batch_scores
      
Ejemplo n.º 2
0
    def translate_batch(self,
                        src,
                        src_len,
                        tgt,
                        max_length,
                        min_length=0,
                        ratio=0.,
                        n_best=1,
                        return_attention=False):

        with torch.no_grad():
            # (0) Prep the components of the search.
            beam_size = self.beam_size
            batch_size = src.size(1)

            # (1) pt 1, Run the encoder on the src.
            enc_states, memory_bank, src_lengths = self.model.encoder(
                src, src_len)
            # (1) pt 2, Convert encoder state to decoder state
            enc_states = self.model.bridge(enc_states)
            # (1) pt 3, Make size of memory bank same as that of decoder
            memory_bank = self.model.W(memory_bank)
            self.model.decoder.init_state(enc_states)

            results = {
                "predictions": None,
                "scores": None,
                "attention": None,
                "gold_score": self._gold_score(tgt, enc_states, memory_bank, src_lengths) \
                    if tgt is not None else None
            }

            # (2) Repeat src objects `beam_size` times.
            # We use batch_size x beam_size
            self.model.decoder.map_state(
                lambda state, dim: tile(state, beam_size, dim=dim))

            memory_bank = tile(memory_bank, beam_size, dim=1)
            memory_lengths = tile(src_lengths, beam_size)

            # (0) pt 2, prep the beam object
            beam = BeamSearch(beam_size=beam_size,
                              batch_size=batch_size,
                              pad=self._tgt_pad_idx,
                              bos=self._tgt_bos_idx,
                              eos=self._tgt_eos_idx,
                              n_best=n_best,
                              device=self._dev,
                              min_length=min_length,
                              max_length=max_length,
                              return_attention=return_attention,
                              block_ngram_repeat=self.block_ngram_repeat,
                              exclusion_tokens=self._exclusion_idxs,
                              memory_lengths=memory_lengths,
                              ratio=ratio)

            for step in range(max_length):
                decoder_input = beam.current_predictions.view(1, -1, 1)

                log_probs, attn = self.model.decoder(
                    decoder_input, memory_bank, memory_lengths=memory_lengths)
                log_probs = log_probs.squeeze(0)

                beam.advance(log_probs, attn)
                any_beam_is_finished = beam.is_finished.any()
                if any_beam_is_finished:
                    beam.update_finished()
                    if beam.done:
                        break

                select_indices = beam.current_origin

                if any_beam_is_finished:
                    # Reorder states.
                    memory_bank = memory_bank.index_select(1, select_indices)
                    memory_lengths = memory_lengths.index_select(
                        0, select_indices)

                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

            results["scores"] = beam.scores
            results["predictions"] = beam.predictions
            results["attention"] = beam.attention
            return results