Esempio n. 1
0
	def _algorithm(head, rel, tail):
		""" graph embedding similarity algorithm method """
		head = layers.l2_normalize(head, axis=-1)
		rel = layers.l2_normalize(rel, axis=-1)
		tail = layers.l2_normalize(tail, axis=-1)
		score = head + rel - tail
		return score
Esempio n. 2
0
    def forward(self, weight):
        weight_mat = L.reshape(weight, (weight.shape[0], -1))
        with fluid.dygraph.no_grad():
            for i in range(self.power_iters):
                self.weight_v.set_value(
                    L.l2_normalize(
                        L.matmul(weight_mat,
                                 self.weight_u,
                                 transpose_x=True,
                                 transpose_y=False),
                        axis=0,
                        epsilon=self.eps,
                    ))

                self.weight_u.set_value(
                    L.l2_normalize(
                        L.matmul(weight_mat,
                                 self.weight_v,
                                 transpose_x=False,
                                 transpose_y=False),
                        axis=0,
                        epsilon=self.eps,
                    ))
        sigma = L.matmul(self.weight_u, L.matmul(weight_mat, self.weight_v))
        norm_weight = L.elementwise_div(weight, sigma)
        return norm_weight
Esempio n. 3
0
def compute_l2_normalized_weight(v, g, dim):
    shape = v.shape
    ndim = len(shape)

    if dim is None:
        v_normalized = v / (F.reduce_sum(F.square(v)) + 1e-12)
    elif dim == 0:
        param_matrix = F.reshape(v, (shape[0], np.prod(shape[1:])))
        v_normalized = F.l2_normalize(param_matrix, axis=1)
    elif dim == -1 or dim == ndim - 1:
        param_matrix = F.reshape(v, (np.prod(shape[:-1]), shape[-1]))
        v_normalized = F.l2_normalize(param_matrix, axis=0)
    else:
        perm = list(range(ndim))
        perm[0] = dim
        perm[dim] = 0
        transposed_param = F.transpose(v, perm)
        param_matrix = F.reshape(
            transposed_param,
            (transposed_param.shape[0], np.prod(transposed_param.shape[1:])))
        v_normalized = F.l2_normalize(param_matrix, axis=1)
        v_normalized = F.transpose(v_normalized, perm)
    v_normalized = F.reshape(v_normalized, shape)
    weight = F.elementwise_mul(v_normalized, g, axis=dim)
    return weight
Esempio n. 4
0
    def take_final_feature(self, feature, index, name):
        """take final feature"""
        feat = L.gather(feature, index, overwrite=False)

        ernie_config = self.config.ernie_config
        ernie = ErnieGraphModel(src_ids=feat,
                                config=ernie_config,
                                slot_seqlen=self.config.max_seqlen,
                                name="student_")
        feat = ernie.get_pooled_output()
        fc_lr = self.config.lr / 0.001
        feat = L.fc(
            feat,
            self.config.hidden_size,
            act="relu",
            param_attr=F.ParamAttr(name=name + "_l", learning_rate=fc_lr),
        )
        feat = L.l2_normalize(feat, axis=1)

        if self.config.final_fc:
            feat = L.fc(feat,
                        self.config.hidden_size,
                        param_attr=F.ParamAttr(name=name + '_w'),
                        bias_attr=F.ParamAttr(name=name + '_b'))

        if self.config.final_l2_norm:
            feat = L.l2_normalize(feat, axis=1)
        return feat
    def _update_u(self):
        w = self.weight
        u = self.weight_u

        if len(w.shape) == 4:
            _w = layers.transpose(w, [2, 3, 1, 0])
            _w = layers.reshape(_w, [-1, _w.shape[-1]])
        else:
            _w = layers.reshape(w, [-1, w.shape[-1]])
            _w = layers.reshape(_w, [-1, _w.shape[-1]])
        singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
        norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
        for _ in range(self.power_iterations):
            if singular_value == "left":
                v = layers.l2_normalize(layers.matmul(_w, u, transpose_x=True),
                                        axis=norm_dim)
                u = layers.l2_normalize(layers.matmul(_w, v), axis=norm_dim)
            else:
                v = layers.l2_normalize(layers.matmul(u, _w, transpose_y=True),
                                        axis=norm_dim)
                u = layers.l2_normalize(layers.matmul(v, _w), axis=norm_dim)

        if singular_value == "left":
            sigma = layers.matmul(layers.matmul(u, _w, transpose_x=True), v)
        else:
            sigma = layers.matmul(layers.matmul(v, _w), u, transpose_y=True)
        _w = w / sigma.detach()
        setattr(self.module, self.name,
                _w.detach())  # setattr(self.module, self.name, _w)
