def main():
    data_path = args.data_dir
    checkpoint_path = args.checkpoint_path
    tb_log_path = args.tb_log_path
    model_select = args.model_select

    rank_out = args.rank
    user_batch_size = 1000
    n_scores_user = 2500
    data_batch_size = 100
    dropout = args.dropout
    recall_at = range(10, 110, 10)
    eval_batch_size = 1000
    max_data_per_step = 2500000
    eval_every = args.eval_every
    num_epoch = 500

    _lr = args.lr
    _decay_lr_every = 100
    _lr_decay = 0.1

    experiment = '%s_%s' % (
        datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'), '-'.join(
            str(x / 100) for x in model_select) if model_select else 'simple')
    _tf_ckpt_file = None if checkpoint_path is None else checkpoint_path + experiment + '/tf_checkpoint'

    print('running: ' + experiment)

    dat = load_data(data_path)
    u_pref_scaled = dat['u_pref_scaled']
    v_pref_scaled = dat['v_pref_scaled']
    eval_cold = dat['eval_cold']
    item_content = dat['item_content']
    u_pref = dat['u_pref']
    v_pref = dat['v_pref']
    user_indices = dat['user_indices']

    timer = utils.timer(name='main').tic()

    # append pref factors for faster dropout
    v_pref_expanded = np.vstack(
        [v_pref_scaled, np.zeros_like(v_pref_scaled[0, :])])
    v_pref_last = v_pref_scaled.shape[0]
    u_pref_expanded = np.vstack(
        [u_pref_scaled, np.zeros_like(u_pref_scaled[0, :])])
    u_pref_last = u_pref_scaled.shape[0]
    timer.toc('initialized numpy data for tf')

    # prep eval
    eval_batch_size = eval_batch_size
    timer.tic()
    eval_cold.init_tf(u_pref_scaled, v_pref_scaled, None, item_content,
                      eval_batch_size)
    timer.toc('initialized eval for tf').tic()

    dropout_net = model.DeepCF(latent_rank_in=u_pref.shape[1],
                               user_content_rank=0,
                               item_content_rank=item_content.shape[1],
                               model_select=model_select,
                               rank_out=rank_out)

    config = tf.ConfigProto(allow_soft_placement=True)

    with tf.device(args.model_device):
        dropout_net.build_model()

    with tf.device(args.inf_device):
        dropout_net.build_predictor(recall_at, n_scores_user)

    with tf.Session(config=config) as sess:
        tf_saver = None if _tf_ckpt_file is None else tf.train.Saver()
        train_writer = None if tb_log_path is None else tf.summary.FileWriter(
            tb_log_path + experiment, sess.graph)
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        timer.toc('initialized tf')

        row_index = np.copy(user_indices)
        n_step = 0
        best_cold = 0
        n_batch_trained = 0
        best_step = 0
        for epoch in range(num_epoch):
            np.random.shuffle(row_index)
            for b in utils.batch(row_index, user_batch_size):
                n_step += 1
                # prep targets
                target_users = np.repeat(b, n_scores_user)
                target_users_rand = np.repeat(np.arange(len(b)), n_scores_user)
                target_items_rand = [
                    np.random.choice(v_pref.shape[0], n_scores_user) for _ in b
                ]
                target_items_rand = np.array(target_items_rand).flatten()
                target_ui_rand = np.transpose(
                    np.vstack([target_users_rand, target_items_rand]))
                [target_scores, target_items, random_scores] = sess.run(
                    [
                        dropout_net.tf_topk_vals, dropout_net.tf_topk_inds,
                        dropout_net.preds_random
                    ],
                    feed_dict={
                        dropout_net.U_pref_tf: u_pref[b, :],
                        dropout_net.V_pref_tf: v_pref,
                        dropout_net.rand_target_ui: target_ui_rand
                    })
                # merge topN and randomN items per user
                target_scores = np.append(target_scores, random_scores)
                target_items = np.append(target_items, target_items_rand)
                target_users = np.append(target_users, target_users)

                tf.local_variables_initializer().run()
                n_targets = len(target_scores)
                perm = np.random.permutation(n_targets)
                n_targets = min(n_targets, max_data_per_step)
                data_batch = [(n, min(n + data_batch_size, n_targets))
                              for n in range(0, n_targets, data_batch_size)]
                f_batch = 0
                gen = data_batch
                for (start, stop) in tqdm(gen):
                    batch_perm = perm[start:stop]
                    batch_users = target_users[batch_perm]
                    batch_items = target_items[batch_perm]
                    if dropout != 0:
                        n_to_drop = int(np.floor(dropout * len(batch_perm)))
                        perm_user = np.random.permutation(
                            len(batch_perm))[:n_to_drop]
                        perm_item = np.random.permutation(
                            len(batch_perm))[:n_to_drop]
                        batch_v_pref = np.copy(batch_items)
                        batch_u_pref = np.copy(batch_users)
                        batch_v_pref[perm_user] = v_pref_last
                        batch_u_pref[perm_item] = u_pref_last
                    else:
                        batch_v_pref = batch_items
                        batch_u_pref = batch_users
                    item_content_batch = item_content[batch_items, :]
                    if sp.issparse(item_content):
                        item_content_batch = item_content_batch.todense()

                    _, _, loss_out = sess.run(
                        [
                            dropout_net.preds, dropout_net.updates,
                            dropout_net.loss
                        ],
                        feed_dict={
                            dropout_net.Uin: u_pref_expanded[batch_u_pref, :],
                            dropout_net.Vin: v_pref_expanded[batch_v_pref, :],
                            dropout_net.Vcontent: item_content_batch,
                            #
                            dropout_net.target: target_scores[batch_perm],
                            dropout_net.lr_placeholder: _lr,
                            dropout_net.phase: 1
                        })
                    f_batch += loss_out
                    if np.isnan(f_batch):
                        raise Exception('f is nan')

                n_batch_trained += len(data_batch)
                if n_step % _decay_lr_every == 0:
                    _lr = _lr_decay * _lr
                    print('decayed lr:' + str(_lr))
                if n_step % eval_every == 0:
                    recall_cold = utils.batch_eval_recall(
                        sess,
                        dropout_net.eval_preds_cold,
                        eval_feed_dict=dropout_net.get_eval_dict,
                        recall_k=recall_at,
                        eval_data=eval_cold)

                    # checkpoint
                    if np.sum(recall_cold) > np.sum(best_cold):
                        best_cold = recall_cold
                        best_step = n_step
                        if tf_saver is not None:
                            tf_saver.save(sess, _tf_ckpt_file)

                    timer.toc('%d [%d]b [%d]tot f=%.2f best[%d]' %
                              (n_step, len(data_batch), n_batch_trained,
                               f_batch, best_step)).tic()
                    print('\t\t\t' + ' '.join([('@' + str(i)).ljust(6)
                                               for i in recall_at]))
                    print('cold start\t%s' %
                          (' '.join(['%.4f' % i for i in recall_cold]), ))
                    print('best epoch[%d]\t%s' % (
                        best_step,
                        ' '.join(['%.4f' % i for i in best_cold]),
                    ))
                    summaries = []
                    for i, k in enumerate(recall_at):
                        if k % 100 == 0:
                            summaries.extend([
                                tf.Summary.Value(tag="recall@" + str(k) +
                                                 " cold",
                                                 simple_value=recall_cold[i]),
                            ])
                    recall_summary = tf.Summary(value=summaries)
                    if train_writer is not None:
                        train_writer.add_summary(recall_summary, n_step)
