Пример #1
0
    def parse(self, src_sent, context=None, beam_size=5, debug=False):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        args = self.args
        primitive_vocab = self.vocab.primitive
        T = torch.cuda if args.cuda else torch

        src_sent_var = nn_utils.to_input_variable([src_sent],
                                                  self.vocab.source,
                                                  cuda=args.cuda,
                                                  training=False)

        # Variable(1, src_sent_len, hidden_size * 2)
        ##########src_encodings, (last_state, last_cell) = self.encode(src_sent_var, [len(src_sent)])
        src_encodings, last_state = self.encode(src_sent_var, [len(src_sent)])
        # (1, src_sent_len, hidden_size)
        src_encodings_att_linear = self.att_src_linear(src_encodings)

        #######dec_init_vec = self.init_decoder_state(last_state, last_cell)
        dec_init_vec = self.init_decoder_state(last_state)  #####
        if args.lstm == 'parent_feed':
            h_tm1 = dec_init_vec[0], dec_init_vec[1], \
                    Variable(self.new_tensor(args.hidden_size).zero_()), \
                    Variable(self.new_tensor(args.hidden_size).zero_())
        else:
            h_tm1 = dec_init_vec

        zero_action_embed = Variable(
            self.new_tensor(args.action_embed_size).zero_())

        with torch.no_grad():
            hyp_scores = Variable(self.new_tensor([0.]))

        # For computing copy probabilities, we marginalize over tokens with the same surface form
        # `aggregated_primitive_tokens` stores the position of occurrence of each source token
        aggregated_primitive_tokens = OrderedDict()
        for token_pos, token in enumerate(src_sent):
            aggregated_primitive_tokens.setdefault(token, []).append(token_pos)

        t = 0
        hypotheses = [DecodeHypothesis()]
        hyp_states = [[]]
        completed_hypotheses = []

        while len(completed_hypotheses
                  ) < beam_size and t < args.decode_max_time_step:
            hyp_num = len(hypotheses)

            # (hyp_num, src_sent_len, hidden_size * 2)
            exp_src_encodings = src_encodings.expand(hyp_num,
                                                     src_encodings.size(1),
                                                     src_encodings.size(2))
            # (hyp_num, src_sent_len, hidden_size)
            exp_src_encodings_att_linear = src_encodings_att_linear.expand(
                hyp_num, src_encodings_att_linear.size(1),
                src_encodings_att_linear.size(2))

            if t == 0:
                with torch.no_grad():
                    x = Variable(
                        self.new_tensor(1,
                                        self.decoder_lstm.input_size).zero_())
                if args.no_parent_field_type_embed is False:
                    offset = args.action_embed_size  # prev_action
                    offset += args.att_vec_size * (not args.no_input_feed)
                    offset += args.action_embed_size * (
                        not args.no_parent_production_embed)
                    offset += args.field_embed_size * (
                        not args.no_parent_field_embed)

                    x[0, offset: offset + args.type_embed_size] = \
                        self.type_embed.weight[self.grammar.type2id[self.grammar.root_type]]
            else:
                actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                a_tm1_embeds = []
                for a_tm1 in actions_tm1:
                    if a_tm1:
                        if isinstance(a_tm1, ApplyRuleAction):
                            a_tm1_embed = self.production_embed.weight[
                                self.grammar.prod2id[a_tm1.production]]
                        elif isinstance(a_tm1, ReduceAction):
                            a_tm1_embed = self.production_embed.weight[len(
                                self.grammar)]
                        else:
                            a_tm1_embed = self.primitive_embed.weight[
                                self.vocab.primitive[a_tm1.token]]

                        a_tm1_embeds.append(a_tm1_embed)
                    else:
                        a_tm1_embeds.append(zero_action_embed)
                a_tm1_embeds = torch.stack(a_tm1_embeds)

                inputs = [a_tm1_embeds]
                if args.no_input_feed is False:
                    inputs.append(att_tm1)
                if args.no_parent_production_embed is False:
                    # frontier production
                    frontier_prods = [
                        hyp.frontier_node.production for hyp in hypotheses
                    ]
                    frontier_prod_embeds = self.production_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ])))
                    inputs.append(frontier_prod_embeds)
                if args.no_parent_field_embed is False:
                    # frontier field
                    frontier_fields = [
                        hyp.frontier_field.field for hyp in hypotheses
                    ]
                    frontier_field_embeds = self.field_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ])))

                    inputs.append(frontier_field_embeds)
                if args.no_parent_field_type_embed is False:
                    # frontier field type
                    frontier_field_types = [
                        hyp.frontier_field.type for hyp in hypotheses
                    ]
                    frontier_field_type_embeds = self.type_embed(
                        Variable(
                            self.new_long_tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ])))
                    inputs.append(frontier_field_type_embeds)

                # parent states
                if args.no_parent_state is False:
                    p_ts = [
                        hyp.frontier_node.created_time for hyp in hypotheses
                    ]
                    parent_states = torch.stack([
                        hyp_states[hyp_id][p_t][0]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])
                    parent_cells = torch.stack([
                        hyp_states[hyp_id][p_t][1]
                        for hyp_id, p_t in enumerate(p_ts)
                    ])

                    if args.lstm == 'parent_feed':
                        h_tm1 = (h_tm1[0], h_tm1[1], parent_states,
                                 parent_cells)
                    else:
                        inputs.append(parent_states)

                x = torch.cat(inputs, dim=-1)

            (h_t, cell_t), att_t = self.step(x,
                                             h_tm1,
                                             exp_src_encodings,
                                             exp_src_encodings_att_linear,
                                             src_token_mask=None)

            # Variable(batch_size, grammar_size)
            # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
            apply_rule_log_prob = F.log_softmax(self.production_readout(att_t),
                                                dim=-1)

            # Variable(batch_size, primitive_vocab_size)
            gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                            dim=-1)

            if args.no_copy:
                primitive_prob = gen_from_vocab_prob
            else:
                # Variable(batch_size, src_sent_len)
                primitive_copy_prob = self.src_pointer_net(
                    src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                # Variable(batch_size, 2)
                primitive_predictor_prob = F.softmax(
                    self.primitive_predictor(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                    1) * gen_from_vocab_prob

                # if src_unk_pos_list:
                #     primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

            gentoken_prev_hyp_ids = []
            gentoken_new_hyp_unks = []
            applyrule_new_hyp_scores = []
            applyrule_new_hyp_prod_ids = []
            applyrule_prev_hyp_ids = []

            for hyp_id, hyp in enumerate(hypotheses):
                # generate new continuations
                action_types = self.transition_system.get_valid_continuation_types(
                    hyp)

                for action_type in action_types:
                    if action_type == ApplyRuleAction:
                        productions = self.transition_system.get_valid_continuating_productions(
                            hyp)
                        for production in productions:
                            prod_id = self.grammar.prod2id[production]
                            prod_score = apply_rule_log_prob[
                                hyp_id, prod_id].data.item()
                            new_hyp_score = hyp.score + prod_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(prod_id)
                            applyrule_prev_hyp_ids.append(hyp_id)
                    elif action_type == ReduceAction:
                        action_score = apply_rule_log_prob[
                            hyp_id, len(self.grammar)].data.item()
                        new_hyp_score = hyp.score + action_score

                        applyrule_new_hyp_scores.append(new_hyp_score)
                        applyrule_new_hyp_prod_ids.append(len(self.grammar))
                        applyrule_prev_hyp_ids.append(hyp_id)
                    else:
                        # GenToken action
                        gentoken_prev_hyp_ids.append(hyp_id)
                        hyp_copy_info = dict()  # of (token_pos, copy_prob)
                        hyp_unk_copy_info = []

                        if args.no_copy is False:
                            for token, token_pos_list in aggregated_primitive_tokens.items(
                            ):
                                sum_copy_prob = torch.gather(
                                    primitive_copy_prob[hyp_id], 0,
                                    Variable(
                                        T.LongTensor(token_pos_list))).sum()
                                gated_copy_prob = primitive_predictor_prob[
                                    hyp_id, 1] * sum_copy_prob

                                if token in primitive_vocab:
                                    token_id = primitive_vocab[token]
                                    primitive_prob[
                                        hyp_id, token_id] = primitive_prob[
                                            hyp_id, token_id] + gated_copy_prob

                                    hyp_copy_info[token] = (
                                        token_pos_list,
                                        gated_copy_prob.data.item())
                                else:
                                    hyp_unk_copy_info.append({
                                        'token':
                                        token,
                                        'token_pos_list':
                                        token_pos_list,
                                        'copy_prob':
                                        gated_copy_prob.data.item()
                                    })

                        if args.no_copy is False and len(
                                hyp_unk_copy_info) > 0:
                            unk_i = np.array([
                                x['copy_prob'] for x in hyp_unk_copy_info
                            ]).argmax()
                            token = hyp_unk_copy_info[unk_i]['token']
                            primitive_prob[
                                hyp_id, primitive_vocab.
                                unk_id] = hyp_unk_copy_info[unk_i]['copy_prob']
                            gentoken_new_hyp_unks.append(token)

                            hyp_copy_info[token] = (
                                hyp_unk_copy_info[unk_i]['token_pos_list'],
                                hyp_unk_copy_info[unk_i]['copy_prob'])

            new_hyp_scores = None
            if applyrule_new_hyp_scores:
                new_hyp_scores = Variable(
                    self.new_tensor(applyrule_new_hyp_scores))
            if gentoken_prev_hyp_ids:
                primitive_log_prob = torch.log(primitive_prob)
                gen_token_new_hyp_scores = (
                    hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                    primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                if new_hyp_scores is None:
                    new_hyp_scores = gen_token_new_hyp_scores
                else:
                    new_hyp_scores = torch.cat(
                        [new_hyp_scores, gen_token_new_hyp_scores])
            top_new_hyp_scores, top_new_hyp_pos = torch.topk(
                new_hyp_scores,
                k=min(new_hyp_scores.size(0),
                      beam_size - len(completed_hypotheses)))

            live_hyp_ids = []
            new_hypotheses = []
            for new_hyp_score, new_hyp_pos in zip(
                    top_new_hyp_scores.data.cpu(), top_new_hyp_pos.data.cpu()):
                action_info = ActionInfo()
                if new_hyp_pos < len(applyrule_new_hyp_scores):
                    # it's an ApplyRule or Reduce action
                    prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                    prev_hyp = hypotheses[prev_hyp_id]

                    prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                    # ApplyRule action
                    if prod_id < len(self.grammar):
                        production = self.grammar.id2prod[prod_id]
                        action = ApplyRuleAction(production)
                    # Reduce action
                    else:
                        action = ReduceAction()
                else:
                    # it's a GenToken action
                    token_id = (new_hyp_pos - len(applyrule_new_hyp_scores)
                                ) % primitive_prob.size(1)

                    k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                         ) // primitive_prob.size(1)
                    # try:
                    # copy_info = gentoken_copy_infos[k]
                    prev_hyp_id = gentoken_prev_hyp_ids[k]
                    prev_hyp = hypotheses[prev_hyp_id]
                    # except:
                    #     print('k=%d' % k, file=sys.stderr)
                    #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                    #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                    #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                    #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                    #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                    #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                    #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                    #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                    #
                    #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                    #
                    #     # exit(-1)
                    #     raise ValueError()

                    if token_id == primitive_vocab.unk_id:
                        if gentoken_new_hyp_unks:
                            token = gentoken_new_hyp_unks[k]
                        else:
                            token = primitive_vocab.id2word(
                                primitive_vocab.unk_id)  ######
                    else:
                        token = primitive_vocab.id2word(token_id.item())

                    action = GenTokenAction(token)

                    if token in aggregated_primitive_tokens:
                        action_info.copy_from_src = True
                        action_info.src_token_position = aggregated_primitive_tokens[
                            token]

                    if debug:
                        action_info.gen_copy_switch = 'n/a' if args.no_copy else primitive_predictor_prob[
                            prev_hyp_id, :].log().cpu().data.numpy()
                        action_info.in_vocab = token in primitive_vocab
                        action_info.gen_token_prob = gen_from_vocab_prob[prev_hyp_id, token_id].log().cpu().data.item() \
                            if token in primitive_vocab else 'n/a'
                        action_info.copy_token_prob = torch.gather(primitive_copy_prob[prev_hyp_id],
                                                                   0,
                                                                   Variable(T.LongTensor(action_info.src_token_position))).sum().log().cpu().data.item() \
                            if args.no_copy is False and action_info.copy_from_src else 'n/a'

                action_info.action = action
                action_info.t = t
                if t > 0:
                    action_info.parent_t = prev_hyp.frontier_node.created_time
                    action_info.frontier_prod = prev_hyp.frontier_node.production
                    action_info.frontier_field = prev_hyp.frontier_field.field

                if debug:
                    action_info.action_prob = new_hyp_score - prev_hyp.score

                new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                new_hyp.score = new_hyp_score

                if new_hyp.completed:
                    # add length normalization
                    new_hyp.score /= (t + 1)
                    completed_hypotheses.append(new_hyp)
                else:
                    new_hypotheses.append(new_hyp)
                    live_hyp_ids.append(prev_hyp_id)

            if live_hyp_ids:
                hyp_states = [
                    hyp_states[i] + [(h_t[i], cell_t[i])] for i in live_hyp_ids
                ]
                h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
                att_tm1 = att_t[live_hyp_ids]
                hypotheses = new_hypotheses
                hyp_scores = Variable(
                    self.new_tensor([hyp.score for hyp in hypotheses]))
                t += 1
            else:
                break

        completed_hypotheses.sort(key=lambda hyp: -hyp.score)

        return completed_hypotheses
Пример #2
0
    def parse(self, src_sent, context=None, beam_size=5, debug=False):
        """Perform beam search to infer the target AST given a source utterance

        Args:
            src_sent: list of source utterance tokens
            context: other context used for prediction
            beam_size: beam size

        Returns:
            A list of `DecodeHypothesis`, each representing an AST
        """

        with torch.no_grad():
            args = self.args
            primitive_vocab = self.vocab.primitive

            src_sent_var = nn_utils.to_input_variable([src_sent],
                                                      self.vocab.source,
                                                      cuda=args.cuda,
                                                      training=False)

            # Variable(1, src_sent_len, hidden_size)
            src_encodings = self.encode(src_sent_var)

            zero_action_embed = torch.zeros(args.action_embed_size)

            hyp_scores = torch.tensor([0.0])

            # For computing copy probabilities, we marginalize over tokens with the same surface form
            # `aggregated_primitive_tokens` stores the position of occurrence of each source token
            aggregated_primitive_tokens = OrderedDict()
            for token_pos, token in enumerate(src_sent):
                aggregated_primitive_tokens.setdefault(token,
                                                       []).append(token_pos)

            t = 0
            hypotheses = [DecodeHypothesis()]
            completed_hypotheses = []

            while len(completed_hypotheses
                      ) < beam_size and t < args.decode_max_time_step:
                hyp_num = len(hypotheses)

                # (hyp_num, src_sent_len, hidden_size)
                exp_src_encodings = src_encodings.expand(
                    hyp_num, src_encodings.size(1), src_encodings.size(2))

                if t == 0:
                    x = torch.zeros(1, self.d_model)
                    parent_ids = np.array([[0]])
                    if args.no_parent_field_type_embed is False:
                        offset = self.args.action_embed_size  # prev_action
                        offset += self.args.action_embed_size * (
                            not self.args.no_parent_production_embed)
                        offset += self.args.field_embed_size * (
                            not self.args.no_parent_field_embed)

                        x[0, offset:offset +
                          self.type_embed_size] = self.type_embed(
                              torch.tensor(self.grammar.type2id[
                                  self.grammar.root_type]))
                        x = x.unsqueeze(-2)
                else:
                    actions_tm1 = [hyp.actions[-1] for hyp in hypotheses]

                    a_tm1_embeds = []
                    for a_tm1 in actions_tm1:
                        if a_tm1:
                            if isinstance(a_tm1, ApplyRuleAction):
                                a_tm1_embed = self.production_embed(
                                    torch.tensor(self.grammar.prod2id[
                                        a_tm1.production]))
                            elif isinstance(a_tm1, ReduceAction):
                                a_tm1_embed = self.production_embed(
                                    torch.tensor(len(self.grammar)))
                            else:
                                a_tm1_embed = self.primitive_embed(
                                    torch.tensor(
                                        self.vocab.primitive[a_tm1.token]))

                            a_tm1_embeds.append(a_tm1_embed)
                        else:
                            a_tm1_embeds.append(zero_action_embed)
                    a_tm1_embeds = torch.stack(a_tm1_embeds)

                    inputs = [a_tm1_embeds]
                    if args.no_parent_production_embed is False:
                        # frontier production
                        frontier_prods = [
                            hyp.frontier_node.production for hyp in hypotheses
                        ]
                        frontier_prod_embeds = self.production_embed(
                            torch.tensor([
                                self.grammar.prod2id[prod]
                                for prod in frontier_prods
                            ],
                                         dtype=torch.long))
                        inputs.append(frontier_prod_embeds)
                    if args.no_parent_field_embed is False:
                        # frontier field
                        frontier_fields = [
                            hyp.frontier_field.field for hyp in hypotheses
                        ]
                        frontier_field_embeds = self.field_embed(
                            torch.tensor([
                                self.grammar.field2id[field]
                                for field in frontier_fields
                            ],
                                         dtype=torch.long))

                        inputs.append(frontier_field_embeds)
                    if args.no_parent_field_type_embed is False:
                        # frontier field type
                        frontier_field_types = [
                            hyp.frontier_field.type for hyp in hypotheses
                        ]
                        frontier_field_type_embeds = self.type_embed(
                            torch.tensor([
                                self.grammar.type2id[type]
                                for type in frontier_field_types
                            ],
                                         dtype=torch.long))
                        inputs.append(frontier_field_type_embeds)

                    x = torch.cat(
                        [x, torch.cat(inputs, dim=-1).unsqueeze(-2)], dim=1)
                    recent_parents = np.array(
                        [[hyp.frontier_node.created_time]
                         if hyp.frontier_node else 0 for hyp in hypotheses])
                    parent_ids = np.hstack([parent_ids, recent_parents])

                src_mask = torch.ones(
                    exp_src_encodings.shape[:-1],
                    dtype=torch.uint8,
                    device=exp_src_encodings.device).unsqueeze(-2)
                tgt_mask = subsequent_mask(x.shape[-2])

                att_t = self.decoder(x, exp_src_encodings, [parent_ids],
                                     src_mask, tgt_mask)[:, -1]

                # Variable(batch_size, grammar_size)
                # apply_rule_log_prob = torch.log(F.softmax(self.production_readout(att_t), dim=-1))
                apply_rule_log_prob = F.log_softmax(
                    self.production_readout(att_t), dim=-1)

                # Variable(batch_size, primitive_vocab_size)
                gen_from_vocab_prob = F.softmax(self.tgt_token_readout(att_t),
                                                dim=-1)

                if args.no_copy:
                    primitive_prob = gen_from_vocab_prob
                else:
                    # Variable(batch_size, src_sent_len)
                    primitive_copy_prob = self.src_pointer_net(
                        src_encodings, None, att_t.unsqueeze(0)).squeeze(0)

                    # Variable(batch_size, 2)
                    primitive_predictor_prob = F.softmax(
                        self.primitive_predictor(att_t), dim=-1)

                    # Variable(batch_size, primitive_vocab_size)
                    primitive_prob = primitive_predictor_prob[:, 0].unsqueeze(
                        1) * gen_from_vocab_prob

                    # if src_unk_pos_list:
                    #     primitive_prob[:, primitive_vocab.unk_id] = 1.e-10

                gentoken_prev_hyp_ids = []
                gentoken_new_hyp_unks = []
                applyrule_new_hyp_scores = []
                applyrule_new_hyp_prod_ids = []
                applyrule_prev_hyp_ids = []

                for hyp_id, hyp in enumerate(hypotheses):
                    # generate new continuations
                    action_types = self.transition_system.get_valid_continuation_types(
                        hyp)

                    for action_type in action_types:
                        if action_type == ApplyRuleAction:
                            productions = self.transition_system.get_valid_continuating_productions(
                                hyp)
                            for production in productions:
                                prod_id = self.grammar.prod2id[production]
                                prod_score = apply_rule_log_prob[
                                    hyp_id, prod_id].item()
                                new_hyp_score = hyp.score + prod_score

                                applyrule_new_hyp_scores.append(new_hyp_score)
                                applyrule_new_hyp_prod_ids.append(prod_id)
                                applyrule_prev_hyp_ids.append(hyp_id)
                        elif action_type == ReduceAction:
                            action_score = apply_rule_log_prob[
                                hyp_id, len(self.grammar)].item()
                            new_hyp_score = hyp.score + action_score

                            applyrule_new_hyp_scores.append(new_hyp_score)
                            applyrule_new_hyp_prod_ids.append(len(
                                self.grammar))
                            applyrule_prev_hyp_ids.append(hyp_id)
                        else:
                            # GenToken action
                            gentoken_prev_hyp_ids.append(hyp_id)
                            hyp_copy_info = dict()  # of (token_pos, copy_prob)
                            hyp_unk_copy_info = []

                            if args.no_copy is False:
                                for (token, token_pos_list
                                     ) in aggregated_primitive_tokens.items():
                                    sum_copy_prob = torch.gather(
                                        primitive_copy_prob[hyp_id], 0,
                                        torch.tensor(token_pos_list,
                                                     dtype=torch.long)).sum()
                                    gated_copy_prob = primitive_predictor_prob[
                                        hyp_id, 1] * sum_copy_prob

                                    if token in primitive_vocab:
                                        token_id = primitive_vocab[token]
                                        primitive_prob[hyp_id, token_id] = (
                                            primitive_prob[hyp_id, token_id] +
                                            gated_copy_prob)

                                        hyp_copy_info[token] = (
                                            token_pos_list,
                                            gated_copy_prob.item())
                                    else:
                                        hyp_unk_copy_info.append({
                                            "token":
                                            token,
                                            "token_pos_list":
                                            token_pos_list,
                                            "copy_prob":
                                            gated_copy_prob.item(),
                                        })

                            if args.no_copy is False and len(
                                    hyp_unk_copy_info) > 0:
                                unk_i = np.array([
                                    x["copy_prob"] for x in hyp_unk_copy_info
                                ]).argmax()
                                token = hyp_unk_copy_info[unk_i]["token"]
                                primitive_prob[hyp_id, primitive_vocab.
                                               unk_id] = hyp_unk_copy_info[
                                                   unk_i]["copy_prob"]
                                gentoken_new_hyp_unks.append(token)

                                hyp_copy_info[token] = (
                                    hyp_unk_copy_info[unk_i]["token_pos_list"],
                                    hyp_unk_copy_info[unk_i]["copy_prob"],
                                )

                new_hyp_scores = None
                if applyrule_new_hyp_scores:
                    new_hyp_scores = torch.tensor(applyrule_new_hyp_scores)
                if gentoken_prev_hyp_ids:
                    primitive_log_prob = torch.log(primitive_prob)
                    gen_token_new_hyp_scores = (
                        hyp_scores[gentoken_prev_hyp_ids].unsqueeze(1) +
                        primitive_log_prob[gentoken_prev_hyp_ids, :]).view(-1)

                    if new_hyp_scores is None:
                        new_hyp_scores = gen_token_new_hyp_scores
                    else:
                        new_hyp_scores = torch.cat(
                            [new_hyp_scores, gen_token_new_hyp_scores])

                top_new_hyp_scores, top_new_hyp_pos = torch.topk(
                    new_hyp_scores,
                    k=min(new_hyp_scores.size(0),
                          beam_size - len(completed_hypotheses)))

                live_hyp_ids = []
                new_hypotheses = []
                for new_hyp_score, new_hyp_pos in zip(
                        top_new_hyp_scores.data.cpu(),
                        top_new_hyp_pos.data.cpu()):
                    action_info = ActionInfo()
                    if new_hyp_pos < len(applyrule_new_hyp_scores):
                        # it's an ApplyRule or Reduce action
                        prev_hyp_id = applyrule_prev_hyp_ids[new_hyp_pos]
                        prev_hyp = hypotheses[prev_hyp_id]

                        prod_id = applyrule_new_hyp_prod_ids[new_hyp_pos]
                        # ApplyRule action
                        if prod_id < len(self.grammar):
                            production = self.grammar.id2prod[prod_id]
                            action = ApplyRuleAction(production)
                        # Reduce action
                        else:
                            action = ReduceAction()
                    else:
                        # it's a GenToken action
                        token_id = int(
                            (new_hyp_pos - len(applyrule_new_hyp_scores)) %
                            primitive_prob.size(1))

                        k = (new_hyp_pos - len(applyrule_new_hyp_scores)
                             ) // primitive_prob.size(1)
                        # try:
                        # copy_info = gentoken_copy_infos[k]
                        prev_hyp_id = gentoken_prev_hyp_ids[k]
                        prev_hyp = hypotheses[prev_hyp_id]
                        # except:
                        #     print('k=%d' % k, file=sys.stderr)
                        #     print('primitive_prob.size(1)=%d' % primitive_prob.size(1), file=sys.stderr)
                        #     print('len copy_info=%d' % len(gentoken_copy_infos), file=sys.stderr)
                        #     print('prev_hyp_id=%s' % ', '.join(str(i) for i in gentoken_prev_hyp_ids), file=sys.stderr)
                        #     print('len applyrule_new_hyp_scores=%d' % len(applyrule_new_hyp_scores), file=sys.stderr)
                        #     print('len gentoken_prev_hyp_ids=%d' % len(gentoken_prev_hyp_ids), file=sys.stderr)
                        #     print('top_new_hyp_pos=%s' % top_new_hyp_pos, file=sys.stderr)
                        #     print('applyrule_new_hyp_scores=%s' % applyrule_new_hyp_scores, file=sys.stderr)
                        #     print('new_hyp_scores=%s' % new_hyp_scores, file=sys.stderr)
                        #     print('top_new_hyp_scores=%s' % top_new_hyp_scores, file=sys.stderr)
                        #
                        #     torch.save((applyrule_new_hyp_scores, primitive_prob), 'data.bin')
                        #
                        #     # exit(-1)
                        #     raise ValueError()

                        if token_id == int(primitive_vocab.unk_id):
                            if gentoken_new_hyp_unks:
                                token = gentoken_new_hyp_unks[k]
                            else:
                                token = primitive_vocab.id2word[
                                    primitive_vocab.unk_id]
                        else:
                            token = primitive_vocab.id2word[token_id]

                        action = GenTokenAction(token)

                        if token in aggregated_primitive_tokens:
                            action_info.copy_from_src = True
                            action_info.src_token_position = aggregated_primitive_tokens[
                                token]

                        if debug:
                            action_info.gen_copy_switch = (
                                "n/a"
                                if args.no_copy else primitive_predictor_prob[
                                    prev_hyp_id, :].log().cpu().data.numpy())
                            action_info.in_vocab = token in primitive_vocab
                            action_info.gen_token_prob = (
                                gen_from_vocab_prob[prev_hyp_id,
                                                    token_id].log().cpu().
                                item() if token in primitive_vocab else "n/a")
                            action_info.copy_token_prob = (
                                torch.gather(
                                    primitive_copy_prob[prev_hyp_id],
                                    0,
                                    torch.tensor(
                                        action_info.src_token_position,
                                        dtype=torch.long,
                                        device=self.device),
                                ).sum().log().cpu().item()
                                if args.no_copy is False
                                and action_info.copy_from_src else "n/a")

                    action_info.action = action
                    action_info.t = t
                    if t > 0:
                        action_info.parent_t = prev_hyp.frontier_node.created_time
                        action_info.frontier_prod = prev_hyp.frontier_node.production
                        action_info.frontier_field = prev_hyp.frontier_field.field

                    if debug:
                        action_info.action_prob = new_hyp_score - prev_hyp.score

                    new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
                    new_hyp.score = new_hyp_score

                    if new_hyp.completed:
                        completed_hypotheses.append(new_hyp)
                    else:
                        new_hypotheses.append(new_hyp)
                        live_hyp_ids.append(prev_hyp_id)

                if live_hyp_ids:
                    x = x[live_hyp_ids]
                    parent_ids = parent_ids[live_hyp_ids]
                    hypotheses = new_hypotheses
                    hyp_scores = torch.tensor(
                        [hyp.score for hyp in hypotheses])
                    t += 1
                else:
                    break

            completed_hypotheses.sort(key=lambda hyp: -hyp.score)

            return completed_hypotheses