コード例 #1
0
    def _score(self, ex):
        batch = Batch([ex], self.grammar, self.vocab)
        context_vecs, encoder_outputs = self.encode(batch)
        init_state = encoder_outputs

        return self._score_node(self.grammar.root_type, init_state,
                                ex.tgt_actions, context_vecs, batch.sent_masks)
コード例 #2
0
    def naive_parse(self, ex):
        batch = Batch([ex], self.grammar, self.vocab, train=False)
        context_vecs, encoder_outputs = self.encode(batch)
        init_state = encoder_outputs

        action_tree = self._naive_parse(self.grammar.root_type, init_state,
                                        context_vecs, batch.sent_masks, 1)

        return self.transition_system.build_ast_from_actions(action_tree)
コード例 #3
0
    def parse(self, ex):
        batch = Batch([ex], self.grammar, self.vocab, train=False)
        context_vecs, encoder_outputs = self.encode(batch)
        init_state = encoder_outputs

        # action_tree = self._naive_parse(self.grammar.root_type, init_state, context_vecs, batch.sent_masks, 1)

        completed_hyps = []
        cur_beam = [
            Hypothesis.init_hypothesis(self.grammar.root_type, init_state)
        ]

        for ts in range(self.args.max_decode_step):
            hyp_pools = []
            for hyp in cur_beam:
                continuations = self.continuations_of_hyp(
                    hyp, context_vecs, batch.sent_masks)
                hyp_pools.extend(continuations)

            hyp_pools.sort(key=lambda x: x.score, reverse=True)
            # next_beam = next_beam[:self.args.beam_size]

            num_slots = self.args.beam_size - len(completed_hyps)

            cur_beam = []
            for hyp_i, hyp in enumerate(hyp_pools[:num_slots]):
                if hyp.is_complete():
                    completed_hyps.append(hyp)
                else:
                    cur_beam.append(hyp)

            if not cur_beam:
                break

        completed_hyps.sort(key=lambda x: x.score, reverse=True)
        return completed_hyps
コード例 #4
0
    def score(self, examples, return_encode_state=False):
        """Given a list of examples, compute the log-likelihood of generating the target AST

        Args:
            examples: a batch of examples
            return_encode_state: return encoding states of input utterances
        output: score for each training example: Variable(batch_size)
        """

        batch = Batch(examples,
                      self.grammar,
                      self.vocab,
                      copy=self.args.no_copy is False,
                      cuda=self.args.cuda)

        # src_encodings: (batch_size, src_sent_len, hidden_size * 2)
        # (last_state, last_cell, dec_init_vec): (batch_size, hidden_size)
        src_encodings, last_state = self.encode(batch.src_sents_var,
                                                batch.src_sents_len)
        dec_init_vec = self.init_decoder_state(last_state)
        ###### src_encodings, (last_state, last_cell) = self.encode(batch.src_sents_var, batch.src_sents_len)
        ######## dec_init_vec = self.init_decoder_state(last_state, last_cell)

        # query vectors are sufficient statistics used to compute action probabilities
        # query_vectors: (tgt_action_len, batch_size, hidden_size)

        # if use supervised attention
        if self.args.sup_attention:
            query_vectors, att_prob = self.decode(batch, src_encodings,
                                                  dec_init_vec)
        else:
            query_vectors = self.decode(batch, src_encodings, dec_init_vec)

        # ApplyRule (i.e., ApplyConstructor) action probabilities
        # (tgt_action_len, batch_size, grammar_size)
        apply_rule_prob = F.softmax(self.production_readout(query_vectors),
                                    dim=-1)

        # probabilities of target (gold-standard) ApplyRule actions
        # (tgt_action_len, batch_size)
        tgt_apply_rule_prob = torch.gather(
            apply_rule_prob,
            dim=2,
            index=batch.apply_rule_idx_matrix.unsqueeze(2)).squeeze(2)

        #### compute generation and copying probabilities

        # (tgt_action_len, batch_size, primitive_vocab_size)
        gen_from_vocab_prob = F.softmax(self.tgt_token_readout(query_vectors),
                                        dim=-1)

        # (tgt_action_len, batch_size)
        tgt_primitive_gen_from_vocab_prob = torch.gather(
            gen_from_vocab_prob,
            dim=2,
            index=batch.primitive_idx_matrix.unsqueeze(2)).squeeze(2)

        if self.args.no_copy:
            # mask positions in action_prob that are not used

            if self.training and self.args.primitive_token_label_smoothing:
                # (tgt_action_len, batch_size)
                # this is actually the negative KL divergence size we will flip the sign later
                # tgt_primitive_gen_from_vocab_log_prob = -self.label_smoothing(
                #     gen_from_vocab_prob.view(-1, gen_from_vocab_prob.size(-1)).log(),
                #     batch.primitive_idx_matrix.view(-1)).view(-1, len(batch))

                tgt_primitive_gen_from_vocab_log_prob = -self.label_smoothing(
                    gen_from_vocab_prob.log(), batch.primitive_idx_matrix)
            else:
                tgt_primitive_gen_from_vocab_log_prob = tgt_primitive_gen_from_vocab_prob.log(
                )

            # (tgt_action_len, batch_size)
            action_prob = tgt_apply_rule_prob.log() * batch.apply_rule_mask + \
                          tgt_primitive_gen_from_vocab_log_prob * batch.gen_token_mask
        else:
            # binary gating probabilities between generating or copying a primitive token
            # (tgt_action_len, batch_size, 2)
            primitive_predictor = F.softmax(
                self.primitive_predictor(query_vectors), dim=-1)

            # pointer network copying scores over source tokens
            # (tgt_action_len, batch_size, src_sent_len)
            primitive_copy_prob = self.src_pointer_net(src_encodings,
                                                       batch.src_token_mask,
                                                       query_vectors)

            # marginalize over the copy probabilities of tokens that are same
            # (tgt_action_len, batch_size)
            tgt_primitive_copy_prob = torch.sum(
                primitive_copy_prob * batch.primitive_copy_token_idx_mask,
                dim=-1)

            # mask positions in action_prob that are not used
            # (tgt_action_len, batch_size)
            action_mask_pad = torch.eq(
                batch.apply_rule_mask + batch.gen_token_mask +
                batch.primitive_copy_mask, 0.)
            action_mask = 1. - action_mask_pad.float()

            # (tgt_action_len, batch_size)
            action_prob = tgt_apply_rule_prob * batch.apply_rule_mask + \
                          primitive_predictor[:, :, 0] * tgt_primitive_gen_from_vocab_prob * batch.gen_token_mask + \
                          primitive_predictor[:, :, 1] * tgt_primitive_copy_prob * batch.primitive_copy_mask

            # avoid nan in log
            action_prob.data.masked_fill_(action_mask_pad.data, 1.e-7)

            action_prob = action_prob.log() * action_mask

        scores = torch.sum(action_prob, dim=0)

        returns = [scores]
        if self.args.sup_attention:
            returns.append(att_prob)
        if return_encode_state: returns.append(last_state)

        return returns
