def single_vector_softmax_loss( query_emb, # [B,D], normalized pos_emb, # [B,D], normalized neg_emb, # [Q,D], normalized query_msk=None, # [B], tf.bool, mask if False neg_weights=None, # [B,Q], 0.0<=weights<=1.0, sum(weights)>0, for weights*tf.exp(logits) temperature=0.07, # temperature is critical if embeddings are l2-normalized margin=0.0, metric_ops=None, scope='loss'): batch_size, emb_dim = get_shape(query_emb) num_negatives = get_shape(neg_emb)[0] pos_logits = tf.reduce_sum(tf.multiply(query_emb, pos_emb), -1, True) # [B,1] neg_logits = tf.reduce_sum( tf.multiply(tf.reshape(query_emb, [batch_size, 1, emb_dim]), tf.reshape(neg_emb, [1, num_negatives, emb_dim])), -1) # [B,Q] all_logits = tf.concat([pos_logits - margin, neg_logits], 1) # [B,1+Q] if neg_weights is None: log_prob_y_given_v = tf.nn.log_softmax(all_logits / temperature, -1) # [B,1+Q] log_prob_y_given_v = log_prob_y_given_v[:, 0] else: weights_of_exp = neg_weights # [B,Q] weights_of_exp = tf.concat([tf.ones_like(pos_logits), weights_of_exp], 1) # [B,1+Q] prob_y_given_v = weighted_softmax(all_logits / temperature, weights_of_exp, -1) # [B,1+Q] log_prob_y_given_v = tf.log(prob_y_given_v[:, 0]) # [B] if query_msk is None: nll_loss = -tf.reduce_mean(log_prob_y_given_v) else: log_prob_y_given_v = tf.where(query_msk, log_prob_y_given_v, tf.zeros_like(log_prob_y_given_v)) num_real_queries = tf.reduce_sum(tf.to_float(query_msk)) + 1e-8 nll_loss = -tf.reduce_sum(log_prob_y_given_v) / num_real_queries if metric_ops is not None: max_pos_lgt = tf.reduce_max(pos_logits, -1) # [B] if neg_weights is None: max_neg_lgt = tf.reduce_max(neg_logits, -1) # [B] else: max_neg_lgt = tf.reduce_max( tf.log(neg_weights + 1e-8) + neg_logits, -1) add_or_fail_if_exist(metric_ops, scope + '/nll_loss', nll_loss) add_accuracy_metrics(metric_ops=metric_ops, scope=scope, max_pos_lgt=max_pos_lgt, max_neg_lgt=max_neg_lgt, query_msk=query_msk, neg_weights=neg_weights) return nll_loss
def add_accuracy_metrics(metric_ops, scope, max_pos_lgt, max_neg_lgt, query_msk, neg_weights): batch_size = get_shape(max_pos_lgt)[0] if query_msk is None: query_msk = tf.ones(shape=[batch_size], dtype=tf.bool) else: add_or_fail_if_exist(metric_ops, scope + '/qw', tf.reduce_mean(tf.to_float(query_msk))) num_real_queries = tf.reduce_sum(tf.to_float(query_msk)) + 1e-8 if neg_weights is not None: add_or_fail_if_exist(metric_ops, scope + '/nw', tf.reduce_mean(neg_weights)) ones = tf.ones(shape=[batch_size], dtype=tf.float32) zeros = tf.zeros(shape=[batch_size], dtype=tf.float32) add_or_fail_if_exist(metric_ops, scope + '/min_cos', tf.reduce_min(tf.where(query_msk, max_pos_lgt, ones))) add_or_fail_if_exist( metric_ops, scope + '/max_cos', tf.reduce_max(tf.where(query_msk, max_pos_lgt, -ones))) add_or_fail_if_exist( metric_ops, scope + '/avg_cos', tf.reduce_sum(tf.where(query_msk, max_pos_lgt, zeros)) / num_real_queries) add_or_fail_if_exist( metric_ops, scope + '/rk1_acc', tf.reduce_sum( tf.where(query_msk, tf.to_float(max_pos_lgt > max_neg_lgt), zeros)) / num_real_queries)
def multi_vector_sequence_encoder(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero ctx_emb, # [B,T,D'] num_heads, num_vectors, scope='mv_seq_enc', reuse=None, partitioner=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, _, emb_dim = get_shape(itm_emb) prob_item = compute_prob_item( # [B,T] itm_emb=itm_emb, itm_msk=itm_msk, ctx_emb=ctx_emb, num_hidden_units=emb_dim * 4, reuse=reuse, partitioner=partitioner) multi_head_emb = disentangle_layer( # [B,H,D] itm_emb=itm_emb, itm_msk=itm_msk, prob_item=prob_item, num_heads=num_heads, add_head_prior=True) # the added prior may lead to weird or over-popular recommendation mean_head_emb = tf.matmul(tf.expand_dims(prob_item, 1), tf.nn.l2_normalize(itm_emb, -1)) # [B,1,T]x[B,T,D]->[B,1,D] mean_head_emb = tf.squeeze(mean_head_emb, [1]) # [B,1,D]->[B,D] mean_head_emb = mean_head_emb + fully_connected(mean_head_emb, emb_dim, None) mean_head_emb = tf.tile(tf.expand_dims(mean_head_emb, 1), [1, num_vectors, 1]) multi_vec_emb = tf.reshape( multi_head_emb, [batch_size, num_vectors, num_heads // num_vectors, emb_dim]) multi_vec_emb = multipath_head_aggregation_abae( multi_vec_emb, query=mean_head_emb, transform_query=False, scope='head2vec') # [B,V,H,D]->[B,V,D] return multi_vec_emb, multi_head_emb # [B,V,D], [B,H,D]
def compute_prob_item(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero ctx_emb, # [B,T,D'] num_hidden_units, position_bias=False, # better diversity when no position bias partitioner=None, scope='', reuse=None): if position_bias: scope += 'prob_itm_psn_ctx' else: scope += 'prob_itm_ctx' with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, seq_len, emb_dim = get_shape(itm_emb) if position_bias: position_emb = get_dnn_variable(name='position_emb', shape=[1, seq_len, emb_dim], initializer=get_unit_emb_initializer(emb_dim)) position_emb = tf.tile(position_emb, [batch_size, 1, 1]) itm_and_ctx_emb = tf.concat([itm_emb, position_emb, ctx_emb], 2) # [B,T,D+D+D'] else: itm_and_ctx_emb = tf.concat([itm_emb, ctx_emb], 2) # [B,T,D+D'] prob_item = fully_connected(itm_and_ctx_emb, num_hidden_units, None) # [B,T,D''] prob_item = tf.nn.tanh(prob_item) prob_item = fully_connected(prob_item, 1, None) # [B,T,1] prob_item = tf.squeeze(prob_item, [2]) # [B,T] prob_item = weighted_softmax( prob_item, tf.to_float(tf.not_equal(itm_msk, 0)), -1) # [B,T] return prob_item
def predict(server_config, master, input_vars): writer = OdpsTableWriter(FLAGS.outputs, slice_id=FLAGS.task_index) model = BaseModel(input_vars=input_vars) if FLAGS.predict_user: usr_emb_3d = model.inference_output_3d else: usr_emb_3d = tf.zeros(shape=(get_shape(model.usr_ids)[0], 1, 1), dtype=tf.float32) print('checkpointDir:', FLAGS.checkpointDir) sys.stdout.flush() assert (FLAGS.checkpointDir is not None) and (len(FLAGS.checkpointDir) > 0) with tf.train.MonitoredTrainingSession( master=master, config=server_config, is_chief=(FLAGS.task_index == 0), checkpoint_dir=FLAGS.checkpointDir, save_checkpoint_secs=None) as mon_sess: print(, "- start mon_sess") sys.stdout.flush() local_step = 0 while not mon_sess.should_stop(): try: usr_ids, usr_emb, itm_ids, itm_emb, _ =[ model.usr_ids, usr_emb_3d, model.pos_nid_ids, model.pos_itm_emb_normalized, model.inc_global_step_op ]) batch_size = usr_ids.shape[0] usr_ids = [str(i) for i in usr_ids] usr_emb = [ ';'.join(','.join(str(x) for x in e) for e in u) for u in usr_emb ] assert len(usr_emb) == batch_size itm_ids = [str(i) for i in itm_ids] assert len(itm_ids) == batch_size itm_emb = [','.join(str(x) for x in e) for e in itm_emb] assert len(itm_emb) == batch_size writer.write(list(zip(usr_ids, usr_emb, itm_ids, itm_emb)), indices=[0, 1, 2, 3]) local_step += 1 if local_step % FLAGS.print_every == 0: print(, "- %dk cases saved" % (local_step * batch_size // 1000)) sys.stdout.flush() except tf.errors.OutOfRangeError: print('tf.errors.OutOfRangeError') break except tf.python_io.OutOfRangeException: print('tf.python_io.OutOfRangeException') break sys.stdout.flush() writer.close()
def select_topk_vectors_by_scores(multi_vec_emb, # [B,V,D] vec_scores, # [B,V] topk): batch_size, _, emb_dim = get_shape(multi_vec_emb) topk = tf.to_int32(topk) _, col_indices = tf.nn.top_k(vec_scores, k=topk) # [B,V] -> [B,K] col_indices = tf.to_int64(col_indices) row_indices = tf.tile(tf.expand_dims( tf.range(0, tf.to_int64(batch_size), dtype=tf.int64), -1), [1, topk]) indices = tf.stack([tf.reshape(row_indices, [-1]), tf.reshape(col_indices, [-1])], 1) # [B*K, 2] seq_topk_emb = tf.gather_nd(multi_vec_emb, indices) # [B*K,D] seq_topk_emb = tf.reshape(seq_topk_emb, [batch_size, topk, emb_dim]) return seq_topk_emb
def _build_encoder(self): with tf.variable_scope(name_or_scope='encoder_block'): batch_size = get_shape(self.usr_ids)[0] seq_itm_msk = self.inputs['user__clk_nid'].var seq_itm_emb = self.clk_itm_emb with tf.variable_scope(name_or_scope='ctx_q_as_mlp'): usr_ctx_q = fully_connected(self.usr_ctx_query_emb, FLAGS.dim * FLAGS.dim, None) # [B,D'']->[B,DxD] usr_ctx_q = tf.reshape( usr_ctx_q, [batch_size, FLAGS.dim, FLAGS.dim]) # [B,D,D] with tf.variable_scope(name_or_scope='ctx_proj_k'): seq_ctx_k = fully_connected(self.clk_ctx_key_emb, FLAGS.dim, None) # [B,T,D']->[B,T,D] def ctx_co_action(q, k): # [B,D,D], [B,T,D] return tf.tanh( k + tf.matmul(tf.tanh(k), q)) # [B,T,D]x[B,D,D]->[B,T,D] seq_ctx_qk = ctx_co_action(q=usr_ctx_q, k=seq_ctx_k) # [B,T,D] self.multi_vec_emb, self.multi_head_emb = multi_vector_sequence_encoder( itm_emb=seq_itm_emb, itm_msk=seq_itm_msk, ctx_emb=seq_ctx_qk, num_heads=FLAGS.num_heads, num_vectors=FLAGS.num_vectors, scope='enc_%dh%dv' % (FLAGS.num_heads, FLAGS.num_vectors)) # [B,V,D] self.multi_vec_emb_normalized = tf.nn.l2_normalize( self.multi_vec_emb, -1) # [B,V,D] self.inference_output_3d = self.multi_vec_emb_normalized with tf.variable_scope("predictions"): output_name = 'user_emb' inference_output_2d = tf.reshape(self.inference_output_3d, [-1, FLAGS.dim]) # [B*V,D] inference_output_2d = tf.identity(inference_output_2d, output_name) print('inference output: name=%s, tensor=%s' % (output_name, inference_output_2d)) # not really ctr, just a score between 0.0 and 1.0 self.ctr_predictions = tf.matmul( self.inference_output_3d, # [B,V,D]x[B,D,1]->[B,V,1] tf.expand_dims(self.pos_itm_emb_normalized, 2)) self.ctr_predictions = tf.reduce_max( tf.squeeze(self.ctr_predictions, [2]), -1) # [B,V]->[B]
def compute_prob_item_given_queries(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero query_emb, # [B,Q,D] query_msk, # [B,Q], tf.int64, mask if equal zero transform_query=True, temperature=0.07, partitioner=None, scope='prob_item_given_query', reuse=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): _, num_query, emb_dim = get_shape(query_emb) _, num_itm, __ = get_shape(itm_emb) if transform_query: query_emb = query_emb + fully_connected(query_emb, emb_dim, None) # [B,Q,D] prob_item = tf.matmul(query_emb, itm_emb, transpose_b=True) # [B,Q,D]x[B,T,D]^t->[B,Q,T] attn_mask = tf.tile(tf.expand_dims(tf.to_float(tf.not_equal(itm_msk, 0)), axis=1), [1, num_query, 1]) # [B,T]->[B,Q,T] prob_item = weighted_softmax(prob_item / temperature, attn_mask, 2) # 【B,Q,T] query_cnt = tf.reduce_sum(query_msk, -1, True) # [B,1] query_cnt = tf.tile(tf.to_float(query_cnt) + 1e-8, [1, num_itm]) # [B,T] query_msk = tf.tile(tf.expand_dims(query_msk, axis=2), [1, 1, num_itm]) # [B,Q]->[B,Q,T] prob_item = tf.where(tf.equal(query_msk, 0), tf.zeros_like(prob_item), prob_item) # [B,Q,T] prob_item = tf.reduce_sum(prob_item, 1) # [B,Q,T]->[B,T] prob_item = tf.div(prob_item, query_cnt) # sum(p(item)) = 1 return prob_item
def create_id_queue(queue_name, id_var): queue = tf.get_local_variable( queue_name, trainable=False, # not trainable parameters collections=[tf.GraphKeys.LOCAL_VARIABLES ], # place it on the local worker initializer=tf.zeros_initializer(dtype=tf.int64), dtype=tf.int64, shape=[FLAGS.queue_size]) # update dictionary: dequeue the earliest batch, and enqueue the current batch # the indexing operation, i.e., queue[...], will not work if the queue is partitioned updated_queue = tf.concat( [queue[get_shape(id_var)[0]:], id_var], axis=0) # [Q-B],[B]->[Q] self.queue_ops.append(queue.assign(updated_queue)) return updated_queue
def multipath_fully_connected(x, path_out_dim=None, use_bias=True, scope='multi_fc', reuse=None, partitioner=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, num_path, emb_dim = get_shape(x) if path_out_dim is None: path_out_dim = emb_dim w = get_dnn_variable('weight', shape=[1, num_path, emb_dim, path_out_dim], partitioner=partitioner) w = tf.tile(w, [batch_size, 1, 1, 1]) # [B,V,D,D'] # [B,V,1,D]x[B,V,D,D']=[B,V,1,D']=[B,V,D'] y = tf.squeeze(tf.matmul(tf.expand_dims(x, 2), w), [2]) # [B,V,D'] if use_bias: b = get_dnn_variable('bias', shape=[1, num_path, path_out_dim], partitioner=partitioner, initializer=tf.zeros_initializer()) y = y + b # [B,V,D'] + [1,V,D'] return y # [B,V,D']
def head_aggregation_abae(x, query=None, scope='abae', reuse=None, transform_query=True, temperature=None): batch_size, num_heads, emb_dim = get_shape(x) with tf.variable_scope(scope, reuse=reuse): if query is None: mu = tf.reduce_mean(x, axis=1) # [B,H,D]->[B,D] else: mu = query # [B,D] if transform_query: mu = mu + fully_connected(mu, emb_dim, None) wg = tf.matmul(x, tf.expand_dims(mu, axis=-1)) # [B,H,D] x [B,D,1] if temperature is not None: wg = tf.div(wg, temperature) wg = tf.nn.softmax(wg, 1) # [B,H,1] y = tf.reduce_mean(x * wg, axis=1) # [B,H,D]->[B,D] return y
def compute_prob_proto(self, x, take_log=False, hard_value=False, soft_grad=True): assert len(get_shape(x)) == 2 # only [B,D] is supported x = tf.nn.l2_normalize(x, -1) y = tf.reduce_sum(tf.multiply( # [B,1,D]*[1,H,D] -> [B,H,D] -> [B,H] tf.expand_dims(x, 1), tf.expand_dims(self.proto_embs, 0)), -1) if take_log: assert (not hard_value) and soft_grad y = tf.nn.log_softmax(y / self.temperature, 1) # [B,H] return y y = tf.nn.softmax(y / self.temperature, 1) # [B,H] if hard_value: y_hard = tf.one_hot( tf.argmax(y, -1), self.num_proto, dtype=tf.float32) if soft_grad: y = tf.stop_gradient(y_hard - y) + y else: y = tf.stop_gradient(y_hard) else: assert soft_grad return y # [B,H]
def bloom_filter_emb(ids, hashes, zero_pad=True, mark_for_serving=True): ids_flat = tf.reshape(ids, [-1]) e = [] for h in hashes: e.append( get_emb_variable( name=h['table_name'], ids=h['hash_fn'](ids_flat), shape=(h['bucket_size'], h['emb_dim']), mark_for_serving=mark_for_serving, initializer=get_unit_emb_initializer(FLAGS.dim) )) # important: use normal, not uniform e = tf.concat(e, axis=1) if len(hashes) == 1 and hashes[0]['emb_dim'] == FLAGS.dim: print('bloom filter w/o fc: [%s]' % hashes[0]['table_name']) else: dnn_name = 'dnn__' + '__'.join(h['table_name'] for h in hashes) dnn_in_dim = sum(h['emb_dim'] for h in hashes) dnn = get_dnn_variable( name=dnn_name, shape=[dnn_in_dim, FLAGS.dim], initializer=tf.glorot_normal_initializer( ), # important: use normal, not uniform partitioner=tf.min_max_variable_partitioner( max_partitions=self.ps_num, min_slice_size=FLAGS.dnn_pt_size)) e = tf.matmul(e, dnn) if zero_pad: id_eq_zero = tf.tile( tf.expand_dims(tf.equal(ids_flat, 0), -1), [1, FLAGS.dim]) e = tf.where(id_eq_zero, tf.zeros_like(e), e) e = tf.reshape(e, get_shape(ids) + [FLAGS.dim]) return e
def layer_normalize(inputs, epsilon=1e-8, scale_to_unit_norm=False, scope="ln", reuse=None, partitioner=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): inputs_shape = get_shape(inputs) params_shape = inputs_shape[-1:] mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) beta = get_dnn_variable(name="beta", shape=params_shape, initializer=tf.zeros_initializer(), partitioner=partitioner) gamma = get_dnn_variable(name="gamma", shape=params_shape, initializer=tf.ones_initializer(), partitioner=partitioner) normalized = (inputs - mean) / ((variance + epsilon)**.5) if scale_to_unit_norm: normalized /= (inputs_shape[-1]**0.5) outputs = gamma * normalized + beta return outputs
def disentangle_layer(itm_emb, # [B,T,D] itm_msk, # [B,T], tf.int64, mask if equal zero prob_item, # [B,T] num_heads, proto_router=None, add_head_prior=True, equalize_heads=False, scope='disentangle_layer', reuse=None, partitioner=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, max_seq_len, emb_dim = get_shape(itm_emb) if proto_router is None: proto_router = ProtoRouter(num_proto=num_heads, emb_dim=emb_dim) assert proto_router.num_proto == num_heads prob_head_given_item = proto_router.compute_prob_proto( tf.reshape(itm_emb, [batch_size * max_seq_len, emb_dim])) # [B*T,H] prob_head_given_item = tf.reshape( prob_head_given_item, [batch_size, max_seq_len, num_heads]) # [B,T,H] prob_head_given_item = tf.transpose(prob_head_given_item, [0, 2, 1]) # [B,H,T] # p(head, item) = p(item) * p(head | item) prob_item_and_head = tf.multiply( # [B,1,T]*[B,H,T]->[B,H,T] tf.expand_dims(prob_item, axis=1), prob_head_given_item) itm_msk_expand = tf.tile(tf.expand_dims(itm_msk, 1), [1, num_heads, 1]) # [B,T]->[B,H,T] prob_item_and_head = tf.where( tf.equal(itm_msk_expand, 0), tf.zeros_like(prob_item_and_head), prob_item_and_head) if equalize_heads: # p(item | head) = p(head, item) / p(head) # Would it be too sensitive/responsive to heads with ONE trigger? prob_item_given_head = tf.div( prob_item_and_head, tf.reduce_sum(prob_item_and_head, -1, True) + 1e-8) # [B,H,T] init_multi_emb = tf.matmul(prob_item_given_head, tf.nn.l2_normalize(itm_emb, -1)) # [B,H,D] else: init_multi_emb = tf.matmul(prob_item_and_head, tf.nn.l2_normalize(itm_emb, -1)) # [B,H,D] # # Spill-over: If no items under this head's category is present in the # sequence, the head's vector will mainly be composed of items (with # small values of prob_head_given_item for this head) from other # categories. As a result, its kNNs will be items from other categories. # This effect is sometimes useful, though, since it reuses the empty # heads to retrieve relevant items from other categories. # # Adding a head-specific bias to avoid the spill-over effect when the # head is in fact empty. But then the retrieved kNNs may be too # irrelevant to the user and make the user unhappy about the result. # if add_head_prior: head_bias = get_dnn_variable( name='head_bias', shape=[1, num_heads, emb_dim], initializer=get_unit_emb_initializer(emb_dim)) head_bias = tf.tile(head_bias, [batch_size, 1, 1]) # [B,H,D] out_multi_emb = tf.concat([init_multi_emb, head_bias], 2) # [B,H,2*D] else: out_multi_emb = init_multi_emb out_multi_emb = tf.nn.tanh(fully_connected(out_multi_emb, emb_dim, None)) out_multi_emb = fully_connected(out_multi_emb, emb_dim, None) out_multi_emb = out_multi_emb + init_multi_emb # # Don't use multipath_fully_connected if num_heads is large, cuz it is memory consuming. # # out_multi_emb = init_multi_emb # out_multi_emb = tf.nn.tanh(multipath_fully_connected( # out_multi_emb, use_bias=add_head_prior, scope='multi_head_fc1')) # out_multi_emb = multipath_fully_connected( # out_multi_emb, use_bias=add_head_prior, scope='multi_head_fc2') # out_multi_emb = out_multi_emb + init_multi_emb return out_multi_emb # [B,H,D]
def _build_optimizer(self): with tf.variable_scope(name_or_scope='optimizer_block'): batch_size = get_shape(self.pos_nid_ids)[0] cate_router = ProtoRouter(num_proto=FLAGS.num_heads, emb_dim=FLAGS.dim) pos_prob_h = cate_router.compute_prob_proto( self.pos_cat_emb) # [B,H] # the user and item embeddings need to be l2-normalized when using the contrastive loss neg_itm_emb_normalized = tf.nn.l2_normalize(self.neg_itm_emb, -1) neg_itm_weights = tf.to_float( tf.not_equal(tf.expand_dims(self.pos_nid_ids, -1), tf.expand_dims(self.neg_itm_nid, 0))) # [B,1]==[1,Q] -> [B,Q] if FLAGS.hard_queue_size > 0: num_easy_neg = FLAGS.queue_size - FLAGS.hard_queue_size assert num_easy_neg >= FLAGS.batch_size neg_prob_h = cate_router.compute_prob_proto( self.neg_cat_emb[:FLAGS.hard_queue_size]) # [Q',H] hard_neg_msk = tf.to_float( tf.equal(tf.expand_dims(tf.argmax(pos_prob_h, -1), -1), tf.expand_dims(tf.argmax(neg_prob_h, -1), 0))) # [B,1]==[1,Q']->[B,Q'] hard_neg_cnt = tf.reduce_sum(hard_neg_msk, -1) # [B] add_or_fail_if_exist(self.metric_ops, 'neg_queue/max_hard_cnt', tf.reduce_max(hard_neg_cnt, -1)) add_or_fail_if_exist(self.metric_ops, 'neg_queue/min_hard_cnt', tf.reduce_min(hard_neg_cnt, -1)) hard_neg_msk = tf.concat( # [B,Q] [ hard_neg_msk, tf.ones(shape=(batch_size, num_easy_neg), dtype=tf.float32) ], 1) neg_itm_weights = tf.multiply(neg_itm_weights, hard_neg_msk) # [B,Q] if FLAGS.rm_dup_neg: # This implementation only de-duplicate easy negative samples. Hard ones are not de-duplicated. easy_neg_nid = self.neg_itm_nid[FLAGS.hard_queue_size:] neg_itm_appear_cnt = tf.reduce_sum( tf.to_float( # [Q'',1]==[1,Q'']->[Q'',Q'']->[Q''] tf.equal(tf.expand_dims(easy_neg_nid, 1), tf.expand_dims(easy_neg_nid, 0))), -1) add_or_fail_if_exist(self.metric_ops, 'neg_queue/max_appear_cnt', tf.reduce_max(neg_itm_appear_cnt)) add_or_fail_if_exist(self.metric_ops, 'neg_queue/min_appear_cnt', tf.reduce_min(neg_itm_appear_cnt)) neg_itm_appear_cnt = tf.concat([ tf.ones(shape=[FLAGS.hard_queue_size], dtype=tf.float32), neg_itm_appear_cnt ], 0) # [Q'],[Q'']->[Q] neg_itm_weights = tf.div(neg_itm_weights, tf.expand_dims(neg_itm_appear_cnt, 0)) # [B,Q]/[1,Q] multi_vec_nll_loss, disentangle_aux_loss = disentangled_multi_vector_loss( multi_vec_emb=self.multi_vec_emb_normalized, pos_emb=self.pos_itm_emb_normalized, neg_emb=neg_itm_emb_normalized, multi_head_emb=tf.nn.l2_normalize(self.multi_head_emb, -1), prob_h=pos_prob_h, neg_weights=neg_itm_weights, metric_ops=self.metric_ops, scope='multi') self.loss = multi_vec_nll_loss self.loss += disentangle_aux_loss self.optim_op = tf.train.AdamOptimizer(use_locking=True) if FLAGS.grad_clip > 1e-3: # sources of NaN: (1) div-zero; (2) log(non-positive); (3) gradient explosion; ... grads = self.optim_op.compute_gradients(self.loss) grads = [(tf.clip_by_norm(g, FLAGS.grad_clip), v) for g, v in grads] self.optim_op = self.optim_op.apply_gradients( grads, global_step=self.global_step) else: self.optim_op = self.optim_op.minimize( self.loss, global_step=self.global_step)
def _build_embedding(self): with tf.variable_scope(name_or_scope='embedding_block', partitioner=tf.min_max_variable_partitioner( max_partitions=self.ps_num, min_slice_size=FLAGS.emb_pt_size)): with tf.variable_scope(name_or_scope='bloom_filter'): # # We need to ensure that the final embeddings are uniformly distributed on a hyper-sphere, # after l2-normalization. Otherwise it will have a hard time converging at the beginning. # So we need to use random_normal initialization, rather than random_uniform. # def bloom_filter_emb(ids, hashes, zero_pad=True, mark_for_serving=True): ids_flat = tf.reshape(ids, [-1]) e = [] for h in hashes: e.append( get_emb_variable( name=h['table_name'], ids=h['hash_fn'](ids_flat), shape=(h['bucket_size'], h['emb_dim']), mark_for_serving=mark_for_serving, initializer=get_unit_emb_initializer(FLAGS.dim) )) # important: use normal, not uniform e = tf.concat(e, axis=1) if len(hashes) == 1 and hashes[0]['emb_dim'] == FLAGS.dim: print('bloom filter w/o fc: [%s]' % hashes[0]['table_name']) else: dnn_name = 'dnn__' + '__'.join(h['table_name'] for h in hashes) dnn_in_dim = sum(h['emb_dim'] for h in hashes) dnn = get_dnn_variable( name=dnn_name, shape=[dnn_in_dim, FLAGS.dim], initializer=tf.glorot_normal_initializer( ), # important: use normal, not uniform partitioner=tf.min_max_variable_partitioner( max_partitions=self.ps_num, min_slice_size=FLAGS.dnn_pt_size)) e = tf.matmul(e, dnn) if zero_pad: id_eq_zero = tf.tile( tf.expand_dims(tf.equal(ids_flat, 0), -1), [1, FLAGS.dim]) e = tf.where(id_eq_zero, tf.zeros_like(e), e) e = tf.reshape(e, get_shape(ids) + [FLAGS.dim]) return e def combine_mean(emb_list): assert len(emb_list) >= 2 return sum(emb_list) * (1.0 / len(emb_list)) self.usr_mem_emb = bloom_filter_emb( ids=self.inputs['user__uid'].var, hashes=self.inputs['user__uid'].spec['hashes']) self.is_recent_click = tf.greater( self.inputs['user__clk_st'].var, 0) # [B,T], tf.bool self.is_recent_click_expand = tf.tile( tf.expand_dims(self.is_recent_click, -1), [1, 1, FLAGS.dim]) clk_st_emb = bloom_filter_emb( ids=tf.abs(self.inputs['user__clk_st'].var), hashes=self.inputs['user__clk_st'].spec['hashes']) clk_rel_time_emb = bloom_filter_emb( ids=tf.tile( tf.expand_dims(self.inputs['user__abs_time'].var, -1), [1, FLAGS.max_len]) - self.inputs['user__clk_abs_time'].var, hashes=self.inputs['user__clk_abs_time']. spec['rel_time_hashes']) clk_nid_emb = bloom_filter_emb( ids=self.inputs['user__clk_nid'].var, hashes=self.inputs['user__clk_nid'].spec['hashes']) clk_uid_emb = bloom_filter_emb( ids=self.inputs['user__clk_uid'].var, hashes=self.inputs['user__clk_uid'].spec['hashes']) clk_cate_emb = bloom_filter_emb( ids=self.inputs['user__clk_cate'].var, hashes=self.inputs['user__clk_cate'].spec['hashes']) clk_cat1_emb = bloom_filter_emb( ids=self.inputs['user__clk_cat1'].var, hashes=self.inputs['user__clk_cat1'].spec['hashes']) self.clk_itm_emb = combine_mean( [clk_nid_emb, clk_uid_emb, clk_cate_emb, clk_cat1_emb]) # [B,T,D] clk_ctx_time_key = bloom_filter_emb( ids=tf.concat([ tf.expand_dims(self.inputs['user__abs_time'].var, -1), self.inputs['user__clk_abs_time'].var ], 1), # [B,1] [B,T] -> [B,1+T] hashes=self.inputs['user__clk_abs_time']. spec['abs_time_hashes']) # [B,1+T,D] usr_ctx_time_query, clk_ctx_time_key = tf.split( clk_ctx_time_key, [1, FLAGS.max_len], 1) # [B,1,D], [B,T,D] usr_ctx_time_query = tf.squeeze(usr_ctx_time_query, [1]) # [B,D] prob_psn = get_dnn_variable( name='position_w_%d' % FLAGS.max_len, shape=[1, FLAGS.max_len], initializer=get_unit_emb_initializer(FLAGS.dim)) prob_psn = tf.tile( prob_psn, [get_shape(self.usr_ids)[0], 1]) # [1,T]->[B,T] prob_psn = weighted_softmax(prob_psn, tf.to_float(self.is_recent_click), axis=-1) # [B,T] along T clk_ctx_cate_key = combine_mean([clk_cate_emb, clk_cat1_emb]) # [B,T,D] usr_ctx_cate_query = tf.squeeze( tf.matmul(tf.expand_dims(prob_psn, 1), clk_ctx_cate_key), [1]) # [B,1,T]x[B,T,D]->[B,1,D]->[B,D] self.clk_ctx_key_emb = tf.concat([ combine_mean([clk_st_emb, clk_rel_time_emb]), clk_ctx_time_key, clk_ctx_cate_key ], 2) # [B,T,3*D] self.usr_ctx_query_emb = tf.concat( [usr_ctx_time_query, usr_ctx_cate_query], 1) # [B,2*D] pos_nid_emb = bloom_filter_emb( ids=self.inputs['item__nid'].var, hashes=self.inputs['item__nid'].spec['hashes'], mark_for_serving=False) pos_uid_emb = bloom_filter_emb( ids=self.inputs['item__uid'].var, hashes=self.inputs['item__uid'].spec['hashes'], mark_for_serving=False) pos_cate_emb = bloom_filter_emb( ids=self.inputs['item__cate'].var, hashes=self.inputs['item__cate'].spec['hashes'], mark_for_serving=False) pos_cat1_emb = bloom_filter_emb( ids=self.inputs['item__cat1'].var, hashes=self.inputs['item__cat1'].spec['hashes'], mark_for_serving=False) self.pos_itm_emb = combine_mean( [pos_nid_emb, pos_uid_emb, pos_cate_emb, pos_cat1_emb]) # [B,D] self.pos_itm_emb_normalized = tf.nn.l2_normalize( self.pos_itm_emb, -1) self.pos_cat_emb = combine_mean([pos_cate_emb, pos_cat1_emb]) # [B,D] neg_nid_emb = bloom_filter_emb( ids=self.neg_itm_nid, hashes=self.inputs['item__nid'].spec['hashes'], mark_for_serving=False) neg_uid_emb = bloom_filter_emb( ids=self.neg_itm_uid, hashes=self.inputs['item__uid'].spec['hashes'], mark_for_serving=False) neg_cate_emb = bloom_filter_emb( ids=self.neg_itm_cate, hashes=self.inputs['item__cate'].spec['hashes'], mark_for_serving=False) neg_cat1_emb = bloom_filter_emb( ids=self.neg_itm_cat1, hashes=self.inputs['item__cat1'].spec['hashes'], mark_for_serving=False) self.neg_itm_emb = combine_mean( [neg_nid_emb, neg_uid_emb, neg_cate_emb, neg_cat1_emb]) # [Q,D] self.neg_cat_emb = combine_mean([neg_cate_emb, neg_cat1_emb]) # [Q,D]
def disentangled_multi_vector_loss(multi_vec_emb, # [B,V,D], normalized pos_emb, # [B,D], normalized neg_emb, # [Q,D], normalized multi_head_emb, # [B,H,D], normalized prob_h, # [B,H] query_msk=None, # [B], tf.bool, mask if False neg_weights=None, # [B,Q], w*exp(lgt) temperature=0.07, margin=0.0, model_prob_y_and_v=True, # converge much faster and auc is much better when True metric_ops=None, scope='multi_loss', reuse=None, partitioner=None): with tf.variable_scope(scope, reuse=reuse, partitioner=partitioner): batch_size, num_vectors, emb_dim = get_shape(multi_vec_emb) num_negatives = get_shape(neg_emb)[0] _, num_heads, _ = get_shape(multi_head_emb) # It is Z = \sum_v \sum_i p(v,i), instead of Z_v = \sum_i p(i|v) for each v. # The datum is sampled from the support of all possible pairs of (v,i). pos_logits = tf.reduce_sum( multi_vec_emb * tf.expand_dims(pos_emb, 1), -1, True) # [B,V,D]*[B,1,D]->[B,V]->[B,V,1] neg_logits = tf.reduce_sum(tf.multiply( tf.reshape(multi_vec_emb, [batch_size, num_vectors, 1, emb_dim]), tf.reshape(neg_emb, [1, 1, num_negatives, emb_dim])), -1) # [B,V,Q] all_logits = tf.concat([pos_logits - margin, neg_logits], 2) # [B,V,1+Q] if neg_weights is None: if model_prob_y_and_v: raise NotImplementedError log_prob_y_given_v = tf.nn.log_softmax( all_logits / temperature, -1) # [B,V,1+Q] log_likelihood_v = log_prob_y_given_v[:, :, 0] # [B,V] else: neg_weights_expand = tf.tile( tf.expand_dims(neg_weights, 1), [1, num_vectors, 1]) # [B,1,Q]->[B,V,Q] weights_of_exp = tf.concat( [tf.ones_like(pos_logits), neg_weights_expand], 2) # [B,V,1],[B,V,Q]->[B,V,1+Q] # prob_y_v = prob_y_and_v if model_prob_y_and_v else prob_y_given_v prob_y_v = weighted_softmax( all_logits / temperature, weights_of_exp, [1, 2] if model_prob_y_and_v else -1) # [B,V,1+Q] log_likelihood_v = tf.log(prob_y_v[:, :, 0]) # [B,V] prob_v = tf.reduce_sum( tf.reshape(prob_h, [batch_size, num_vectors, num_heads // num_vectors]), -1) # [B,V,H/V]->[B,V] log_likelihood = tf.reduce_sum( # expected log_likelihood prob_v * log_likelihood_v, -1) # [B] if query_msk is None: nll_loss = -tf.reduce_mean(log_likelihood) else: log_likelihood = tf.where(query_msk, log_likelihood, tf.zeros_like(log_likelihood)) num_real_queries = tf.reduce_sum(tf.to_float(query_msk)) + 1e-8 nll_loss = -tf.reduce_sum(log_likelihood) / num_real_queries # Issue 1: There are some dead heads that receive no categories. prior_h = 1.0 / tf.to_float(num_heads) posterior_h = tf.reduce_mean(prob_h, 0) # [H] # version 1: (an okay version) # head_kl_loss = tf.reduce_sum( # prior_h * (tf.log(prior_h) - tf.log(posterior_h)), -1) # version 2: (much worse than version 1, min(posterior_h) ~ 1e-5 # head_kl_loss = tf.reduce_sum( # posterior_h * (tf.log(posterior_h) - tf.log(prior_h)), -1) # version 3: head_kl_loss = tf.reduce_sum( prior_h * tf.nn.relu(tf.log(prior_h) - tf.log(posterior_h)), -1) # # Issue 2: The same category is assigned to more than one heads. # max_prob_h = tf.reduce_max(prob_h, -1) # version 1: # sharpness_loss = -tf.reduce_mean(tf.log(max_prob_h)) max_prob_h_clip = tf.where(tf.greater(max_prob_h, 0.95), tf.ones_like(max_prob_h), max_prob_h) sharpness_loss = -tf.reduce_mean(tf.log(max_prob_h_clip)) # clip, don't be be too aggressive # version 2: # sharpness_loss = -tf.reduce_mean(max_prob_h * tf.log(max_prob_h)) # version 3: minimizes the entropy for a skewed distribution (too strong and may lead to NaN) # sharpness_loss = tf.reduce_sum(tf.multiply(prob_h, tf.log(prob_h)), -1) # [B,H]->[B] # sharpness_loss = -tf.reduce_mean(sharpness_loss, -1) # version 4: # prob_h_clip = tf.where( # max(prob_h) being too close to 1.0 will causes tf.log(0)=NaN # tf.tile(tf.greater(tf.reduce_max(prob_h, -1, True), 0.95), [1, num_heads]), # tf.zeros_like(prob_h), prob_h) # sharpness_loss = tf.reduce_sum(tf.multiply(prob_h_clip, tf.log(prob_h + 1e-8)), -1) # [B,H]->[B] # sharpness_loss = -tf.reduce_mean(sharpness_loss, -1) # # Using -p*log(p) or -log(p) can be viewed as using one of the two different # directions of KL between the one-hot distribution and p. The gradient # (direction & steepness) of -p*log(p) seems to be nicer. # # Issue 3: The heads's output vectors are the same. semantic_loss = tf.reduce_sum(tf.multiply( # [B,H] multi_head_emb, tf.expand_dims(pos_emb, 1)), -1) # [B,H,D]*[B,1,D]->[B,H] semantic_loss = tf.nn.log_softmax(semantic_loss / temperature, -1) # [B,H] # version 1: # one_hot_h = tf.one_hot(tf.argmax(prob_h, -1), num_heads, dtype=tf.float32) # [B,H] # semantic_loss = -tf.reduce_mean( # tf.reduce_sum(tf.multiply(semantic_loss, one_hot_h), -1), -1) # version 2: semantic_loss = -tf.reduce_mean( tf.reduce_sum(tf.multiply(semantic_loss, prob_h), -1), -1) # Here sharpness_loss and semantic_loss can in fact be unified into one # single regularization loss, by using prob_h to weight the latter. if metric_ops is not None: max_pos_lgt = tf.reduce_max(tf.squeeze(pos_logits, [2]), -1) if neg_weights is None: max_neg_lgt = tf.reduce_max(tf.reduce_max(neg_logits, -1), -1) # [B,V,Q]->[B] else: max_neg_lgt = tf.reduce_max(tf.reduce_max( tf.log(neg_weights_expand + 1e-8) + neg_logits, -1), -1) # [B,V,Q]->[B] add_or_fail_if_exist(metric_ops, scope + '/nll_loss', nll_loss) add_accuracy_metrics(metric_ops=metric_ops, scope=scope, max_pos_lgt=max_pos_lgt, max_neg_lgt=max_neg_lgt, query_msk=query_msk, neg_weights=neg_weights) add_or_fail_if_exist(metric_ops, scope + '_ex/kl_loss', head_kl_loss) add_or_fail_if_exist(metric_ops, scope + '_ex/sharp_loss', sharpness_loss) add_or_fail_if_exist(metric_ops, scope + '_ex/semantic_loss', semantic_loss) add_or_fail_if_exist(metric_ops, scope + '_ex/max_prob_h', tf.reduce_mean(tf.reduce_max(prob_h, -1))) add_or_fail_if_exist(metric_ops, scope + '_ex/max_post_h', tf.reduce_max(posterior_h)) add_or_fail_if_exist(metric_ops, scope + '_ex/min_post_h', tf.reduce_min(posterior_h)) aux_loss = head_kl_loss + sharpness_loss + semantic_loss return nll_loss, aux_loss