Esempio n. 6
0
def _weight_norm(v, g, dim):
    shape = v.shape
    ndims = len(shape)

    if dim is None:
        v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
    elif dim == 0:
        p_matrix = F.reshape(v, (shape[0], -1))
        v_normalized = F.l2_normalize(p_matrix, axis=1)
        v_normalized = F.reshape(v_normalized, shape)
    elif dim == -1 or dim == ndims - 1:
        p_matrix = F.reshape(v, (-1, shape[-1]))
        v_normalized = F.l2_normalize(p_matrix, axis=0)
        v_normalized = F.reshape(v_normalized, shape)
    else:
        perm = list(range(ndims))
        perm[0] = dim
        perm[dim] = 0
        p_transposed = F.transpose(v, perm)
        transposed_shape = p_transposed.shape
        p_matrix = F.reshape(p_transposed, (p_transposed.shape[0], -1))
        v_normalized = F.l2_normalize(p_matrix, axis=1)
        v_normalized = F.reshape(v_normalized, transposed_shape)
        v_normalized = F.transpose(v_normalized, perm)
    weight = F.elementwise_mul(v_normalized,
                               g,
                               axis=dim if dim is not None else -1)
    return weight
Esempio n. 7
0
 def safe_cosine_sim(self, x, y):
     """
         fluid.layers.cos_sim maybe nan
         avoid nan
     """
     l2x = L.l2_normalize(x, axis=-1)
     l2y = L.l2_normalize(y, axis=-1)
     cos = L.reduce_sum(l2x * l2y, dim=1, keep_dim=True)
     return cos
Esempio n. 8
0
 def listwise_hinge_loss(self):
     """listwise hinge loss model"""
     self.poi_repr = L.l2_normalize(self.poi_repr, -1)
     self.query_repr = L.l2_normalize(self.query_repr, -1)
     pos_logits = L.reduce_sum(self.query_repr * self.poi_repr, -1, keep_dim=True)
     neg_logits = L.matmul(self.query_repr, self.poi_repr, transpose_y = True)
     self.loss = L.reduce_mean(L.relu(neg_logits - pos_logits + 0.3))
     self.acc = L.accuracy(L.softmax(neg_logits), self.labels)
     self.metrics = [self.loss, self.acc]
Esempio n. 9
0
	def test_forward(self):
		entity_embedding, relation_embedding = self.create_share_variables()
		entity = layers.l2_normalize(entity_embedding, axis=-1)
		relation = layers.l2_normalize(relation_embedding, axis=-1)
		head_vec = self.lookup_table(self.test_input[0], entity)
		rel_vec = self.lookup_table(self.test_input[1], relation)
		tail_vec = self.lookup_table(self.test_input[2], entity)
		id_replace_head = layers.reduce_sum(layers.abs(entity + rel_vec - tail_vec), dim=1)
		id_replace_tail = layers.reduce_sum(layers.abs(entity - rel_vec - head_vec), dim=1)
		return [id_replace_head, id_replace_tail]
Esempio n. 10
0
        def erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer,
                                    learning_rate, name):
            feature = L.unsqueeze(feature, [-1])
            msg = gw.send(ernie_send, nfeat_list=[("term_ids", feature)])
            neigh_feature = gw.recv(
                msg,
                lambda feat: F.layers.sequence_pool(feat, pool_type="sum"))

            term_ids = feature
            cls = L.fill_constant_batch_size_like(term_ids, [-1, 1, 1],
                                                  "int64", 1)
            term_ids = L.concat([cls, term_ids], 1)
            term_ids.stop_gradient = True
            ernie = ErnieModel(term_ids,
                               L.zeros_like(term_ids),
                               config=self.config.ernie_config)
            self_feature = ernie.get_pooled_output()

            self_feature = L.fc(
                self_feature,
                hidden_size,
                act=act,
                param_attr=F.ParamAttr(name=name + "_l",
                                       learning_rate=learning_rate),
            )
            neigh_feature = L.fc(
                neigh_feature,
                hidden_size,
                act=act,
                param_attr=F.ParamAttr(name=name + "_r",
                                       learning_rate=learning_rate),
            )
            output = L.concat([self_feature, neigh_feature], axis=1)
            output = L.l2_normalize(output, axis=1)
            return output