コード例 #5
0
ファイル: parser.py プロジェクト: thu-spmi/seq2seq-JAE
    def score(self, examples, return_enc_state=False, copy=True, force=0):
        """
        input: a batch of examples
        output: score for each training example: Variable(batch_size)
        """
        batch = Batch(examples, self.grammar, self.vocab, self.args.cuda, copy,
                      force)
        src_encodings, (last_state,
                        last_cell) = self.encode(batch.src_sents_var,
                                                 batch.src_sents_len)
        dec_init_vec = self.init_decoder_state(last_state, last_cell)

        # Variable(tgt_action_len, batch_size, hidden_size)
        query_vectors = self.decode(batch, src_encodings, dec_init_vec)

        # ApplyRule action probability
        # (tgt_action_len, batch_size, grammar_size)
        apply_rule_prob = F.softmax(self.production_readout(query_vectors),
                                    dim=-1)

        # pointer network scores over source tokens
        # Variable(tgt_action_len, batch_size, src_sent_len)
        primitive_copy_prob = self.src_pointer_net(src_encodings,
                                                   batch.src_token_mask,
                                                   query_vectors)

        # Variable(tgt_action_len, batch_size, primitive_vocab_size)
        gen_from_vocab_prob = F.softmax(self.tgt_token_readout(query_vectors),
                                        dim=-1)

        # Variable(tgt_action_len, batch_size, 2)
        primitive_predictor_prob = F.softmax(
            self.primitive_predictor(query_vectors), dim=-1)

        # (tgt_action_len, batch_size)
        tgt_apply_rule_prob = torch.gather(
            apply_rule_prob,
            dim=2,
            index=batch.apply_rule_idx_matrix.unsqueeze(2)).squeeze(2)

        # (tgt_action_len, bathc_size)
        tgt_primitive_copy_prob = torch.gather(
            primitive_copy_prob,
            dim=2,
            index=batch.primitive_copy_pos_matrix.unsqueeze(2)).squeeze(2)

        tgt_primitive_gen_from_vocab_prob = torch.gather(
            gen_from_vocab_prob,
            dim=2,
            index=batch.primitive_idx_matrix.unsqueeze(2)).squeeze(2)

        # (tgt_action_len, batch_size)
        action_mask = 1. - torch.eq(
            batch.apply_rule_mask + batch.gen_token_mask +
            batch.primitive_copy_mask, 0.).float()
        action_prob = tgt_apply_rule_prob * batch.apply_rule_mask + \
                      primitive_predictor_prob[:, :, 0] * tgt_primitive_gen_from_vocab_prob * batch.gen_token_mask + \
                      primitive_predictor_prob[:, :, 1] * tgt_primitive_copy_prob * batch.primitive_copy_mask

        action_prob = torch.log(action_prob + 1.e-7 * (1. - action_mask))
        action_prob = action_prob * action_mask

        scores = torch.sum(action_prob, dim=0)

        if return_enc_state:
            return scores, last_state
        else:
            return scores
