Ejemplo n.º 1
0
    def forward(self, embedding, targets):
        if isinstance(embedding, dict):
            embedding = embedding['features']
        # Normalize embedding features
        embedding = F.normalize(embedding, axis=1)
        dist_mat = paddle.matmul(embedding, embedding, transpose_y=True)

        N = dist_mat.shape[0]
        is_pos = targets.reshape([N, 1]).expand([N, N]).equal(
            paddle.t(targets.reshape([N, 1]).expand([N, N]))).astype('float')
        is_neg = targets.reshape([N, 1]).expand([N, N]).not_equal(
            paddle.t(targets.reshape([N, 1]).expand([N, N]))).astype('float')

        # Mask scores related to itself
        is_pos = is_pos - paddle.eye(N, N)

        s_p = dist_mat * is_pos
        s_n = dist_mat * is_neg

        logit_p = -self.gamma * s_p + (-99999999.) * (1 - is_pos)
        logit_n = self.gamma * (s_n + self.margin) + (-99999999.) * (1 -
                                                                     is_neg)

        loss = F.softplus(
            paddle.logsumexp(logit_p, axis=1) +
            paddle.logsumexp(logit_n, axis=1)).mean()

        return {"PairwiseCosface": loss}
Ejemplo n.º 2
0
    def forward(self, inputs, lengths):
        """
        Computes the normalization in a linear-chain CRF. See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
        $$ F = logZ(x) = log\\sum_y exp(score(x,y)) $$
        $$ score(x,y) = \\sum_i Emit(x_i,y_i) + Trans(y_{i-1}, y_i) $$
        mark $$ p(y_i) = Emit(x_i,y_i), T(y_{i-1}, y_i)=Trans(y_{i-1}, y_i) $$
        then we can get
        $$ F(1) = log\\sum_{y1} exp(p(y_1) + T([START], y1)) $$
        $$ F(2) = log\\sum_{y1}\\sum_{y2} exp(p(y_1) + T([START], y1) + p(y_2) + T(y_1,y_2)) =  log\\sum_{y2} exp(F(1) + p(y_2) + T(y_1,y_2)) $$
        $$ F(...) = ... $$
        A recursive formula.

        Args:
            inputs (Tensor): The input tensor with shape `[batch_size, sequence_length, num_tags]`.
            lengths (Tensor): The input length with shape `[batch_size]`.

        Returns:
            Tensor: The normalizers tensor, with shape `[batch_size]`.
        """
        batch_size, seq_len, n_labels = inputs.shape
        inputs_t_exp = inputs.transpose([1, 0, 2]).unsqueeze(-1).expand(
            [seq_len, batch_size, n_labels, n_labels])
        # trans_exp: batch_size, num_tags, num_tags
        trans_exp = self.transitions.unsqueeze(0).expand(
            [batch_size, n_labels, n_labels])

        all_alpha = []
        if self.with_start_stop_tag:
            alpha = self._initialize_alpha(batch_size)

        for i, input_exp in enumerate(inputs_t_exp):
            # input_exp: batch_size, num_tags, num_tags
            # alpha_exp: batch_size, num_tags, num_tags
            if i == 0 and not self.with_start_stop_tag:
                mat = input_exp
            else:
                alpha_exp = alpha.unsqueeze(1).expand(
                    [batch_size, n_labels, n_labels])
                # F(n) = logsumexp(F(n-1) + p(y_n) + T(y_{n-1}, y_n))
                mat = input_exp + trans_exp + alpha_exp
            alpha = paddle.logsumexp(mat, 2)
            all_alpha.append(alpha)

        # Get the valid alpha
        all_alpha = paddle.stack(all_alpha).transpose([1, 0, 2])
        batch_index = self._get_batch_index(batch_size)
        last_index = lengths - 1
        idxs = paddle.stack([batch_index, last_index], axis=1)
        alpha = paddle.gather_nd(all_alpha, idxs)

        if self.with_start_stop_tag:
            # The last one step
            alpha += self.transitions[self.stop_idx].unsqueeze(0)
        norm_score = paddle.logsumexp(alpha, 1)
        return norm_score
Ejemplo n.º 3
0
    def api_case(self, axis=None, keepdim=False):
        out_ref = ref_logsumexp(self.x, axis, keepdim)
        with paddle.static.program_guard(paddle.static.Program()):
            x = paddle.fluid.data('X', self.shape)
            out = paddle.logsumexp(x, axis, keepdim)
            exe = paddle.static.Executor(self.place)
            res = exe.run(feed={'X': self.x}, fetch_list=[out])
        self.assertTrue(np.allclose(res[0], out_ref))

        paddle.disable_static(self.place)
        x = paddle.to_tensor(self.x)
        out = paddle.logsumexp(x, axis, keepdim)
        self.assertTrue(np.allclose(out.numpy(), out_ref))
        paddle.enable_static()