Beispiel #2
0
def main():
    data_name = args.data
    model_select = args.model_select
    rank_out = args.rank
    data_batch_size = 1024
    dropout = args.dropout
    recall_at = [20, 50, 100]
    eval_batch_size = 5000  # the batch size when test
    eval_every = args.eval_every
    num_epoch = 100
    neg = args.neg

    _lr = args.lr
    _decay_lr_every = 2
    _lr_decay = 0.9

    dat = load_data(data_name)
    u_pref = dat['u_pref']
    v_pref = dat['v_pref']
    test_eval = dat['test_eval']
    vali_eval = dat['vali_eval']
    user_content = dat['user_content']
    user_list = dat['user_list']
    item_list = dat['item_list']
    item_warm = np.unique(item_list)
    timer = utils.timer(name='main').tic()

    # prep eval
    eval_batch_size = eval_batch_size
    timer.tic()
    test_eval.init_tf(u_pref,
                      v_pref,
                      user_content,
                      None,
                      eval_batch_size,
                      cold_user=True)  # init data for evaluation
    vali_eval.init_tf(u_pref,
                      v_pref,
                      user_content,
                      None,
                      eval_batch_size,
                      cold_user=True)  # init data for evaluation
    timer.toc('initialized eval data').tic()

    heater = model.Heater(latent_rank_in=u_pref.shape[1],
                          user_content_rank=user_content.shape[1],
                          item_content_rank=0,
                          model_select=model_select,
                          rank_out=rank_out,
                          reg=args.reg,
                          alpha=args.alpha,
                          dim=args.dim)
    heater.build_model()
    heater.build_predictor(recall_at)

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        timer.toc('initialized tf')

        best_epoch = 0
        best_recall = 0  # val
        best_test_recall = 0  # test
        for epoch in range(num_epoch):
            user_array, item_array, target_array = utils.negative_sampling(
                user_list, item_list, neg, item_warm)
            random_idx = np.random.permutation(user_array.shape[0])
            n_targets = len(random_idx)
            data_batch = [(n, min(n + data_batch_size, n_targets))
                          for n in range(0, n_targets, data_batch_size)]
            loss_epoch = 0.
            reg_loss_epoch = 0.
            diff_loss_epoch = 0.
            rec_loss_epoch = 0.
            for (start, stop) in data_batch:

                batch_idx = random_idx[start:stop]
                batch_users = user_array[batch_idx]
                batch_items = item_array[batch_idx]
                batch_targets = target_array[batch_idx]

                # content
                user_content_batch = user_content[batch_users, :].todense()
                # dropout
                if dropout != 0:
                    n_to_drop = int(np.floor(
                        dropout *
                        len(batch_idx)))  # number of u-i pairs to be dropped
                    zero_index = np.random.choice(np.arange(len(batch_idx)),
                                                  n_to_drop,
                                                  replace=False)
                else:
                    zero_index = np.array([])
                dropout_indicator = np.zeros_like(batch_targets).reshape(
                    (-1, 1))
                if len(zero_index) > 0:
                    dropout_indicator[zero_index] = 1

                _, _, loss_out, rec_loss_out, reg_loss_out, diff_loss_out = sess.run(
                    [
                        heater.preds, heater.optimizer, heater.loss,
                        heater.rec_loss, heater.reg_loss, heater.diff_loss
                    ],
                    feed_dict={
                        heater.Uin: u_pref[batch_users, :],
                        heater.Vin: v_pref[batch_items, :],
                        heater.Ucontent: user_content_batch,
                        heater.dropout_user_indicator: dropout_indicator,
                        heater.target: batch_targets,
                        heater.lr_placeholder: _lr,
                        heater.is_training: True
                    })
                loss_epoch += loss_out
                rec_loss_out += rec_loss_out
                reg_loss_epoch += reg_loss_out
                diff_loss_epoch += diff_loss_out
                if np.isnan(loss_epoch):
                    raise Exception('f is nan')

            if (epoch + 1) % _decay_lr_every == 0:
                _lr = _lr_decay * _lr
                print('decayed lr:' + str(_lr))

            if epoch % eval_every == 0:
                recall, precision, ndcg = utils.batch_eval_recall(
                    sess,
                    heater.eval_preds_cold,
                    eval_feed_dict=heater.get_eval_dict,
                    recall_k=recall_at,
                    eval_data=vali_eval)

            # checkpoint
            if np.sum(recall) > np.sum(best_recall):
                best_recall = recall
                test_recall, test_precision, test_ndcg = utils.batch_eval_recall(
                    sess,
                    heater.eval_preds_cold,
                    eval_feed_dict=heater.get_eval_dict,
                    recall_k=recall_at,
                    eval_data=test_eval)
                best_test_recall = test_recall
                best_epoch = epoch

            # print results at every epoch
            timer.toc(
                '%d loss=%.4f reg_loss=%.4f diff_loss=%.4f rec_loss=%.4f' %
                (epoch, loss_epoch / len(data_batch), reg_loss_epoch /
                 len(data_batch), diff_loss_epoch / len(data_batch),
                 rec_loss_epoch / len(data_batch))).tic()
            print('\t\t\t' + '\t '.join([
                ('@' + str(i)).ljust(6) for i in recall_at
            ]))  # ljust: padding to fixed len
            print('Current recall\t\t%s' %
                  (' '.join(['%.6f' % i for i in recall])))
            print('Current precision\t%s' %
                  (' '.join(['%.6f' % i for i in precision])))
            print('Current ndcg\t\t%s' % (' '.join(['%.6f' % i
                                                    for i in ndcg])))
            print('Current test recall\t%s' %
                  (' '.join(['%.6f' % i for i in test_recall])))
            print('Current test precision\t%s' %
                  (' '.join(['%.6f' % i for i in test_precision])))
            print('Current test ndcg\t%s' %
                  (' '.join(['%.6f' % i for i in test_ndcg])))
            print('best[%d] vali recall:\t%s' %
                  (best_epoch, ' '.join(['%.6f' % i for i in best_recall])))
            print('best[%d] test recall:\t%s' %
                  (best_epoch, ' '.join(['%.6f' % i
                                         for i in best_test_recall])))