示例#1
0
    def beam_search(model,
                    batch,
                    beam_size,
                    start=1,
                    end=2,
                    pad=0,
                    min_length=3,
                    min_n_best=5,
                    max_ts=40,
                    block_ngram=0):
        """ Beam search given the model and Batch
        This function uses model with the following reqs:
        - model.encoder takes input returns tuple (enc_out, enc_hidden, attn_mask)
        - model.decoder takes decoder params and returns decoder outputs after attn
        - model.output takes decoder outputs and returns distr over dictionary

        Function arguments:
        model : nn.Module, here defined in modules.py
        batch : Batch structure with input and labels
        beam_size : Size of each beam during the search
        start : start of sequence token
        end : end of sequence token
        pad : padding token
        min_length : minimum length of the decoded sequence
        min_n_best : minimum number of completed hypothesis generated from each beam
        max_ts: the maximum length of the decoded sequence

        Return:
        beam_preds_scores : list of tuples (prediction, score) for each sample in Batch
        n_best_preds_scores : list of n_best list of tuples (prediction, score) for
                              each sample from Batch
        beams : list of Beam instances defined in Beam class, can be used for any
                following postprocessing, e.g. dot logging.
        """
        encoder_states = model.encoder(batch.text_vec)
        enc_out = encoder_states[0]
        enc_hidden = encoder_states[1]
        attn_mask = encoder_states[2]
        current_device = encoder_states[0][0].device

        batch_size = len(batch.text_lengths)
        beams = [
            Beam(beam_size,
                 min_length=min_length,
                 padding_token=pad,
                 bos_token=start,
                 eos_token=end,
                 min_n_best=min_n_best,
                 cuda=current_device,
                 block_ngram=block_ngram) for i in range(batch_size)
        ]
        decoder_input = torch.Tensor([start]).detach().expand(
            batch_size, 1).long().to(current_device)
        # repeat encoder_outputs, hiddens, attn_mask
        decoder_input = decoder_input.repeat(1, beam_size).view(
            beam_size * batch_size, -1)
        enc_out = enc_out.unsqueeze(1).repeat(1, beam_size, 1,
                                              1).view(batch_size * beam_size,
                                                      -1, enc_out.size(-1))
        attn_mask = encoder_states[2].repeat(1, beam_size).view(
            attn_mask.size(0) * beam_size, -1)
        repeated_hiddens = []
        if isinstance(enc_hidden, tuple):  # LSTM
            for i in range(len(enc_hidden)):
                repeated_hiddens.append(enc_hidden[i].unsqueeze(2).repeat(
                    1, 1, beam_size, 1))
            num_layers = enc_hidden[0].size(0)
            hidden_size = enc_hidden[0].size(-1)
            enc_hidden = tuple([
                repeated_hiddens[i].view(num_layers, batch_size * beam_size,
                                         hidden_size)
                for i in range(len(repeated_hiddens))
            ])
        else:  # GRU
            num_layers = enc_hidden.size(0)
            hidden_size = enc_hidden.size(-1)
            enc_hidden = enc_hidden.unsqueeze(2).repeat(
                1, 1, beam_size, 1).view(num_layers, batch_size * beam_size,
                                         hidden_size)

        hidden = enc_hidden
        for ts in range(max_ts):
            if all((b.done() for b in beams)):
                break
            output, hidden = model.decoder(decoder_input, hidden,
                                           (enc_out, attn_mask))
            score = model.output(output)
            # score contains softmax scores for batch_size * beam_size samples
            score = score.view(batch_size, beam_size, -1)
            score = F.log_softmax(score, dim=-1)
            for i, b in enumerate(beams):
                b.advance(score[i])
            decoder_input = torch.cat([
                b.get_output_from_current_step() for b in beams
            ]).unsqueeze(-1)
            permute_hidden_idx = torch.cat([
                beam_size * i + b.get_backtrack_from_current_step()
                for i, b in enumerate(beams)
            ])
            # permute decoder hiddens with respect to chosen hypothesis now
            if isinstance(hidden, tuple):  # LSTM
                for i in range(len(hidden)):
                    hidden[i].data.copy_(hidden[i].data.index_select(
                        dim=1, index=permute_hidden_idx))
            else:  # GRU
                hidden.data.copy_(
                    hidden.data.index_select(dim=1, index=permute_hidden_idx))
        for b in beams:
            b.check_finished()

        beam_preds_scores = [list(b.get_top_hyp()) for b in beams]
        for pair in beam_preds_scores:
            pair[0] = Beam.get_pretty_hypothesis(pair[0])

        n_best_beams = [
            b.get_rescored_finished(n_best=min_n_best) for b in beams
        ]
        n_best_beam_preds_scores = []
        for i, beamhyp in enumerate(n_best_beams):
            this_beam = []
            for hyp in beamhyp:
                pred = beams[i].get_pretty_hypothesis(
                    beams[i].get_hyp_from_finished(hyp))
                score = hyp.score
                this_beam.append((pred, score))
            n_best_beam_preds_scores.append(this_beam)

        return beam_preds_scores, n_best_beam_preds_scores, beams