Esempio n. 11
0
def msg_norm(x, msg, name):
    """Implementation of message normalization, see more information in the paper
    "DeeperGCN: All You Need to Train Deeper GCNs"
    (https://arxiv.org/pdf/2006.07739.pdf)

    Args:
        x: centre node feature (num_nodes, feature_size)
        msg: neighbor node feature (num_nodes, feature_size)
        name: name for s

    Return:
        An output tensor with shape (num_nodes, feature_size)
    """
    s = L.create_parameter(
            shape=[1],
            dtype='float32',
            default_initializer=
                fluid.initializer.ConstantInitializer(value=1.0),
            name=name + '_s_msg_norm')

    msg = L.l2_normalize(msg, axis=1)
    x_norm = L.reduce_sum(x * x, dim=1, keep_dim=True)
    x_norm = L.sqrt(x_norm)
    msg = msg * x_norm * s
    return msg
Esempio n. 12
0
    def take_final_feature(self, feature, index, name):
        """take final feature"""
        feat = L.gather(feature, index, overwrite=False)

        if self.config.final_fc:
            feat = linear(feat, self.config.hidden_size, name)

        if self.config.final_l2_norm:
            feat = L.l2_normalize(feat, axis=1)
        return feat
Esempio n. 13
0
def graphsage_sum(feature, gw, hidden_size, name, act):
    msg = gw.send(lambda s, d, e: s["h"], nfeat_list=[("h", feature)])
    neigh_feature = gw.recv(
        msg, lambda feat: L.sequence_pool(feat, pool_type="sum"))

    hidden_size = hidden_size
    self_feature = linear(feature, hidden_size, name + "_l", act)
    neigh_feature = linear(neigh_feature, hidden_size, name + "_r", act)
    output = L.concat([self_feature, neigh_feature], axis=1)
    output = L.l2_normalize(output, axis=1)
    return output
Esempio n. 14
0
    def take_final_feature(self, feature, index, name):
        """take final feature"""
        feat = L.gather(feature, index, overwrite=False)

        if self.config.final_fc:
            feat = L.fc(feat,
                           self.config.hidden_size,
                           param_attr=F.ParamAttr(name=name + '_w'),
                           bias_attr=F.ParamAttr(name=name + '_b'))

        if self.config.final_l2_norm:
            feat = L.l2_normalize(feat, axis=1)
        return feat
Esempio n. 15
0
    def forward(self, pred, target):
        target = 1 - target[:, 0]
        batch_size, vector_size = pred.shape[0], pred.shape[1]

        pred = L.l2_normalize(pred, axis=1, epsilon=1e-10)

        square_norm = L.reduce_sum(L.square(pred), dim=1)
        dist = L.elementwise_add(-2.0 * L.matmul(pred, pred, transpose_y=True),
                                 square_norm,
                                 axis=0)
        dist = L.elementwise_add(dist, square_norm, axis=1)
        dist = L.elementwise_max(dist, L.zeros_like(dist))
        dist = L.sqrt(dist)

        ap_dist = L.reshape(dist, (0, 0, 1))
        an_dist = L.reshape(dist, (0, 1, -1))

        loss = L.expand(ap_dist, (1, 1, batch_size)) - L.expand(
            an_dist, (1, batch_size, 1)) + self.magin

        indice_equal = L.diag(
            L.fill_constant((batch_size, ), dtype='float32', value=1.0))
        indice_not_equal = 1.0 - indice_equal

        broad_matrix = L.expand(L.reshape(target, (-1, 1)),
                                (1, batch_size)) + L.expand(
                                    L.reshape(target, (1, -1)),
                                    (batch_size, 1))

        pp = L.cast(L.equal(broad_matrix, L.zeros_like(broad_matrix)),
                    dtype='float32')
        pp = L.reshape(indice_not_equal * pp, (0, 0, 1))

        pn = L.cast(L.equal(broad_matrix,
                            L.zeros_like(broad_matrix) + 1),
                    dtype='float32')
        pn = L.reshape(indice_not_equal * pn, (1, 0, -1))

        apn = L.expand(pp,
                       (1, 1, batch_size)) * L.expand(pn, (batch_size, 1, 1))

        loss = loss * L.cast(apn, dtype='float32')
        loss = L.elementwise_max(loss, L.zeros_like(loss))

        num_tri = L.reduce_sum(
            L.cast(L.greater_than(loss, L.zeros_like(loss)), dtype='float32'))

        loss = L.reduce_sum(loss) * self.loss_weight / (num_tri + 1e-16)

        return loss
