コード例 #1
0
ファイル: paraphrase.py プロジェクト: yzhen-li/tranX
    def _score(self, src_codes, tgt_nls):
        """score examples sorted by code length"""
        args = self.args

        if args.tie_embed:
            src_code_var = self.to_input_variable_with_unk_handling(
                src_codes, cuda=args.cuda).t()
            tgt_nl_var = self.to_input_variable_with_unk_handling(
                tgt_nls, cuda=args.cuda).t()
        else:
            src_code_var = nn_utils.to_input_variable(src_codes,
                                                      self.vocab.code,
                                                      cuda=args.cuda).t()
            tgt_nl_var = nn_utils.to_input_variable(tgt_nls,
                                                    self.vocab.source,
                                                    cuda=args.cuda).t()

        src_code_mask = Variable(nn_utils.length_array_to_mask_tensor(
            [len(x) for x in src_codes],
            cuda=args.cuda,
            valid_entry_has_mask_one=True).float(),
                                 requires_grad=False)
        tgt_nl_mask = Variable(nn_utils.length_array_to_mask_tensor(
            [len(x) for x in tgt_nls],
            cuda=args.cuda,
            valid_entry_has_mask_one=True).float(),
                               requires_grad=False)

        scores = self.pi_model(src_code_var, tgt_nl_var, src_code_mask,
                               tgt_nl_mask)

        return scores
コード例 #2
0
    def _score(self, src_codes, tgt_nls):
        """score examples sorted by code length"""
        args = self.args

        # src_code = [self.tokenize_code(e.tgt_code) for e in examples]
        # tgt_nl = [e.src_sent for e in examples]

        src_code_var = nn_utils.to_input_variable(src_codes,
                                                  self.vocab.code,
                                                  cuda=args.cuda)
        tgt_nl_var = nn_utils.to_input_variable(tgt_nls,
                                                self.vocab.source,
                                                cuda=args.cuda,
                                                append_boundary_sym=True)

        tgt_token_copy_idx_mask, tgt_token_gen_mask = self.get_generate_and_copy_meta_tensor(
            src_codes, tgt_nls)

        if isinstance(self.seq2seq, Seq2SeqWithCopy):
            scores = self.seq2seq(src_code_var, [len(c) for c in src_codes],
                                  tgt_nl_var, tgt_token_copy_idx_mask,
                                  tgt_token_gen_mask)
        else:
            scores = self.seq2seq(src_code_var, [len(c) for c in src_codes],
                                  tgt_nl_var)

        return scores
コード例 #3
0
ファイル: prior.py プロジェクト: chubbymaggie/tranX
    def __call__(self, code_list):
        # we assume the code is generated from astor and therefore has an astor style!
        code_tokens = [self.transition_system.tokenize_code(code, mode='canonicalize') for code in code_list]
        code_var = nn_utils.to_input_variable(code_tokens, self.vocab,
                                              cuda=self.args.cuda, append_boundary_sym=True)

        return -self.forward(code_var)
コード例 #4
0
 def sample(self, src_sents, sample_size):
     src_sents_len = [len(src_sent) for src_sent in src_sents]
     # Variable: (src_sent_len, batch_size)
     src_sents_var = nn_utils.to_input_variable(src_sents,
                                                self.vocab.src,
                                                cuda=self.cuda,
                                                training=False)
     return self.sample_from_variable(src_sents_var, src_sents_len,
                                      sample_size)
コード例 #5
0
    def _score(self, src_codes, tgt_nls):
        """score examples sorted by code length"""
        args = self.args

        # src_code = [self.tokenize_code(e.tgt_code) for e in examples]
        # tgt_nl = [e.src_sent for e in examples]

        src_code_var = nn_utils.to_input_variable(src_codes, self.vocab.code, cuda=args.cuda)
        tgt_nl_var = nn_utils.to_input_variable(tgt_nls, self.vocab.source, cuda=args.cuda, append_boundary_sym=True)

        tgt_token_copy_pos, tgt_token_copy_mask, tgt_token_gen_mask = self.get_generate_and_copy_meta_tensor(src_codes, tgt_nls)

        scores = self.seq2seq(src_code_var,
                              [len(c) for c in src_codes],
                              tgt_nl_var,
                              tgt_token_copy_pos, tgt_token_copy_mask, tgt_token_gen_mask)

        return scores
コード例 #6
0
ファイル: prior.py プロジェクト: tomsonsgs/TRAN-MMA-master
    def __call__(self, code_list):
        # we assume the code is generated from astor and therefore has an astor style!
        code_tokens = [
            self.transition_system.tokenize_code(code, mode='canonicalize')
            for code in code_list
        ]
        code_var = nn_utils.to_input_variable(code_tokens,
                                              self.vocab,
                                              cuda=self.args.cuda,
                                              append_boundary_sym=True)

        return -self.forward(code_var)
コード例 #7
0
    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for examples in nn_utils.batch_iter(dev_set.examples, args.batch_size):
            batch_tokens = [transition_system.tokenize_code(e.tgt_code) for e in examples]
            batch = nn_utils.to_input_variable(batch_tokens, vocab, cuda=args.cuda, append_boundary_sym=True)
            loss = model.forward(batch).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(tokens) + 1 for tokens in batch_tokens)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl
コード例 #8
0
    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for batch in dev_set.batch_iter(args.batch_size):
            src_sents_var = nn_utils.to_input_variable(
                [e.src_sent for e in batch],
                vocab.source,
                cuda=args.cuda,
                append_boundary_sym=True)
            loss = model(src_sents_var).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(e.src_sent) + 1
                                 for e in batch)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl
コード例 #9
0
    def parse(self, src_sent, context=None, beam_size=5):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """
        # print('!!!!!!')
        args = self.args
        primitive_vocab = self.vocab.primitive

        src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        dec_init_vec = self.init_decoder_state(last_state, last_cell)
        if args.lstm == 'parent_feed':
            h_tm1 = dec_init_vec[0], dec_init_vec[1], \
                    Variable(self.new_tensor(args.hidden_size).zero_()), \
                    Variable(self.new_tensor(args.hidden_size).zero_())
        else:
            h_tm1 = dec_init_vec

        zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())

        hyp_scores = Variable(self.new_tensor([0.]), volatile=True)

        src_token_vocab_ids = [primitive_vocab[token] for token in src_sent]
        src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id]
        # sometimes a word may appear multi-times in the source, in this case,
        # we just copy its first appearing position. Therefore we mask the words
        # appearing second and onwards to -1
        token_set = set()
        for i, tid in enumerate(src_token_vocab_ids):
            if tid in token_set:
                src_token_vocab_ids[i] = -1
            else: token_set.add(tid)

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))

            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)
                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.att_vec_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                a_tm1_embeds = []
                for a_tm1 in actions_tm1:
                    if a_tm1:
                        if isinstance(a_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]]
                        elif isinstance(a_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        else:
                            a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]]

                        a_tm1_embeds.append(a_tm1_embed)
                    else:
                        a_tm1_embeds.append(zero_action_embed)
                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            if args.lstm == 'lstm_with_dropout':
                self.decoder_lstm.set_dropout_masks(hyp_num)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # Variable(batch_size, grammar_size)
            # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # Variable(batch_size, src_sent_len)
            primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

            # Variable(batch_size, primitive_vocab_size)
            gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            # Variable(batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # Variable(batch_size, primitive_vocab_size)
            primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob
            if src_unk_pos_list:
                primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

            gentoken_prev_hyp_ids = []
            gentoken_new_hyp_unks = []
            gentoken_copy_infos = []
            applyrule_new_hyp_scores = []
            applyrule_new_hyp_prod_ids = []
            applyrule_prev_hyp_ids = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0]
                            # print(type(hyp.score),type(prod_score))
                            new_hyp_score = hyp.score + prod_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(prod_id)
                            applyrule_prev_hyp_ids.append(hyp_id)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0]
                        new_hyp_score = hyp.score + action_score

                        applyrule_new_hyp_scores.append(new_hyp_score)
                        applyrule_new_hyp_prod_ids.append(len(self.grammar))
                        applyrule_prev_hyp_ids.append(hyp_id)
                    else:
                        # GenToken action
                        gentoken_prev_hyp_ids.append(hyp_id)
                        hyp_copy_info = dict()  # of (token_pos, copy_prob)
                        # first, we compute copy probabilities for tokens in the source sentence
                        for token_pos, token_vocab_id in enumerate(src_token_vocab_ids):
                            if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id:
                                p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos]
                                primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy

                                token = src_sent[token_pos]
                                hyp_copy_info[token] = (token_pos, p_copy.data[0])

                        # second, add the probability of copying the most probable unk word
                        if args.no_copy is False and src_unk_pos_list:
                            unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax()
                            unk_pos = src_unk_pos_list[unk_pos]
                            token = src_sent[unk_pos]
                            gentoken_new_hyp_unks.append(token)

                            unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos]
                            primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score

                            hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0])

                        gentoken_copy_infos.append(hyp_copy_info)

            new_hyp_scores = None
            if applyrule_new_hyp_scores:
                new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores))
            if gentoken_prev_hyp_ids:
                primitive_log_prob = torch.log(primitive_prob)
                gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores
                else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores])

            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                                                             k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()):
            # for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cuda(), top_new_hyp_pos.data.cuda()):
                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1)
                    # try:
                    copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]
                    # except:
                    #     print('k=%d' % k, file=sys.stderr)
                    #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                    #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                    #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                    #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                    #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                    #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                    #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                    #
                    #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                    #
                    #     # exit(-1)
                    #     raise ValueError()

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word[primitive_vocab.unk_id]
                    else:
                        token = primitive_vocab.id2word[token_id]

                    action = GenTokenAction(token)

                    if token in copy_info:
                        action_info.copy_from_src = True
                        action_info.src_token_position = copy_info[token][0]

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score
                #advoid none new hyp
                if new_hyp.completed:
                    # print('What happened!')
                    # print(new_hyp.actions)
                    # print(len(new_hyp.actions))
                    if len(new_hyp.actions) != 2:
                        completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses]))
                t += 1
            else:
                break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
コード例 #10
0
    def parse(self, question, context, beam_size=5):
        table = context
        args = self.args
        src_sent_var = nn_utils.to_input_variable([question], self.vocab.source,
                                                  cuda=self.args.cuda, training=False)

        utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)])
        dec_init_vec = self.init_decoder_state(last_state, last_cell)

        column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table])

        h_tm1 = dec_init_vec
        # (batch_size, query_len, hidden_size)
        utterance_encodings_att_linear = self.att_src_linear(utterance_encodings)

        zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_())

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num,
                                                                                 utterance_encodings_att_linear.size(1),
                                                                                 utterance_encodings_att_linear.size(2))

            # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state]
            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)

                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.hidden_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                a_tm1_embeds = []
                for e_id, hyp in enumerate(hypotheses):
                    action_tm1 = hyp.actions[-1]
                    if action_tm1:
                        if isinstance(action_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
                        elif isinstance(action_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        elif isinstance(action_tm1, WikiSqlSelectColumnAction):
                            a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id])
                        elif isinstance(action_tm1, GenTokenAction):
                            a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]]
                        else:
                            raise ValueError('unknown action %s' % action_tm1)
                    else:
                        a_tm1_embed = zero_action_embed

                    a_tm1_embeds.append(a_tm1_embed)

                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # ApplyRule action probability
            # (batch_size, grammar_size)
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # column attention
            # (batch_size, max_head_num)
            column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask,
                                                               att_t.unsqueeze(0)).squeeze(0)
            column_selection_log_prob = torch.log(column_attention_weights)

            # (batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # primitive copy prob
            # (batch_size, src_token_num)
            primitive_copy_prob = self.src_pointer_net(utterance_encodings, None,
                                                       att_t.unsqueeze(0)).squeeze(0)

            # (batch_size, primitive_vocab_size)
            primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            new_hyp_meta = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id]
                            new_hyp_score = hyp.score + prod_score

                            meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
                                          'score': prod_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)]
                        new_hyp_score = hyp.score + action_score

                        meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
                                      'score': action_score, 'new_hyp_score': new_hyp_score,
                                      'prev_hyp_id': hyp_id}
                        new_hyp_meta.append(meta_entry)
                    elif action_type == WikiSqlSelectColumnAction:
                        for col_id, column in enumerate(table.header):
                            col_sel_score = column_selection_log_prob[hyp_id, col_id]
                            new_hyp_score = hyp.score + col_sel_score

                            meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
                                          'score': col_sel_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == GenTokenAction:
                        # remember that we can only copy stuff from the input!
                        # we only copy tokens sequentially!!
                        prev_action = hyp.action_infos[-1].action

                        valid_token_pos_list = []
                        if type(prev_action) is GenTokenAction and \
                                not prev_action.is_stop_signal():
                            token_pos = hyp.action_infos[-1].src_token_position + 1
                            if token_pos < len(question):
                                valid_token_pos_list = [token_pos]
                        else:
                            valid_token_pos_list = list(range(len(question)))

                        col_id = hyp.frontier_node['col_idx'].value
                        if table.header[col_id].type == 'real':
                            valid_token_pos_list = [i for i in valid_token_pos_list
                                                    if any(c.isdigit() for c in question[i]) or
                                                    hyp._value_buffer and question[i] in (',', '.', '-', '%')]

                        p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id]
                        for token_pos in valid_token_pos_list:
                            token = question[token_pos]
                            p_copy = p_copies[token_pos]
                            score_copy = torch.log(p_copy)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': token, 'token_pos': token_pos,
                                          'score': score_copy, 'new_hyp_score': score_copy + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

                        # add generation probability for </primitive>
                        if hyp._value_buffer:
                            eos_prob = primitive_predictor_prob[hyp_id, 0] * \
                                       primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']]
                            eos_score = torch.log(eos_prob)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': '</primitive>',
                                          'score': eos_score, 'new_hyp_score': eos_score + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

            if not new_hyp_meta: break

            new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta])
            top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
                                                      k=min(new_hyp_scores.size(0),
                                                            beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
                action_info = ActionInfo()
                hyp_meta_entry = new_hyp_meta[meta_id]
                prev_hyp_id = hyp_meta_entry['prev_hyp_id']
                prev_hyp = hypotheses[prev_hyp_id]

                action_type_str = hyp_meta_entry['action_type']
                if action_type_str == 'apply_rule':
                    # ApplyRule action
                    prod_id = hyp_meta_entry['prod_id']
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                elif action_type_str == 'sel_col':
                    action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id'])
                else:
                    action = GenTokenAction(hyp_meta_entry['token'])
                    if 'token_pos' in hyp_meta_entry:
                        action_info.copy_from_src = True
                        action_info.src_token_position = hyp_meta_entry['token_pos']

                action_info.action = action
                action_info.t = t

                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                t += 1
            else: break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
コード例 #11
0
    def beam_search(self,
                    src_sents,
                    decode_max_time_step,
                    beam_size=5,
                    to_word=True):
        """
        given a not-batched source, sentence perform beam search to find the n-best
        :param src_sent: List[word_id], encoded source sentence
        :return: list[list[word_id]] top-k predicted natural language sentence in the beam
        """
        src_sents_var = nn_utils.to_input_variable(src_sents,
                                                   self.src_vocab,
                                                   cuda=self.cuda,
                                                   training=False,
                                                   append_boundary_sym=False)

        #TODO(junxian): check if src_sents_var(src_seq_length, embed_size) is ok
        src_encodings, (last_state,
                        last_cell) = self.encode(src_sents_var,
                                                 [len(src_sents[0])])
        # (1, query_len, hidden_size * 2)
        src_encodings = src_encodings.permute(1, 0, 2)
        src_encodings_att_linear = self.att_src_linear(src_encodings)
        h_tm1 = self.init_decoder_state(last_state, last_cell)

        # tensor constructors
        new_float_tensor = src_encodings.data.new
        if self.cuda:
            new_long_tensor = torch.cuda.LongTensor
        else:
            new_long_tensor = torch.LongTensor

        att_tm1 = Variable(torch.zeros(1, self.hidden_size), volatile=True)
        hyp_scores = Variable(torch.zeros(1), volatile=True)
        if self.cuda:
            att_tm1 = att_tm1.cuda()
            hyp_scores = hyp_scores.cuda()

        eos_id = self.tgt_vocab['</s>']
        bos_id = self.tgt_vocab['<s>']
        tgt_vocab_size = len(self.tgt_vocab)

        hypotheses = [[bos_id]]
        completed_hypotheses = []
        completed_hypothesis_scores = []

        t = 0
        while len(
                completed_hypotheses) < beam_size and t < decode_max_time_step:
            t += 1
            hyp_num = len(hypotheses)

            expanded_src_encodings = src_encodings.expand(
                hyp_num, src_encodings.size(1), src_encodings.size(2))
            expanded_src_encodings_att_linear = src_encodings_att_linear.expand(
                hyp_num, src_encodings_att_linear.size(1),
                src_encodings_att_linear.size(2))

            y_tm1 = Variable(new_long_tensor([hyp[-1] for hyp in hypotheses]),
                             volatile=True)
            y_tm1_embed = self.tgt_embed(y_tm1)

            x = torch.cat([y_tm1_embed, att_tm1], 1)

            # h_t: (hyp_num, hidden_size)
            (h_t, cell_t), att_t, score_t = self.step(
                x,
                h_tm1,
                expanded_src_encodings,
                expanded_src_encodings_att_linear,
                src_sent_masks=None)

            p_t = F.log_softmax(score_t)

            live_hyp_num = beam_size - len(completed_hypotheses)
            new_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(p_t) +
                              p_t).view(-1)
            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                                                             k=live_hyp_num)
            prev_hyp_ids = top_new_hyp_pos / tgt_vocab_size
            word_ids = top_new_hyp_pos % tgt_vocab_size

            new_hypotheses = []

            live_hyp_ids = []
            new_hyp_scores = []
            for prev_hyp_id, word_id, new_hyp_score in zip(
                    prev_hyp_ids.cpu().data,
                    word_ids.cpu().data,
                    top_new_hyp_scores.cpu().data):
                hyp_tgt_words = hypotheses[prev_hyp_id] + [word_id]
                if word_id == eos_id:
                    completed_hypotheses.append(
                        hyp_tgt_words[1:-1]
                    )  # remove <s> and </s> in completed hypothesis
                    completed_hypothesis_scores.append(new_hyp_score)
                else:
                    new_hypotheses.append(hyp_tgt_words)
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            live_hyp_ids = new_long_tensor(live_hyp_ids)
            h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
            att_tm1 = att_t[live_hyp_ids]

            hyp_scores = Variable(
                new_float_tensor(new_hyp_scores),
                volatile=True)  # new_hyp_scores[live_hyp_ids]
            hypotheses = new_hypotheses

        if len(completed_hypotheses) == 0:
            completed_hypotheses = [
                hypotheses[0][1:-1]
            ]  # remove <s> and </s> in completed hypothesis
            completed_hypothesis_scores = [0.0]

        if to_word:
            for i, hyp in enumerate(completed_hypotheses):
                completed_hypotheses[i] = [
                    self.tgt_vocab.id2word[w] for w in hyp
                ]

        ranked_hypotheses = sorted(zip(completed_hypotheses,
                                       completed_hypothesis_scores),
                                   key=lambda x: x[1],
                                   reverse=True)

        return [hyp for hyp, score in ranked_hypotheses]
コード例 #12
0
    def parse(self, src_sent, context=None, beam_size=5, debug=False):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        with torch.no_grad():
            args = self.args
            primitive_vocab = self.vocab.primitive

            src_sent_var = nn_utils.to_input_variable([src_sent],
                                                      self.vocab.source,
                                                      cuda=args.cuda,
                                                      training=False)

            # Variable(1, src_sent_len, hidden_size)
            src_encodings = self.encode(src_sent_var)

            zero_action_embed = torch.zeros(args.action_embed_size)

            hyp_scores = torch.tensor([0.0])

            # For computing copy probabilities, we marginalize over tokens with the same surface form
            # `aggregated_primitive_tokens` stores the position of occurrence of each source token
            aggregated_primitive_tokens = OrderedDict()
            for token_pos, token in enumerate(src_sent):
                aggregated_primitive_tokens.setdefault(token,
                                                       []).append(token_pos)

            t = 0
            hypotheses = [DecodeHypothesis()]
            completed_hypotheses = []

            while len(completed_hypotheses
                      ) < beam_size and t < args.decode_max_time_step:
                hyp_num = len(hypotheses)

                # (hyp_num, src_sent_len, hidden_size)
                exp_src_encodings = src_encodings.expand(
                    hyp_num, src_encodings.size(1), src_encodings.size(2))

                if t == 0:
                    x = torch.zeros(1, self.d_model)
                    parent_ids = np.array([[0]])
                    if args.no_parent_field_type_embed is False:
                        offset = self.args.action_embed_size  # prev_action
                        offset += self.args.action_embed_size * (
                            not self.args.no_parent_production_embed)
                        offset += self.args.field_embed_size * (
                            not self.args.no_parent_field_embed)

                        x[0, offset:offset +
                          self.type_embed_size] = self.type_embed(
                              torch.tensor(self.grammar.type2id[
                                  self.grammar.root_type]))
                        x = x.unsqueeze(-2)
                else:
                    actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                    a_tm1_embeds = []
                    for a_tm1 in actions_tm1:
                        if a_tm1:
                            if isinstance(a_tm1, ApplyRuleAction):
                                a_tm1_embed = self.production_embed(
                                    torch.tensor(self.grammar.prod2id[
                                        a_tm1.production]))
                            elif isinstance(a_tm1, ReduceAction):
                                a_tm1_embed = self.production_embed(
                                    torch.tensor(len(self.grammar)))
                            else:
                                a_tm1_embed = self.primitive_embed(
                                    torch.tensor(
                                        self.vocab.primitive[a_tm1.token]))

                            a_tm1_embeds.append(a_tm1_embed)
                        else:
                            a_tm1_embeds.append(zero_action_embed)
                    a_tm1_embeds = torch.stack(a_tm1_embeds)

                    inputs = [a_tm1_embeds]
                    if args.no_parent_production_embed is False:
                        # frontier production
                        frontier_prods = [
                            hyp.frontier_node.production for hyp in hypotheses
                        ]
                        frontier_prod_embeds = self.production_embed(
                            torch.tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ],
                                         dtype=torch.long))
                        inputs.append(frontier_prod_embeds)
                    if args.no_parent_field_embed is False:
                        # frontier field
                        frontier_fields = [
                            hyp.frontier_field.field for hyp in hypotheses
                        ]
                        frontier_field_embeds = self.field_embed(
                            torch.tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ],
                                         dtype=torch.long))

                        inputs.append(frontier_field_embeds)
                    if args.no_parent_field_type_embed is False:
                        # frontier field type
                        frontier_field_types = [
                            hyp.frontier_field.type for hyp in hypotheses
                        ]
                        frontier_field_type_embeds = self.type_embed(
                            torch.tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ],
                                         dtype=torch.long))
                        inputs.append(frontier_field_type_embeds)

                    x = torch.cat(
                        [x, torch.cat(inputs, dim=-1).unsqueeze(-2)], dim=1)
                    recent_parents = np.array(
                        [[hyp.frontier_node.created_time]
                         if hyp.frontier_node else 0 for hyp in hypotheses])
                    parent_ids = np.hstack([parent_ids, recent_parents])

                src_mask = torch.ones(
                    exp_src_encodings.shape[:-1],
                    dtype=torch.uint8,
                    device=exp_src_encodings.device).unsqueeze(-2)
                tgt_mask = subsequent_mask(x.shape[-2])

                att_t = self.decoder(x, exp_src_encodings, [parent_ids],
                                     src_mask, tgt_mask)[:, -1]

                # Variable(batch_size, grammar_size)
                # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
                apply_rule_log_prob = F.log_softmax(
                    self.production_readout(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                                dim=-1)

                if args.no_copy:
                    primitive_prob = gen_from_vocab_prob
                else:
                    # Variable(batch_size, src_sent_len)
                    primitive_copy_prob = self.src_pointer_net(
                        src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                    # Variable(batch_size, 2)
                    primitive_predictor_prob = F.softmax(
                        self.primitive_predictor(att_t), dim=-1)

                    # Variable(batch_size, primitive_vocab_size)
                    primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                        1) * gen_from_vocab_prob

                    # if src_unk_pos_list:
                    #     primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

                gentoken_prev_hyp_ids = []
                gentoken_new_hyp_unks = []
                applyrule_new_hyp_scores = []
                applyrule_new_hyp_prod_ids = []
                applyrule_prev_hyp_ids = []

                for hyp_id, hyp in enumerate(hypotheses):
                    # generate new continuations
                    action_types = self.transition_system.get_valid_continuation_types(
                        hyp)

                    for action_type in action_types:
                        if action_type == ApplyRuleAction:
                            productions = self.transition_system.get_valid_continuating_productions(
                                hyp)
                            for production in productions:
                                prod_id = self.grammar.prod2id[production]
                                prod_score = apply_rule_log_prob[
                                    hyp_id, prod_id].item()
                                new_hyp_score = hyp.score + prod_score

                                applyrule_new_hyp_scores.append(new_hyp_score)
                                applyrule_new_hyp_prod_ids.append(prod_id)
                                applyrule_prev_hyp_ids.append(hyp_id)
                        elif action_type == ReduceAction:
                            action_score = apply_rule_log_prob[
                                hyp_id, len(self.grammar)].item()
                            new_hyp_score = hyp.score + action_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(len(
                                self.grammar))
                            applyrule_prev_hyp_ids.append(hyp_id)
                        else:
                            # GenToken action
                            gentoken_prev_hyp_ids.append(hyp_id)
                            hyp_copy_info = dict()  # of (token_pos, copy_prob)
                            hyp_unk_copy_info = []

                            if args.no_copy is False:
                                for (token, token_pos_list
                                     ) in aggregated_primitive_tokens.items():
                                    sum_copy_prob = torch.gather(
                                        primitive_copy_prob[hyp_id], 0,
                                        torch.tensor(token_pos_list,
                                                     dtype=torch.long)).sum()
                                    gated_copy_prob = primitive_predictor_prob[
                                        hyp_id, 1] * sum_copy_prob

                                    if token in primitive_vocab:
                                        token_id = primitive_vocab[token]
                                        primitive_prob[hyp_id, token_id] = (
                                            primitive_prob[hyp_id, token_id] +
                                            gated_copy_prob)

                                        hyp_copy_info[token] = (
                                            token_pos_list,
                                            gated_copy_prob.item())
                                    else:
                                        hyp_unk_copy_info.append({
                                            "token":
                                            token,
                                            "token_pos_list":
                                            token_pos_list,
                                            "copy_prob":
                                            gated_copy_prob.item(),
                                        })

                            if args.no_copy is False and len(
                                    hyp_unk_copy_info) > 0:
                                unk_i = np.array([
                                    x["copy_prob"] for x in hyp_unk_copy_info
                                ]).argmax()
                                token = hyp_unk_copy_info[unk_i]["token"]
                                primitive_prob[hyp_id, primitive_vocab.
                                               unk_id] = hyp_unk_copy_info[
                                                   unk_i]["copy_prob"]
                                gentoken_new_hyp_unks.append(token)

                                hyp_copy_info[token] = (
                                    hyp_unk_copy_info[unk_i]["token_pos_list"],
                                    hyp_unk_copy_info[unk_i]["copy_prob"],
                                )

                new_hyp_scores = None
                if applyrule_new_hyp_scores:
                    new_hyp_scores = torch.tensor(applyrule_new_hyp_scores)
                if gentoken_prev_hyp_ids:
                    primitive_log_prob = torch.log(primitive_prob)
                    gen_token_new_hyp_scores = (
                        hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                        primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                    if new_hyp_scores is None:
                        new_hyp_scores = gen_token_new_hyp_scores
                    else:
                        new_hyp_scores = torch.cat(
                            [new_hyp_scores, gen_token_new_hyp_scores])

                top_new_hyp_scores, top_new_hyp_pos = torch.topk(
                    new_hyp_scores,
                    k=min(new_hyp_scores.size(0),
                          beam_size - len(completed_hypotheses)))

                live_hyp_ids = []
                new_hypotheses = []
                for new_hyp_score, new_hyp_pos in zip(
                        top_new_hyp_scores.data.cpu(),
                        top_new_hyp_pos.data.cpu()):
                    action_info = ActionInfo()
                    if new_hyp_pos < len(applyrule_new_hyp_scores):
                        # it's an ApplyRule or Reduce action
                        prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                        prev_hyp = hypotheses[prev_hyp_id]

                        prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                        # ApplyRule action
                        if prod_id < len(self.grammar):
                            production = self.grammar.id2prod[prod_id]
                            action = ApplyRuleAction(production)
                        # Reduce action
                        else:
                            action = ReduceAction()
                    else:
                        # it's a GenToken action
                        token_id = int(
                            (new_hyp_pos - len(applyrule_new_hyp_scores)) %
                            primitive_prob.size(1))

                        k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                             ) // primitive_prob.size(1)
                        # try:
                        # copy_info = gentoken_copy_infos[k]
                        prev_hyp_id = gentoken_prev_hyp_ids[k]
                        prev_hyp = hypotheses[prev_hyp_id]
                        # except:
                        #     print('k=%d' % k, file=sys.stderr)
                        #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                        #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                        #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                        #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                        #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                        #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                        #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                        #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                        #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                        #
                        #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                        #
                        #     # exit(-1)
                        #     raise ValueError()

                        if token_id == int(primitive_vocab.unk_id):
                            if gentoken_new_hyp_unks:
                                token = gentoken_new_hyp_unks[k]
                            else:
                                token = primitive_vocab.id2word[
                                    primitive_vocab.unk_id]
                        else:
                            token = primitive_vocab.id2word[token_id]

                        action = GenTokenAction(token)

                        if token in aggregated_primitive_tokens:
                            action_info.copy_from_src = True
                            action_info.src_token_position = aggregated_primitive_tokens[
                                token]

                        if debug:
                            action_info.gen_copy_switch = (
                                "n/a"
                                if args.no_copy else primitive_predictor_prob[
                                    prev_hyp_id, :].log().cpu().data.numpy())
                            action_info.in_vocab = token in primitive_vocab
                            action_info.gen_token_prob = (
                                gen_from_vocab_prob[prev_hyp_id,
                                                    token_id].log().cpu().
                                item() if token in primitive_vocab else "n/a")
                            action_info.copy_token_prob = (
                                torch.gather(
                                    primitive_copy_prob[prev_hyp_id],
                                    0,
                                    torch.tensor(
                                        action_info.src_token_position,
                                        dtype=torch.long,
                                        device=self.device),
                                ).sum().log().cpu().item()
                                if args.no_copy is False
                                and action_info.copy_from_src else "n/a")

                    action_info.action = action
                    action_info.t = t
                    if t > 0:
                        action_info.parent_t = prev_hyp.frontier_node.created_time
                        action_info.frontier_prod = prev_hyp.frontier_node.production
                        action_info.frontier_field = prev_hyp.frontier_field.field

                    if debug:
                        action_info.action_prob = new_hyp_score - prev_hyp.score

                    new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                    new_hyp.score = new_hyp_score

                    if new_hyp.completed:
                        completed_hypotheses.append(new_hyp)
                    else:
                        new_hypotheses.append(new_hyp)
                        live_hyp_ids.append(prev_hyp_id)

                if live_hyp_ids:
                    x = x[live_hyp_ids]
                    parent_ids = parent_ids[live_hyp_ids]
                    hypotheses = new_hypotheses
                    hyp_scores = torch.tensor(
                        [hyp.score for hyp in hypotheses])
                    t += 1
                else:
                    break

            completed_hypotheses.sort(key=lambda hyp: -hyp.score)

            return completed_hypotheses
コード例 #13
0
ファイル: parser.py プロジェクト: chubbymaggie/tranX
    def parse(self, src_sent, context=None, beam_size=5):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        args = self.args
        primitive_vocab = self.vocab.primitive

        src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        dec_init_vec = self.init_decoder_state(last_state, last_cell)
        if args.lstm == 'parent_feed':
            h_tm1 = dec_init_vec[0], dec_init_vec[1], \
                    Variable(self.new_tensor(args.hidden_size).zero_()), \
                    Variable(self.new_tensor(args.hidden_size).zero_())
        else:
            h_tm1 = dec_init_vec

        zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())

        hyp_scores = Variable(self.new_tensor([0.]), volatile=True)

        src_token_vocab_ids = [primitive_vocab[token] for token in src_sent]
        src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id]
        # sometimes a word may appear multi-times in the source, in this case,
        # we just copy its first appearing position. Therefore we mask the words
        # appearing second and onwards to -1
        token_set = set()
        for i, tid in enumerate(src_token_vocab_ids):
            if tid in token_set:
                src_token_vocab_ids[i] = -1
            else: token_set.add(tid)

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))

            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)
                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.att_vec_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                a_tm1_embeds = []
                for a_tm1 in actions_tm1:
                    if a_tm1:
                        if isinstance(a_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]]
                        elif isinstance(a_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        else:
                            a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]]

                        a_tm1_embeds.append(a_tm1_embed)
                    else:
                        a_tm1_embeds.append(zero_action_embed)
                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            if args.lstm == 'lstm_with_dropout':
                self.decoder_lstm.set_dropout_masks(hyp_num)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # Variable(batch_size, grammar_size)
            # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # Variable(batch_size, src_sent_len)
            primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

            # Variable(batch_size, primitive_vocab_size)
            gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            # Variable(batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # Variable(batch_size, primitive_vocab_size)
            primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob
            if src_unk_pos_list:
                primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

            gentoken_prev_hyp_ids = []
            gentoken_new_hyp_unks = []
            gentoken_copy_infos = []
            applyrule_new_hyp_scores = []
            applyrule_new_hyp_prod_ids = []
            applyrule_prev_hyp_ids = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0]
                            new_hyp_score = hyp.score + prod_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(prod_id)
                            applyrule_prev_hyp_ids.append(hyp_id)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0]
                        new_hyp_score = hyp.score + action_score

                        applyrule_new_hyp_scores.append(new_hyp_score)
                        applyrule_new_hyp_prod_ids.append(len(self.grammar))
                        applyrule_prev_hyp_ids.append(hyp_id)
                    else:
                        # GenToken action
                        gentoken_prev_hyp_ids.append(hyp_id)
                        hyp_copy_info = dict()  # of (token_pos, copy_prob)
                        # first, we compute copy probabilities for tokens in the source sentence
                        for token_pos, token_vocab_id in enumerate(src_token_vocab_ids):
                            if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id:
                                p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos]
                                primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy

                                token = src_sent[token_pos]
                                hyp_copy_info[token] = (token_pos, p_copy.data[0])

                        # second, add the probability of copying the most probable unk word
                        if args.no_copy is False and src_unk_pos_list:
                            unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax()
                            unk_pos = src_unk_pos_list[unk_pos]
                            token = src_sent[unk_pos]
                            gentoken_new_hyp_unks.append(token)

                            unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos]
                            primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score

                            hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0])

                        gentoken_copy_infos.append(hyp_copy_info)

            new_hyp_scores = None
            if applyrule_new_hyp_scores:
                new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores))
            if gentoken_prev_hyp_ids:
                primitive_log_prob = torch.log(primitive_prob)
                gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores
                else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores])

            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                                                             k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()):
                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1)
                    # try:
                    copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]
                    # except:
                    #     print('k=%d' % k, file=sys.stderr)
                    #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                    #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                    #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                    #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                    #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                    #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                    #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                    #
                    #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                    #
                    #     # exit(-1)
                    #     raise ValueError()

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word[primitive_vocab.unk_id]
                    else:
                        token = primitive_vocab.id2word[token_id]

                    action = GenTokenAction(token)

                    if token in copy_info:
                        action_info.copy_from_src = True
                        action_info.src_token_position = copy_info[token][0]

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses]))
                t += 1
            else:
                break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
コード例 #14
0
def train(args):
    train_set = Dataset.from_bin_file(args.train_file)
    dev_set = Dataset.from_bin_file(args.dev_file)
    vocab = pickle.load(open(args.vocab, 'rb'))

    model = LSTMLanguageModel(vocab.source,
                              args.embed_size,
                              args.hidden_size,
                              dropout=args.dropout)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for batch in dev_set.batch_iter(args.batch_size):
            src_sents_var = nn_utils.to_input_variable(
                [e.src_sent for e in batch],
                vocab.source,
                cuda=args.cuda,
                append_boundary_sym=True)
            loss = model(src_sents_var).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(e.src_sent) + 1
                                 for e in batch)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl

    print('begin training decoder, %d training examples, %d dev examples' %
          (len(train_set), len(dev_set)))
    print('vocab size: %d' % len(vocab.source))

    epoch = train_iter = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    num_trial = patience = 0
    while True:
        epoch += 1
        epoch_begin = time.time()

        for batch_examples in train_set.batch_iter(batch_size=args.batch_size,
                                                   shuffle=True):
            batch_examples = [
                e for e in batch_examples if len(e.tgt_actions) <= 100
            ]
            src_sents = [e.src_sent for e in batch_examples]
            src_sents_var = nn_utils.to_input_variable(
                src_sents,
                vocab.source,
                cuda=args.cuda,
                append_boundary_sym=True)

            train_iter += 1
            optimizer.zero_grad()

            loss = model(src_sents_var)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(batch_examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        # evaluate ppl
        ppl = evaluate_ppl()
        print('[Epoch %d] ppl=%.5f took %ds' %
              (epoch, ppl, time.time() - eval_start),
              file=sys.stderr)
        dev_acc = -ppl
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

            if patience == args.patience:
                num_trial += 1
                print('hit #%d trial' % num_trial, file=sys.stderr)
                if num_trial == args.max_num_trial:
                    print('early stop!', file=sys.stderr)
                    exit(0)

                # decay lr, and restore from previously best checkpoint
                lr = optimizer.param_groups[0]['lr'] * args.lr_decay
                print(
                    'load previously best model and decay learning rate to %f'
                    % lr,
                    file=sys.stderr)

                # load model
                params = torch.load(args.save_to + '.bin',
                                    map_location=lambda storage, loc: storage)
                model.load_state_dict(params['state_dict'])
                if args.cuda: model = model.cuda()

                # load optimizers
                if args.reset_optimizer:
                    print('reset optimizer', file=sys.stderr)
                    optimizer = torch.optim.Adam(
                        model.inference_model.parameters(), lr=lr)
                else:
                    print('restore parameters of the optimizers',
                          file=sys.stderr)
                    optimizer.load_state_dict(
                        torch.load(args.save_to + '.optim.bin'))

                # set new lr
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                # reset patience
                patience = 0
コード例 #15
0
ファイル: parser.py プロジェクト: thu-spmi/seq2seq-JAE
    def parse_sample(self, src_sent):
        # get samples of q_{\phi}(z|x) by sampling
        args = self.args
        sample_size = self.sample_size
        primitive_vocab = self.vocab.primitive

        src_sent_var = nn_utils.to_input_variable([src_sent],
                                                  self.vocab.source,
                                                  cuda=args.cuda,
                                                  training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        src_encodings, (last_state,
                        last_cell) = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        h_tm1 = self.init_decoder_state(last_state, last_cell)
        zero_action_embed = Variable(
            self.new_tensor(args.action_embed_size).zero_())

        hyp_scores = Variable(self.new_tensor([0.]), volatile=True)

        src_token_vocab_ids = [primitive_vocab[token] for token in src_sent]
        src_unk_pos_list = [
            pos for pos, token_id in enumerate(src_token_vocab_ids)
            if token_id == primitive_vocab.unk_id
        ]
        # sometimes a word may appear multi-times in the source, in this case,
        # we just copy its first appearing position. Therefore we mask the words
        # appearing second and onwards to -1
        token_set = set()
        for i, tid in enumerate(src_token_vocab_ids):
            if tid in token_set:
                src_token_vocab_ids[i] = -1
            else:
                token_set.add(tid)

        completed_hypotheses = []
        while len(completed_hypotheses) < sample_size:
            t = 0
            hypotheses = [DecodeHypothesis()]
            hyp_states = [[]]
            while t < args.decode_max_time_step:
                hyp_num = len(hypotheses)

                # (hyp_num, src_sent_len, hidden_size * 2)
                exp_src_encodings = src_encodings.expand(
                    hyp_num, src_encodings.size(1), src_encodings.size(2))
                # (hyp_num, src_sent_len, hidden_size)
                exp_src_encodings_att_linear = src_encodings_att_linear.expand(
                    hyp_num, src_encodings_att_linear.size(1),
                    src_encodings_att_linear.size(2))

                if t == 0:
                    x = Variable(self.new_tensor(
                        1, self.decoder_lstm.input_size).zero_(),
                                 volatile=True)
                    offset = args.action_embed_size * 2 + args.field_embed_size
                    x[0, offset:offset +
                      args.type_embed_size] = self.type_embed.weight[
                          self.grammar.type2id[self.grammar.root_type]]
                else:
                    actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                    a_tm1_embeds = []
                    for a_tm1 in actions_tm1:
                        if a_tm1:
                            if isinstance(a_tm1, ApplyRuleAction):
                                a_tm1_embed = self.production_embed.weight[
                                    self.grammar.prod2id[a_tm1.production]]
                            elif isinstance(a_tm1, ReduceAction):
                                a_tm1_embed = self.production_embed.weight[len(
                                    self.grammar)]
                            else:
                                a_tm1_embed = self.primitive_embed.weight[
                                    self.vocab.primitive[a_tm1.token]]

                            a_tm1_embeds.append(a_tm1_embed)
                        else:
                            a_tm1_embeds.append(zero_action_embed)
                    a_tm1_embeds = torch.stack(a_tm1_embeds)

                    # frontier production
                    frontier_prods = [
                        hyp.frontier_node.production for hyp in hypotheses
                    ]
                    frontier_prod_embeds = self.production_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ])))

                    # frontier field
                    frontier_fields = [
                        hyp.frontier_field.field for hyp in hypotheses
                    ]
                    frontier_field_embeds = self.field_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ])))

                    # frontier field type
                    frontier_field_types = [
                        hyp.frontier_field.type for hyp in hypotheses
                    ]
                    frontier_field_type_embeds = self.type_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ])))

                    # parent states
                    p_ts = [
                        hyp.frontier_node.created_time for hyp in hypotheses
                    ]
                    hist_states = torch.stack([
                        hyp_states[hyp_id][p_t]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])

                    x = torch.cat([
                        a_tm1_embeds, att_tm1, frontier_prod_embeds,
                        frontier_field_embeds, frontier_field_type_embeds,
                        hist_states
                    ],
                                  dim=-1)

                if args.lstm == 'lstm_with_dropout':
                    self.decoder_lstm.set_dropout_masks(hyp_num)

                (h_t, cell_t), att_t = self.step(x,
                                                 h_tm1,
                                                 exp_src_encodings,
                                                 exp_src_encodings_att_linear,
                                                 src_token_mask=None)

                # Variable(batch_size, grammar_size)
                apply_rule_log_prob = F.log_softmax(
                    self.production_readout(att_t), dim=-1)

                # Variable(batch_size, src_sent_len)
                primitive_copy_prob = self.src_pointer_net(
                    src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                # Variable(batch_size, primitive_vocab_size)
                gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                                dim=-1)

                # Variable(batch_size, 2)
                primitive_predictor_prob = F.softmax(
                    self.primitive_predictor(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                    1) * gen_from_vocab_prob
                if src_unk_pos_list:
                    primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

                gentoken_prev_hyp_ids = []
                gentoken_new_hyp_unks = []
                gentoken_copy_infos = []
                applyrule_new_hyp_scores = []
                applyrule_new_hyp_prod_ids = []
                applyrule_prev_hyp_ids = []

                for hyp_id, hyp in enumerate(hypotheses):
                    # generate new continuations
                    action_types = self.transition_system.get_valid_continuation_types(
                        hyp)

                    for action_type in action_types:
                        if action_type == ApplyRuleAction:
                            productions = self.transition_system.get_valid_continuating_productions(
                                hyp)
                            for production in productions:
                                prod_id = self.grammar.prod2id[production]
                                prod_score = apply_rule_log_prob[
                                    hyp_id, prod_id].data[0]
                                new_hyp_score = hyp.score + prod_score

                                applyrule_new_hyp_scores.append(new_hyp_score)
                                applyrule_new_hyp_prod_ids.append(prod_id)
                                applyrule_prev_hyp_ids.append(hyp_id)
                        elif action_type == ReduceAction:
                            action_score = apply_rule_log_prob[
                                hyp_id, len(self.grammar)].data[0]
                            new_hyp_score = hyp.score + action_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(len(
                                self.grammar))
                            applyrule_prev_hyp_ids.append(hyp_id)
                        else:
                            # GenToken action
                            gentoken_prev_hyp_ids.append(hyp_id)
                            hyp_copy_info = dict()  # of (token_pos, copy_prob)
                            # first, we compute copy probabilities for tokens in the source sentence
                            for token_pos, token_vocab_id in enumerate(
                                    src_token_vocab_ids):
                                if token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id:
                                    p_copy = primitive_predictor_prob[
                                        hyp_id,
                                        1] * primitive_copy_prob[hyp_id,
                                                                 token_pos]
                                    primitive_prob[
                                        hyp_id,
                                        token_vocab_id] = primitive_prob[
                                            hyp_id, token_vocab_id] + p_copy

                                    token = src_sent[token_pos]
                                    hyp_copy_info[token] = (token_pos,
                                                            p_copy.data[0])

                            # second, add the probability of copying the most probable unk word
                            if src_unk_pos_list:
                                unk_pos = primitive_copy_prob[
                                    hyp_id][src_unk_pos_list].data.cpu().numpy(
                                    ).argmax()
                                unk_pos = src_unk_pos_list[unk_pos]
                                token = src_sent[unk_pos]
                                gentoken_new_hyp_unks.append(token)

                                unk_copy_score = primitive_predictor_prob[
                                    hyp_id, 1] * primitive_copy_prob[hyp_id,
                                                                     unk_pos]
                                primitive_prob[
                                    hyp_id,
                                    primitive_vocab.unk_id] = unk_copy_score

                                hyp_copy_info[token] = (unk_pos,
                                                        unk_copy_score.data[0])

                            gentoken_copy_infos.append(hyp_copy_info)

                new_hyp_scores = None
                if applyrule_new_hyp_scores:
                    new_hyp_scores = Variable(
                        self.new_tensor(applyrule_new_hyp_scores))
                if gentoken_prev_hyp_ids:
                    primitive_log_prob = torch.log(primitive_prob)
                    gen_token_new_hyp_scores = (
                        hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                        primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                    if new_hyp_scores is None:
                        new_hyp_scores = gen_token_new_hyp_scores
                    else:
                        new_hyp_scores = torch.cat(
                            [new_hyp_scores, gen_token_new_hyp_scores])

                # top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                #                                                  k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses)))

                pro = F.softmax(new_hyp_scores, 0)
                new_hyp_pos = torch.multinomial(pro, 1).data[0]
                new_hyp_score = new_hyp_scores[new_hyp_pos].data[0]

                live_hyp_ids = []
                new_hypotheses = []

                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)
                                ) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                         ) // primitive_prob.size(1)
                    # try:
                    copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word[
                                primitive_vocab.unk_id]
                    else:
                        token = primitive_vocab.id2word[token_id]

                    action = GenTokenAction(token)

                    if token in copy_info:
                        action_info.copy_from_src = True
                        action_info.src_token_position = copy_info[token][0]

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

                if live_hyp_ids:
                    hyp_states = [
                        hyp_states[i] + [h_t[i]] for i in live_hyp_ids
                    ]
                    h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                    att_tm1 = att_t[live_hyp_ids]
                    hypotheses = new_hypotheses
                    hyp_scores = Variable(
                        self.new_tensor([hyp.score for hyp in hypotheses]))
                    t += 1
                else:
                    break
        return completed_hypotheses
コード例 #16
0
    def sample(self,
               src_sent,
               sample_size,
               decode_max_time_step,
               cuda=False,
               mode='sample'):
        new_float_tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
        new_long_tensor = torch.cuda.LongTensor if cuda else torch.LongTensor

        src_sent_var = nn_utils.to_input_variable([src_sent],
                                                  self.src_vocab,
                                                  cuda=cuda,
                                                  training=False)

        # analyze which tokens can be copied from the source
        src_token_tgt_vocab_ids = [self.tgt_vocab[token] for token in src_sent]
        src_unk_pos_list = [
            pos for pos, token_id in enumerate(src_token_tgt_vocab_ids)
            if token_id == self.tgt_vocab.unk_id
        ]
        # sometimes a word may appear multi-times in the source, in this case,
        # we just copy its first appearing position. Therefore we mask the words
        # appearing second and onwards to -1
        token_set = set()
        for i, tid in enumerate(src_token_tgt_vocab_ids):
            if tid in token_set:
                src_token_tgt_vocab_ids[i] = -1
            else:
                token_set.add(tid)

        src_encodings, (last_state,
                        last_cell) = self.encode(src_sent_var, [len(src_sent)])
        h_tm1 = self.init_decoder_state(last_state, last_cell)

        # (batch_size, 1, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        t = 0
        eos_id = self.tgt_vocab['</s>']

        completed_hypotheses = []
        completed_hypothesis_scores = []

        if mode == 'beam_search':
            hypotheses = [['<s>']]
            hypotheses_word_ids = [[self.tgt_vocab['<s>']]]
        else:
            hypotheses = [['<s>'] for _ in xrange(sample_size)]
            hypotheses_word_ids = [[self.tgt_vocab['<s>']]
                                   for _ in xrange(sample_size)]

        att_tm1 = Variable(new_float_tensor(len(hypotheses),
                                            self.hidden_size).zero_(),
                           volatile=True)
        hyp_scores = Variable(new_float_tensor(len(hypotheses)).zero_(),
                              volatile=True)

        while len(completed_hypotheses
                  ) < sample_size and t < decode_max_time_step:
            t += 1
            hyp_num = len(hypotheses)

            expanded_src_encodings = src_encodings.expand(
                hyp_num, src_encodings.size(1), src_encodings.size(2))
            expanded_src_encodings_att_linear = src_encodings_att_linear.expand(
                hyp_num, src_encodings_att_linear.size(1),
                src_encodings_att_linear.size(2))

            y_tm1 = Variable(new_long_tensor(
                [hyp[-1] for hyp in hypotheses_word_ids]),
                             volatile=True)
            y_tm1_embed = self.tgt_embed(y_tm1)

            x = torch.cat([y_tm1_embed, att_tm1], 1)

            (h_t,
             cell_t), att_t = self.step(x, h_tm1, expanded_src_encodings,
                                        expanded_src_encodings_att_linear)

            # (batch_size, 2)
            tgt_token_predictor = F.softmax(self.tgt_token_predictor(att_t),
                                            dim=-1)

            # (batch_size, tgt_vocab_size)
            token_gen_prob = F.softmax(self.readout(att_t), dim=-1)

            # (batch_size, src_sent_len)
            token_copy_prob = self.src_pointer_net(
                src_encodings,
                src_token_mask=None,
                query_vec=att_t.unsqueeze(0)).squeeze(0)

            # (batch_size, tgt_vocab_size)
            token_gen_prob = tgt_token_predictor[:, 0].unsqueeze(
                1) * token_gen_prob

            for token_pos, token_vocab_id in enumerate(
                    src_token_tgt_vocab_ids):
                if token_vocab_id != -1 and token_vocab_id != self.tgt_vocab.unk_id:
                    p_copy = tgt_token_predictor[:,
                                                 1] * token_copy_prob[:,
                                                                      token_pos]
                    token_gen_prob[:,
                                   token_vocab_id] = token_gen_prob[:,
                                                                    token_vocab_id] + p_copy

            # second, add the probability of copying the most probable unk word
            gentoken_new_hyp_unks = []
            if src_unk_pos_list:
                for hyp_id in xrange(hyp_num):
                    unk_pos = token_copy_prob[hyp_id][
                        src_unk_pos_list].data.cpu().numpy().argmax()
                    unk_pos = src_unk_pos_list[unk_pos]
                    token = src_sent[unk_pos]
                    gentoken_new_hyp_unks.append(token)

                    unk_copy_score = tgt_token_predictor[
                        hyp_id, 1] * token_copy_prob[hyp_id, unk_pos]
                    token_gen_prob[hyp_id,
                                   self.tgt_vocab.unk_id] = unk_copy_score

            live_hyp_num = sample_size - len(completed_hypotheses)

            if mode == 'beam_search':
                log_token_gen_prob = torch.log(token_gen_prob)
                new_hyp_scores = (
                    hyp_scores.unsqueeze(1).expand_as(token_gen_prob) +
                    log_token_gen_prob).view(-1)
                top_new_hyp_scores, top_new_hyp_pos = torch.topk(
                    new_hyp_scores, k=live_hyp_num)
                prev_hyp_ids = (top_new_hyp_pos /
                                len(self.tgt_vocab)).cpu().data
                word_ids = (top_new_hyp_pos % len(self.tgt_vocab)).cpu().data
                top_new_hyp_scores = top_new_hyp_scores.cpu().data
            else:
                word_ids = torch.multinomial(token_gen_prob, num_samples=1)
                prev_hyp_ids = range(live_hyp_num)
                top_new_hyp_scores = hyp_scores + torch.log(
                    torch.gather(token_gen_prob, dim=1,
                                 index=word_ids)).squeeze(1)
                top_new_hyp_scores = top_new_hyp_scores.cpu().data
                word_ids = word_ids.view(-1).cpu().data

            new_hypotheses = []
            new_hypotheses_word_ids = []
            live_hyp_ids = []
            new_hyp_scores = []
            for prev_hyp_id, word_id, new_hyp_score in zip(
                    prev_hyp_ids, word_ids, top_new_hyp_scores):
                if word_id == eos_id:
                    hyp_tgt_words = hypotheses[prev_hyp_id][1:]
                    completed_hypotheses.append(
                        hyp_tgt_words
                    )  # remove <s> and </s> in completed hypothesis
                    completed_hypothesis_scores.append(new_hyp_score)
                else:
                    if word_id == self.tgt_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            word = gentoken_new_hyp_unks[prev_hyp_id]
                        else:
                            word = self.tgt_vocab.id2word[
                                self.tgt_vocab.unk_id]
                    else:
                        word = self.tgt_vocab.id2word[word_id]

                    hyp_tgt_words = hypotheses[prev_hyp_id] + [word]
                    new_hypotheses.append(hyp_tgt_words)
                    new_hypotheses_word_ids.append(
                        hypotheses_word_ids[prev_hyp_id] + [word_id])
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(new_hyp_score)

            if len(completed_hypotheses) == sample_size:
                break

            live_hyp_ids = new_long_tensor(live_hyp_ids)
            h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
            att_tm1 = att_t[live_hyp_ids]

            hyp_scores = Variable(
                new_float_tensor(new_hyp_scores),
                volatile=True)  # new_hyp_scores[live_hyp_ids]
            hypotheses = new_hypotheses
            hypotheses_word_ids = new_hypotheses_word_ids

        return completed_hypotheses
コード例 #17
0
ファイル: seq2seq.py プロジェクト: chubbymaggie/tranX
    def beam_search(self, src_sents, decode_max_time_step, beam_size=5, to_word=True):
        """
        given a not-batched source, sentence perform beam search to find the n-best
        :param src_sent: List[word_id], encoded source sentence
        :return: list[list[word_id]] top-k predicted natural language sentence in the beam
        """
        src_sents_var = nn_utils.to_input_variable(src_sents, self.src_vocab,
                                                   cuda=self.cuda, training=False, append_boundary_sym=False)

        #TODO(junxian): check if src_sents_var(src_seq_length, embed_size) is ok
        src_encodings, (last_state, last_cell) = self.encode(src_sents_var, [len(src_sents[0])])
        # (1, query_len, hidden_size * 2)
        src_encodings = src_encodings.permute(1, 0, 2)
        src_encodings_att_linear = self.att_src_linear(src_encodings)
        h_tm1 = self.init_decoder_state(last_state, last_cell)

        # tensor constructors
        new_float_tensor = src_encodings.data.new
        if self.cuda:
            new_long_tensor = torch.cuda.LongTensor
        else:
            new_long_tensor = torch.LongTensor

        att_tm1 = Variable(torch.zeros(1, self.hidden_size), volatile=True)
        hyp_scores = Variable(torch.zeros(1), volatile=True)
        if self.cuda:
            att_tm1 = att_tm1.cuda()
            hyp_scores = hyp_scores.cuda()

        eos_id = self.tgt_vocab['</s>']
        bos_id = self.tgt_vocab['<s>']
        tgt_vocab_size = len(self.tgt_vocab)

        hypotheses = [[bos_id]]
        completed_hypotheses = []
        completed_hypothesis_scores = []

        t = 0
        while len(completed_hypotheses) < beam_size and t < decode_max_time_step:
            t += 1
            hyp_num = len(hypotheses)

            expanded_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2))
            expanded_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))

            y_tm1 = Variable(new_long_tensor([hyp[-1] for hyp in hypotheses]), volatile=True)
            y_tm1_embed = self.tgt_embed(y_tm1)

            x = torch.cat([y_tm1_embed, att_tm1], 1)

            # h_t: (hyp_num, hidden_size)
            (h_t, cell_t), att_t, score_t = self.step(x, h_tm1,
                                                      expanded_src_encodings, expanded_src_encodings_att_linear,
                                                      src_sent_masks=None)

            p_t = F.log_softmax(score_t)

            live_hyp_num = beam_size - len(completed_hypotheses)
            new_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(p_t) + p_t).view(-1)
            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=live_hyp_num)
            prev_hyp_ids = top_new_hyp_pos / tgt_vocab_size
            word_ids = top_new_hyp_pos % tgt_vocab_size

            new_hypotheses = []

            live_hyp_ids = []
            new_hyp_scores = []
            for prev_hyp_id, word_id, new_hyp_score in zip(prev_hyp_ids.cpu().data, word_ids.cpu().data, top_new_hyp_scores.cpu().data):
                hyp_tgt_words = hypotheses[prev_hyp_id] + [word_id]
                if word_id == eos_id:
                    completed_hypotheses.append(hyp_tgt_words[1:-1])  # remove <s> and </s> in completed hypothesis
                    completed_hypothesis_scores.append(new_hyp_score)
                else:
                    new_hypotheses.append(hyp_tgt_words)
                    live_hyp_ids.append(prev_hyp_id)
                    new_hyp_scores.append(new_hyp_score)

            if len(completed_hypotheses) == beam_size:
                break

            live_hyp_ids = new_long_tensor(live_hyp_ids)
            h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
            att_tm1 = att_t[live_hyp_ids]

            hyp_scores = Variable(new_float_tensor(new_hyp_scores), volatile=True)  # new_hyp_scores[live_hyp_ids]
            hypotheses = new_hypotheses

        if len(completed_hypotheses) == 0:
            completed_hypotheses = [hypotheses[0][1:-1]]  # remove <s> and </s> in completed hypothesis
            completed_hypothesis_scores = [0.0]

        if to_word:
            for i, hyp in enumerate(completed_hypotheses):
                completed_hypotheses[i] = [self.tgt_vocab.id2word[w] for w in hyp]

        ranked_hypotheses = sorted(zip(completed_hypotheses, completed_hypothesis_scores), key=lambda x: x[1], reverse=True)

        return [hyp for hyp, score in ranked_hypotheses]
コード例 #18
0
ファイル: dataset.py プロジェクト: liuhuigmail/CoDas4CG
 def src_sents_var(self):
     return nn_utils.to_input_variable(self.src_sents,
                                       self.vocab.source,
                                       cuda=self.cuda)
コード例 #19
0
def train_lstm_lm(args):
    all_data = load_code_dir(args.code_dir)
    np.random.shuffle(all_data)
    train_data = all_data[:-1000]
    dev_data = all_data[-1000:]
    print('train data size: %d, dev data size: %d' %
          (len(train_data), len(dev_data)),
          file=sys.stderr)

    vocab = VocabEntry.from_corpus([e['tokens'] for e in train_data],
                                   size=args.vocab_size,
                                   freq_cutoff=args.freq_cutoff)
    print('vocab size: %d' % len(vocab), file=sys.stderr)

    model = LSTMPrior(args, vocab)
    model.train()
    if args.cuda: model.cuda()
    optimizer = torch.optim.Adam(model.parameters())

    def evaluate_ppl():
        model.eval()
        cum_loss = 0.
        cum_tgt_words = 0.
        for examples in nn_utils.batch_iter(dev_data, args.batch_size):
            batch_tokens = [e['tokens'] for e in examples]
            batch = nn_utils.to_input_variable(batch_tokens,
                                               vocab,
                                               cuda=args.cuda,
                                               append_boundary_sym=True)
            loss = model.forward(batch).sum()
            cum_loss += loss.data[0]
            cum_tgt_words += sum(len(tokens) + 1
                                 for tokens in batch_tokens)  # add ending </s>

        ppl = np.exp(cum_loss / cum_tgt_words)
        model.train()
        return ppl

    print('begin training decoder, %d training examples, %d dev examples' %
          (len(train_data), len(dev_data)),
          file=sys.stderr)

    epoch = num_trial = train_iter = patience = 0
    report_loss = report_examples = 0.
    history_dev_scores = []
    while True:
        epoch += 1
        epoch_begin = time.time()

        for examples in nn_utils.batch_iter(train_data,
                                            batch_size=args.batch_size,
                                            shuffle=True):
            train_iter += 1
            optimizer.zero_grad()

            batch_tokens = [e['tokens'] for e in examples]
            batch = nn_utils.to_input_variable(batch_tokens,
                                               vocab,
                                               cuda=args.cuda,
                                               append_boundary_sym=True)
            loss = model.forward(batch)
            # print(loss.data)
            loss_val = torch.sum(loss).data[0]
            report_loss += loss_val
            report_examples += len(examples)
            loss = torch.mean(loss)

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                      args.clip_grad)

            optimizer.step()

            if train_iter % args.log_every == 0:
                print('[Iter %d] encoder loss=%.5f' %
                      (train_iter, report_loss / report_examples),
                      file=sys.stderr)

                report_loss = report_examples = 0.

        print('[Epoch %d] epoch elapsed %ds' %
              (epoch, time.time() - epoch_begin),
              file=sys.stderr)
        # model_file = args.save_to + '.iter%d.bin' % train_iter
        # print('save model to [%s]' % model_file, file=sys.stderr)
        # model.save(model_file)

        # perform validation
        print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
        eval_start = time.time()
        # evaluate ppl
        ppl = evaluate_ppl()
        print('[Epoch %d] ppl=%.5f took %ds' %
              (epoch, ppl, time.time() - eval_start),
              file=sys.stderr)
        dev_acc = -ppl
        is_better = history_dev_scores == [] or dev_acc > max(
            history_dev_scores)
        history_dev_scores.append(dev_acc)

        if is_better:
            patience = 0
            model_file = args.save_to + '.bin'
            print('save currently the best model ..', file=sys.stderr)
            print('save model to [%s]' % model_file, file=sys.stderr)
            model.save(model_file)
            # also save the optimizers' state
            torch.save(optimizer.state_dict(), args.save_to + '.optim.bin')
        elif patience < args.patience:
            patience += 1
            print('hit patience %d' % patience, file=sys.stderr)

        if patience == args.patience:
            num_trial += 1
            print('hit #%d trial' % num_trial, file=sys.stderr)
            if num_trial == args.max_num_trial:
                print('early stop!', file=sys.stderr)
                exit(0)

            # decay lr, and restore from previously best checkpoint
            lr = optimizer.param_groups[0]['lr'] * args.lr_decay
            print('load previously best model and decay learning rate to %f' %
                  lr,
                  file=sys.stderr)

            # load model
            params = torch.load(args.save_to + '.bin',
                                map_location=lambda storage, loc: storage)
            model.load_state_dict(params['state_dict'])
            if args.cuda: model = model.cuda()

            # load optimizers
            if args.reset_optimizer:
                print('reset optimizer', file=sys.stderr)
                optimizer = torch.optim.Adam(
                    model.inference_model.parameters(), lr=lr)
            else:
                print('restore parameters of the optimizers', file=sys.stderr)
                optimizer.load_state_dict(
                    torch.load(args.save_to + '.optim.bin'))

            # set new lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # reset patience
            patience = 0
コード例 #20
0
ファイル: dataset.py プロジェクト: chubbymaggie/tranX
 def src_sents_var(self):
     return nn_utils.to_input_variable(self.src_sents, self.vocab.source,
                                       cuda=self.cuda)
コード例 #21
0
ファイル: parser.py プロジェクト: chubbymaggie/tranX
    def parse(self, question, context, beam_size=5):
        table = context
        args = self.args
        src_sent_var = nn_utils.to_input_variable([question], self.vocab.source,
                                                  cuda=self.args.cuda, training=False)

        utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)])
        dec_init_vec = self.init_decoder_state(last_state, last_cell)

        column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table])

        h_tm1 = dec_init_vec
        # (batch_size, query_len, hidden_size)
        utterance_encodings_att_linear = self.att_src_linear(utterance_encodings)

        zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_())

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num,
                                                                                 utterance_encodings_att_linear.size(1),
                                                                                 utterance_encodings_att_linear.size(2))

            # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state]
            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)

                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.hidden_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                a_tm1_embeds = []
                for e_id, hyp in enumerate(hypotheses):
                    action_tm1 = hyp.actions[-1]
                    if action_tm1:
                        if isinstance(action_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
                        elif isinstance(action_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        elif isinstance(action_tm1, WikiSqlSelectColumnAction):
                            a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id])
                        elif isinstance(action_tm1, GenTokenAction):
                            a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]]
                        else:
                            raise ValueError('unknown action %s' % action_tm1)
                    else:
                        a_tm1_embed = zero_action_embed

                    a_tm1_embeds.append(a_tm1_embed)

                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # ApplyRule action probability
            # (batch_size, grammar_size)
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # column attention
            # (batch_size, max_head_num)
            column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask,
                                                               att_t.unsqueeze(0)).squeeze(0)
            column_selection_log_prob = torch.log(column_attention_weights)

            # (batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # primitive copy prob
            # (batch_size, src_token_num)
            primitive_copy_prob = self.src_pointer_net(utterance_encodings, None,
                                                       att_t.unsqueeze(0)).squeeze(0)

            # (batch_size, primitive_vocab_size)
            primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            new_hyp_meta = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id]
                            new_hyp_score = hyp.score + prod_score

                            meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
                                          'score': prod_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)]
                        new_hyp_score = hyp.score + action_score

                        meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
                                      'score': action_score, 'new_hyp_score': new_hyp_score,
                                      'prev_hyp_id': hyp_id}
                        new_hyp_meta.append(meta_entry)
                    elif action_type == WikiSqlSelectColumnAction:
                        for col_id, column in enumerate(table.header):
                            col_sel_score = column_selection_log_prob[hyp_id, col_id]
                            new_hyp_score = hyp.score + col_sel_score

                            meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
                                          'score': col_sel_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == GenTokenAction:
                        # remember that we can only copy stuff from the input!
                        # we only copy tokens sequentially!!
                        prev_action = hyp.action_infos[-1].action

                        valid_token_pos_list = []
                        if type(prev_action) is GenTokenAction and \
                                not prev_action.is_stop_signal():
                            token_pos = hyp.action_infos[-1].src_token_position + 1
                            if token_pos < len(question):
                                valid_token_pos_list = [token_pos]
                        else:
                            valid_token_pos_list = list(range(len(question)))

                        col_id = hyp.frontier_node['col_idx'].value
                        if table.header[col_id].type == 'real':
                            valid_token_pos_list = [i for i in valid_token_pos_list
                                                    if any(c.isdigit() for c in question[i]) or
                                                    hyp._value_buffer and question[i] in (',', '.', '-', '%')]

                        p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id]
                        for token_pos in valid_token_pos_list:
                            token = question[token_pos]
                            p_copy = p_copies[token_pos]
                            score_copy = torch.log(p_copy)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': token, 'token_pos': token_pos,
                                          'score': score_copy, 'new_hyp_score': score_copy + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

                        # add generation probability for </primitive>
                        if hyp._value_buffer:
                            eos_prob = primitive_predictor_prob[hyp_id, 0] * \
                                       primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']]
                            eos_score = torch.log(eos_prob)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': '</primitive>',
                                          'score': eos_score, 'new_hyp_score': eos_score + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

            if not new_hyp_meta: break

            new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta])
            top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
                                                      k=min(new_hyp_scores.size(0),
                                                            beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
                action_info = ActionInfo()
                hyp_meta_entry = new_hyp_meta[meta_id]
                prev_hyp_id = hyp_meta_entry['prev_hyp_id']
                prev_hyp = hypotheses[prev_hyp_id]

                action_type_str = hyp_meta_entry['action_type']
                if action_type_str == 'apply_rule':
                    # ApplyRule action
                    prod_id = hyp_meta_entry['prod_id']
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                elif action_type_str == 'sel_col':
                    action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id'])
                else:
                    action = GenTokenAction(hyp_meta_entry['token'])
                    if 'token_pos' in hyp_meta_entry:
                        action_info.copy_from_src = True
                        action_info.src_token_position = hyp_meta_entry['token_pos']

                action_info.action = action
                action_info.t = t

                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                t += 1
            else: break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
コード例 #22
0
    def parse(self, src_sent, context=None, beam_size=5, debug=False):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        args = self.args
        primitive_vocab = self.vocab.primitive
        T = torch.cuda if args.cuda else torch

        src_sent_var = nn_utils.to_input_variable([src_sent],
                                                  self.vocab.source,
                                                  cuda=args.cuda,
                                                  training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        ##########src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)])
        src_encodings, last_state = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        #######dec_init_vec = self.init_decoder_state(last_state, last_cell)
        dec_init_vec = self.init_decoder_state(last_state)  #####
        if args.lstm == 'parent_feed':
            h_tm1 = dec_init_vec[0], dec_init_vec[1], \
                    Variable(self.new_tensor(args.hidden_size).zero_()), \
                    Variable(self.new_tensor(args.hidden_size).zero_())
        else:
            h_tm1 = dec_init_vec

        zero_action_embed = Variable(
            self.new_tensor(args.action_embed_size).zero_())

        with torch.no_grad():
            hyp_scores = Variable(self.new_tensor([0.]))

        # For computing copy probabilities, we marginalize over tokens with the same surface form
        # `aggregated_primitive_tokens` stores the position of occurrence of each source token
        aggregated_primitive_tokens = OrderedDict()
        for token_pos, token in enumerate(src_sent):
            aggregated_primitive_tokens.setdefault(token, []).append(token_pos)

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses
                  ) < beam_size and t < args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = src_encodings.expand(hyp_num,
                                                     src_encodings.size(1),
                                                     src_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = src_encodings_att_linear.expand(
                hyp_num, src_encodings_att_linear.size(1),
                src_encodings_att_linear.size(2))

            if t == 0:
                with torch.no_grad():
                    x = Variable(
                        self.new_tensor(1,
                                        self.decoder_lstm.input_size).zero_())
                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.att_vec_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (
                        not args.no_parent_production_embed)
                    offset += args.field_embed_size * (
                        not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                a_tm1_embeds = []
                for a_tm1 in actions_tm1:
                    if a_tm1:
                        if isinstance(a_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[
                                self.grammar.prod2id[a_tm1.production]]
                        elif isinstance(a_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(
                                self.grammar)]
                        else:
                            a_tm1_embed = self.primitive_embed.weight[
                                self.vocab.primitive[a_tm1.token]]

                        a_tm1_embeds.append(a_tm1_embed)
                    else:
                        a_tm1_embeds.append(zero_action_embed)
                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [
                        hyp.frontier_node.production for hyp in hypotheses
                    ]
                    frontier_prod_embeds = self.production_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [
                        hyp.frontier_field.field for hyp in hypotheses
                    ]
                    frontier_field_embeds = self.field_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [
                        hyp.frontier_field.type for hyp in hypotheses
                    ]
                    frontier_field_type_embeds = self.type_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [
                        hyp.frontier_node.created_time for hyp in hypotheses
                    ]
                    parent_states = torch.stack([
                        hyp_states[hyp_id][p_t][0]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])
                    parent_cells = torch.stack([
                        hyp_states[hyp_id][p_t][1]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states,
                                 parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x,
                                             h_tm1,
                                             exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # Variable(batch_size, grammar_size)
            # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t),
                                                dim=-1)

            # Variable(batch_size, primitive_vocab_size)
            gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                            dim=-1)

            if args.no_copy:
                primitive_prob = gen_from_vocab_prob
            else:
                # Variable(batch_size, src_sent_len)
                primitive_copy_prob = self.src_pointer_net(
                    src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                # Variable(batch_size, 2)
                primitive_predictor_prob = F.softmax(
                    self.primitive_predictor(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                    1) * gen_from_vocab_prob

                # if src_unk_pos_list:
                #     primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

            gentoken_prev_hyp_ids = []
            gentoken_new_hyp_unks = []
            applyrule_new_hyp_scores = []
            applyrule_new_hyp_prod_ids = []
            applyrule_prev_hyp_ids = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(
                    hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(
                            hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[
                                hyp_id, prod_id].data.item()
                            new_hyp_score = hyp.score + prod_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(prod_id)
                            applyrule_prev_hyp_ids.append(hyp_id)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[
                            hyp_id, len(self.grammar)].data.item()
                        new_hyp_score = hyp.score + action_score

                        applyrule_new_hyp_scores.append(new_hyp_score)
                        applyrule_new_hyp_prod_ids.append(len(self.grammar))
                        applyrule_prev_hyp_ids.append(hyp_id)
                    else:
                        # GenToken action
                        gentoken_prev_hyp_ids.append(hyp_id)
                        hyp_copy_info = dict()  # of (token_pos, copy_prob)
                        hyp_unk_copy_info = []

                        if args.no_copy is False:
                            for token, token_pos_list in aggregated_primitive_tokens.items(
                            ):
                                sum_copy_prob = torch.gather(
                                    primitive_copy_prob[hyp_id], 0,
                                    Variable(
                                        T.LongTensor(token_pos_list))).sum()
                                gated_copy_prob = primitive_predictor_prob[
                                    hyp_id, 1] * sum_copy_prob

                                if token in primitive_vocab:
                                    token_id = primitive_vocab[token]
                                    primitive_prob[
                                        hyp_id, token_id] = primitive_prob[
                                            hyp_id, token_id] + gated_copy_prob

                                    hyp_copy_info[token] = (
                                        token_pos_list,
                                        gated_copy_prob.data.item())
                                else:
                                    hyp_unk_copy_info.append({
                                        'token':
                                        token,
                                        'token_pos_list':
                                        token_pos_list,
                                        'copy_prob':
                                        gated_copy_prob.data.item()
                                    })

                        if args.no_copy is False and len(
                                hyp_unk_copy_info) > 0:
                            unk_i = np.array([
                                x['copy_prob'] for x in hyp_unk_copy_info
                            ]).argmax()
                            token = hyp_unk_copy_info[unk_i]['token']
                            primitive_prob[
                                hyp_id, primitive_vocab.
                                unk_id] = hyp_unk_copy_info[unk_i]['copy_prob']
                            gentoken_new_hyp_unks.append(token)

                            hyp_copy_info[token] = (
                                hyp_unk_copy_info[unk_i]['token_pos_list'],
                                hyp_unk_copy_info[unk_i]['copy_prob'])

            new_hyp_scores = None
            if applyrule_new_hyp_scores:
                new_hyp_scores = Variable(
                    self.new_tensor(applyrule_new_hyp_scores))
            if gentoken_prev_hyp_ids:
                primitive_log_prob = torch.log(primitive_prob)
                gen_token_new_hyp_scores = (
                    hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                    primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                if new_hyp_scores is None:
                    new_hyp_scores = gen_token_new_hyp_scores
                else:
                    new_hyp_scores = torch.cat(
                        [new_hyp_scores, gen_token_new_hyp_scores])
            top_new_hyp_scores, top_new_hyp_pos = torch.topk(
                new_hyp_scores,
                k=min(new_hyp_scores.size(0),
                      beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, new_hyp_pos in zip(
                    top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()):
                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)
                                ) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                         ) // primitive_prob.size(1)
                    # try:
                    # copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]
                    # except:
                    #     print('k=%d' % k, file=sys.stderr)
                    #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                    #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                    #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                    #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                    #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                    #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                    #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                    #
                    #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                    #
                    #     # exit(-1)
                    #     raise ValueError()

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word(
                                primitive_vocab.unk_id)  ######
                    else:
                        token = primitive_vocab.id2word(token_id.item())

                    action = GenTokenAction(token)

                    if token in aggregated_primitive_tokens:
                        action_info.copy_from_src = True
                        action_info.src_token_position = aggregated_primitive_tokens[
                            token]

                    if debug:
                        action_info.gen_copy_switch = 'n/a' if args.no_copy else primitive_predictor_prob[
                            prev_hyp_id, :].log().cpu().data.numpy()
                        action_info.in_vocab = token in primitive_vocab
                        action_info.gen_token_prob = gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu().data.item() \
                            if token in primitive_vocab else 'n/a'
                        action_info.copy_token_prob = torch.gather(primitive_copy_prob[prev_hyp_id],
                                                                   0,
                                                                   Variable(T.LongTensor(action_info.src_token_position))).sum().log().cpu().data.item() \
                            if args.no_copy is False and action_info.copy_from_src else 'n/a'

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                if debug:
                    action_info.action_prob = new_hyp_score - prev_hyp.score

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    # add length normalization
                    new_hyp.score /= (t + 1)
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [
                    hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids
                ]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                hyp_scores = Variable(
                    self.new_tensor([hyp.score for hyp in hypotheses]))
                t += 1
            else:
                break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
コード例 #23
0
ファイル: seq2seq.py プロジェクト: chubbymaggie/tranX
 def sample(self, src_sents, sample_size):
     src_sents_len = [len(src_sent) for src_sent in src_sents]
     # Variable: (src_sent_len, batch_size)
     src_sents_var = nn_utils.to_input_variable(src_sents, self.vocab.src,
                                                cuda=self.cuda, training=False)
     return self.sample_from_variable(src_sents_var, src_sents_len, sample_size)