Ejemplo n.º 1
0
    def test_create_tf_dataset(self):
        data_dict = {
            1: [(1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)],
            2: [(5, 5, 5), (6, 6, 6), (7, 7, 7), (8, 8, 8)],
            3: [(9, 9, 9), (10, 10, 10), (11, 11, 11), (12, 12, 12)]
        }
        dataset, _, _ = util.create_tf_dataset(
            data_dict=data_dict,
            batch_size=2,
            itemnum=1,
            query_map={1: 'test'},
            maxquerylen=1,
            maxseqlen=4,
            token_drop_prob=0,
            user_query_seed=collections.defaultdict(lambda: 0),
            randomize_input=False,
        )

        batch = dataset.make_one_shot_iterator().get_next()
        user_ids, items, queries, times, labels = self.evaluate([
            batch['user_ids'], batch['items'], batch['queries'],
            batch['times'], batch['labels']
        ])
        self.assertAllEqual(user_ids, [1, 2])
        self.assertAllEqual(items, [[0, 1, 2, 3], [0, 5, 6, 7]])
        self.assertAllEqual(queries, [[1, 2, 3, 4], [5, 6, 7, 8]])
        self.assertAllEqual(times, [[1, 2, 3, 4], [5, 6, 7, 8]])
        self.assertAllEqual(labels, [[1, 2, 3, 4], [5, 6, 7, 8]])
        user_ids, items, queries, times, labels = self.evaluate([
            batch['user_ids'], batch['items'], batch['queries'],
            batch['times'], batch['labels']
        ])
        self.assertAllEqual(user_ids, [3, 1])
        self.assertAllEqual(items, [[0, 9, 10, 11], [0, 1, 2, 3]])
        self.assertAllEqual(queries, [[9, 10, 11, 12], [1, 2, 3, 4]])
        self.assertAllEqual(times, [[9, 10, 11, 12], [1, 2, 3, 4]])
        self.assertAllEqual(labels, [[9, 10, 11, 12], [1, 2, 3, 4]])