Ejemplo n.º 4
0
 def forward(self, inputs):
     """
     forward
     """
     x = paddle.logsumexp(inputs,
                          axis=self.config["axis"],
                          keepdim=self.config["keepdim"])
     return x
Ejemplo n.º 5
0
def accumulate_logprobs(d, keys_and_logprobs):
    for key, logprob in keys_and_logprobs:
        existing = d.get(key)
        if existing is None:
            d[key] = logprob
        else:
            d[key] = paddle.logsumexp(paddle.stack((logprob, existing),
                                                   axis=0),
                                      axis=0)
Ejemplo n.º 6
0
 def test_alias(self):
     paddle.disable_static(self.place)
     x = paddle.to_tensor(self.x)
     out1 = paddle.logsumexp(x)
     out2 = paddle.tensor.logsumexp(x)
     out3 = paddle.tensor.math.logsumexp(x)
     out_ref = ref_logsumexp(self.x)
     for out in [out1, out2, out3]:
         self.assertTrue(np.allclose(out.numpy(), out_ref))
     paddle.enable_static()
Ejemplo n.º 7
0
    def compute_loss_from_all_ordering(self, enc_input, example, desc_enc,
                                       debug):
        """compute loss from all ordering"""
        def get_permutations(node):
            """get permutations"""
            def traverse_tree(node):
                """traverse tree"""
                nonlocal permutations
                if isinstance(node, (list, tuple)):
                    p = itertools.permutations(range(len(node)))
                    permutations.append(list(p))
                    for child in node:
                        traverse_tree(child)
                elif isinstance(node, dict):
                    for node_name in node:
                        traverse_tree(node[node_name])

            permutations = []
            traverse_tree(node)
            return permutations

        def get_perturbed_tree(node, permutation):
            """get perturbed tree"""
            def traverse_tree(node, parent_type, parent_node):
                """traverse tree"""
                if isinstance(node, (list, tuple)):
                    nonlocal permutation
                    p_node = [node[i] for i in permutation[0]]
                    parent_node[parent_type] = p_node
                    permutation = permutation[1:]
                    for child in node:
                        traverse_tree(child, None, None)
                elif isinstance(node, dict):
                    for node_name in node:
                        traverse_tree(node[node_name], node_name, node)

            node = copy.deepcopy(node)
            traverse_tree(node, None, None)
            return node

        orig_tree = example.tree
        permutations = get_permutations(orig_tree)
        products = itertools.product(*permutations)
        loss_list = []
        for product in products:
            tree = get_perturbed_tree(orig_tree, product)
            example.tree = tree
            loss = self.compute_mle_loss(enc_input, example, desc_enc)
            loss_list.append(loss)
        example.tree = orig_tree
        loss_v = paddle.stack(loss_list, 0)
        return paddle.logsumexp(loss_v, 0)
Ejemplo n.º 8
0
    def pointer_choice(self, node_type, logits, attention_logits):
        """pointer_choice"""
        # Group them based on pointer map
        pointer_logprobs = self.model.pointer_infer(node_type, logits)
        pointer_map = self.desc_enc.pointer_maps.get(node_type)
        if not pointer_map:
            return pointer_logprobs

        pointer_logprobs = dict(pointer_logprobs)
        return [
            (orig_index, paddle.logsumexp(
                            paddle.stack(tuple(pointer_logprobs[i] for i in mapped_indices), axis=0),
                            axis=0))
            for orig_index, mapped_indices in pointer_map.items()
        ]