コード例 #6
0
ファイル: parser.py プロジェクト: zorrock/tranX
    def score(self, examples, return_enc_state=False):
        """
        input: a batch of examples
        output: score for each training example: Variable(batch_size)
        """
        batch = Batch(examples, self.grammar, self.vocab, self.args.cuda)
        src_encodings, (last_state,
                        last_cell) = self.encode(batch.src_sents_var,
                                                 batch.src_sents_len)
        dec_init_vec = self.init_decoder_state(last_state, last_cell)

        if self.args.sup_attention:
            # query_vectors: (tgt_action_len, batch_size, hidden_size)
            query_vectors, att_prob = self.decode(batch, src_encodings,
                                                  dec_init_vec)
        else:
            query_vectors = self.decode(batch, src_encodings, dec_init_vec)

        # ApplyRule action probability
        # (tgt_action_len, batch_size, grammar_size)
        apply_rule_prob = F.softmax(self.production_readout(query_vectors),
                                    dim=-1)

        # (tgt_action_len, batch_size)
        tgt_apply_rule_prob = torch.gather(
            apply_rule_prob,
            dim=2,
            index=batch.apply_rule_idx_matrix.unsqueeze(2)).squeeze(2)

        # (tgt_action_len, batch_size, 2)
        primitive_predictor = F.softmax(
            self.primitive_predictor(query_vectors), dim=-1)

        # pointer network scores over source tokens
        # (tgt_action_len, batch_size, src_sent_len)
        primitive_copy_prob = self.src_pointer_net(src_encodings,
                                                   batch.src_token_mask,
                                                   query_vectors)

        # (tgt_action_len, batch_size)
        tgt_primitive_copy_prob = torch.gather(
            primitive_copy_prob,
            dim=2,
            index=batch.primitive_copy_pos_matrix.unsqueeze(2)).squeeze(2)

        # (tgt_action_len, batch_size, primitive_vocab_size)
        gen_from_vocab_prob = F.softmax(self.tgt_token_readout(query_vectors),
                                        dim=-1)

        # (tgt_action_len, batch_size)
        tgt_primitive_gen_from_vocab_prob = torch.gather(
            gen_from_vocab_prob,
            dim=2,
            index=batch.primitive_idx_matrix.unsqueeze(2)).squeeze(2)

        # (tgt_action_len, batch_size)
        # positions in action_prob that are not used
        action_mask_pad = torch.eq(
            batch.apply_rule_mask + batch.gen_token_mask +
            batch.primitive_copy_mask, 0.)
        action_mask = 1. - action_mask_pad.float()

        # (tgt_action_len, batch_size)
        action_prob = tgt_apply_rule_prob * batch.apply_rule_mask + \
                      primitive_predictor[:, :, 0] * tgt_primitive_gen_from_vocab_prob * batch.gen_token_mask + \
                      primitive_predictor[:, :, 1] * tgt_primitive_copy_prob * batch.primitive_copy_mask

        # avoid nan in log
        action_prob.data.masked_fill_(action_mask_pad.data, 1.e-7)

        action_prob = action_prob.log() * action_mask

        scores = torch.sum(action_prob, dim=0)

        returns = [scores]
        if self.args.sup_attention:
            returns.append(att_prob)
        if return_enc_state: returns.append(last_state)

        return returns