Esempio n. 16
0
    def test_forward(self):
        entity_embedding, relation_embedding, transfer_matrix = self.create_share_variables(
        )

        rel_matrix = layers.reshape(
            self.lookup_table(self.test_input[1], transfer_matrix),
            [self.hidden_size, self.hidden_size])
        entity_embedding_trans = layers.matmul(entity_embedding, rel_matrix,
                                               False, False)
        rel_vec = self.lookup_table(self.test_input[1], relation_embedding)
        entity_embedding_trans = layers.l2_normalize(entity_embedding_trans,
                                                     axis=-1)
        rel_vec = layers.l2_normalize(rel_vec, axis=-1)
        head_vec = self.lookup_table(self.test_input[0],
                                     entity_embedding_trans)
        tail_vec = self.lookup_table(self.test_input[2],
                                     entity_embedding_trans)
        id_replace_head = layers.reduce_sum(layers.abs(entity_embedding_trans +
                                                       rel_vec - tail_vec),
                                            dim=1)
        id_replace_tail = layers.reduce_sum(layers.abs(entity_embedding_trans -
                                                       rel_vec - head_vec),
                                            dim=1)
        return [id_replace_head, id_replace_tail]
Esempio n. 17
0
    def take_final_feature(self, feature, index, name):
        """take final feature"""
        term_ids = L.gather(feature, index, overwrite=False)

        ernie_config = self.config.ernie_config
        self.slot_seqlen = self.config.max_seqlen
        position_ids = self._build_position_ids(term_ids)
        sent_ids = self._build_sentence_ids(term_ids)

        ernie_model = ErnieModel(self.config.ernie_config, "")
        feature, _ = ernie_model(term_ids, sent_ids, position_ids)

        if self.config.final_fc:
            feature = linear(feature, self.config.hidden_size, name)

        if self.config.final_l2_norm:
            feature = L.l2_normalize(feature, axis=1)
        return feature
Esempio n. 18
0
    def ernie_send_aggregate(self, gw, feature, act, name):
        def ernie_send(src_feat, dst_feat, edge_feat):
            def build_position_ids(term_ids):
                input_mask = L.cast(term_ids > 0, "int64")
                position_ids = L.cumsum(input_mask, axis=1) - 1
                return position_ids

            """doc"""
            # input_ids
            cls = L.fill_constant_batch_size_like(src_feat["term_ids"],
                                                  [-1, 1], "int64",
                                                  self.config.cls_id)
            src_ids = L.concat([cls, src_feat["term_ids"]], 1)
            dst_ids = dst_feat["term_ids"]

            # sent_ids
            sent_ids = L.concat([L.zeros_like(src_ids),
                                 L.ones_like(dst_ids)], 1)
            term_ids = L.concat([src_ids, dst_ids], 1)

            # position_ids
            position_ids = build_position_ids(term_ids)
            ernie_model = ErnieModel(self.config.ernie_config, "")
            feature, _ = ernie_model(term_ids, sent_ids, position_ids)
            return feature

        term_ids = feature
        msg = gw.send(ernie_send, nfeat_list=[("term_ids", term_ids)])
        neigh_feature = gw.recv(
            msg, lambda feat: F.layers.sequence_pool(feat, pool_type="sum"))

        cls = L.fill_constant_batch_size_like(term_ids, [-1, 1], "int64",
                                              self.config.cls_id)
        term_ids = L.concat([cls, term_ids], 1)
        ernie_model = ErnieModel(self.config.ernie_config, "")
        self_feature, _ = ernie_model(term_ids)

        hidden_size = self.config.hidden_size
        self_feature = linear(self_feature, hidden_size, name + "_l", act)
        neigh_feature = linear(neigh_feature, hidden_size, name + "_r", act)
        output = L.concat([self_feature, neigh_feature], axis=1)
        output = L.l2_normalize(output, axis=1)
        return output
