Beispiel #1
0
def multipath_fully_connected(x, path_out_dim=None, use_bias=True, scope='multi_fc', reuse=None, partitioner=None):
    with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner):
        batch_size, num_path, emb_dim = get_shape(x)
        if path_out_dim is None:
            path_out_dim = emb_dim
        w = get_dnn_variable('weight', shape=[1, num_path, emb_dim, path_out_dim],
                             partitioner=partitioner)
        w = tf.tile(w, [batch_size, 1, 1, 1])  # [B,V,D,D']
        # [B,V,1,D]x[B,V,D,D']=[B,V,1,D']=[B,V,D']
        y = tf.squeeze(tf.matmul(tf.expand_dims(x, 2), w), [2])  # [B,V,D']
        if use_bias:
            b = get_dnn_variable('bias', shape=[1, num_path, path_out_dim],
                                 partitioner=partitioner,
                                 initializer=tf.zeros_initializer())
            y = y + b  # [B,V,D'] + [1,V,D']
    return y  # [B,V,D']
Beispiel #2
0
def compute_prob_item(itm_emb,  # [B,T,D]
                      itm_msk,  # [B,T], tf.int64, mask if equal zero
                      ctx_emb,  # [B,T,D']
                      num_hidden_units,
                      position_bias=False,  # better diversity when no position bias
                      partitioner=None, scope='', reuse=None):
    if position_bias:
        scope += 'prob_itm_psn_ctx'
    else:
        scope += 'prob_itm_ctx'
    with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner):
        batch_size, seq_len, emb_dim = get_shape(itm_emb)
        if position_bias:
            position_emb = get_dnn_variable(name='position_emb', shape=[1, seq_len, emb_dim],
                                            initializer=get_unit_emb_initializer(emb_dim))
            position_emb = tf.tile(position_emb, [batch_size, 1, 1])
            itm_and_ctx_emb = tf.concat([itm_emb, position_emb, ctx_emb], 2)  # [B,T,D+D+D']
        else:
            itm_and_ctx_emb = tf.concat([itm_emb, ctx_emb], 2)  # [B,T,D+D']
        prob_item = fully_connected(itm_and_ctx_emb, num_hidden_units, None)  # [B,T,D'']
        prob_item = tf.nn.tanh(prob_item)
        prob_item = fully_connected(prob_item, 1, None)  # [B,T,1]
        prob_item = tf.squeeze(prob_item, [2])  # [B,T]
        prob_item = weighted_softmax(
            prob_item, tf.to_float(tf.not_equal(itm_msk, 0)), -1)  # [B,T]
    return prob_item
Beispiel #3
0
 def __init__(self, num_proto, emb_dim, temperature=0.07,
              scope='proto_router', reuse=None, partitioner=None):
     self.num_proto = num_proto
     self.emb_dim = emb_dim
     self.temperature = temperature
     self.proto_embs = get_dnn_variable(
         name='proto_embs_%d' % self.num_proto, shape=[self.num_proto, self.emb_dim],
         partitioner=partitioner, initializer=get_unit_emb_initializer(self.emb_dim),
         scope=scope, reuse=reuse)
     self.proto_embs = tf.nn.l2_normalize(self.proto_embs, -1)  # [H,D]
Beispiel #4
0
 def bloom_filter_emb(ids,
                      hashes,
                      zero_pad=True,
                      mark_for_serving=True):
     ids_flat = tf.reshape(ids, [-1])
     e = []
     for h in hashes:
         e.append(
             get_emb_variable(
                 name=h['table_name'],
                 ids=h['hash_fn'](ids_flat),
                 shape=(h['bucket_size'], h['emb_dim']),
                 mark_for_serving=mark_for_serving,
                 initializer=get_unit_emb_initializer(FLAGS.dim)
             ))  # important: use normal, not uniform
     e = tf.concat(e, axis=1)
     if len(hashes) == 1 and hashes[0]['emb_dim'] == FLAGS.dim:
         print('bloom filter w/o fc: [%s]' %
               hashes[0]['table_name'])
     else:
         dnn_name = 'dnn__' + '__'.join(h['table_name']
                                        for h in hashes)
         dnn_in_dim = sum(h['emb_dim'] for h in hashes)
         dnn = get_dnn_variable(
             name=dnn_name,
             shape=[dnn_in_dim, FLAGS.dim],
             initializer=tf.glorot_normal_initializer(
             ),  # important: use normal, not uniform
             partitioner=tf.min_max_variable_partitioner(
                 max_partitions=self.ps_num,
                 min_slice_size=FLAGS.dnn_pt_size))
         e = tf.matmul(e, dnn)
     if zero_pad:
         id_eq_zero = tf.tile(
             tf.expand_dims(tf.equal(ids_flat, 0), -1),
             [1, FLAGS.dim])
         e = tf.where(id_eq_zero, tf.zeros_like(e), e)
     e = tf.reshape(e, get_shape(ids) + [FLAGS.dim])
     return e