示例#2
0
    def translate_batch(self, src_seq, src_pos):
        ''' Translation work in one batch '''
        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }

        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(src_seq, src_enc, inst_idx_to_position_map,
                                active_inst_idx_list, device):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output,
                             inst_idx_to_position_map, n_bm, device):
            ''' Decode and update beam status, and then return active beam idx '''
            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq, device):
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm, device):
                dec_partial_pos = torch.arange(1,
                                               len_dec_seq + 1,
                                               dtype=torch.long,
                                               device=device)
                dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(
                    n_active_inst * n_bm, 1)
                return dec_partial_pos

            def predict_word(dec_seq, dec_pos, src_seq, enc_output,
                             n_active_inst, n_bm):
                dec_output, *_ = self.decoder(dec_seq, dec_pos, src_seq,
                                              enc_output)
                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
                word_prob = F.log_softmax(self.tgt_word_prj(dec_output), dim=1)
                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items(
                ):
                    is_inst_complete = inst_beams[inst_idx].advance(
                        word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq, device)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm,
                                           device)
            word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output,
                                     n_active_inst, n_bm)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():
            #-- Encode
            #src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
            src_enc, *_ = self.encoder(src_seq, src_pos)

            device = src_seq.device.type
            #-- Repeat data for beam search
            n_bm = self.opt['beam_size']
            n_inst, len_s, d_h = src_enc.size()
            src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s,
                                                      d_h)

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, device=src_seq.device.type) for _ in range(n_inst)
            ]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            #-- Decode
            for len_dec_seq in range(1, self.opt['max_token_seq_len'] + 1):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc,
                    inst_idx_to_position_map, n_bm, device)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map,
                    active_inst_idx_list, device)

        n_best = 1
        batch_hyp, batch_scores = collect_hypothesis_and_scores(
            inst_dec_beams, n_best)

        return batch_hyp, batch_scores
