Example #1
0
    def parse(self, src_sent, context=None, beam_size=5, debug=False):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        args = self.args
        primitive_vocab = self.vocab.primitive
        T = torch.cuda if args.cuda else torch

        src_sent_var = nn_utils.to_input_variable([src_sent],
                                                  self.vocab.source,
                                                  cuda=args.cuda,
                                                  training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        ##########src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)])
        src_encodings, last_state = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        #######dec_init_vec = self.init_decoder_state(last_state, last_cell)
        dec_init_vec = self.init_decoder_state(last_state)  #####
        if args.lstm == 'parent_feed':
            h_tm1 = dec_init_vec[0], dec_init_vec[1], \
                    Variable(self.new_tensor(args.hidden_size).zero_()), \
                    Variable(self.new_tensor(args.hidden_size).zero_())
        else:
            h_tm1 = dec_init_vec

        zero_action_embed = Variable(
            self.new_tensor(args.action_embed_size).zero_())

        with torch.no_grad():
            hyp_scores = Variable(self.new_tensor([0.]))

        # For computing copy probabilities, we marginalize over tokens with the same surface form
        # `aggregated_primitive_tokens` stores the position of occurrence of each source token
        aggregated_primitive_tokens = OrderedDict()
        for token_pos, token in enumerate(src_sent):
            aggregated_primitive_tokens.setdefault(token, []).append(token_pos)

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses
                  ) < beam_size and t < args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = src_encodings.expand(hyp_num,
                                                     src_encodings.size(1),
                                                     src_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = src_encodings_att_linear.expand(
                hyp_num, src_encodings_att_linear.size(1),
                src_encodings_att_linear.size(2))

            if t == 0:
                with torch.no_grad():
                    x = Variable(
                        self.new_tensor(1,
                                        self.decoder_lstm.input_size).zero_())
                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.att_vec_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (
                        not args.no_parent_production_embed)
                    offset += args.field_embed_size * (
                        not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                a_tm1_embeds = []
                for a_tm1 in actions_tm1:
                    if a_tm1:
                        if isinstance(a_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[
                                self.grammar.prod2id[a_tm1.production]]
                        elif isinstance(a_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(
                                self.grammar)]
                        else:
                            a_tm1_embed = self.primitive_embed.weight[
                                self.vocab.primitive[a_tm1.token]]

                        a_tm1_embeds.append(a_tm1_embed)
                    else:
                        a_tm1_embeds.append(zero_action_embed)
                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [
                        hyp.frontier_node.production for hyp in hypotheses
                    ]
                    frontier_prod_embeds = self.production_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [
                        hyp.frontier_field.field for hyp in hypotheses
                    ]
                    frontier_field_embeds = self.field_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [
                        hyp.frontier_field.type for hyp in hypotheses
                    ]
                    frontier_field_type_embeds = self.type_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [
                        hyp.frontier_node.created_time for hyp in hypotheses
                    ]
                    parent_states = torch.stack([
                        hyp_states[hyp_id][p_t][0]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])
                    parent_cells = torch.stack([
                        hyp_states[hyp_id][p_t][1]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states,
                                 parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x,
                                             h_tm1,
                                             exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # Variable(batch_size, grammar_size)
            # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t),
                                                dim=-1)

            # Variable(batch_size, primitive_vocab_size)
            gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                            dim=-1)

            if args.no_copy:
                primitive_prob = gen_from_vocab_prob
            else:
                # Variable(batch_size, src_sent_len)
                primitive_copy_prob = self.src_pointer_net(
                    src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                # Variable(batch_size, 2)
                primitive_predictor_prob = F.softmax(
                    self.primitive_predictor(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                    1) * gen_from_vocab_prob

                # if src_unk_pos_list:
                #     primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

            gentoken_prev_hyp_ids = []
            gentoken_new_hyp_unks = []
            applyrule_new_hyp_scores = []
            applyrule_new_hyp_prod_ids = []
            applyrule_prev_hyp_ids = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(
                    hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(
                            hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[
                                hyp_id, prod_id].data.item()
                            new_hyp_score = hyp.score + prod_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(prod_id)
                            applyrule_prev_hyp_ids.append(hyp_id)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[
                            hyp_id, len(self.grammar)].data.item()
                        new_hyp_score = hyp.score + action_score

                        applyrule_new_hyp_scores.append(new_hyp_score)
                        applyrule_new_hyp_prod_ids.append(len(self.grammar))
                        applyrule_prev_hyp_ids.append(hyp_id)
                    else:
                        # GenToken action
                        gentoken_prev_hyp_ids.append(hyp_id)
                        hyp_copy_info = dict()  # of (token_pos, copy_prob)
                        hyp_unk_copy_info = []

                        if args.no_copy is False:
                            for token, token_pos_list in aggregated_primitive_tokens.items(
                            ):
                                sum_copy_prob = torch.gather(
                                    primitive_copy_prob[hyp_id], 0,
                                    Variable(
                                        T.LongTensor(token_pos_list))).sum()
                                gated_copy_prob = primitive_predictor_prob[
                                    hyp_id, 1] * sum_copy_prob

                                if token in primitive_vocab:
                                    token_id = primitive_vocab[token]
                                    primitive_prob[
                                        hyp_id, token_id] = primitive_prob[
                                            hyp_id, token_id] + gated_copy_prob

                                    hyp_copy_info[token] = (
                                        token_pos_list,
                                        gated_copy_prob.data.item())
                                else:
                                    hyp_unk_copy_info.append({
                                        'token':
                                        token,
                                        'token_pos_list':
                                        token_pos_list,
                                        'copy_prob':
                                        gated_copy_prob.data.item()
                                    })

                        if args.no_copy is False and len(
                                hyp_unk_copy_info) > 0:
                            unk_i = np.array([
                                x['copy_prob'] for x in hyp_unk_copy_info
                            ]).argmax()
                            token = hyp_unk_copy_info[unk_i]['token']
                            primitive_prob[
                                hyp_id, primitive_vocab.
                                unk_id] = hyp_unk_copy_info[unk_i]['copy_prob']
                            gentoken_new_hyp_unks.append(token)

                            hyp_copy_info[token] = (
                                hyp_unk_copy_info[unk_i]['token_pos_list'],
                                hyp_unk_copy_info[unk_i]['copy_prob'])

            new_hyp_scores = None
            if applyrule_new_hyp_scores:
                new_hyp_scores = Variable(
                    self.new_tensor(applyrule_new_hyp_scores))
            if gentoken_prev_hyp_ids:
                primitive_log_prob = torch.log(primitive_prob)
                gen_token_new_hyp_scores = (
                    hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                    primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                if new_hyp_scores is None:
                    new_hyp_scores = gen_token_new_hyp_scores
                else:
                    new_hyp_scores = torch.cat(
                        [new_hyp_scores, gen_token_new_hyp_scores])
            top_new_hyp_scores, top_new_hyp_pos = torch.topk(
                new_hyp_scores,
                k=min(new_hyp_scores.size(0),
                      beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, new_hyp_pos in zip(
                    top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()):
                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)
                                ) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                         ) // primitive_prob.size(1)
                    # try:
                    # copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]
                    # except:
                    #     print('k=%d' % k, file=sys.stderr)
                    #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                    #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                    #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                    #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                    #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                    #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                    #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                    #
                    #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                    #
                    #     # exit(-1)
                    #     raise ValueError()

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word(
                                primitive_vocab.unk_id)  ######
                    else:
                        token = primitive_vocab.id2word(token_id.item())

                    action = GenTokenAction(token)

                    if token in aggregated_primitive_tokens:
                        action_info.copy_from_src = True
                        action_info.src_token_position = aggregated_primitive_tokens[
                            token]

                    if debug:
                        action_info.gen_copy_switch = 'n/a' if args.no_copy else primitive_predictor_prob[
                            prev_hyp_id, :].log().cpu().data.numpy()
                        action_info.in_vocab = token in primitive_vocab
                        action_info.gen_token_prob = gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu().data.item() \
                            if token in primitive_vocab else 'n/a'
                        action_info.copy_token_prob = torch.gather(primitive_copy_prob[prev_hyp_id],
                                                                   0,
                                                                   Variable(T.LongTensor(action_info.src_token_position))).sum().log().cpu().data.item() \
                            if args.no_copy is False and action_info.copy_from_src else 'n/a'

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                if debug:
                    action_info.action_prob = new_hyp_score - prev_hyp.score

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    # add length normalization
                    new_hyp.score /= (t + 1)
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [
                    hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids
                ]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                hyp_scores = Variable(
                    self.new_tensor([hyp.score for hyp in hypotheses]))
                t += 1
            else:
                break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
Example #2
0
    def parse(self, src_sent, context=None, beam_size=5):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """
        # print('!!!!!!')
        args = self.args
        primitive_vocab = self.vocab.primitive

        src_sent_var = nn_utils.to_input_variable([src_sent], self.vocab.source, cuda=args.cuda, training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        dec_init_vec = self.init_decoder_state(last_state, last_cell)
        if args.lstm == 'parent_feed':
            h_tm1 = dec_init_vec[0], dec_init_vec[1], \
                    Variable(self.new_tensor(args.hidden_size).zero_()), \
                    Variable(self.new_tensor(args.hidden_size).zero_())
        else:
            h_tm1 = dec_init_vec

        zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())

        hyp_scores = Variable(self.new_tensor([0.]), volatile=True)

        src_token_vocab_ids = [primitive_vocab[token] for token in src_sent]
        src_unk_pos_list = [pos for pos, token_id in enumerate(src_token_vocab_ids) if token_id == primitive_vocab.unk_id]
        # sometimes a word may appear multi-times in the source, in this case,
        # we just copy its first appearing position. Therefore we mask the words
        # appearing second and onwards to -1
        token_set = set()
        for i, tid in enumerate(src_token_vocab_ids):
            if tid in token_set:
                src_token_vocab_ids[i] = -1
            else: token_set.add(tid)

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), src_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))

            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)
                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.att_vec_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                a_tm1_embeds = []
                for a_tm1 in actions_tm1:
                    if a_tm1:
                        if isinstance(a_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[a_tm1.production]]
                        elif isinstance(a_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        else:
                            a_tm1_embed = self.primitive_embed.weight[self.vocab.primitive[a_tm1.token]]

                        a_tm1_embeds.append(a_tm1_embed)
                    else:
                        a_tm1_embeds.append(zero_action_embed)
                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            if args.lstm == 'lstm_with_dropout':
                self.decoder_lstm.set_dropout_masks(hyp_num)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # Variable(batch_size, grammar_size)
            # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # Variable(batch_size, src_sent_len)
            primitive_copy_prob = self.src_pointer_net(src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

            # Variable(batch_size, primitive_vocab_size)
            gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            # Variable(batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # Variable(batch_size, primitive_vocab_size)
            primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(1) * gen_from_vocab_prob
            if src_unk_pos_list:
                primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

            gentoken_prev_hyp_ids = []
            gentoken_new_hyp_unks = []
            gentoken_copy_infos = []
            applyrule_new_hyp_scores = []
            applyrule_new_hyp_prod_ids = []
            applyrule_prev_hyp_ids = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id].data[0]
                            # print(type(hyp.score),type(prod_score))
                            new_hyp_score = hyp.score + prod_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(prod_id)
                            applyrule_prev_hyp_ids.append(hyp_id)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)].data[0]
                        new_hyp_score = hyp.score + action_score

                        applyrule_new_hyp_scores.append(new_hyp_score)
                        applyrule_new_hyp_prod_ids.append(len(self.grammar))
                        applyrule_prev_hyp_ids.append(hyp_id)
                    else:
                        # GenToken action
                        gentoken_prev_hyp_ids.append(hyp_id)
                        hyp_copy_info = dict()  # of (token_pos, copy_prob)
                        # first, we compute copy probabilities for tokens in the source sentence
                        for token_pos, token_vocab_id in enumerate(src_token_vocab_ids):
                            if args.no_copy is False and token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id:
                                p_copy = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, token_pos]
                                primitive_prob[hyp_id, token_vocab_id] = primitive_prob[hyp_id, token_vocab_id] + p_copy

                                token = src_sent[token_pos]
                                hyp_copy_info[token] = (token_pos, p_copy.data[0])

                        # second, add the probability of copying the most probable unk word
                        if args.no_copy is False and src_unk_pos_list:
                            unk_pos = primitive_copy_prob[hyp_id][src_unk_pos_list].data.cpu().numpy().argmax()
                            unk_pos = src_unk_pos_list[unk_pos]
                            token = src_sent[unk_pos]
                            gentoken_new_hyp_unks.append(token)

                            unk_copy_score = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id, unk_pos]
                            primitive_prob[hyp_id, primitive_vocab.unk_id] = unk_copy_score

                            hyp_copy_info[token] = (unk_pos, unk_copy_score.data[0])

                        gentoken_copy_infos.append(hyp_copy_info)

            new_hyp_scores = None
            if applyrule_new_hyp_scores:
                new_hyp_scores = Variable(self.new_tensor(applyrule_new_hyp_scores))
            if gentoken_prev_hyp_ids:
                primitive_log_prob = torch.log(primitive_prob)
                gen_token_new_hyp_scores = (hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) + primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                if new_hyp_scores is None: new_hyp_scores = gen_token_new_hyp_scores
                else: new_hyp_scores = torch.cat([new_hyp_scores, gen_token_new_hyp_scores])

            top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                                                             k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()):
            # for new_hyp_score, new_hyp_pos in zip(top_new_hyp_scores.data.cuda(), top_new_hyp_pos.data.cuda()):
                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)) // primitive_prob.size(1)
                    # try:
                    copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]
                    # except:
                    #     print('k=%d' % k, file=sys.stderr)
                    #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                    #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                    #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                    #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                    #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                    #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                    #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                    #
                    #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                    #
                    #     # exit(-1)
                    #     raise ValueError()

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word[primitive_vocab.unk_id]
                    else:
                        token = primitive_vocab.id2word[token_id]

                    action = GenTokenAction(token)

                    if token in copy_info:
                        action_info.copy_from_src = True
                        action_info.src_token_position = copy_info[token][0]

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score
                #advoid none new hyp
                if new_hyp.completed:
                    # print('What happened!')
                    # print(new_hyp.actions)
                    # print(len(new_hyp.actions))
                    if len(new_hyp.actions) != 2:
                        completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                hyp_scores = Variable(self.new_tensor([hyp.score for hyp in hypotheses]))
                t += 1
            else:
                break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
Example #3
0
    def parse(self, question, context, beam_size=5):
        table = context
        args = self.args
        src_sent_var = nn_utils.to_input_variable([question], self.vocab.source,
                                                  cuda=self.args.cuda, training=False)

        utterance_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(question)])
        dec_init_vec = self.init_decoder_state(last_state, last_cell)

        column_word_encodings, table_header_encoding, table_header_mask = self.encode_table_header([table])

        h_tm1 = dec_init_vec
        # (batch_size, query_len, hidden_size)
        utterance_encodings_att_linear = self.att_src_linear(utterance_encodings)

        zero_action_embed = Variable(self.new_tensor(self.args.action_embed_size).zero_())

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses) < beam_size and t < self.args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = utterance_encodings.expand(hyp_num, utterance_encodings.size(1), utterance_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = utterance_encodings_att_linear.expand(hyp_num,
                                                                                 utterance_encodings_att_linear.size(1),
                                                                                 utterance_encodings_att_linear.size(2))

            # x: [prev_action, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_action_state]
            if t == 0:
                x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)

                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.hidden_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (not args.no_parent_production_embed)
                    offset += args.field_embed_size * (not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                a_tm1_embeds = []
                for e_id, hyp in enumerate(hypotheses):
                    action_tm1 = hyp.actions[-1]
                    if action_tm1:
                        if isinstance(action_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
                        elif isinstance(action_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(self.grammar)]
                        elif isinstance(action_tm1, WikiSqlSelectColumnAction):
                            a_tm1_embed = self.column_rnn_input(table_header_encoding[0, action_tm1.column_id])
                        elif isinstance(action_tm1, GenTokenAction):
                            a_tm1_embed = self.src_embed.weight[self.vocab.source[action_tm1.token]]
                        else:
                            raise ValueError('unknown action %s' % action_tm1)
                    else:
                        a_tm1_embed = zero_action_embed

                    a_tm1_embeds.append(a_tm1_embed)

                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [hyp.frontier_node.production for hyp in hypotheses]
                    frontier_prod_embeds = self.production_embed(Variable(self.new_long_tensor(
                        [self.grammar.prod2id[prod] for prod in frontier_prods])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [hyp.frontier_field.field for hyp in hypotheses]
                    frontier_field_embeds = self.field_embed(Variable(self.new_long_tensor([
                        self.grammar.field2id[field] for field in frontier_fields])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [hyp.frontier_field.type for hyp in hypotheses]
                    frontier_field_type_embeds = self.type_embed(Variable(self.new_long_tensor([
                        self.grammar.type2id[type] for type in frontier_field_types])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [hyp.frontier_node.created_time for hyp in hypotheses]
                    parent_states = torch.stack([hyp_states[hyp_id][p_t][0] for hyp_id, p_t in enumerate(p_ts)])
                    parent_cells = torch.stack([hyp_states[hyp_id][p_t][1] for hyp_id, p_t in enumerate(p_ts)])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states, parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # ApplyRule action probability
            # (batch_size, grammar_size)
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)

            # column attention
            # (batch_size, max_head_num)
            column_attention_weights = self.column_pointer_net(table_header_encoding, table_header_mask,
                                                               att_t.unsqueeze(0)).squeeze(0)
            column_selection_log_prob = torch.log(column_attention_weights)

            # (batch_size, 2)
            primitive_predictor_prob = F.softmax(self.primitive_predictor(att_t), dim=-1)

            # primitive copy prob
            # (batch_size, src_token_num)
            primitive_copy_prob = self.src_pointer_net(utterance_encodings, None,
                                                       att_t.unsqueeze(0)).squeeze(0)

            # (batch_size, primitive_vocab_size)
            primitive_gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t), dim=-1)

            new_hyp_meta = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[hyp_id, prod_id]
                            new_hyp_score = hyp.score + prod_score

                            meta_entry = {'action_type': 'apply_rule', 'prod_id': prod_id,
                                          'score': prod_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[hyp_id, len(self.grammar)]
                        new_hyp_score = hyp.score + action_score

                        meta_entry = {'action_type': 'apply_rule', 'prod_id': len(self.grammar),
                                      'score': action_score, 'new_hyp_score': new_hyp_score,
                                      'prev_hyp_id': hyp_id}
                        new_hyp_meta.append(meta_entry)
                    elif action_type == WikiSqlSelectColumnAction:
                        for col_id, column in enumerate(table.header):
                            col_sel_score = column_selection_log_prob[hyp_id, col_id]
                            new_hyp_score = hyp.score + col_sel_score

                            meta_entry = {'action_type': 'sel_col', 'col_id': col_id,
                                          'score': col_sel_score, 'new_hyp_score': new_hyp_score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)
                    elif action_type == GenTokenAction:
                        # remember that we can only copy stuff from the input!
                        # we only copy tokens sequentially!!
                        prev_action = hyp.action_infos[-1].action

                        valid_token_pos_list = []
                        if type(prev_action) is GenTokenAction and \
                                not prev_action.is_stop_signal():
                            token_pos = hyp.action_infos[-1].src_token_position + 1
                            if token_pos < len(question):
                                valid_token_pos_list = [token_pos]
                        else:
                            valid_token_pos_list = list(range(len(question)))

                        col_id = hyp.frontier_node['col_idx'].value
                        if table.header[col_id].type == 'real':
                            valid_token_pos_list = [i for i in valid_token_pos_list
                                                    if any(c.isdigit() for c in question[i]) or
                                                    hyp._value_buffer and question[i] in (',', '.', '-', '%')]

                        p_copies = primitive_predictor_prob[hyp_id, 1] * primitive_copy_prob[hyp_id]
                        for token_pos in valid_token_pos_list:
                            token = question[token_pos]
                            p_copy = p_copies[token_pos]
                            score_copy = torch.log(p_copy)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': token, 'token_pos': token_pos,
                                          'score': score_copy, 'new_hyp_score': score_copy + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

                        # add generation probability for </primitive>
                        if hyp._value_buffer:
                            eos_prob = primitive_predictor_prob[hyp_id, 0] * \
                                       primitive_gen_from_vocab_prob[hyp_id, self.vocab.primitive['</primitive>']]
                            eos_score = torch.log(eos_prob)

                            meta_entry = {'action_type': 'gen_token',
                                          'token': '</primitive>',
                                          'score': eos_score, 'new_hyp_score': eos_score + hyp.score,
                                          'prev_hyp_id': hyp_id}
                            new_hyp_meta.append(meta_entry)

            if not new_hyp_meta: break

            new_hyp_scores = torch.cat([x['new_hyp_score'] for x in new_hyp_meta])
            top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
                                                      k=min(new_hyp_scores.size(0),
                                                            beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
                action_info = ActionInfo()
                hyp_meta_entry = new_hyp_meta[meta_id]
                prev_hyp_id = hyp_meta_entry['prev_hyp_id']
                prev_hyp = hypotheses[prev_hyp_id]

                action_type_str = hyp_meta_entry['action_type']
                if action_type_str == 'apply_rule':
                    # ApplyRule action
                    prod_id = hyp_meta_entry['prod_id']
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                elif action_type_str == 'sel_col':
                    action = WikiSqlSelectColumnAction(hyp_meta_entry['col_id'])
                else:
                    action = GenTokenAction(hyp_meta_entry['token'])
                    if 'token_pos' in hyp_meta_entry:
                        action_info.copy_from_src = True
                        action_info.src_token_position = hyp_meta_entry['token_pos']

                action_info.action = action
                action_info.t = t

                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                t += 1
            else: break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
Example #4
0
    def parse_sample(self, src_sent):
        # get samples of q_{\phi}(z|x) by sampling
        args = self.args
        sample_size = self.sample_size
        primitive_vocab = self.vocab.primitive

        src_sent_var = nn_utils.to_input_variable([src_sent],
                                                  self.vocab.source,
                                                  cuda=args.cuda,
                                                  training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        src_encodings, (last_state,
                        last_cell) = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        h_tm1 = self.init_decoder_state(last_state, last_cell)
        zero_action_embed = Variable(
            self.new_tensor(args.action_embed_size).zero_())

        hyp_scores = Variable(self.new_tensor([0.]), volatile=True)

        src_token_vocab_ids = [primitive_vocab[token] for token in src_sent]
        src_unk_pos_list = [
            pos for pos, token_id in enumerate(src_token_vocab_ids)
            if token_id == primitive_vocab.unk_id
        ]
        # sometimes a word may appear multi-times in the source, in this case,
        # we just copy its first appearing position. Therefore we mask the words
        # appearing second and onwards to -1
        token_set = set()
        for i, tid in enumerate(src_token_vocab_ids):
            if tid in token_set:
                src_token_vocab_ids[i] = -1
            else:
                token_set.add(tid)

        completed_hypotheses = []
        while len(completed_hypotheses) < sample_size:
            t = 0
            hypotheses = [DecodeHypothesis()]
            hyp_states = [[]]
            while t < args.decode_max_time_step:
                hyp_num = len(hypotheses)

                # (hyp_num, src_sent_len, hidden_size * 2)
                exp_src_encodings = src_encodings.expand(
                    hyp_num, src_encodings.size(1), src_encodings.size(2))
                # (hyp_num, src_sent_len, hidden_size)
                exp_src_encodings_att_linear = src_encodings_att_linear.expand(
                    hyp_num, src_encodings_att_linear.size(1),
                    src_encodings_att_linear.size(2))

                if t == 0:
                    x = Variable(self.new_tensor(
                        1, self.decoder_lstm.input_size).zero_(),
                                 volatile=True)
                    offset = args.action_embed_size * 2 + args.field_embed_size
                    x[0, offset:offset +
                      args.type_embed_size] = self.type_embed.weight[
                          self.grammar.type2id[self.grammar.root_type]]
                else:
                    actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                    a_tm1_embeds = []
                    for a_tm1 in actions_tm1:
                        if a_tm1:
                            if isinstance(a_tm1, ApplyRuleAction):
                                a_tm1_embed = self.production_embed.weight[
                                    self.grammar.prod2id[a_tm1.production]]
                            elif isinstance(a_tm1, ReduceAction):
                                a_tm1_embed = self.production_embed.weight[len(
                                    self.grammar)]
                            else:
                                a_tm1_embed = self.primitive_embed.weight[
                                    self.vocab.primitive[a_tm1.token]]

                            a_tm1_embeds.append(a_tm1_embed)
                        else:
                            a_tm1_embeds.append(zero_action_embed)
                    a_tm1_embeds = torch.stack(a_tm1_embeds)

                    # frontier production
                    frontier_prods = [
                        hyp.frontier_node.production for hyp in hypotheses
                    ]
                    frontier_prod_embeds = self.production_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ])))

                    # frontier field
                    frontier_fields = [
                        hyp.frontier_field.field for hyp in hypotheses
                    ]
                    frontier_field_embeds = self.field_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ])))

                    # frontier field type
                    frontier_field_types = [
                        hyp.frontier_field.type for hyp in hypotheses
                    ]
                    frontier_field_type_embeds = self.type_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ])))

                    # parent states
                    p_ts = [
                        hyp.frontier_node.created_time for hyp in hypotheses
                    ]
                    hist_states = torch.stack([
                        hyp_states[hyp_id][p_t]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])

                    x = torch.cat([
                        a_tm1_embeds, att_tm1, frontier_prod_embeds,
                        frontier_field_embeds, frontier_field_type_embeds,
                        hist_states
                    ],
                                  dim=-1)

                if args.lstm == 'lstm_with_dropout':
                    self.decoder_lstm.set_dropout_masks(hyp_num)

                (h_t, cell_t), att_t = self.step(x,
                                                 h_tm1,
                                                 exp_src_encodings,
                                                 exp_src_encodings_att_linear,
                                                 src_token_mask=None)

                # Variable(batch_size, grammar_size)
                apply_rule_log_prob = F.log_softmax(
                    self.production_readout(att_t), dim=-1)

                # Variable(batch_size, src_sent_len)
                primitive_copy_prob = self.src_pointer_net(
                    src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                # Variable(batch_size, primitive_vocab_size)
                gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                                dim=-1)

                # Variable(batch_size, 2)
                primitive_predictor_prob = F.softmax(
                    self.primitive_predictor(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                    1) * gen_from_vocab_prob
                if src_unk_pos_list:
                    primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

                gentoken_prev_hyp_ids = []
                gentoken_new_hyp_unks = []
                gentoken_copy_infos = []
                applyrule_new_hyp_scores = []
                applyrule_new_hyp_prod_ids = []
                applyrule_prev_hyp_ids = []

                for hyp_id, hyp in enumerate(hypotheses):
                    # generate new continuations
                    action_types = self.transition_system.get_valid_continuation_types(
                        hyp)

                    for action_type in action_types:
                        if action_type == ApplyRuleAction:
                            productions = self.transition_system.get_valid_continuating_productions(
                                hyp)
                            for production in productions:
                                prod_id = self.grammar.prod2id[production]
                                prod_score = apply_rule_log_prob[
                                    hyp_id, prod_id].data[0]
                                new_hyp_score = hyp.score + prod_score

                                applyrule_new_hyp_scores.append(new_hyp_score)
                                applyrule_new_hyp_prod_ids.append(prod_id)
                                applyrule_prev_hyp_ids.append(hyp_id)
                        elif action_type == ReduceAction:
                            action_score = apply_rule_log_prob[
                                hyp_id, len(self.grammar)].data[0]
                            new_hyp_score = hyp.score + action_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(len(
                                self.grammar))
                            applyrule_prev_hyp_ids.append(hyp_id)
                        else:
                            # GenToken action
                            gentoken_prev_hyp_ids.append(hyp_id)
                            hyp_copy_info = dict()  # of (token_pos, copy_prob)
                            # first, we compute copy probabilities for tokens in the source sentence
                            for token_pos, token_vocab_id in enumerate(
                                    src_token_vocab_ids):
                                if token_vocab_id != -1 and token_vocab_id != primitive_vocab.unk_id:
                                    p_copy = primitive_predictor_prob[
                                        hyp_id,
                                        1] * primitive_copy_prob[hyp_id,
                                                                 token_pos]
                                    primitive_prob[
                                        hyp_id,
                                        token_vocab_id] = primitive_prob[
                                            hyp_id, token_vocab_id] + p_copy

                                    token = src_sent[token_pos]
                                    hyp_copy_info[token] = (token_pos,
                                                            p_copy.data[0])

                            # second, add the probability of copying the most probable unk word
                            if src_unk_pos_list:
                                unk_pos = primitive_copy_prob[
                                    hyp_id][src_unk_pos_list].data.cpu().numpy(
                                    ).argmax()
                                unk_pos = src_unk_pos_list[unk_pos]
                                token = src_sent[unk_pos]
                                gentoken_new_hyp_unks.append(token)

                                unk_copy_score = primitive_predictor_prob[
                                    hyp_id, 1] * primitive_copy_prob[hyp_id,
                                                                     unk_pos]
                                primitive_prob[
                                    hyp_id,
                                    primitive_vocab.unk_id] = unk_copy_score

                                hyp_copy_info[token] = (unk_pos,
                                                        unk_copy_score.data[0])

                            gentoken_copy_infos.append(hyp_copy_info)

                new_hyp_scores = None
                if applyrule_new_hyp_scores:
                    new_hyp_scores = Variable(
                        self.new_tensor(applyrule_new_hyp_scores))
                if gentoken_prev_hyp_ids:
                    primitive_log_prob = torch.log(primitive_prob)
                    gen_token_new_hyp_scores = (
                        hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                        primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                    if new_hyp_scores is None:
                        new_hyp_scores = gen_token_new_hyp_scores
                    else:
                        new_hyp_scores = torch.cat(
                            [new_hyp_scores, gen_token_new_hyp_scores])

                # top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores,
                #                                                  k=min(new_hyp_scores.size(0), beam_size - len(completed_hypotheses)))

                pro = F.softmax(new_hyp_scores, 0)
                new_hyp_pos = torch.multinomial(pro, 1).data[0]
                new_hyp_score = new_hyp_scores[new_hyp_pos].data[0]

                live_hyp_ids = []
                new_hypotheses = []

                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)
                                ) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                         ) // primitive_prob.size(1)
                    # try:
                    copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word[
                                primitive_vocab.unk_id]
                    else:
                        token = primitive_vocab.id2word[token_id]

                    action = GenTokenAction(token)

                    if token in copy_info:
                        action_info.copy_from_src = True
                        action_info.src_token_position = copy_info[token][0]

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

                if live_hyp_ids:
                    hyp_states = [
                        hyp_states[i] + [h_t[i]] for i in live_hyp_ids
                    ]
                    h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                    att_tm1 = att_t[live_hyp_ids]
                    hypotheses = new_hypotheses
                    hyp_scores = Variable(
                        self.new_tensor([hyp.score for hyp in hypotheses]))
                    t += 1
                else:
                    break
        return completed_hypotheses
Example #5
0
    def parse(self, encodings, mask, h0, batch, beam_size=5):
        """ Parse one by one, batch size for each args in encodings is 1
        """
        args = self.args
        zero_action_embed = encodings.new_zeros(args.action_embed_size)
        assert encodings.size(0) == 1 and mask.size(0) == 1
        encodings, mask = encodings.repeat(beam_size, 1,
                                           1), mask.repeat(beam_size, 1)
        split_len = [
            batch.max_question_len, batch.max_table_len, batch.max_column_len
        ]
        q, tab, col = encodings.split(split_len, dim=1)
        q_mask, tab_mask, col_mask = mask.split(split_len, dim=1)
        if args.sep_cxt:
            encodings, mask = q, q_mask
            schema, schema_mask = torch.cat([tab, col], dim=1), torch.cat(
                [tab_mask, col_mask], dim=1)
            schema_context, _ = self.schema_attn(schema[:1], h0,
                                                 schema_mask[:1])
        context, _ = self.context_attn(encodings[:1], h0, mask[:1])
        h0 = h0.unsqueeze(0).repeat(args.lstm_num_layers, 1, 1)
        # h0 = h0.new_zeros(h0.size()) # init decoder with 0-vector
        h_c = (h0, h0.new_zeros(h0.size()))
        hypotheses = [DecodeHypothesis()]
        hyp_states, completed_hypotheses, t = [[]], [], 0
        while t < args.decode_max_step:
            hyp_num = len(hypotheses)
            cur_encodings, cur_mask = encodings[:hyp_num], mask[:hyp_num]
            if args.sep_cxt:
                cur_schema, cur_schema_mask = schema[:
                                                     hyp_num], schema_mask[:
                                                                           hyp_num]
            cur_tab, cur_tab_mask, cur_col, cur_col_mask = tab[:
                                                               hyp_num], tab_mask[:
                                                                                  hyp_num], col[:
                                                                                                hyp_num], col_mask[:
                                                                                                                   hyp_num]
            # x: [prev_action_embed, parent_production_embed, parent_field_embed, parent_field_type_embed, parent_state]
            if t == 0:
                x = encodings.new_zeros(hyp_num, self.decoder_lstm.input_size)
                offset = args.action_embed_size
                if not args.no_context_feeding:
                    x[:, offset:offset + args.gnn_hidden_size] = context
                    offset += args.gnn_hidden_size
                    if args.sep_cxt:
                        x[:, offset:offset +
                          args.gnn_hidden_size] = schema_context
                        offset += args.gnn_hidden_size
                offset += args.action_embed_size * int(
                    not args.no_parent_production_embed)
                offset += args.field_embed_size * int(
                    not args.no_parent_field_embed)
                if not args.no_parent_field_type_embed:
                    start_rule = torch.tensor(
                        [self.grammar.type2id[self.grammar.root_type]] *
                        hyp_num,
                        dtype=torch.long,
                        device=x.device)
                    x[:, offset:offset +
                      args.type_embed_size] = self.type_embed(start_rule)
            else:
                prev_action_embed = []
                for e_id, hyp in enumerate(hypotheses):
                    prev_action = hyp.actions[-1]
                    if isinstance(prev_action, ApplyRuleAction):
                        prev_action_embed.append(self.production_embed.weight[
                            self.grammar.prod2id[prev_action.production]])
                    elif isinstance(prev_action, ReduceAction):
                        prev_action_embed.append(
                            self.production_embed.weight[len(self.grammar)])
                    elif isinstance(
                            prev_action, SelectColumnAction
                    ):  # need to first map schema item into prod embed space
                        col_embed = self.column_lstm_input(
                            cur_col[e_id, prev_action.column_id])
                        prev_action_embed.append(col_embed)
                    elif isinstance(prev_action, SelectTableAction):
                        tab_embed = self.table_lstm_input(
                            cur_tab[e_id, prev_action.table_id])
                        prev_action_embed.append(tab_embed)
                    else:
                        raise ValueError(
                            'Unrecognized previous action object!')
                inputs = [torch.stack(prev_action_embed)]
                if not args.no_context_feeding:
                    inputs.append(context)
                    if args.sep_cxt:
                        inputs.append(schema_context)
                if not args.no_parent_production_embed:
                    frontier_prods = [
                        hyp.frontier_node.production for hyp in hypotheses
                    ]
                    frontier_prod_embeds = self.production_embed(
                        torch.tensor([
                            self.grammar.prod2id[prod]
                            for prod in frontier_prods
                        ],
                                     dtype=torch.long,
                                     device=encodings.device))
                    inputs.append(frontier_prod_embeds)
                if not args.no_parent_field_embed:
                    frontier_fields = [
                        hyp.frontier_field.field for hyp in hypotheses
                    ]
                    frontier_field_embeds = self.field_embed(
                        torch.tensor([
                            self.grammar.field2id[field]
                            for field in frontier_fields
                        ],
                                     dtype=torch.long,
                                     device=encodings.device))
                    inputs.append(frontier_field_embeds)
                if not args.no_parent_field_type_embed:
                    frontier_field_types = [
                        hyp.frontier_field.type for hyp in hypotheses
                    ]
                    frontier_field_type_embeds = self.type_embed(
                        torch.tensor([
                            self.grammar.type2id[tp]
                            for tp in frontier_field_types
                        ],
                                     dtype=torch.long,
                                     device=encodings.device))
                    inputs.append(frontier_field_type_embeds)
                if not args.no_parent_state:
                    p_ts = [
                        hyp.frontier_node.created_time for hyp in hypotheses
                    ]
                    parent_states = torch.stack([
                        hyp_states[hyp_id][p_t]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])
                    inputs.append(parent_states)
                x = torch.cat(inputs, dim=-1)  # hyp_num x input_size

            # advance decoder lstm and attention calculation
            out, (h_t, c_t) = self.decoder_lstm(x.unsqueeze(1), h_c)
            out = out.squeeze(1)
            context, _ = self.context_attn(cur_encodings, out, cur_mask)
            if args.sep_cxt:
                schema_context, _ = self.schema_attn(cur_schema, out,
                                                     cur_schema_mask)
                att_vec = self.att_vec_linear(
                    torch.cat([out, context, schema_context],
                              dim=-1))  # hyp_num x args.att_vec_size
            else:
                att_vec = self.att_vec_linear(torch.cat([out, context],
                                                        dim=-1))

            # action logprobs
            apply_rule_logprob = F.log_softmax(
                self.apply_rule(self.apply_rule_affine(att_vec)),
                dim=-1)  # remaining_sents*beam x prod_num
            _, select_tab_prob = self.select_table(cur_tab, att_vec,
                                                   cur_tab_mask)
            select_tab_logprob = torch.log(select_tab_prob + 1e-32)
            _, select_col_prob = self.select_column(cur_col, att_vec,
                                                    cur_col_mask)
            select_col_logprob = torch.log(select_col_prob + 1e-32)

            new_hyp_meta = []
            for hyp_id, hyp in enumerate(hypotheses):
                action_types = self.transition_system.get_valid_continuation_types(
                    hyp)
                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(
                            hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_logprob[hyp_id, prod_id]
                            new_hyp_score = hyp.score + prod_score
                            meta_entry = {
                                'action_type': 'apply_rule',
                                'prod_id': prod_id,
                                'score': prod_score,
                                'new_hyp_score': new_hyp_score,
                                'prev_hyp_id': hyp_id
                            }
                            new_hyp_meta.append(meta_entry)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_logprob[hyp_id,
                                                          len(self.grammar)]
                        new_hyp_score = hyp.score + action_score
                        meta_entry = {
                            'action_type': 'apply_rule',
                            'prod_id': len(self.grammar),
                            'score': action_score,
                            'new_hyp_score': new_hyp_score,
                            'prev_hyp_id': hyp_id
                        }
                        new_hyp_meta.append(meta_entry)
                    elif action_type == SelectColumnAction:
                        col_num = cur_col_mask[hyp_id].int().sum().item()
                        for col_id in range(col_num):
                            col_sel_score = select_col_logprob[hyp_id, col_id]
                            new_hyp_score = hyp.score + col_sel_score
                            meta_entry = {
                                'action_type': 'sel_col',
                                'col_id': col_id,
                                'score': col_sel_score,
                                'new_hyp_score': new_hyp_score,
                                'prev_hyp_id': hyp_id
                            }
                            new_hyp_meta.append(meta_entry)
                    elif action_type == SelectTableAction:
                        tab_num = cur_tab_mask[hyp_id].int().sum().item()
                        for tab_id in range(tab_num):
                            tab_sel_score = select_tab_logprob[hyp_id, tab_id]
                            new_hyp_score = hyp.score + tab_sel_score
                            meta_entry = {
                                'action_type': 'sel_tab',
                                'tab_id': tab_id,
                                'score': tab_sel_score,
                                'new_hyp_score': new_hyp_score,
                                'prev_hyp_id': hyp_id
                            }
                            new_hyp_meta.append(meta_entry)
                    else:
                        raise ValueError(
                            'Unrecognized action type while decoding!')
            if not new_hyp_meta: break

            # rank and pick
            new_hyp_scores = torch.stack(
                [x['new_hyp_score'] for x in new_hyp_meta])
            top_new_hyp_scores, meta_ids = new_hyp_scores.topk(
                min(beam_size, new_hyp_scores.size(0)))

            live_hyp_ids, new_hypotheses = [], []
            for new_hyp_score, meta_id in zip(top_new_hyp_scores.tolist(),
                                              meta_ids.tolist()):
                action_info = ActionInfo()
                hyp_meta_entry = new_hyp_meta[meta_id]
                prev_hyp_id = hyp_meta_entry['prev_hyp_id']
                prev_hyp = hypotheses[prev_hyp_id]
                action_type_str = hyp_meta_entry['action_type']
                if action_type_str == 'apply_rule':
                    prod_id = hyp_meta_entry['prod_id']
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    else:
                        action = ReduceAction()
                elif action_type_str == 'sel_col':
                    action = SelectColumnAction(hyp_meta_entry['col_id'])
                elif action_type_str == 'sel_tab':
                    action = SelectTableAction(hyp_meta_entry['tab_id'])
                else:
                    raise ValueError('Unrecognized action type str!')

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                if not args.no_parent_state:
                    hyp_states = [
                        hyp_states[i] + [h_t[-1, i]] for i in live_hyp_ids
                    ]
                h_c = (h_t[:, live_hyp_ids], c_t[:, live_hyp_ids])
                if not args.no_context_feeding:
                    context = context[live_hyp_ids]
                    if args.sep_cxt:
                        schema_context = schema_context[live_hyp_ids]
                hypotheses = new_hypotheses
                t += 1
            else:
                break

            # for idx, hyp in enumerate(new_hypotheses):
            # print('Idx %d' % (idx))
            # print('Tree:', hyp.tree.to_string())
            # print('Action:', hyp.action_infos[-1])
            # print('============================================')
            # input('********************Next parse step ....************************')

        if len(completed_hypotheses) == 0:  # no completed sql
            completed_hypotheses.append(DecodeHypothesis())
        else:
            completed_hypotheses.sort(
                key=lambda hyp: -hyp.score)  # / hyp.tree.size
        return completed_hypotheses
    def parse(self, src_sent, context=None, beam_size=5, debug=False):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        with torch.no_grad():
            args = self.args
            primitive_vocab = self.vocab.primitive

            src_sent_var = nn_utils.to_input_variable([src_sent],
                                                      self.vocab.source,
                                                      cuda=args.cuda,
                                                      training=False)

            # Variable(1, src_sent_len, hidden_size)
            src_encodings = self.encode(src_sent_var)

            zero_action_embed = torch.zeros(args.action_embed_size)

            hyp_scores = torch.tensor([0.0])

            # For computing copy probabilities, we marginalize over tokens with the same surface form
            # `aggregated_primitive_tokens` stores the position of occurrence of each source token
            aggregated_primitive_tokens = OrderedDict()
            for token_pos, token in enumerate(src_sent):
                aggregated_primitive_tokens.setdefault(token,
                                                       []).append(token_pos)

            t = 0
            hypotheses = [DecodeHypothesis()]
            completed_hypotheses = []

            while len(completed_hypotheses
                      ) < beam_size and t < args.decode_max_time_step:
                hyp_num = len(hypotheses)

                # (hyp_num, src_sent_len, hidden_size)
                exp_src_encodings = src_encodings.expand(
                    hyp_num, src_encodings.size(1), src_encodings.size(2))

                if t == 0:
                    x = torch.zeros(1, self.d_model)
                    parent_ids = np.array([[0]])
                    if args.no_parent_field_type_embed is False:
                        offset = self.args.action_embed_size  # prev_action
                        offset += self.args.action_embed_size * (
                            not self.args.no_parent_production_embed)
                        offset += self.args.field_embed_size * (
                            not self.args.no_parent_field_embed)

                        x[0, offset:offset +
                          self.type_embed_size] = self.type_embed(
                              torch.tensor(self.grammar.type2id[
                                  self.grammar.root_type]))
                        x = x.unsqueeze(-2)
                else:
                    actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                    a_tm1_embeds = []
                    for a_tm1 in actions_tm1:
                        if a_tm1:
                            if isinstance(a_tm1, ApplyRuleAction):
                                a_tm1_embed = self.production_embed(
                                    torch.tensor(self.grammar.prod2id[
                                        a_tm1.production]))
                            elif isinstance(a_tm1, ReduceAction):
                                a_tm1_embed = self.production_embed(
                                    torch.tensor(len(self.grammar)))
                            else:
                                a_tm1_embed = self.primitive_embed(
                                    torch.tensor(
                                        self.vocab.primitive[a_tm1.token]))

                            a_tm1_embeds.append(a_tm1_embed)
                        else:
                            a_tm1_embeds.append(zero_action_embed)
                    a_tm1_embeds = torch.stack(a_tm1_embeds)

                    inputs = [a_tm1_embeds]
                    if args.no_parent_production_embed is False:
                        # frontier production
                        frontier_prods = [
                            hyp.frontier_node.production for hyp in hypotheses
                        ]
                        frontier_prod_embeds = self.production_embed(
                            torch.tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ],
                                         dtype=torch.long))
                        inputs.append(frontier_prod_embeds)
                    if args.no_parent_field_embed is False:
                        # frontier field
                        frontier_fields = [
                            hyp.frontier_field.field for hyp in hypotheses
                        ]
                        frontier_field_embeds = self.field_embed(
                            torch.tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ],
                                         dtype=torch.long))

                        inputs.append(frontier_field_embeds)
                    if args.no_parent_field_type_embed is False:
                        # frontier field type
                        frontier_field_types = [
                            hyp.frontier_field.type for hyp in hypotheses
                        ]
                        frontier_field_type_embeds = self.type_embed(
                            torch.tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ],
                                         dtype=torch.long))
                        inputs.append(frontier_field_type_embeds)

                    x = torch.cat(
                        [x, torch.cat(inputs, dim=-1).unsqueeze(-2)], dim=1)
                    recent_parents = np.array(
                        [[hyp.frontier_node.created_time]
                         if hyp.frontier_node else 0 for hyp in hypotheses])
                    parent_ids = np.hstack([parent_ids, recent_parents])

                src_mask = torch.ones(
                    exp_src_encodings.shape[:-1],
                    dtype=torch.uint8,
                    device=exp_src_encodings.device).unsqueeze(-2)
                tgt_mask = subsequent_mask(x.shape[-2])

                att_t = self.decoder(x, exp_src_encodings, [parent_ids],
                                     src_mask, tgt_mask)[:, -1]

                # Variable(batch_size, grammar_size)
                # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
                apply_rule_log_prob = F.log_softmax(
                    self.production_readout(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                                dim=-1)

                if args.no_copy:
                    primitive_prob = gen_from_vocab_prob
                else:
                    # Variable(batch_size, src_sent_len)
                    primitive_copy_prob = self.src_pointer_net(
                        src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                    # Variable(batch_size, 2)
                    primitive_predictor_prob = F.softmax(
                        self.primitive_predictor(att_t), dim=-1)

                    # Variable(batch_size, primitive_vocab_size)
                    primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                        1) * gen_from_vocab_prob

                    # if src_unk_pos_list:
                    #     primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

                gentoken_prev_hyp_ids = []
                gentoken_new_hyp_unks = []
                applyrule_new_hyp_scores = []
                applyrule_new_hyp_prod_ids = []
                applyrule_prev_hyp_ids = []

                for hyp_id, hyp in enumerate(hypotheses):
                    # generate new continuations
                    action_types = self.transition_system.get_valid_continuation_types(
                        hyp)

                    for action_type in action_types:
                        if action_type == ApplyRuleAction:
                            productions = self.transition_system.get_valid_continuating_productions(
                                hyp)
                            for production in productions:
                                prod_id = self.grammar.prod2id[production]
                                prod_score = apply_rule_log_prob[
                                    hyp_id, prod_id].item()
                                new_hyp_score = hyp.score + prod_score

                                applyrule_new_hyp_scores.append(new_hyp_score)
                                applyrule_new_hyp_prod_ids.append(prod_id)
                                applyrule_prev_hyp_ids.append(hyp_id)
                        elif action_type == ReduceAction:
                            action_score = apply_rule_log_prob[
                                hyp_id, len(self.grammar)].item()
                            new_hyp_score = hyp.score + action_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(len(
                                self.grammar))
                            applyrule_prev_hyp_ids.append(hyp_id)
                        else:
                            # GenToken action
                            gentoken_prev_hyp_ids.append(hyp_id)
                            hyp_copy_info = dict()  # of (token_pos, copy_prob)
                            hyp_unk_copy_info = []

                            if args.no_copy is False:
                                for (token, token_pos_list
                                     ) in aggregated_primitive_tokens.items():
                                    sum_copy_prob = torch.gather(
                                        primitive_copy_prob[hyp_id], 0,
                                        torch.tensor(token_pos_list,
                                                     dtype=torch.long)).sum()
                                    gated_copy_prob = primitive_predictor_prob[
                                        hyp_id, 1] * sum_copy_prob

                                    if token in primitive_vocab:
                                        token_id = primitive_vocab[token]
                                        primitive_prob[hyp_id, token_id] = (
                                            primitive_prob[hyp_id, token_id] +
                                            gated_copy_prob)

                                        hyp_copy_info[token] = (
                                            token_pos_list,
                                            gated_copy_prob.item())
                                    else:
                                        hyp_unk_copy_info.append({
                                            "token":
                                            token,
                                            "token_pos_list":
                                            token_pos_list,
                                            "copy_prob":
                                            gated_copy_prob.item(),
                                        })

                            if args.no_copy is False and len(
                                    hyp_unk_copy_info) > 0:
                                unk_i = np.array([
                                    x["copy_prob"] for x in hyp_unk_copy_info
                                ]).argmax()
                                token = hyp_unk_copy_info[unk_i]["token"]
                                primitive_prob[hyp_id, primitive_vocab.
                                               unk_id] = hyp_unk_copy_info[
                                                   unk_i]["copy_prob"]
                                gentoken_new_hyp_unks.append(token)

                                hyp_copy_info[token] = (
                                    hyp_unk_copy_info[unk_i]["token_pos_list"],
                                    hyp_unk_copy_info[unk_i]["copy_prob"],
                                )

                new_hyp_scores = None
                if applyrule_new_hyp_scores:
                    new_hyp_scores = torch.tensor(applyrule_new_hyp_scores)
                if gentoken_prev_hyp_ids:
                    primitive_log_prob = torch.log(primitive_prob)
                    gen_token_new_hyp_scores = (
                        hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                        primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                    if new_hyp_scores is None:
                        new_hyp_scores = gen_token_new_hyp_scores
                    else:
                        new_hyp_scores = torch.cat(
                            [new_hyp_scores, gen_token_new_hyp_scores])

                top_new_hyp_scores, top_new_hyp_pos = torch.topk(
                    new_hyp_scores,
                    k=min(new_hyp_scores.size(0),
                          beam_size - len(completed_hypotheses)))

                live_hyp_ids = []
                new_hypotheses = []
                for new_hyp_score, new_hyp_pos in zip(
                        top_new_hyp_scores.data.cpu(),
                        top_new_hyp_pos.data.cpu()):
                    action_info = ActionInfo()
                    if new_hyp_pos < len(applyrule_new_hyp_scores):
                        # it's an ApplyRule or Reduce action
                        prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                        prev_hyp = hypotheses[prev_hyp_id]

                        prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                        # ApplyRule action
                        if prod_id < len(self.grammar):
                            production = self.grammar.id2prod[prod_id]
                            action = ApplyRuleAction(production)
                        # Reduce action
                        else:
                            action = ReduceAction()
                    else:
                        # it's a GenToken action
                        token_id = int(
                            (new_hyp_pos - len(applyrule_new_hyp_scores)) %
                            primitive_prob.size(1))

                        k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                             ) // primitive_prob.size(1)
                        # try:
                        # copy_info = gentoken_copy_infos[k]
                        prev_hyp_id = gentoken_prev_hyp_ids[k]
                        prev_hyp = hypotheses[prev_hyp_id]
                        # except:
                        #     print('k=%d' % k, file=sys.stderr)
                        #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                        #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                        #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                        #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                        #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                        #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                        #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                        #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                        #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                        #
                        #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                        #
                        #     # exit(-1)
                        #     raise ValueError()

                        if token_id == int(primitive_vocab.unk_id):
                            if gentoken_new_hyp_unks:
                                token = gentoken_new_hyp_unks[k]
                            else:
                                token = primitive_vocab.id2word[
                                    primitive_vocab.unk_id]
                        else:
                            token = primitive_vocab.id2word[token_id]

                        action = GenTokenAction(token)

                        if token in aggregated_primitive_tokens:
                            action_info.copy_from_src = True
                            action_info.src_token_position = aggregated_primitive_tokens[
                                token]

                        if debug:
                            action_info.gen_copy_switch = (
                                "n/a"
                                if args.no_copy else primitive_predictor_prob[
                                    prev_hyp_id, :].log().cpu().data.numpy())
                            action_info.in_vocab = token in primitive_vocab
                            action_info.gen_token_prob = (
                                gen_from_vocab_prob[prev_hyp_id,
                                                    token_id].log().cpu().
                                item() if token in primitive_vocab else "n/a")
                            action_info.copy_token_prob = (
                                torch.gather(
                                    primitive_copy_prob[prev_hyp_id],
                                    0,
                                    torch.tensor(
                                        action_info.src_token_position,
                                        dtype=torch.long,
                                        device=self.device),
                                ).sum().log().cpu().item()
                                if args.no_copy is False
                                and action_info.copy_from_src else "n/a")

                    action_info.action = action
                    action_info.t = t
                    if t > 0:
                        action_info.parent_t = prev_hyp.frontier_node.created_time
                        action_info.frontier_prod = prev_hyp.frontier_node.production
                        action_info.frontier_field = prev_hyp.frontier_field.field

                    if debug:
                        action_info.action_prob = new_hyp_score - prev_hyp.score

                    new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                    new_hyp.score = new_hyp_score

                    if new_hyp.completed:
                        completed_hypotheses.append(new_hyp)
                    else:
                        new_hypotheses.append(new_hyp)
                        live_hyp_ids.append(prev_hyp_id)

                if live_hyp_ids:
                    x = x[live_hyp_ids]
                    parent_ids = parent_ids[live_hyp_ids]
                    hypotheses = new_hypotheses
                    hyp_scores = torch.tensor(
                        [hyp.score for hyp in hypotheses])
                    t += 1
                else:
                    break

            completed_hypotheses.sort(key=lambda hyp: -hyp.score)

            return completed_hypotheses