Ejemplo n.º 9
0
    def gen_token_loss(self, output, gen_logodds, token, desc_enc):
        """gen token loss"""
        # token_idx shape: batch (=1), LongTensor
        token_idx = self._index(self.terminal_vocab, token)
        # action_emb shape: batch (=1) x emb_size
        action_emb = self.terminal_embedding(token_idx)

        # +unk, +in desc: copy
        # +unk, -in desc: gen (an unk token)
        # -unk, +in desc: copy, gen
        # -unk, -in desc: gen
        # gen_logodds shape: batch (=1)
        desc_locs = desc_enc.find_word_occurrences(token)
        if desc_locs:
            # copy: if the token appears in the description at least once
            # copy_loc_logits shape: batch (=1) x desc length
            copy_loc_logits = self.copy_pointer(output, desc_enc.memory)
            copy_logprob = (
                # log p(copy | output)
                # shape: batch (=1)
                paddle.nn.functional.log_sigmoid(-gen_logodds) -
                # xent_loss: -log p(location | output)
                # TODO: sum the probability of all occurrences
                # shape: batch (=1)
                self.xent_loss(copy_loc_logits, self._tensor(desc_locs[0:1])))
        else:
            copy_logprob = None

        # gen: ~(unk & in desc), equivalent to  ~unk | ~in desc
        if token in self.terminal_vocab or copy_logprob is None:
            token_logits = self.terminal_logits(output)
            # shape:
            gen_logprob = (
                # log p(gen | output)
                # shape: batch (=1)
                paddle.nn.functional.log_sigmoid(gen_logodds) -
                # xent_loss: -log p(token | output)
                # shape: batch (=1)
                self.xent_loss(token_logits, token_idx))
        else:
            gen_logprob = None

        # loss should be -log p(...), so negate
        loss_piece = -paddle.logsumexp(
            maybe_stack([copy_logprob, gen_logprob], axis=1), axis=1)
        return loss_piece
Ejemplo n.º 10
0
    def test_dygraph(self):
        with fluid.dygraph.guard():
            np_x = np.random.uniform(0.1, 1, [123]).astype(np.float32)
            x = fluid.dygraph.to_variable(np_x)
            self.assertTrue(
                np.allclose(
                    paddle.logsumexp(x).numpy(), np.log(np.sum(np.exp(np_x)))))

            np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32)
            x = fluid.dygraph.to_variable(np_x)
            self.assertTrue(
                np.allclose(
                    paddle.logsumexp(x, dim=[1, 2]).numpy(),
                    np.log(np.sum(np.exp(np_x), axis=(1, 2)))))

            np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32)
            x = fluid.dygraph.to_variable(np_x)
            self.assertTrue(
                np.allclose(
                    paddle.logsumexp(x, dim=[2]).numpy(),
                    np.log(np.sum(np.exp(np_x), axis=(2)))))

            np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32)
            x = fluid.dygraph.to_variable(np_x)
            self.assertTrue(
                np.allclose(
                    paddle.logsumexp(x, keepdim=True).numpy(),
                    np.log(np.sum(np.exp(np_x), keepdims=True))))

            np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32)
            x = fluid.dygraph.to_variable(np_x)
            helper = LayerHelper("test_logsumexp")
            out = helper.create_variable(type=x.type,
                                         name='out',
                                         dtype=x.dtype,
                                         persistable=False)
            paddle.logsumexp(x, out=out)
            self.assertTrue(
                np.allclose(out.numpy(), np.log(np.sum(np.exp(np_x)))))