Esempio n. 19
0
    def _make_params(self):
        # paddle linear weight is similar with tf's, and conv weight is similar with pytorch's.
        w = getattr(self.module, self.name)

        if len(w.shape) == 4:
            _w = layers.transpose(w, [2, 3, 1, 0])
            _w = layers.reshape(_w, [-1, _w.shape[-1]])
        else:
            _w = layers.reshape(w, [-1, w.shape[-1]])
        singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
        norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
        u_shape = (_w.shape[0],
                   1) if singular_value == "left" else (1, _w.shape[-1])

        u = self.create_parameter(shape=u_shape,
                                  default_initializer=Normal(0, 1))
        u.stop_gradient = True
        u.set_value(layers.l2_normalize(u, axis=norm_dim))

        del self.module._parameters[self.name]
        self.add_parameter("weight", w)
        self.add_parameter("weight_u", u)
Esempio n. 20
0
def graphsage_mean(gw, feature, hidden_size, act, initializer, learning_rate, name):
    """doc"""
    msg = gw.send(copy_send, nfeat_list=[("h", feature)])
    neigh_feature = gw.recv(msg, mean_recv)
    self_feature = feature
    self_feature = L.fc(self_feature,
                                   hidden_size,
                                   act=act,
                                   param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer,
                                   learning_rate=learning_rate),
                                    bias_attr=name+"_l.b_0"
                                   )
    neigh_feature = L.fc(neigh_feature,
                                    hidden_size,
                                    act=act,
                                    param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer,
                                   learning_rate=learning_rate),
                                    bias_attr=name+"_r.b_0"
                                    )
    output = L.concat([self_feature, neigh_feature], axis=1)
    output = L.l2_normalize(output, axis=1)
    return output
Esempio n. 21
0
 def test_l2_normalize(self):
     program = Program()
     with program_guard(program):
         x = layers.data(name='x', shape=[8, 7, 10], dtype="float32")
         output = layers.l2_normalize(x, axis=1)
