Exemplo n.º 1
0
    def score(self, examples, return_enc_state=False):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        src_words = [e.src for e in examples]
        tgt_words = [e.tgt for e in examples]

        src_length = [len(c) for c in src_words]
        seqs_x = to_input_variable(src_words, self.vocab.src, cuda=args.cuda, batch_first=True)
        seqs_y = to_input_variable(tgt_words, self.vocab.tgt, max_len=self.max_len,
                                   cuda=args.cuda,
                                   append_boundary_sym=True,
                                   batch_first=True)
        # log_probs = self.forward(seqs_x, seqs_y, x_length=src_length, to_word=False)['prob']
        # # y_label = seqs_y[:, 1:].contiguous()
        # y_label = seqs_y.contiguous()
        #
        # words_norm = y_label.ne(-1).float().sum(1)
        #
        # loss = self.critic(inputs=log_probs, labels=y_label, reduce=False, normalization=self.normalization)
        #
        # if self.norm_by_words:
        #     loss = loss.div(words_norm).sum()
        # else:
        #     loss = loss.sum()
        # return loss
        enc_out = self.encode(seqs_x, seqs_length=src_length)
        # check: batch_size, hidden_size
        return self.decoder.score(raw_score=enc_out, tgt_var=seqs_y)
Exemplo n.º 2
0
    def score(self, examples, return_enc_state=False):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        src_words = [e.src for e in examples]
        src_length = [len(c) for c in src_words]

        src_var = to_input_variable(src_words,
                                    self.vocab.src,
                                    cuda=args.cuda,
                                    batch_first=True)
        tgt_var = to_input_variable(src_words,
                                    self.vocab.tgt,
                                    cuda=args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        encoder_outputs, encoder_hidden = self.encode_var(
            src_var=src_var, src_length=src_length)
        scores = self.decoder.score(inputs=tgt_var,
                                    encoder_hidden=encoder_hidden,
                                    encoder_outputs=encoder_outputs)

        enc_states = self.decoder.init_state(encoder_hidden)
        if return_enc_state:
            return scores, enc_states
        else:
            return scores
Exemplo n.º 3
0
    def encode_to_hidden(self, examples, need_sort=False):
        if not isinstance(examples, list):
            examples = [examples]
        if not need_sort:
            sent_words = [e.src for e in examples]
            length = [len(e.src) for e in examples]
            src_var = to_input_variable(sent_words,
                                        self.vocab.src,
                                        training=False,
                                        cuda=self.args.cuda,
                                        batch_first=True)

            encoder_output, encoder_hidden = self.encode(input_var=src_var,
                                                         length=length)

            return {
                "outputs": encoder_output,
                "hidden": encoder_hidden,
                'length': length,
                'batch_size': len(examples)
            }
        sent_words = [e.src for e in examples]
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encode(sorted_sent_var, sorted_sent_lengths)

        if sent_hidden.dim() > 2:
            hidden = sent_hidden[:, example_old_pos_map, :]
        else:
            hidden = sent_hidden[example_old_pos_map, :]

        return {
            "outputs": None,
            "hidden": hidden,
            'length': sent_lengths,
            'batch_size': batch_size
        }
Exemplo n.º 4
0
    def greedy_search(self, examples, to_word=True):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        src_words = [e.src for e in examples]

        src_var = to_input_variable(src_words,
                                    self.vocab.src,
                                    cuda=args.cuda,
                                    batch_first=True)
        src_length = [len(c) for c in src_words]
        encoder_outputs, encoder_hidden = self.encode_var(
            src_var=src_var, src_length=src_length)

        decoder_output, decoder_hidden, ret_dict, _ = self.decoder.forward(
            encoder_hidden=encoder_hidden,
            encoder_outputs=encoder_outputs,
            teacher_forcing_ratio=0.0)

        result = torch.stack(ret_dict['sequence']).squeeze()
        final_result = []
        example_nums = result.size(1)
        if to_word:
            for i in range(example_nums):
                hyp = result[:, i].data.tolist()
                res = id2word(hyp, self.vocab)
                seems = [[res], [len(res)]]
                final_result.append(seems)
        return final_result
Exemplo n.º 5
0
    def beam_search(self, src_sent, beam_size=5, dmts=None):
        if dmts is None:
            dmts = self.args.tgt_max_time_step
        src_var = to_input_variable(src_sent,
                                    self.vocab.src,
                                    cuda=self.args.cuda,
                                    training=False,
                                    append_boundary_sym=False,
                                    batch_first=True)
        src_length = [len(src_sent)]

        encoder_outputs, encoder_hidden = self.encode_var(
            src_var=src_var, src_length=src_length)

        meta_data = self.beam_decoder.beam_search(
            encoder_hidden=encoder_hidden,
            encoder_outputs=encoder_outputs,
            beam_size=beam_size,
            decode_max_time_step=dmts)
        topk_sequence = meta_data['sequence']
        topk_score = meta_data['score'].squeeze()

        completed_hypotheses = torch.cat(topk_sequence, dim=-1)

        number_return = completed_hypotheses.size(0)
        final_result = []
        final_scores = []
        for i in range(number_return):
            hyp = completed_hypotheses[i, :].data.tolist()
            res = id2word(hyp, self.vocab.tgt)
            final_result.append(res)
            final_scores.append(topk_score[i].item())
        return final_result, final_scores
Exemplo n.º 6
0
    def sentence_encode(self, sent_words):
        batch_size = len(sent_words)
        sent_lengths = [len(sent_word) for sent_word in sent_words]

        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -sent_lengths[x])

        example_old_pos_map = [-1] * batch_size
        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        sorted_sent_words = [sent_words[i] for i in sorted_example_ids]
        sorted_sent_var = to_input_variable(sorted_sent_words,
                                            self.vocab.src,
                                            cuda=self.args.cuda,
                                            batch_first=True)

        if self.training and self.args.src_wd:
            sorted_sent_var = unk_replace(sorted_sent_var, self.step_unk_rate,
                                          self.vocab.src)

        sorted_sent_lengths = [
            len(sent_word) for sent_word in sorted_sent_words
        ]

        _, sent_hidden = self.encoder.forward(sorted_sent_var,
                                              sorted_sent_lengths)

        hidden = sent_hidden[:, example_old_pos_map, :]

        return hidden