Ejemplo n.º 2
0
def main(_):

    # Load raw data and organize training directory.
    with tf.gfile.Open(FLAGS.query_map_path, 'r') as f:
        query_map = json.load(f)
    tf.logging.info('Query map is loaded.')
    dataset = util.data_partition(FLAGS.dataset)
    tf.logging.info('Data is loaded')
    [user_train, _, _, usernum, _, itemnum, user_query_seed,
     item_popularity] = dataset
    num_batch = int(len(user_train) / FLAGS.batch_size)
    tf.gfile.MakeDirs(FLAGS.train_dir)

    # Create summary/log files and respective variables.
    train_summary_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.train_dir, 'train'))
    valid_summary_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.train_dir, 'validation'))
    test_summary_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.train_dir, 'test'))
    value_loss = tf.placeholder(tf.float32, [])
    value_ndcg = tf.placeholder(tf.float32, [])
    value_hit = tf.placeholder(tf.float32, [])
    summary_loss = tf.summary.scalar('Loss', value_loss)
    summary_ndcg = tf.summary.scalar('NDCG@10', value_ndcg)
    summary_hit = tf.summary.scalar('HIT@10', value_hit)
    log_filename = os.path.join(FLAGS.train_dir, 'log.txt')
    f = tf.gfile.Open(log_filename, 'w')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)

    # Fetch dataset
    tf_dataset, vocab, query_word_ids = util.create_tf_dataset(
        user_train,
        FLAGS.batch_size,
        itemnum=itemnum,
        query_map=query_map,
        maxquerylen=FLAGS.maxquerylen,
        maxseqlen=FLAGS.maxseqlen,
        token_drop_prob=FLAGS.token_drop_prob,
        user_query_seed=user_query_seed,
        randomize_input=True,
        random_seed=0)

    # Create model
    model = model_lib.Model(usernum,
                            itemnum,
                            len(vocab),
                            use_last_query=FLAGS.use_last_query,
                            maxseqlen=FLAGS.maxseqlen,
                            maxquerylen=FLAGS.maxquerylen,
                            hidden_units=FLAGS.hidden_units,
                            l2_emb=FLAGS.l2_emb,
                            dropout_rate=FLAGS.dropout_rate,
                            lr=FLAGS.lr,
                            num_self_attn_heads=FLAGS.num_self_attn_heads,
                            num_query_attn_heads=FLAGS.num_query_attn_heads,
                            num_self_attn_layers=FLAGS.num_self_attn_layers,
                            num_query_attn_layers=FLAGS.num_query_attn_layers,
                            num_final_layers=FLAGS.num_final_layers,
                            query_item_attention=FLAGS.query_item_attention,
                            query_item_combine=FLAGS.query_item_combine,
                            query_layer_norm=FLAGS.query_layer_norm,
                            query_residual=FLAGS.query_residual,
                            time_exp_base=FLAGS.time_exp_base,
                            overlapping_chunks=FLAGS.overlapping_chunks)
    tf.logging.info('Model is created.')
    sess.run(tf.global_variables_initializer())
    raw_time = 0.0
    t0 = time.time()

    iterator = tf_dataset.make_one_shot_iterator()
    batch_data = iterator.get_next()
    user_id = batch_data['user_ids']
    item_seq = batch_data['items']
    query_seq = batch_data['queries']
    query_words_seq = batch_data['query_words']
    time_seq = batch_data['times']
    label_seq = batch_data['labels']
    random_neg = batch_data['random_neg']

    # For popularity based negative sampling, we priorly sample a large set of
    # lists (each consisting FLAGS.neg_sample_size_eval many negative samples);
    # and later randomly select one of the pre-sampled lists while evaluating each
    # user. Since sampling a list of elements with a given probability
    # distributuion is a rather slow operation, this is a much faster approach
    # compared to re-sampling for each user every time we perform evaluation.
    presampled_negatives = []
    if FLAGS.sampling_strategy == 'popularity':
        tf.logging.info('Presampling negatives for popularity based strategy.')
        presampled_negatives.extend(
            util.presample_popularity_negatives(
                1,
                itemnum + 1,
                FLAGS.neg_sample_size_eval,
                item_popularity,
                NUM_PRESAMPLED_LISTS_OF_POPULARITY_NEGATIVES,
            ), )

    # Start training.
    for epoch in range(1, FLAGS.num_epochs + 1):
        tf.logging.info('Epoch %d' % epoch)
        epoch_loss = 0
        for _ in range(num_batch):
            u, x, q, q_w, t, y, ny = sess.run([
                user_id, item_seq, query_seq, query_words_seq, time_seq,
                label_seq, random_neg
            ])
            loss, _ = sess.run(
                [model.loss, model.train_op], {
                    model.u: u,
                    model.item_seq: x,
                    model.query_seq: q,
                    model.query_words_seq: q_w,
                    model.time_seq: t,
                    model.pos: y,
                    model.neg: ny,
                    model.is_training: True
                })
            epoch_loss += loss

        # Adding average epoch train loss summary.
        train_summary_writer.add_summary(
            sess.run(summary_loss,
                     feed_dict={value_loss: float(epoch_loss / num_batch)}),
            epoch)

        # Evaluate.
        if epoch % FLAGS.eval_frequency == 0:
            t1 = time.time() - t0
            raw_time += t1
            tf.logging.info('Evaluating')
            tf.logging.info('Sampling strategy is: {}'.format(
                FLAGS.sampling_strategy))
            t_test = util.evaluate(model,
                                   dataset,
                                   query_word_ids,
                                   FLAGS.maxseqlen,
                                   FLAGS.maxquerylen,
                                   sess,
                                   FLAGS.token_drop_prob,
                                   neg_sample_size=FLAGS.neg_sample_size_eval,
                                   presampled_negatives=presampled_negatives,
                                   eval_on='test')
            t_valid = util.evaluate(model,
                                    dataset,
                                    query_word_ids,
                                    FLAGS.maxseqlen,
                                    FLAGS.maxquerylen,
                                    sess,
                                    FLAGS.token_drop_prob,
                                    neg_sample_size=FLAGS.neg_sample_size_eval,
                                    presampled_negatives=presampled_negatives,
                                    eval_on='valid')
            eval_str = (
                'epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f),'
                ' test (NDCG@10: %.4f, HR@10: %.4f)'
            ) % (epoch, raw_time, t_valid[0], t_valid[1], t_test[0], t_test[1])
            tf.logging.info(eval_str)
            f.write(eval_str + '\n')
            f.flush()
            t0 = time.time()
            valid_summary_writer.add_summary(
                sess.run(summary_ndcg, feed_dict={value_ndcg: t_valid[0]}),
                epoch)
            valid_summary_writer.add_summary(
                sess.run(summary_hit, feed_dict={value_hit: t_valid[1]}),
                epoch)
            test_summary_writer.add_summary(
                sess.run(summary_ndcg, feed_dict={value_ndcg: t_test[0]}),
                epoch)
            test_summary_writer.add_summary(
                sess.run(summary_hit, feed_dict={value_hit: t_test[1]}), epoch)

            # Evaluate on train split.
            if FLAGS.save_train_eval:
                t_train = util.evaluate(
                    model,
                    dataset,
                    query_word_ids,
                    FLAGS.maxseqlen,
                    FLAGS.maxquerylen,
                    sess,
                    FLAGS.token_drop_prob,
                    neg_sample_size=FLAGS.neg_sample_size_eval,
                    presampled_negatives=presampled_negatives,
                    eval_on='train')
                train_str = (
                    'epoch:%d, time: %f(s), train (NDCG@10: %.4f, HR@10: %.4f)'
                ) % (epoch, raw_time, t_train[0], t_train[1])
                tf.logging.info(train_str)
                train_summary_writer.add_summary(
                    sess.run(summary_ndcg, feed_dict={value_ndcg: t_train[0]}),
                    epoch)
                train_summary_writer.add_summary(
                    sess.run(summary_hit, feed_dict={value_hit: t_train[1]}),
                    epoch)
    tf.logging.info('Done. Log written to %s' % log_filename)
    sess.close()
    f.close()