示例#3
0
    def forward(self, xs, ys=None, cands=None, valid_cands=None, prev_enc=None,
                rank_during_training=False, beam_size=1, topk=1):
        """Get output predictions from the model.

        Arguments:
        xs -- input to the encoder
        ys -- expected output from the decoder
        cands -- set of candidates to rank, if applicable
        valid_cands -- indices to match candidates with their appropriate xs
        prev_enc -- if you know you'll pass in the same xs multiple times and
            the model is in eval mode, you can pass in the encoder output from
            the last forward pass to skip recalcuating the same encoder output
        rank_during_training -- (default False) if set, ranks any available
            cands during training as well
        """
        input_xs = xs
        nbest_beam_preds, nbest_beam_scores = None, None
        bsz = len(xs)
        if ys is not None:
            # keep track of longest label we've ever seen
            # we'll never produce longer ones than that during prediction
            self.longest_label = max(self.longest_label, ys.size(1))

        if prev_enc is not None:
            enc_out, hidden, attn_mask = prev_enc
        else:
            enc_out, hidden = self.encoder(xs)
            attn_mask = xs.ne(0).float() if self.attn_type != 'none' else None
        encoder_states = (enc_out, hidden, attn_mask)
        start = self.START.detach()
        starts = start.expand(bsz, 1)

        predictions = []
        scores = []
        cand_preds, cand_scores = None, None
        if self.rank and cands is not None:
            decode_params = (start, hidden, enc_out, attn_mask)
            if self.training:
                if rank_during_training:
                    cand_preds, cand_scores = self.ranker.forward(cands, valid_cands, decode_params=decode_params)
            else:
                cand_preds, cand_scores = self.ranker.forward(cands, valid_cands, decode_params=decode_params)

        if ys is not None:
            y_in = ys.narrow(1, 0, ys.size(1) - 1)
            xs = torch.cat([starts, y_in], 1)
            if self.attn_type == 'none':
                preds, score, hidden = self.decoder(xs, hidden, enc_out, attn_mask)
                predictions.append(preds)
                scores.append(score)
            else:
                for i in range(ys.size(1)):
                    xi = xs.select(1, i)
                    preds, score, hidden = self.decoder(xi, hidden, enc_out, attn_mask)
                    predictions.append(preds)
                    scores.append(score)
        else:
            # here we do search: supported search types: greedy, beam search
            if beam_size == 1:
                done = [False for _ in range(bsz)]
                total_done = 0
                xs = starts

                for _ in range(self.longest_label):
                    # generate at most longest_label tokens
                    preds, score, hidden = self.decoder(xs, hidden, enc_out, attn_mask, topk)
                    scores.append(score)
                    xs = preds
                    predictions.append(preds)

                    # check if we've produced the end token
                    for b in range(bsz):
                        if not done[b]:
                            # only add more tokens for examples that aren't done
                            if preds.data[b][0] == self.END_IDX:
                                # if we produced END, we're done
                                done[b] = True
                                total_done += 1
                    if total_done == bsz:
                        # no need to generate any more
                        break

            elif beam_size > 1:
                enc_out, hidden = encoder_states[0], encoder_states[1]  # take it from encoder
                enc_out = enc_out.unsqueeze(1).repeat(1, beam_size, 1, 1)
                # create batch size num of beams
                data_device = enc_out.device
                beams = [Beam(beam_size, 3, 0, 1, 2, min_n_best=beam_size / 2, cuda=data_device) for _ in range(bsz)]
                # init the input with start token
                xs = starts
                # repeat tensors to support batched beam
                xs = xs.repeat(1, beam_size)
                attn_mask = input_xs.ne(0).float()
                attn_mask = attn_mask.unsqueeze(1).repeat(1, beam_size, 1)
                repeated_hidden = []

                if isinstance(hidden, tuple):
                    for i in range(len(hidden)):
                        repeated_hidden.append(hidden[i].unsqueeze(2).repeat(1, 1, beam_size, 1))
                    hidden = self.unbeamize_hidden(tuple(repeated_hidden), beam_size, bsz)
                else:  # GRU
                    repeated_hidden = hidden.unsqueeze(2).repeat(1, 1, beam_size, 1)
                    hidden = self.unbeamize_hidden(repeated_hidden, beam_size, bsz)
                enc_out = self.unbeamize_enc_out(enc_out, beam_size, bsz)
                xs = xs.view(bsz * beam_size, -1)
                for step in range(self.longest_label):
                    if all((b.done() for b in beams)):
                        break
                    out = self.decoder(xs, hidden, enc_out)
                    scores = out[1]
                    scores = scores.view(bsz, beam_size, -1)  # -1 is a vocab size
                    for i, b in enumerate(beams):
                        b.advance(F.log_softmax(scores[i, :], dim=-1))
                    xs = torch.cat([b.get_output_from_current_step() for b in beams]).unsqueeze(-1)
                    permute_hidden_idx = torch.cat(
                        [beam_size * i + b.get_backtrack_from_current_step() for i, b in enumerate(beams)])
                    new_hidden = out[2]
                    if isinstance(hidden, tuple):
                        for i in range(len(hidden)):
                            hidden[i].data.copy_(new_hidden[i].data.index_select(dim=1, index=permute_hidden_idx))
                    else:  # GRU
                        hidden.data.copy_(new_hidden.data.index_select(dim=1, index=permute_hidden_idx))

                for b in beams:
                    b.check_finished()
                beam_pred = [b.get_pretty_hypothesis(b.get_top_hyp()[0])[1:] for b in beams]
                # these beam scores are rescored with length penalty!
                beam_scores = torch.stack([b.get_top_hyp()[1] for b in beams])
                pad_length = max([t.size(0) for t in beam_pred])
                beam_pred = torch.stack([pad(t, length=pad_length, dim=0) for t in beam_pred], dim=0)

                #  prepare n best list for each beam
                n_best_beam_tails = [b.get_rescored_finished(n_best=len(b.finished)) for b in beams]
                nbest_beam_scores = []
                nbest_beam_preds = []
                for i, beamtails in enumerate(n_best_beam_tails):
                    perbeam_preds = []
                    perbeam_scores = []
                    for tail in beamtails:
                        perbeam_preds.append(beams[i].get_pretty_hypothesis(beams[i].get_hyp_from_finished(tail)))
                        perbeam_scores.append(tail.score)
                    nbest_beam_scores.append(perbeam_scores)
                    nbest_beam_preds.append(perbeam_preds)

                if self.beam_log_freq > 0.0:
                    num_dump = round(bsz * self.beam_log_freq)
                    for i in range(num_dump):
                        dot_graph = beams[i].get_beam_dot(dictionary=self.dict)
                        dot_graph.write_png(os.path.join(self.beam_dump_path, "{}.png".format(self.beam_dump_filecnt)))
                        self.beam_dump_filecnt += 1

                predictions = beam_pred
                scores = beam_scores

        if isinstance(predictions, list):
            predictions = torch.cat(predictions, 1)
        if isinstance(scores, list):
            scores = torch.cat(scores, 1)

        return predictions, scores, cand_preds, cand_scores, encoder_states, nbest_beam_preds, nbest_beam_scores