Beispiel #5
0
    def _build_embedding(self):
        with tf.variable_scope(name_or_scope='embedding_block',
                               partitioner=tf.min_max_variable_partitioner(
                                   max_partitions=self.ps_num,
                                   min_slice_size=FLAGS.emb_pt_size)):

            with tf.variable_scope(name_or_scope='bloom_filter'):
                #
                # We need to ensure that the final embeddings are uniformly distributed on a hyper-sphere,
                # after l2-normalization. Otherwise it will have a hard time converging at the beginning.
                # So we need to use random_normal initialization, rather than random_uniform.
                #
                def bloom_filter_emb(ids,
                                     hashes,
                                     zero_pad=True,
                                     mark_for_serving=True):
                    ids_flat = tf.reshape(ids, [-1])
                    e = []
                    for h in hashes:
                        e.append(
                            get_emb_variable(
                                name=h['table_name'],
                                ids=h['hash_fn'](ids_flat),
                                shape=(h['bucket_size'], h['emb_dim']),
                                mark_for_serving=mark_for_serving,
                                initializer=get_unit_emb_initializer(FLAGS.dim)
                            ))  # important: use normal, not uniform
                    e = tf.concat(e, axis=1)
                    if len(hashes) == 1 and hashes[0]['emb_dim'] == FLAGS.dim:
                        print('bloom filter w/o fc: [%s]' %
                              hashes[0]['table_name'])
                    else:
                        dnn_name = 'dnn__' + '__'.join(h['table_name']
                                                       for h in hashes)
                        dnn_in_dim = sum(h['emb_dim'] for h in hashes)
                        dnn = get_dnn_variable(
                            name=dnn_name,
                            shape=[dnn_in_dim, FLAGS.dim],
                            initializer=tf.glorot_normal_initializer(
                            ),  # important: use normal, not uniform
                            partitioner=tf.min_max_variable_partitioner(
                                max_partitions=self.ps_num,
                                min_slice_size=FLAGS.dnn_pt_size))
                        e = tf.matmul(e, dnn)
                    if zero_pad:
                        id_eq_zero = tf.tile(
                            tf.expand_dims(tf.equal(ids_flat, 0), -1),
                            [1, FLAGS.dim])
                        e = tf.where(id_eq_zero, tf.zeros_like(e), e)
                    e = tf.reshape(e, get_shape(ids) + [FLAGS.dim])
                    return e

                def combine_mean(emb_list):
                    assert len(emb_list) >= 2
                    return sum(emb_list) * (1.0 / len(emb_list))

                self.usr_mem_emb = bloom_filter_emb(
                    ids=self.inputs['user__uid'].var,
                    hashes=self.inputs['user__uid'].spec['hashes'])
                self.is_recent_click = tf.greater(
                    self.inputs['user__clk_st'].var, 0)  # [B,T], tf.bool
                self.is_recent_click_expand = tf.tile(
                    tf.expand_dims(self.is_recent_click, -1),
                    [1, 1, FLAGS.dim])

                clk_st_emb = bloom_filter_emb(
                    ids=tf.abs(self.inputs['user__clk_st'].var),
                    hashes=self.inputs['user__clk_st'].spec['hashes'])
                clk_rel_time_emb = bloom_filter_emb(
                    ids=tf.tile(
                        tf.expand_dims(self.inputs['user__abs_time'].var, -1),
                        [1, FLAGS.max_len]) -
                    self.inputs['user__clk_abs_time'].var,
                    hashes=self.inputs['user__clk_abs_time'].
                    spec['rel_time_hashes'])

                clk_nid_emb = bloom_filter_emb(
                    ids=self.inputs['user__clk_nid'].var,
                    hashes=self.inputs['user__clk_nid'].spec['hashes'])
                clk_uid_emb = bloom_filter_emb(
                    ids=self.inputs['user__clk_uid'].var,
                    hashes=self.inputs['user__clk_uid'].spec['hashes'])
                clk_cate_emb = bloom_filter_emb(
                    ids=self.inputs['user__clk_cate'].var,
                    hashes=self.inputs['user__clk_cate'].spec['hashes'])
                clk_cat1_emb = bloom_filter_emb(
                    ids=self.inputs['user__clk_cat1'].var,
                    hashes=self.inputs['user__clk_cat1'].spec['hashes'])
                self.clk_itm_emb = combine_mean(
                    [clk_nid_emb, clk_uid_emb, clk_cate_emb,
                     clk_cat1_emb])  # [B,T,D]

                clk_ctx_time_key = bloom_filter_emb(
                    ids=tf.concat([
                        tf.expand_dims(self.inputs['user__abs_time'].var, -1),
                        self.inputs['user__clk_abs_time'].var
                    ], 1),  # [B,1] [B,T] -> [B,1+T]
                    hashes=self.inputs['user__clk_abs_time'].
                    spec['abs_time_hashes'])  # [B,1+T,D]
                usr_ctx_time_query, clk_ctx_time_key = tf.split(
                    clk_ctx_time_key, [1, FLAGS.max_len],
                    1)  # [B,1,D], [B,T,D]
                usr_ctx_time_query = tf.squeeze(usr_ctx_time_query,
                                                [1])  # [B,D]

                prob_psn = get_dnn_variable(
                    name='position_w_%d' % FLAGS.max_len,
                    shape=[1, FLAGS.max_len],
                    initializer=get_unit_emb_initializer(FLAGS.dim))
                prob_psn = tf.tile(
                    prob_psn, [get_shape(self.usr_ids)[0], 1])  # [1,T]->[B,T]
                prob_psn = weighted_softmax(prob_psn,
                                            tf.to_float(self.is_recent_click),
                                            axis=-1)  # [B,T] along T

                clk_ctx_cate_key = combine_mean([clk_cate_emb,
                                                 clk_cat1_emb])  # [B,T,D]
                usr_ctx_cate_query = tf.squeeze(
                    tf.matmul(tf.expand_dims(prob_psn, 1), clk_ctx_cate_key),
                    [1])  # [B,1,T]x[B,T,D]->[B,1,D]->[B,D]
                self.clk_ctx_key_emb = tf.concat([
                    combine_mean([clk_st_emb, clk_rel_time_emb]),
                    clk_ctx_time_key, clk_ctx_cate_key
                ], 2)  # [B,T,3*D]
                self.usr_ctx_query_emb = tf.concat(
                    [usr_ctx_time_query, usr_ctx_cate_query], 1)  # [B,2*D]

                pos_nid_emb = bloom_filter_emb(
                    ids=self.inputs['item__nid'].var,
                    hashes=self.inputs['item__nid'].spec['hashes'],
                    mark_for_serving=False)
                pos_uid_emb = bloom_filter_emb(
                    ids=self.inputs['item__uid'].var,
                    hashes=self.inputs['item__uid'].spec['hashes'],
                    mark_for_serving=False)
                pos_cate_emb = bloom_filter_emb(
                    ids=self.inputs['item__cate'].var,
                    hashes=self.inputs['item__cate'].spec['hashes'],
                    mark_for_serving=False)
                pos_cat1_emb = bloom_filter_emb(
                    ids=self.inputs['item__cat1'].var,
                    hashes=self.inputs['item__cat1'].spec['hashes'],
                    mark_for_serving=False)
                self.pos_itm_emb = combine_mean(
                    [pos_nid_emb, pos_uid_emb, pos_cate_emb,
                     pos_cat1_emb])  # [B,D]
                self.pos_itm_emb_normalized = tf.nn.l2_normalize(
                    self.pos_itm_emb, -1)
                self.pos_cat_emb = combine_mean([pos_cate_emb,
                                                 pos_cat1_emb])  # [B,D]

                neg_nid_emb = bloom_filter_emb(
                    ids=self.neg_itm_nid,
                    hashes=self.inputs['item__nid'].spec['hashes'],
                    mark_for_serving=False)
                neg_uid_emb = bloom_filter_emb(
                    ids=self.neg_itm_uid,
                    hashes=self.inputs['item__uid'].spec['hashes'],
                    mark_for_serving=False)
                neg_cate_emb = bloom_filter_emb(
                    ids=self.neg_itm_cate,
                    hashes=self.inputs['item__cate'].spec['hashes'],
                    mark_for_serving=False)
                neg_cat1_emb = bloom_filter_emb(
                    ids=self.neg_itm_cat1,
                    hashes=self.inputs['item__cat1'].spec['hashes'],
                    mark_for_serving=False)
                self.neg_itm_emb = combine_mean(
                    [neg_nid_emb, neg_uid_emb, neg_cate_emb,
                     neg_cat1_emb])  # [Q,D]
                self.neg_cat_emb = combine_mean([neg_cate_emb,
                                                 neg_cat1_emb])  # [Q,D]