コード例 #7
0
    def score(self, examples, return_encode_state=False):
        """Given a list of examples, compute the log-likelihood of generating the target AST

        Args:
            examples: a batch of examples
            return_encode_state: return encoding states of input utterances
        output: score for each training example: Variable(batch_size)
        """

        batch = Batch(examples,
                      self.grammar,
                      self.vocab,
                      copy=self.args.no_copy is False,
                      cuda=self.args.cuda)

        # src_encodings: (batch_size, src_sent_len, hidden_size)
        src_encodings = self.encode(batch.src_sents_var)

        # tgt vector: (batch_size, src_sent_len, hidden_size)
        tgt_vector = self.prepare_tgt(batch)

        parent_indxs = [[a_t.parent_t if a_t else 0 for a_t in e.tgt_actions]
                        for e in batch.examples]
        parent_indxs_np = np.zeros(
            (len(parent_indxs), max(len(ind) for ind in parent_indxs)),
            dtype=np.long)
        for i in range(len(parent_indxs_np)):
            parent_indxs_np[i, :len(parent_indxs[i])] = parent_indxs[i]
            parent_indxs_np[i, 0] = 0

        # query vectors are sufficient statistics used to compute action probabilities
        query_vectors = self.decoder(tgt_vector, src_encodings,
                                     [parent_indxs_np],
                                     batch.src_token_mask_usual.unsqueeze(-2),
                                     batch.tgt_mask)
        # query_vectors: (tgt_action_len, batch_size, hidden_size)
        query_vectors = query_vectors.transpose(0, 1)

        # ApplyRule (i.e., ApplyConstructor) action probabilities
        # (tgt_action_len, batch_size, grammar_size)
        apply_rule_prob = F.softmax(self.production_readout(query_vectors),
                                    dim=-1)

        # probabilities of target (gold-standard) ApplyRule actions
        # (tgt_action_len, batch_size)
        tgt_apply_rule_prob = torch.gather(
            apply_rule_prob,
            dim=2,
            index=batch.apply_rule_idx_matrix.unsqueeze(2)).squeeze(2)

        # compute generation and copying probabilities #

        # (tgt_action_len, batch_size, primitive_vocab_size)
        gen_from_vocab_prob = F.softmax(self.tgt_token_readout(query_vectors),
                                        dim=-1)

        # (tgt_action_len, batch_size)
        tgt_primitive_gen_from_vocab_prob = torch.gather(
            gen_from_vocab_prob,
            dim=2,
            index=batch.primitive_idx_matrix.unsqueeze(2)).squeeze(2)

        if self.args.no_copy:
            # mask positions in action_prob that are not used

            if self.training and self.args.primitive_token_label_smoothing:
                # (tgt_action_len, batch_size)
                # this is actually the negative KL divergence size we will flip the sign later
                # tgt_primitive_gen_from_vocab_log_prob = -self.label_smoothing(
                #     gen_from_vocab_prob.view(-1, gen_from_vocab_prob.size(-1)).log(),
                #     batch.primitive_idx_matrix.view(-1)).view(-1, len(batch))

                tgt_primitive_gen_from_vocab_log_prob = -self.label_smoothing(
                    gen_from_vocab_prob.log(), batch.primitive_idx_matrix)
            else:
                tgt_primitive_gen_from_vocab_log_prob = tgt_primitive_gen_from_vocab_prob.log(
                )

            # (tgt_action_len, batch_size)
            action_prob = (
                tgt_apply_rule_prob.log() * batch.apply_rule_mask +
                tgt_primitive_gen_from_vocab_log_prob * batch.gen_token_mask)
        else:
            # binary gating probabilities between generating or copying a primitive token
            # (tgt_action_len, batch_size, 2)
            primitive_predictor = F.softmax(
                self.primitive_predictor(query_vectors), dim=-1)

            # pointer network copying scores over source tokens
            # (tgt_action_len, batch_size, src_sent_len)
            primitive_copy_prob = self.src_pointer_net(src_encodings,
                                                       batch.src_token_mask,
                                                       query_vectors)

            # marginalize over the copy probabilities of tokens that are same
            # (tgt_action_len, batch_size)
            tgt_primitive_copy_prob = torch.sum(
                primitive_copy_prob * batch.primitive_copy_token_idx_mask,
                dim=-1)

            # mask positions in action_prob that are not used
            # (tgt_action_len, batch_size)
            action_mask_pad = torch.eq(
                batch.apply_rule_mask + batch.gen_token_mask +
                batch.primitive_copy_mask, 0.0)
            action_mask = 1.0 - action_mask_pad.float()

            # (tgt_action_len, batch_size)
            action_prob = (
                tgt_apply_rule_prob * batch.apply_rule_mask +
                primitive_predictor[:, :, 0] *
                tgt_primitive_gen_from_vocab_prob * batch.gen_token_mask +
                primitive_predictor[:, :, 1] * tgt_primitive_copy_prob *
                batch.primitive_copy_mask)

            # avoid nan in log
            action_prob.data.masked_fill_(action_mask_pad.data, 1.0e-7)
            eps = 1.0e-18
            action_prob += eps

            action_prob = action_prob.log() * action_mask

        scores = torch.sum(action_prob, dim=0)

        returns = [scores]

        return returns