def compute_prob_item(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero ctx_emb, # [B,T,D'] num_hidden_units, position_bias=False, # better diversity when no position bias partitioner=None, scope='', reuse=None): if position_bias: scope += 'prob_itm_psn_ctx' else: scope += 'prob_itm_ctx' with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, seq_len, emb_dim = get_shape(itm_emb) if position_bias: position_emb = get_dnn_variable(name='position_emb', shape=[1, seq_len, emb_dim], initializer=get_unit_emb_initializer(emb_dim)) position_emb = tf.tile(position_emb, [batch_size, 1, 1]) itm_and_ctx_emb = tf.concat([itm_emb, position_emb, ctx_emb], 2) # [B,T,D+D+D'] else: itm_and_ctx_emb = tf.concat([itm_emb, ctx_emb], 2) # [B,T,D+D'] prob_item = fully_connected(itm_and_ctx_emb, num_hidden_units, None) # [B,T,D''] prob_item = tf.nn.tanh(prob_item) prob_item = fully_connected(prob_item, 1, None) # [B,T,1] prob_item = tf.squeeze(prob_item, [2]) # [B,T] prob_item = weighted_softmax( prob_item, tf.to_float(tf.not_equal(itm_msk, 0)), -1) # [B,T] return prob_item
def _build_encoder(self): with tf.variable_scope(name_or_scope='encoder_block'): batch_size = get_shape(self.usr_ids)[0] seq_itm_msk = self.inputs['user__clk_nid'].var seq_itm_emb = self.clk_itm_emb with tf.variable_scope(name_or_scope='ctx_q_as_mlp'): usr_ctx_q = fully_connected(self.usr_ctx_query_emb, FLAGS.dim * FLAGS.dim, None) # [B,D'']->[B,DxD] usr_ctx_q = tf.reshape( usr_ctx_q, [batch_size, FLAGS.dim, FLAGS.dim]) # [B,D,D] with tf.variable_scope(name_or_scope='ctx_proj_k'): seq_ctx_k = fully_connected(self.clk_ctx_key_emb, FLAGS.dim, None) # [B,T,D']->[B,T,D] def ctx_co_action(q, k): # [B,D,D], [B,T,D] return tf.tanh( k + tf.matmul(tf.tanh(k), q)) # [B,T,D]x[B,D,D]->[B,T,D] seq_ctx_qk = ctx_co_action(q=usr_ctx_q, k=seq_ctx_k) # [B,T,D] self.multi_vec_emb, self.multi_head_emb = multi_vector_sequence_encoder( itm_emb=seq_itm_emb, itm_msk=seq_itm_msk, ctx_emb=seq_ctx_qk, num_heads=FLAGS.num_heads, num_vectors=FLAGS.num_vectors, scope='enc_%dh%dv' % (FLAGS.num_heads, FLAGS.num_vectors)) # [B,V,D] self.multi_vec_emb_normalized = tf.nn.l2_normalize( self.multi_vec_emb, -1) # [B,V,D] self.inference_output_3d = self.multi_vec_emb_normalized with tf.variable_scope("predictions"): output_name = 'user_emb' inference_output_2d = tf.reshape(self.inference_output_3d, [-1, FLAGS.dim]) # [B*V,D] inference_output_2d = tf.identity(inference_output_2d, output_name) print('inference output: name=%s, tensor=%s' % (output_name, inference_output_2d)) # not really ctr, just a score between 0.0 and 1.0 self.ctr_predictions = tf.matmul( self.inference_output_3d, # [B,V,D]x[B,D,1]->[B,V,1] tf.expand_dims(self.pos_itm_emb_normalized, 2)) self.ctr_predictions = tf.reduce_max( tf.squeeze(self.ctr_predictions, [2]), -1) # [B,V]->[B]
def multi_vector_sequence_encoder(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero ctx_emb, # [B,T,D'] num_heads, num_vectors, scope='mv_seq_enc', reuse=None, partitioner=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, _, emb_dim = get_shape(itm_emb) prob_item = compute_prob_item( # [B,T] itm_emb=itm_emb, itm_msk=itm_msk, ctx_emb=ctx_emb, num_hidden_units=emb_dim * 4, reuse=reuse, partitioner=partitioner) multi_head_emb = disentangle_layer( # [B,H,D] itm_emb=itm_emb, itm_msk=itm_msk, prob_item=prob_item, num_heads=num_heads, add_head_prior=True) # the added prior may lead to weird or over-popular recommendation mean_head_emb = tf.matmul(tf.expand_dims(prob_item, 1), tf.nn.l2_normalize(itm_emb, -1)) # [B,1,T]x[B,T,D]->[B,1,D] mean_head_emb = tf.squeeze(mean_head_emb, [1]) # [B,1,D]->[B,D] mean_head_emb = mean_head_emb + fully_connected(mean_head_emb, emb_dim, None) mean_head_emb = tf.tile(tf.expand_dims(mean_head_emb, 1), [1, num_vectors, 1]) multi_vec_emb = tf.reshape( multi_head_emb, [batch_size, num_vectors, num_heads // num_vectors, emb_dim]) multi_vec_emb = multipath_head_aggregation_abae( multi_vec_emb, query=mean_head_emb, transform_query=False, scope='head2vec') # [B,V,H,D]->[B,V,D] return multi_vec_emb, multi_head_emb # [B,V,D], [B,H,D]
def simplified_multi_head_attention( itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero ctx_emb, # [B,T,D] num_heads, num_hidden_units, scope='simple_multi_head_att', reuse=None): with tf.variable_scope(scope, reuse=reuse): itm_hidden = fully_connected(itm_emb + ctx_emb, num_hidden_units, tf.nn.tanh) itm_att = fully_connected(itm_hidden, num_heads, None) # [B,T,H] itm_att = tf.transpose(itm_att, [0, 2, 1]) # [B,H,T] att_msk = tf.tile(tf.expand_dims(itm_msk, axis=1), [1, num_heads, 1]) # [B,H,T] att_pad = tf.to_float(tf.ones_like(att_msk) * (-2**32 + 1)) itm_att = tf.where(tf.equal(att_msk, 0), att_pad, itm_att) itm_att = tf.nn.softmax(itm_att) seq_multi_emb = tf.matmul(itm_att, itm_emb) return seq_multi_emb # [B,H,D]
def head_aggregation_abae(x, query=None, scope='abae', reuse=None, transform_query=True, temperature=None): batch_size, num_heads, emb_dim = get_shape(x) with tf.variable_scope(scope, reuse=reuse): if query is None: mu = tf.reduce_mean(x, axis=1) # [B,H,D]->[B,D] else: mu = query # [B,D] if transform_query: mu = mu + fully_connected(mu, emb_dim, None) wg = tf.matmul(x, tf.expand_dims(mu, axis=-1)) # [B,H,D] x [B,D,1] if temperature is not None: wg = tf.div(wg, temperature) wg = tf.nn.softmax(wg, 1) # [B,H,1] y = tf.reduce_mean(x * wg, axis=1) # [B,H,D]->[B,D] return y
def compute_prob_item_given_queries(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero query_emb, # [B,Q,D] query_msk, # [B,Q], tf.int64, mask if equal zero transform_query=True, temperature=0.07, partitioner=None, scope='prob_item_given_query', reuse=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): _, num_query, emb_dim = get_shape(query_emb) _, num_itm, __ = get_shape(itm_emb) if transform_query: query_emb = query_emb + fully_connected(query_emb, emb_dim, None) # [B,Q,D] prob_item = tf.matmul(query_emb, itm_emb, transpose_b=True) # [B,Q,D]x[B,T,D]^t->[B,Q,T] attn_mask = tf.tile(tf.expand_dims(tf.to_float(tf.not_equal(itm_msk, 0)), axis=1), [1, num_query, 1]) # [B,T]->[B,Q,T] prob_item = weighted_softmax(prob_item / temperature, attn_mask, 2) # 【B,Q,T] query_cnt = tf.reduce_sum(query_msk, -1, True) # [B,1] query_cnt = tf.tile(tf.to_float(query_cnt) + 1e-8, [1, num_itm]) # [B,T] query_msk = tf.tile(tf.expand_dims(query_msk, axis=2), [1, 1, num_itm]) # [B,Q]->[B,Q,T] prob_item = tf.where(tf.equal(query_msk, 0), tf.zeros_like(prob_item), prob_item) # [B,Q,T] prob_item = tf.reduce_sum(prob_item, 1) # [B,Q,T]->[B,T] prob_item = tf.div(prob_item, query_cnt) # sum(p(item)) = 1 return prob_item
def disentangle_layer(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero prob_item, # [B,T] num_heads, proto_router=None, add_head_prior=True, equalize_heads=False, scope='disentangle_layer', reuse=None, partitioner=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, max_seq_len, emb_dim = get_shape(itm_emb) if proto_router is None: proto_router = ProtoRouter(num_proto=num_heads, emb_dim=emb_dim) assert proto_router.num_proto == num_heads prob_head_given_item = proto_router.compute_prob_proto( tf.reshape(itm_emb, [batch_size * max_seq_len, emb_dim])) # [B*T,H] prob_head_given_item = tf.reshape( prob_head_given_item, [batch_size, max_seq_len, num_heads]) # [B,T,H] prob_head_given_item = tf.transpose(prob_head_given_item, [0, 2, 1]) # [B,H,T] # p(head, item) = p(item) * p(head | item) prob_item_and_head = tf.multiply( # [B,1,T]*[B,H,T]->[B,H,T] tf.expand_dims(prob_item, axis=1), prob_head_given_item) itm_msk_expand = tf.tile(tf.expand_dims(itm_msk, 1), [1, num_heads, 1]) # [B,T]->[B,H,T] prob_item_and_head = tf.where( tf.equal(itm_msk_expand, 0), tf.zeros_like(prob_item_and_head), prob_item_and_head) if equalize_heads: # p(item | head) = p(head, item) / p(head) # Would it be too sensitive/responsive to heads with ONE trigger? prob_item_given_head = tf.div( prob_item_and_head, tf.reduce_sum(prob_item_and_head, -1, True) + 1e-8) # [B,H,T] init_multi_emb = tf.matmul(prob_item_given_head, tf.nn.l2_normalize(itm_emb, -1)) # [B,H,D] else: init_multi_emb = tf.matmul(prob_item_and_head, tf.nn.l2_normalize(itm_emb, -1)) # [B,H,D] # # Spill-over: If no items under this head's category is present in the # sequence, the head's vector will mainly be composed of items (with # small values of prob_head_given_item for this head) from other # categories. As a result, its kNNs will be items from other categories. # This effect is sometimes useful, though, since it reuses the empty # heads to retrieve relevant items from other categories. # # Adding a head-specific bias to avoid the spill-over effect when the # head is in fact empty. But then the retrieved kNNs may be too # irrelevant to the user and make the user unhappy about the result. # if add_head_prior: head_bias = get_dnn_variable( name='head_bias', shape=[1, num_heads, emb_dim], initializer=get_unit_emb_initializer(emb_dim)) head_bias = tf.tile(head_bias, [batch_size, 1, 1]) # [B,H,D] out_multi_emb = tf.concat([init_multi_emb, head_bias], 2) # [B,H,2*D] else: out_multi_emb = init_multi_emb out_multi_emb = tf.nn.tanh(fully_connected(out_multi_emb, emb_dim, None)) out_multi_emb = fully_connected(out_multi_emb, emb_dim, None) out_multi_emb = out_multi_emb + init_multi_emb # # Don't use multipath_fully_connected if num_heads is large, cuz it is memory consuming. # # out_multi_emb = init_multi_emb # out_multi_emb = tf.nn.tanh(multipath_fully_connected( # out_multi_emb, use_bias=add_head_prior, scope='multi_head_fc1')) # out_multi_emb = multipath_fully_connected( # out_multi_emb, use_bias=add_head_prior, scope='multi_head_fc2') # out_multi_emb = out_multi_emb + init_multi_emb return out_multi_emb # [B,H,D]