Exemplo n.º 7
0
    def predict(self, examples, to_word=True):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        src_words = [e.src for e in examples]
        src_length = [len(c) for c in src_words]
        seqs_x = to_input_variable(src_words, self.vocab.src, cuda=args.cuda, batch_first=True)
        predict = self.forward(seqs_x, x_length=src_length, to_word=to_word)
        return predict['pred']
Exemplo n.º 8
0
    def encode_to_hidden(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sorted_example_ids = sorted(range(batch_size),
                                    key=lambda x: -len(examples[x].tgt))
        example_old_pos_map = [-1] * batch_size

        sorted_examples = [examples[i] for i in sorted_example_ids]

        syntax_word = [e.tgt for e in sorted_examples]
        syntax_var = to_input_variable(syntax_word,
                                       self.vocab.tgt,
                                       training=False,
                                       cuda=self.args.cuda,
                                       batch_first=True)
        length = [len(e.tgt) for e in sorted_examples]
        syntax_output, syntax_hidden = self.syntax_encode(syntax_var, length)

        sent_words = [e.src for e in sorted_examples]
        sentence_hidden = self.sentence_encode(sent_words)
        tgt_var = to_input_variable(sent_words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        for new_pos, old_pos in enumerate(sorted_example_ids):
            example_old_pos_map[old_pos] = new_pos

        return {
            'hidden': sentence_hidden,
            "syn_output": syntax_output,
            "syn_hidden": syntax_hidden,
            'tgt_var': tgt_var,
            'old_pos': example_old_pos_map
        }
Exemplo n.º 9
0
    def score(self, examples, return_enc_state=False):
        """
            Used for teacher-forcing training,
            return the log_probability of <input,output>.
        """
        args = self.args
        if isinstance(examples, list):
            src_words = [e.src for e in examples]
            tgt_words = [e.tgt for e in examples]
        else:
            src_words = examples.src
            tgt_words = examples.tgt

        seqs_x = to_input_variable(src_words,
                                   self.src_vocab,
                                   cuda=args.cuda,
                                   batch_first=True)
        seqs_y = to_input_variable(tgt_words,
                                   self.tgt_vocab,
                                   cuda=args.cuda,
                                   append_boundary_sym=True,
                                   batch_first=True)
        y_inp = seqs_y[:, :-1].contiguous()
        y_label = seqs_y[:, 1:].contiguous()

        words_norm = y_label.ne(self.pad).float().sum(1)

        log_probs = self.forward(seqs_x, y_inp)
        loss = self.critic(inputs=log_probs,
                           labels=y_label,
                           reduce=False,
                           normalization=self.normalization)

        if self.norm_by_words:
            loss = loss.div(words_norm).sum()
        else:
            loss = loss.sum()
        return loss
Exemplo n.º 10
0
    def encode(self, examples):
        args = self.args
        if isinstance(examples, list):
            src_words = [e.src for e in examples]
        else:
            src_words = examples.src

        src_var = to_input_variable(src_words,
                                    self.vocab.src,
                                    cuda=args.cuda,
                                    batch_first=True)
        src_length = [len(c) for c in src_words]

        encoder_outputs, encoder_hidden = self.encode_var(src_var, src_length)
        return encoder_hidden
Exemplo n.º 11
0
    def forward(self, examples):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)
        sent_words = [e.src for e in examples]
        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        decode_init = ret['decode_init']
        tgt_var = to_input_variable(sent_words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd > 0.:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_token_scores = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0,
            )
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_token_scores, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(
                    inputs=tgt_var,
                    encoder_outputs=None,
                    encoder_hidden=decode_init,
                ))

        return {
            "mean": ret['mean'],
            "logv": ret['logv'],
            "z": ret['latent'],
            'nll_loss': reconstruct_loss,
            'batch_size': batch_size
        }