Beispiel #6
0
def disentangle_layer(itm_emb,  # [B,T,D]
                      itm_msk,  # [B,T], tf.int64, mask if equal zero
                      prob_item,  # [B,T]
                      num_heads,
                      proto_router=None,
                      add_head_prior=True,
                      equalize_heads=False,
                      scope='disentangle_layer', reuse=None, partitioner=None):
    with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner):
        batch_size, max_seq_len, emb_dim = get_shape(itm_emb)

        if proto_router is None:
            proto_router = ProtoRouter(num_proto=num_heads, emb_dim=emb_dim)
        assert proto_router.num_proto == num_heads

        prob_head_given_item = proto_router.compute_prob_proto(
            tf.reshape(itm_emb, [batch_size * max_seq_len, emb_dim]))  # [B*T,H]
        prob_head_given_item = tf.reshape(
            prob_head_given_item, [batch_size, max_seq_len, num_heads])  # [B,T,H]
        prob_head_given_item = tf.transpose(prob_head_given_item, [0, 2, 1])  # [B,H,T]

        # p(head, item) = p(item) * p(head | item)
        prob_item_and_head = tf.multiply(  # [B,1,T]*[B,H,T]->[B,H,T]
            tf.expand_dims(prob_item, axis=1), prob_head_given_item)
        itm_msk_expand = tf.tile(tf.expand_dims(itm_msk, 1), [1, num_heads, 1])  # [B,T]->[B,H,T]
        prob_item_and_head = tf.where(
            tf.equal(itm_msk_expand, 0), tf.zeros_like(prob_item_and_head), prob_item_and_head)

        if equalize_heads:
            # p(item | head) = p(head, item) / p(head)
            # Would it be too sensitive/responsive to heads with ONE trigger?
            prob_item_given_head = tf.div(
                prob_item_and_head, tf.reduce_sum(prob_item_and_head, -1, True) + 1e-8)  # [B,H,T]
            init_multi_emb = tf.matmul(prob_item_given_head, tf.nn.l2_normalize(itm_emb, -1))  # [B,H,D]
        else:
            init_multi_emb = tf.matmul(prob_item_and_head, tf.nn.l2_normalize(itm_emb, -1))  # [B,H,D]

        #
        # Spill-over: If no items under this head's category is present in the
        # sequence, the head's vector will mainly be composed of items (with
        # small values of prob_head_given_item for this head) from other
        # categories. As a result, its kNNs will be items from other categories.
        # This effect is sometimes useful, though, since it reuses the empty
        # heads to retrieve relevant items from other categories.
        #
        # Adding a head-specific bias to avoid the spill-over effect when the
        # head is in fact empty. But then the retrieved kNNs may be too
        # irrelevant to the user and make the user unhappy about the result.
        #

        if add_head_prior:
            head_bias = get_dnn_variable(
                name='head_bias', shape=[1, num_heads, emb_dim],
                initializer=get_unit_emb_initializer(emb_dim))
            head_bias = tf.tile(head_bias, [batch_size, 1, 1])  # [B,H,D]
            out_multi_emb = tf.concat([init_multi_emb, head_bias], 2)  # [B,H,2*D]
        else:
            out_multi_emb = init_multi_emb
        out_multi_emb = tf.nn.tanh(fully_connected(out_multi_emb, emb_dim, None))
        out_multi_emb = fully_connected(out_multi_emb, emb_dim, None)
        out_multi_emb = out_multi_emb + init_multi_emb

        #
        # Don't use multipath_fully_connected if num_heads is large, cuz it is memory consuming.
        #
        # out_multi_emb = init_multi_emb
        # out_multi_emb = tf.nn.tanh(multipath_fully_connected(
        #     out_multi_emb, use_bias=add_head_prior, scope='multi_head_fc1'))
        # out_multi_emb = multipath_fully_connected(
        #     out_multi_emb, use_bias=add_head_prior, scope='multi_head_fc2')
        # out_multi_emb = out_multi_emb + init_multi_emb
    return out_multi_emb  # [B,H,D]