def test_create_tf_dataset(self): data_dict = { 1: [(1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)], 2: [(5, 5, 5), (6, 6, 6), (7, 7, 7), (8, 8, 8)], 3: [(9, 9, 9), (10, 10, 10), (11, 11, 11), (12, 12, 12)] } dataset, _, _ = util.create_tf_dataset( data_dict=data_dict, batch_size=2, itemnum=1, query_map={1: 'test'}, maxquerylen=1, maxseqlen=4, token_drop_prob=0, user_query_seed=collections.defaultdict(lambda: 0), randomize_input=False, ) batch = dataset.make_one_shot_iterator().get_next() user_ids, items, queries, times, labels = self.evaluate([ batch['user_ids'], batch['items'], batch['queries'], batch['times'], batch['labels'] ]) self.assertAllEqual(user_ids, [1, 2]) self.assertAllEqual(items, [[0, 1, 2, 3], [0, 5, 6, 7]]) self.assertAllEqual(queries, [[1, 2, 3, 4], [5, 6, 7, 8]]) self.assertAllEqual(times, [[1, 2, 3, 4], [5, 6, 7, 8]]) self.assertAllEqual(labels, [[1, 2, 3, 4], [5, 6, 7, 8]]) user_ids, items, queries, times, labels = self.evaluate([ batch['user_ids'], batch['items'], batch['queries'], batch['times'], batch['labels'] ]) self.assertAllEqual(user_ids, [3, 1]) self.assertAllEqual(items, [[0, 9, 10, 11], [0, 1, 2, 3]]) self.assertAllEqual(queries, [[9, 10, 11, 12], [1, 2, 3, 4]]) self.assertAllEqual(times, [[9, 10, 11, 12], [1, 2, 3, 4]]) self.assertAllEqual(labels, [[9, 10, 11, 12], [1, 2, 3, 4]])
def main(_): # Load raw data and organize training directory. with tf.gfile.Open(FLAGS.query_map_path, 'r') as f: query_map = json.load(f) tf.logging.info('Query map is loaded.') dataset = util.data_partition(FLAGS.dataset) tf.logging.info('Data is loaded') [user_train, _, _, usernum, _, itemnum, user_query_seed, item_popularity] = dataset num_batch = int(len(user_train) / FLAGS.batch_size) tf.gfile.MakeDirs(FLAGS.train_dir) # Create summary/log files and respective variables. train_summary_writer = tf.summary.FileWriter( os.path.join(FLAGS.train_dir, 'train')) valid_summary_writer = tf.summary.FileWriter( os.path.join(FLAGS.train_dir, 'validation')) test_summary_writer = tf.summary.FileWriter( os.path.join(FLAGS.train_dir, 'test')) value_loss = tf.placeholder(tf.float32, []) value_ndcg = tf.placeholder(tf.float32, []) value_hit = tf.placeholder(tf.float32, []) summary_loss = tf.summary.scalar('Loss', value_loss) summary_ndcg = tf.summary.scalar('NDCG@10', value_ndcg) summary_hit = tf.summary.scalar('HIT@10', value_hit) log_filename = os.path.join(FLAGS.train_dir, 'log.txt') f = tf.gfile.Open(log_filename, 'w') config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.Session(config=config) # Fetch dataset tf_dataset, vocab, query_word_ids = util.create_tf_dataset( user_train, FLAGS.batch_size, itemnum=itemnum, query_map=query_map, maxquerylen=FLAGS.maxquerylen, maxseqlen=FLAGS.maxseqlen, token_drop_prob=FLAGS.token_drop_prob, user_query_seed=user_query_seed, randomize_input=True, random_seed=0) # Create model model = model_lib.Model(usernum, itemnum, len(vocab), use_last_query=FLAGS.use_last_query, maxseqlen=FLAGS.maxseqlen, maxquerylen=FLAGS.maxquerylen, hidden_units=FLAGS.hidden_units, l2_emb=FLAGS.l2_emb, dropout_rate=FLAGS.dropout_rate, lr=FLAGS.lr, num_self_attn_heads=FLAGS.num_self_attn_heads, num_query_attn_heads=FLAGS.num_query_attn_heads, num_self_attn_layers=FLAGS.num_self_attn_layers, num_query_attn_layers=FLAGS.num_query_attn_layers, num_final_layers=FLAGS.num_final_layers, query_item_attention=FLAGS.query_item_attention, query_item_combine=FLAGS.query_item_combine, query_layer_norm=FLAGS.query_layer_norm, query_residual=FLAGS.query_residual, time_exp_base=FLAGS.time_exp_base, overlapping_chunks=FLAGS.overlapping_chunks) tf.logging.info('Model is created.') sess.run(tf.global_variables_initializer()) raw_time = 0.0 t0 = time.time() iterator = tf_dataset.make_one_shot_iterator() batch_data = iterator.get_next() user_id = batch_data['user_ids'] item_seq = batch_data['items'] query_seq = batch_data['queries'] query_words_seq = batch_data['query_words'] time_seq = batch_data['times'] label_seq = batch_data['labels'] random_neg = batch_data['random_neg'] # For popularity based negative sampling, we priorly sample a large set of # lists (each consisting FLAGS.neg_sample_size_eval many negative samples); # and later randomly select one of the pre-sampled lists while evaluating each # user. Since sampling a list of elements with a given probability # distributuion is a rather slow operation, this is a much faster approach # compared to re-sampling for each user every time we perform evaluation. presampled_negatives = [] if FLAGS.sampling_strategy == 'popularity': tf.logging.info('Presampling negatives for popularity based strategy.') presampled_negatives.extend( util.presample_popularity_negatives( 1, itemnum + 1, FLAGS.neg_sample_size_eval, item_popularity, NUM_PRESAMPLED_LISTS_OF_POPULARITY_NEGATIVES, ), ) # Start training. for epoch in range(1, FLAGS.num_epochs + 1): tf.logging.info('Epoch %d' % epoch) epoch_loss = 0 for _ in range(num_batch): u, x, q, q_w, t, y, ny = sess.run([ user_id, item_seq, query_seq, query_words_seq, time_seq, label_seq, random_neg ]) loss, _ = sess.run( [model.loss, model.train_op], { model.u: u, model.item_seq: x, model.query_seq: q, model.query_words_seq: q_w, model.time_seq: t, model.pos: y, model.neg: ny, model.is_training: True }) epoch_loss += loss # Adding average epoch train loss summary. train_summary_writer.add_summary( sess.run(summary_loss, feed_dict={value_loss: float(epoch_loss / num_batch)}), epoch) # Evaluate. if epoch % FLAGS.eval_frequency == 0: t1 = time.time() - t0 raw_time += t1 tf.logging.info('Evaluating') tf.logging.info('Sampling strategy is: {}'.format( FLAGS.sampling_strategy)) t_test = util.evaluate(model, dataset, query_word_ids, FLAGS.maxseqlen, FLAGS.maxquerylen, sess, FLAGS.token_drop_prob, neg_sample_size=FLAGS.neg_sample_size_eval, presampled_negatives=presampled_negatives, eval_on='test') t_valid = util.evaluate(model, dataset, query_word_ids, FLAGS.maxseqlen, FLAGS.maxquerylen, sess, FLAGS.token_drop_prob, neg_sample_size=FLAGS.neg_sample_size_eval, presampled_negatives=presampled_negatives, eval_on='valid') eval_str = ( 'epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f),' ' test (NDCG@10: %.4f, HR@10: %.4f)' ) % (epoch, raw_time, t_valid[0], t_valid[1], t_test[0], t_test[1]) tf.logging.info(eval_str) f.write(eval_str + '\n') f.flush() t0 = time.time() valid_summary_writer.add_summary( sess.run(summary_ndcg, feed_dict={value_ndcg: t_valid[0]}), epoch) valid_summary_writer.add_summary( sess.run(summary_hit, feed_dict={value_hit: t_valid[1]}), epoch) test_summary_writer.add_summary( sess.run(summary_ndcg, feed_dict={value_ndcg: t_test[0]}), epoch) test_summary_writer.add_summary( sess.run(summary_hit, feed_dict={value_hit: t_test[1]}), epoch) # Evaluate on train split. if FLAGS.save_train_eval: t_train = util.evaluate( model, dataset, query_word_ids, FLAGS.maxseqlen, FLAGS.maxquerylen, sess, FLAGS.token_drop_prob, neg_sample_size=FLAGS.neg_sample_size_eval, presampled_negatives=presampled_negatives, eval_on='train') train_str = ( 'epoch:%d, time: %f(s), train (NDCG@10: %.4f, HR@10: %.4f)' ) % (epoch, raw_time, t_train[0], t_train[1]) tf.logging.info(train_str) train_summary_writer.add_summary( sess.run(summary_ndcg, feed_dict={value_ndcg: t_train[0]}), epoch) train_summary_writer.add_summary( sess.run(summary_hit, feed_dict={value_hit: t_train[1]}), epoch) tf.logging.info('Done. Log written to %s' % log_filename) sess.close() f.close()