Exemplo n.º 12
0
    def forward(self, examples, is_dis=False):
        if not isinstance(examples, list):
            examples = [examples]
        batch_size = len(examples)

        words = [e.src for e in examples]
        tgt_var = to_input_variable(words,
                                    self.vocab.src,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)
        syn_seqs = [e.tgt for e in examples]
        syn_var = to_input_variable(syn_seqs,
                                    self.vocab.tgt,
                                    training=False,
                                    cuda=self.args.cuda,
                                    append_boundary_sym=True,
                                    batch_first=True)

        ret = self.encode_to_hidden(examples)
        ret = self.hidden_to_latent(ret=ret, is_sampling=self.training)
        ret = self.latent_for_init(ret=ret)
        syn_hidden = ret['syn_hidden']
        sem_hidden = ret['sem_hidden']

        if is_dis:
            dis_syn_loss, dis_sem_loss = self.get_dis_loss(
                syntax_hidden=syn_hidden,
                semantic_hidden=sem_hidden,
                syn_tgt=syn_var,
                sem_tgt=tgt_var)
            ret['dis syn'] = dis_syn_loss
            ret['dis sem'] = dis_sem_loss
            return ret

        decode_init = ret['decode_init']

        sentence_decode_init = self.bridger.forward(decode_init)
        if self.training and self.args.tgt_wd:
            input_var = unk_replace(tgt_var, self.step_unk_rate,
                                    self.vocab.src)
            tgt_log_score = self.decoder.generate(
                con_inputs=input_var,
                encoder_hidden=sentence_decode_init,
                encoder_outputs=None,
                teacher_forcing_ratio=1.0)
            reconstruct_loss = -torch.sum(
                self.decoder.score_decoding_results(tgt_log_score, tgt_var))
        else:
            reconstruct_loss = -torch.sum(
                self.decoder.score(inputs=tgt_var,
                                   encoder_outputs=None,
                                   encoder_hidden=sentence_decode_init))

        mul_syn_loss, mul_sem_loss = self.get_mul_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)

        adv_syn_loss, adv_sem_loss = self.get_adv_loss(
            syntax_hidden=syn_hidden,
            semantic_hidden=sem_hidden,
            syn_tgt=syn_var,
            sem_tgt=tgt_var)
        ret['adv'] = adv_syn_loss + adv_sem_loss
        ret['mul'] = mul_syn_loss + mul_sem_loss

        ret['nll_loss'] = reconstruct_loss
        ret['sem_loss'] = mul_sem_loss
        ret['syn_loss'] = mul_syn_loss
        ret['batch_size'] = batch_size
        return ret