Esempio n. 22
0
    def forward(self):
        """Build the GATNE net.
        """
        param_attr_init = fluid.initializer.Uniform(
            low=-1.0, high=1.0, seed=np.random.randint(100))
        embed_param_attrs = fluid.ParamAttr(name='Base_node_embed',
                                            initializer=param_attr_init)

        # node_embeddings
        base_node_embed = fl.embedding(
            input=fl.reshape(self.train_inputs, shape=[-1, 1]),
            size=[self.num_nodes, self.embedding_size],
            param_attr=embed_param_attrs)

        node_features = []
        for edge_type in self.edge_types:
            param_attr_init = fluid.initializer.Uniform(
                low=-1.0, high=1.0, seed=np.random.randint(100))
            embed_param_attrs = fluid.ParamAttr(name='%s_node_embed' %
                                                edge_type,
                                                initializer=param_attr_init)

            features = fl.embedding(
                input=self.gw[edge_type].node_feat['index'],
                size=[self.num_nodes, self.embedding_u_size],
                param_attr=embed_param_attrs)

            node_features.append(features)

        # mp_output: list of embedding(self.num_nodes, dim)
        mp_output = self.message_passing(self.gw, self.edge_types,
                                         node_features)

        # U : (num_type[m], num_nodes, dim[s])
        node_type_embed = fl.stack(mp_output, axis=0)

        # U : (num_nodes, num_type[m], dim[s])
        node_type_embed = fl.transpose(node_type_embed, perm=[1, 0, 2])

        #gather node_type_embed from train_inputs
        node_type_embed = fl.gather(node_type_embed, self.train_inputs)

        # M_r
        trans_weights = fl.create_parameter(
            shape=[
                self.edge_type_count, self.embedding_u_size,
                self.embedding_size // self.att_head
            ],
            attr=fluid.initializer.TruncatedNormalInitializer(
                loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
            dtype='float32',
            name='trans_w')

        # W_r
        trans_weights_s1 = fl.create_parameter(
            shape=[self.edge_type_count, self.embedding_u_size, self.dim_a],
            attr=fluid.initializer.TruncatedNormalInitializer(
                loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
            dtype='float32',
            name='trans_w_s1')

        # w_r
        trans_weights_s2 = fl.create_parameter(
            shape=[self.edge_type_count, self.dim_a, self.att_head],
            attr=fluid.initializer.TruncatedNormalInitializer(
                loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
            dtype='float32',
            name='trans_w_s2')

        trans_w = fl.gather(trans_weights, self.train_types)
        trans_w_s1 = fl.gather(trans_weights_s1, self.train_types)
        trans_w_s2 = fl.gather(trans_weights_s2, self.train_types)

        attention = self.attention(node_type_embed, trans_w_s1, trans_w_s2)
        node_type_embed = fl.matmul(attention, node_type_embed)
        node_embed = base_node_embed + fl.reshape(
            fl.matmul(node_type_embed, trans_w), [-1, self.embedding_size])

        self.last_node_embed = fl.l2_normalize(node_embed, axis=1)

        nce_weight_initializer = fluid.initializer.TruncatedNormalInitializer(
            loc=0.0, scale=1.0 / math.sqrt(self.embedding_size))
        nce_weight_attrs = fluid.ParamAttr(name='nce_weight',
                                           initializer=nce_weight_initializer)

        weight_pos = fl.embedding(input=self.train_labels,
                                  size=[self.num_nodes, self.embedding_size],
                                  param_attr=nce_weight_attrs)
        weight_neg = fl.embedding(input=self.train_negs,
                                  size=[self.num_nodes, self.embedding_size],
                                  param_attr=nce_weight_attrs)
        tmp_node_embed = fl.unsqueeze(self.last_node_embed, axes=[1])
        pos_logits = fl.matmul(tmp_node_embed, weight_pos,
                               transpose_y=True)  # [B, 1, 1]

        neg_logits = fl.matmul(tmp_node_embed, weight_neg,
                               transpose_y=True)  # [B, 1, neg_num]

        pos_score = fl.squeeze(pos_logits, axes=[1])
        pos_score = fl.clip(pos_score, min=-10, max=10)
        pos_score = -1.0 * fl.logsigmoid(pos_score)

        neg_score = fl.squeeze(neg_logits, axes=[1])
        neg_score = fl.clip(neg_score, min=-10, max=10)
        neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score)

        neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True)
        self.loss = fl.reduce_mean(pos_score + neg_score)
Esempio n. 23
0
    def build_model(self, args, task):
        """ build graph model"""
        self.query_geo = L.data(name="query_geo", shape=[-1, 80], dtype="float32")
        self.holder_list.append(self.query_geo)
        self.poi_geo = L.data(name="poi_geo", shape=[-1, 40], dtype="float32")
        self.holder_list.append(self.poi_geo)

        if task != "predict_query":
            self.city_id = L.data(name="city_id", shape=[-1], dtype="int64")
            self.holder_list.append(self.city_id)

            poi_city_embed = self.city_embedding(self.city_id)

            self.poi_index = L.data(name="poi_index", shape=[-1], dtype="int64")
            self.holder_list.append(self.poi_index)

        if task != "predict_poi":
            self.query_city = L.data(name="query_city", shape=[-1], dtype="int64")
            self.holder_list.append(self.query_city)
            query_city_embed = self.city_embedding(self.query_city)

            self.query_index = L.data(
                name="query_index", shape=[-1], dtype="int64")
            self.holder_list.append(self.query_index)


        if task == 'pointwise':
            self.labels = L.data(name="labels", shape=[-1], dtype="float32")
            self.holder_list.append(self.labels)
        elif task == "pairwise":
            self.labels = L.data(name="labels", shape=[-1], dtype="float32")
            self.holder_list.append(self.labels)
            self.labels = L.reshape(self.labels, [-1, 1])
            self.labels.stop_gradients = True
        elif task == "listwise" or task == "listwise_hinge":
            self.labels = L.data(name="labels", shape=[-1], dtype="int64")
            self.holder_list.append(self.labels)
            self.labels = L.reshape(self.labels, [-1, 1])
            self.labels.stop_gradients = True
        elif task == "predict_query":
            pass
        elif task == "predict_poi":
            pass

        src_ids = self.graph_wrapper.node_feat["src_ids"]
        pos_ids = self.graph_wrapper.node_feat["pos_ids"]
        sent_ids = self.graph_wrapper.node_feat["sent_ids"]
        input_mask = self.graph_wrapper.node_feat["input_mask"]

        src_ids = ernie_unsqueeze(src_ids)
        pos_ids = ernie_unsqueeze(pos_ids)
        sent_ids = ernie_unsqueeze(sent_ids)
        input_mask = ernie_unsqueeze(input_mask)
        task_ids = L.zeros_like(sent_ids)
        task_ids = L.cast(task_ids, dtype="int64")

        if args.model_type == "cnn":
            encoder_model = CnnModel
        elif args.model_type == "ernie":
            encoder_model = ErnieModel 
        else:
            raise ValueError("model type %s not exists." % args.model_type)

        ernie = encoder_model(
                src_ids=src_ids,
                position_ids=pos_ids,
                sentence_ids=sent_ids,
                input_mask=input_mask,
                config=self.ernie_config,
                task_ids=task_ids, )

        if task != "predict_query":
            args.max_addr_len = args.max_seq_len

            addr_src_ids = L.data(
                name='addr_src_ids',
                shape=[None, args.max_addr_len],
                dtype="int64")
            self.holder_list.append(addr_src_ids)

            addr_pos_ids = L.data(
                name='addr_pos_ids',
                shape=[None, args.max_addr_len],
                dtype="int64")
            self.holder_list.append(addr_pos_ids)

            addr_sent_ids = L.data(
                name='addr_sent_ids',
                shape=[None, args.max_addr_len],
                dtype="int64")
            self.holder_list.append(addr_sent_ids)

            addr_input_mask = L.data(
                name='addr_input_mask',
                shape=[None, args.max_addr_len],
                dtype="float32")
            self.holder_list.append(addr_input_mask)

            addr_src_ids = ernie_unsqueeze(addr_src_ids)
            addr_pos_ids = ernie_unsqueeze(addr_pos_ids)
            addr_sent_ids = ernie_unsqueeze(addr_sent_ids)
            addr_input_mask = ernie_unsqueeze(addr_input_mask)
            addr_task_ids = L.zeros_like(addr_sent_ids)
            addr_task_ids = L.cast(addr_task_ids, dtype="int64")

            addr_ernie = encoder_model(
                src_ids=addr_src_ids,
                position_ids=addr_pos_ids,
                sentence_ids=addr_sent_ids,
                input_mask=addr_input_mask,
                config=self.ernie_config,
                task_ids=addr_task_ids, )

            addr_repr = addr_ernie.get_pooled_output()

        # get first token as sentence repr
        sent_repr = ernie.get_pooled_output()

        if task != "predict_poi":
            self.query_repr = L.gather(
                sent_repr, self.query_index, overwrite=False)

            self.query_city_embed = query_city_embed
            for_concat = []
            if args.with_city:
                for_concat.append(query_city_embed)
            if args.with_geo_id:
                for_concat.append(self.query_geo)
            
            if len(for_concat) > 0:
                self.query_repr = L.concat(
                    [self.query_repr ] + for_concat, axis=-1)

            self.query_repr = L.fc(self.query_repr,
                               self.hidden_size,
                               act="tanh",
                               name="query_fc")
            self.query_city_score = L.reduce_sum(L.l2_normalize(self.query_city_embed, -1) *
                                              L.l2_normalize(self.query_repr, -1), -1)

        if task != "predict_query":
            neigh_repr = self.neighbor_aggregator(sent_repr)

            self.poi_repr = L.gather(sent_repr, self.poi_index, overwrite=False)
            for_concat = [self.poi_repr, addr_repr, ]
            if args.with_city:
                for_concat.append(poi_city_embed)

            if args.with_geo_id:
                for_concat.append(self.poi_geo)

            if neigh_repr is not None:
                poi_neigh_repr = L.gather(
                    neigh_repr, self.poi_index, overwrite=False)
                for_concat.append(poi_neigh_repr)

            self.poi_repr = L.concat(for_concat, axis=-1)

            self.poi_repr = L.fc(self.poi_repr,
                             self.hidden_size,
                             act="tanh",
                             name="pos_fc")

        if task == "pointwise":
            self.pointwise_loss()
        elif task == "pairwise":
            self.pairwise_loss()
        elif task == "listwise":
            self.listwise_loss(args)
        elif task == "listwise_hinge":
            self.listwise_hinge_loss()