Ejemplo n.º 11
0
    def __init__(
            self,
            preproc,
            #
            rule_emb_size=128,
            node_embed_size=64,
            # TODO: This should be automatically inferred from encoder
            enc_recurrent_size=768,
            recurrent_size=512,
            dropout=0.,
            desc_attn='bahdanau',
            copy_pointer=None,
            multi_loss_type='logsumexp',
            sup_att=None,
            use_align_mat=False,
            use_align_loss=False,
            enumerate_order=False,
            loss_type="softmax"):
        """init"""
        super().__init__()
        self.preproc = preproc
        self.ast_wrapper = preproc.ast_wrapper
        self.terminal_vocab = preproc.vocab

        self.rule_emb_size = rule_emb_size
        self.node_emb_size = node_embed_size
        self.enc_recurrent_size = enc_recurrent_size
        self.recurrent_size = recurrent_size

        self.rules_index = {
            v: idx
            for idx, v in enumerate(self.preproc.all_rules)
        }
        self.use_align_mat = use_align_mat
        self.use_align_loss = use_align_loss
        self.enumerate_order = enumerate_order

        if use_align_mat:
            self.compute_align_loss = lambda *args: align_dec_func.compute_align_loss(
                self, *args)
            self.compute_pointer_with_align = lambda *args: align_dec_func.compute_pointer_with_align(
                self, *args)

        if self.preproc.use_seq_elem_rules:
            self.node_type_vocab = vocab.Vocab(
                sorted(self.preproc.primitive_types) +
                sorted(self.ast_wrapper.custom_primitive_types) +
                sorted(self.preproc.sum_type_constructors.keys()) +
                sorted(self.preproc.field_presence_infos.keys()) +
                sorted(self.preproc.seq_lengths.keys()),
                special_elems=())
        else:
            self.node_type_vocab = vocab.Vocab(
                sorted(self.preproc.primitive_types) +
                sorted(self.ast_wrapper.custom_primitive_types) +
                sorted(self.ast_wrapper.sum_types.keys()) +
                sorted(self.ast_wrapper.singular_types.keys()) +
                sorted(self.preproc.seq_lengths.keys()),
                special_elems=())

        self.state_update = paddle.nn.LSTMCell(
            input_size=self.rule_emb_size * 2 + self.enc_recurrent_size +
            self.recurrent_size + self.node_emb_size,
            hidden_size=self.recurrent_size)
        #dropout=dropout)

        self.attn_type = desc_attn
        if desc_attn == 'bahdanau':
            self.desc_attn = attention.BahdanauAttention(
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size,
                proj_size=50)
        elif desc_attn == 'mha':
            self.desc_attn = attention.MultiHeadedAttention(
                h=8,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        elif desc_attn == 'mha-1h':
            self.desc_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        elif desc_attn == 'sep':
            self.question_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
            self.schema_attn = attention.MultiHeadedAttention(
                h=1,
                query_size=self.recurrent_size,
                value_size=self.enc_recurrent_size)
        else:
            # TODO: Figure out how to get right sizes (query, value) to module
            self.desc_attn = desc_attn
        self.sup_att = sup_att

        self.rule_logits = paddle.nn.Sequential(
            paddle.nn.Linear(self.recurrent_size, self.rule_emb_size),
            paddle.nn.Tanh(),
            paddle.nn.Linear(self.rule_emb_size, len(self.rules_index)))
        self.rule_embedding = paddle.nn.Embedding(
            num_embeddings=len(self.rules_index),
            embedding_dim=self.rule_emb_size)

        self.gen_logodds = paddle.nn.Linear(self.recurrent_size, 1)
        self.terminal_logits = paddle.nn.Sequential(
            paddle.nn.Linear(self.recurrent_size, self.rule_emb_size),
            paddle.nn.Tanh(),
            paddle.nn.Linear(self.rule_emb_size, len(self.terminal_vocab)))
        self.terminal_embedding = paddle.nn.Embedding(
            num_embeddings=len(self.terminal_vocab),
            embedding_dim=self.rule_emb_size)
        if copy_pointer is None:
            self.copy_pointer = attention.BahdanauPointer(
                query_size=self.recurrent_size,
                key_size=self.enc_recurrent_size,
                proj_size=50)
        else:
            # TODO: Figure out how to get right sizes (query, key) to module
            self.copy_pointer = copy_pointer
        if multi_loss_type == 'logsumexp':
            self.multi_loss_reduction = lambda logprobs: -paddle.logsumexp(
                logprobs, axis=1)
        elif multi_loss_type == 'mean':
            self.multi_loss_reduction = lambda logprobs: -paddle.mean(logprobs,
                                                                      axis=1)

        self.pointers = {}
        self.pointer_action_emb_proj = {}
        for pointer_type in self.preproc.grammar.pointers:
            self.pointers[pointer_type] = attention.ScaledDotProductPointer(
                query_size=self.recurrent_size,
                key_size=self.enc_recurrent_size)
            self.pointer_action_emb_proj[pointer_type] = paddle.nn.Linear(
                self.enc_recurrent_size, self.rule_emb_size)
            setattr(self, pointer_type + '_pointer',
                    self.pointers[pointer_type])
            setattr(self, pointer_type + '_action_emb_proj',
                    self.pointer_action_emb_proj[pointer_type])

        self.node_type_embedding = paddle.nn.Embedding(
            num_embeddings=len(self.node_type_vocab),
            embedding_dim=self.node_emb_size)

        # TODO batching
        self.zero_rule_emb = paddle.zeros([1, self.rule_emb_size])
        self.zero_recurrent_emb = paddle.zeros([1, self.recurrent_size])
        if loss_type == "softmax":
            self.xent_loss = paddle.nn.CrossEntropyLoss(reduction='none')
        elif loss_type == "entmax":
            raise ValueError("entmax is not supported")
            #self.xent_loss = entmax.entmax15_loss
        elif loss_type == "sparsemax":
            raise ValueError("sparsemax is not supported")
            #self.xent_loss = entmax.sparsemax_loss
        elif loss_type == "label_smooth":
            self.xent_loss = self.label_smooth_loss
Ejemplo n.º 12
0
 def forward(self, inputs):
     """
     forward
     """
     x = paddle.logsumexp(inputs, keepdim=self.keepdim, axis=self.axis)
     return x
Ejemplo n.º 13
0
def logsumexp_wrapper(x, axis=None, keepdim=False, allreduce=False):
    if allreduce:
        return paddle.logsumexp(x, None, keepdim)
    return paddle.logsumexp(x, axis, keepdim)