Example #1
0
    def get_hidden(self, examples):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=True,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        tgt_var = input_dict['tgt']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)
        return encoder_hidden
Example #2
0
    def score(self, examples, to_word=False, **kwargs):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(examples=examples,
                                   vocab=self.vocab,
                                   max_tgt_len=self.max_len,
                                   cuda=args.cuda,
                                   training=self.training,
                                   use_tgt=True,
                                   use_tag=args.use_arc,
                                   use_dst=args.use_dst)
        enc_out = self.encode(seqs_x=input_dict['src'],
                              seqs_length=input_dict['src_len'])
        ret = self.decode(enc_out, ret_syn=True)

        pos_word_score = self.decoder.scoring(raw_score=ret['word'],
                                              tgt_var=input_dict['tgt'])
        pos_syn_score = self.syntax_score(ret,
                                          input_dict,
                                          use_arc=args.use_arc,
                                          use_dst=args.use_dst)
        sum_score = pos_word_score + pos_syn_score
        return sum_score
Example #3
0
    def score(self, examples, return_enc_state=False, **kwargs):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=True,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        tgt_var = input_dict['tgt']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)
        scores = self.decoder.score(inputs=tgt_var,
                                    encoder_hidden=encoder_hidden,
                                    encoder_outputs=encoder_outputs)

        if return_enc_state:
            return scores, encoder_hidden
        else:
            return scores
Example #4
0
    def predict(self, examples, to_word=True):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(examples=examples,
                                   vocab=self.vocab,
                                   max_tgt_len=self.max_len,
                                   cuda=args.cuda,
                                   training=self.training,
                                   use_tgt=False,
                                   use_tag=False,
                                   use_dst=False)
        predict = self.forward(seqs_x=input_dict['src'],
                               x_length=input_dict['src_len'],
                               to_word=to_word)
        return predict['pred']
Example #5
0
    def predict(self, examples, to_word=True):
        args = self.args
        if not isinstance(examples, list):
            examples = [examples]

        input_dict = to_input_dict(
            examples=examples,
            vocab=self.vocab,
            max_tgt_len=-1,
            cuda=args.cuda,
            training=self.training,
            src_append=False,
            tgt_append=True,
            use_tgt=False,
            use_tag=False,
            use_dst=False,
        )

        src_var = input_dict['src']
        src_length = input_dict['src_len']

        encoder_outputs, encoder_hidden = self.encode(src_var=src_var,
                                                      src_length=src_length)
        encoder_hidden = self.bridge(encoder_hidden)

        decoder_output, decoder_hidden, ret_dict, _ = self.decoder.forward(
            encoder_hidden=encoder_hidden,
            encoder_outputs=encoder_outputs,
            teacher_forcing_ratio=0.0)
        result = torch.stack(ret_dict['sequence']).squeeze()
        final_result = []
        if len(result.size()) < 2:
            result = result.view(-1, 1)
        example_nums = result.size(-1)
        if to_word:
            for i in range(example_nums):
                hyp = result[:, i].data.tolist()
                res = id2word(hyp, self.vocab.tgt)
                seems = [[res], [len(res)]]
                final_result.append(seems)
        return final_result