def eval(kg_val, model, n_samples):
    dataloader = DataLoader(kg_val, 1, shuffle=True)
    data = [d for d in dataloader]

    n = n_samples

    n_ent = kg_val.n_ent
    model.eval()

    head_rank_mean, tail_rank_mean = [0] * 2
    head_hits_10, tail_hits_10 = [0] * 2

    with torch.no_grad():
        for i in range(n_samples):
            triplets_h = generate_eval_triplets(data[i], "head", n_ent)
            triplets_h, _ = negative_sampling(triplets_h, n_ent, 0)
            triplets_h = triplets_h.to("cuda")
            ee, re = model(triplets_h, eval_=True)

            dst = ee[data[i][1]].squeeze()
            rel = re[data[i][2]].squeeze()
            dist = ee + (rel - dst).repeat(n_ent).view(-1, 100)
            head_preds = torch.topk(torch.norm(dist, dim=1),
                                    k=n_ent).indices.cpu().tolist()
            rank = head_preds.index(data[i][0])
            head_rank_mean += rank
            if rank < 10:
                head_hits_10 += 1

            # # # # tail
            triplets_t = generate_eval_triplets(data[i], "tail", n_ent)
            triplets_t, _ = negative_sampling(triplets_t, n_ent, 0)
            triplets_t = triplets_t.to("cuda")
            ee, re = model(triplets_t)

            src = ee[data[i][0]].squeeze()
            rel = re[data[i][2]].squeeze()
            dist = (src + rel).repeat(n_ent).view(-1, 100) - ee
            tail_preds = torch.topk(torch.norm(dist, dim=1),
                                    k=n_ent).indices.cpu().tolist()
            rank = tail_preds.index(data[i][1])
            tail_rank_mean += rank
            if rank < 10:
                tail_hits_10 += 1

        head_rank_mean /= n
        tail_rank_mean /= n
        head_hits_10 /= n
        tail_hits_10 /= n
        mean_rank = (head_rank_mean + tail_rank_mean) / 2
        hits_10 = (head_hits_10 + tail_hits_10) / 2

    #  print(f"Mean Rank: {mean_rank}")
    #  print(f"Hits@10: {hits0}")
    print("mean rank: {}".format(mean_rank))
    print("hits@10: {}".format(hits_10))
def _One_Hogwild_pass(q_in, q_out, frequence, K, Shape, *args):
    '''
    :param q: queue
    :param *args: arguments for the Hogwild pass
    '''

    M_in_update = np.zeros(shape=Shape, dtype=np.float32)
    M_out_update = np.zeros(shape=Shape, dtype=np.float32)
    STOP = False

    while True:
        if q_in.empty()==False:
            context_word, context_words, target_word, M_in, M_out, STOP = q_in.get()

            # stop condition
            if STOP:
                break
            # Sampling negative samples
            target_words = [target_word] + negative_sampling(frequence, context_words, K)

            # Executing hogwild pass
            M_in_update, M_out_update, loss = _One_Hogwild_pass_jitted(context_word, target_words,
                                                                       K, M_in, M_out, M_in_update, M_out_update, *args)

            q_in.task_done()
            q_out.put([M_in_update, M_out_update, loss, context_word, target_words])

        if STOP:
            break
示例#3
0
    def iter_epoches(self, sess, epoch, user_pro_dict, pro_user_dict, user_df,
                     item_df):
        ''' iterate epoches
		sess: current session
		epoch: current epoch
		user_pro_dict: user and product mapping
		pro_user_dict: product and user mapping
		user_df, item_df: user and item features
		'''
        # learning rate decay
        sess.run(
            tf.assign(self.model.lr,
                      self.args.learning_rate * (self.args.decay_rate**epoch)))
        # reset the batch pointer to 0
        self.batch_loader.reset_batch_pointer()
        # iterate every batches
        out_list = []
        for iteration in range(self.batch_loader.num_batches):
            train_X, _ = self.batch_loader.next_batch()
            train_X['eval'] = 1
            # get the features of positive samples
            train_X_fea = utils.feature_map(train_X, user_df, item_df)
            if self.mode == 'u_p':
                # u_p: users to items
                train_neg_X = utils.negative_sampling(train_X, user_pro_dict,
                                                      pro_user_dict, 1, 'u_p')
                # get the features of negative samples
                train_neg_X_fea = utils.feature_map(train_neg_X, user_df,
                                                    item_df)
                feed = {self.model._users: train_X['uid'], self.model._items_pos: train_X['pid'], self.model._items_neg: train_neg_X['pid'], \
                  self.model._pos_fea: train_X_fea.iloc[:, 3:], \
                  self.model._neg_fea: train_neg_X_fea.iloc[:, 3:]}
            elif self.mode == 'p_u':
                # p_u: items to users
                train_neg_X = utils.negative_sampling(train_X, user_pro_dict,
                                                      pro_user_dict, 1, 'p_u')
                train_neg_X_fea = utils.feature_map(train_neg_X, user_df,
                                                    item_df)

                feed = {self.model._items: train_X['pid'], self.model._users_pos: train_X['uid'], self.model._users_neg: train_neg_X['uid'], \
                  self.model._pos_fea: train_X_fea.iloc[:, 3:], self.model._neg_fea: train_neg_X_fea.iloc[:, 3:], self.model._keep_prob: self.args.keep_prob}

            pred, _, loss = sess.run(
                [self.model._pred, self.model._optimizer, self.model._loss],
                feed)
            print("epoches: %3d, train loss: %2.6f" % (epoch, loss))
示例#4
0
    def run(self):
        self.model.train()
        training_range = tqdm(range(self.n_epochs))
        for epoch in training_range:
            res = 0
            for batch in self.dataloader_train:
                triplets = torch.stack(batch)
                triplets, _ = negative_sampling(triplets, self.n_ent, self.negative_rate)
                triplets = triplets.to(self.device)

                loss = self.train_one_step(triplets, "tail")
                res += loss 
            training_range.set_description("Epoch %d | loss: %f" % (epoch, res))
def load_valid_func(index):
    return x_valid[index], y_valid[index], negative_sampling(
        y_valid[index], words, prob)
def load_train_func(index):
    return x_train[index], y_train[index], negative_sampling(
        y_train[index], words, prob)
示例#7
0
def main():
    data_name = args.data
    model_select = args.model_select
    rank_out = args.rank
    data_batch_size = 1024
    dropout = args.dropout
    recall_at = [20, 50, 100]
    eval_batch_size = 5000  # the batch size when test
    eval_every = args.eval_every
    num_epoch = 100
    neg = args.neg

    _lr = args.lr
    _decay_lr_every = 2
    _lr_decay = 0.9

    dat = load_data(data_name)
    u_pref = dat['u_pref']
    v_pref = dat['v_pref']
    test_eval = dat['test_eval']
    vali_eval = dat['vali_eval']
    user_content = dat['user_content']
    user_list = dat['user_list']
    item_list = dat['item_list']
    item_warm = np.unique(item_list)
    timer = utils.timer(name='main').tic()

    # prep eval
    eval_batch_size = eval_batch_size
    timer.tic()
    test_eval.init_tf(u_pref,
                      v_pref,
                      user_content,
                      None,
                      eval_batch_size,
                      cold_user=True)  # init data for evaluation
    vali_eval.init_tf(u_pref,
                      v_pref,
                      user_content,
                      None,
                      eval_batch_size,
                      cold_user=True)  # init data for evaluation
    timer.toc('initialized eval data').tic()

    heater = model.Heater(latent_rank_in=u_pref.shape[1],
                          user_content_rank=user_content.shape[1],
                          item_content_rank=0,
                          model_select=model_select,
                          rank_out=rank_out,
                          reg=args.reg,
                          alpha=args.alpha,
                          dim=args.dim)
    heater.build_model()
    heater.build_predictor(recall_at)

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        timer.toc('initialized tf')

        best_epoch = 0
        best_recall = 0  # val
        best_test_recall = 0  # test
        for epoch in range(num_epoch):
            user_array, item_array, target_array = utils.negative_sampling(
                user_list, item_list, neg, item_warm)
            random_idx = np.random.permutation(user_array.shape[0])
            n_targets = len(random_idx)
            data_batch = [(n, min(n + data_batch_size, n_targets))
                          for n in range(0, n_targets, data_batch_size)]
            loss_epoch = 0.
            reg_loss_epoch = 0.
            diff_loss_epoch = 0.
            rec_loss_epoch = 0.
            for (start, stop) in data_batch:

                batch_idx = random_idx[start:stop]
                batch_users = user_array[batch_idx]
                batch_items = item_array[batch_idx]
                batch_targets = target_array[batch_idx]

                # content
                user_content_batch = user_content[batch_users, :].todense()
                # dropout
                if dropout != 0:
                    n_to_drop = int(np.floor(
                        dropout *
                        len(batch_idx)))  # number of u-i pairs to be dropped
                    zero_index = np.random.choice(np.arange(len(batch_idx)),
                                                  n_to_drop,
                                                  replace=False)
                else:
                    zero_index = np.array([])
                dropout_indicator = np.zeros_like(batch_targets).reshape(
                    (-1, 1))
                if len(zero_index) > 0:
                    dropout_indicator[zero_index] = 1

                _, _, loss_out, rec_loss_out, reg_loss_out, diff_loss_out = sess.run(
                    [
                        heater.preds, heater.optimizer, heater.loss,
                        heater.rec_loss, heater.reg_loss, heater.diff_loss
                    ],
                    feed_dict={
                        heater.Uin: u_pref[batch_users, :],
                        heater.Vin: v_pref[batch_items, :],
                        heater.Ucontent: user_content_batch,
                        heater.dropout_user_indicator: dropout_indicator,
                        heater.target: batch_targets,
                        heater.lr_placeholder: _lr,
                        heater.is_training: True
                    })
                loss_epoch += loss_out
                rec_loss_out += rec_loss_out
                reg_loss_epoch += reg_loss_out
                diff_loss_epoch += diff_loss_out
                if np.isnan(loss_epoch):
                    raise Exception('f is nan')

            if (epoch + 1) % _decay_lr_every == 0:
                _lr = _lr_decay * _lr
                print('decayed lr:' + str(_lr))

            if epoch % eval_every == 0:
                recall, precision, ndcg = utils.batch_eval_recall(
                    sess,
                    heater.eval_preds_cold,
                    eval_feed_dict=heater.get_eval_dict,
                    recall_k=recall_at,
                    eval_data=vali_eval)

            # checkpoint
            if np.sum(recall) > np.sum(best_recall):
                best_recall = recall
                test_recall, test_precision, test_ndcg = utils.batch_eval_recall(
                    sess,
                    heater.eval_preds_cold,
                    eval_feed_dict=heater.get_eval_dict,
                    recall_k=recall_at,
                    eval_data=test_eval)
                best_test_recall = test_recall
                best_epoch = epoch

            # print results at every epoch
            timer.toc(
                '%d loss=%.4f reg_loss=%.4f diff_loss=%.4f rec_loss=%.4f' %
                (epoch, loss_epoch / len(data_batch), reg_loss_epoch /
                 len(data_batch), diff_loss_epoch / len(data_batch),
                 rec_loss_epoch / len(data_batch))).tic()
            print('\t\t\t' + '\t '.join([
                ('@' + str(i)).ljust(6) for i in recall_at
            ]))  # ljust: padding to fixed len
            print('Current recall\t\t%s' %
                  (' '.join(['%.6f' % i for i in recall])))
            print('Current precision\t%s' %
                  (' '.join(['%.6f' % i for i in precision])))
            print('Current ndcg\t\t%s' % (' '.join(['%.6f' % i
                                                    for i in ndcg])))
            print('Current test recall\t%s' %
                  (' '.join(['%.6f' % i for i in test_recall])))
            print('Current test precision\t%s' %
                  (' '.join(['%.6f' % i for i in test_precision])))
            print('Current test ndcg\t%s' %
                  (' '.join(['%.6f' % i for i in test_ndcg])))
            print('best[%d] vali recall:\t%s' %
                  (best_epoch, ' '.join(['%.6f' % i for i in best_recall])))
            print('best[%d] test recall:\t%s' %
                  (best_epoch, ' '.join(['%.6f' % i
                                         for i in best_test_recall])))
示例#8
0
    # original result
    # org_warm_test = utils.batch_eval(sess, heater.eval_preds_warm,
    #                                  eval_feed_dict=heater.get_eval_dict,
    #                                  eval_data=warm_test_eval,
    #                                  U_pref=u_pref, V_pref=v_pref,
    #                                  excluded_dict=dat['pos_nb'],
    #                                  V_content=item_content,
    #                                  metric=dat['metric']['warm_test'],
    #                                  )

    best_epoch = 0
    patience = 0
    val_auc, best_val_auc = 0., 0.
    for epoch in range(num_epoch):
        user_array, item_array, target_array = utils.negative_sampling(user_list, item_list, neg, item_warm)
        random_idx = np.random.permutation(user_array.shape[0])  # 生成一个打乱的 range 序列作为下标
        data_batch = [(n, min(n + data_batch_size, len(random_idx))) for n in
                      range(0, len(random_idx), data_batch_size)]
        loss_epoch = 0.
        reg_loss_epoch = 0.
        diff_loss_epoch = 0.
        rec_loss_epoch = 0.
        for (start, stop) in data_batch:

            batch_idx = random_idx[start:stop]
            batch_users = user_array[batch_idx]
            batch_items = item_array[batch_idx]
            batch_targets = target_array[batch_idx]

            # content
示例#9
0
def main():
    data_name = args.data
    model_select = args.model_select
    rank_out = args.rank
    data_batch_size = 1024
    dropout = args.dropout
    eval_batch_size = 5000  # the batch size when test
    num_epoch = 100
    neg = args.neg
    _lr = args.lr
    _decay_lr_every = 2
    _lr_decay = 0.9

    dat = load_data(data_name)
    u_pref = dat['u_pref']  # all user pre embedding
    v_pref = dat['v_pref']  # all item pre embedding
    user_content = dat['user_content']  # all item context matrix
    test_eval = dat['test_eval']  # EvalData
    val_eval = dat['val_eval']  # EvalData
    warm_test_eval = dat['warm_test']  # EvalData
    user_list = dat['user_list']  # users of train interactions
    item_list = dat['item_list']  # items of train interactions
    item_warm = np.unique(item_list)  # train item set
    timer = utils.timer(name='main')

    # prep eval
    timer.tic()
    test_eval.init_tf(u_pref, v_pref, user_content, None, eval_batch_size, cold_user=True)  # init data for evaluation
    val_eval.init_tf(u_pref, v_pref, user_content, None, eval_batch_size, cold_user=True)
    warm_test_eval.init_tf(u_pref, v_pref, user_content, None, eval_batch_size, cold_user=True)
    timer.toc('initialized eval data').tic()

    heater = model.Heater(latent_rank_in=u_pref.shape[1],
                          user_content_rank=user_content.shape[1],
                          item_content_rank=0,
                          model_select=model_select,
                          rank_out=rank_out, reg=args.reg, alpha=args.alpha, dim=args.dim)
    heater.build_model()
    heater.build_predictor()

    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        timer.toc('initialized tf')

        # original result
        org_warm_test = utils.batch_eval(sess, heater.eval_preds_warm,
                                         eval_feed_dict=heater.get_eval_dict,
                                         eval_data=warm_test_eval,
                                         metric=dat['metric']['warm_test'],
                                         warm=True)

        best_epoch = 0
        patience = 0
        val_auc, best_val_auc = 0., 0.
        best_warm_test = np.zeros(3)
        best_cold_test = np.zeros(3)
        for epoch in range(num_epoch):
            user_array, item_array, target_array = utils.negative_sampling(user_list, item_list, neg, item_warm)
            random_idx = np.random.permutation(user_array.shape[0])
            n_targets = len(random_idx)
            data_batch = [(n, min(n + data_batch_size, n_targets)) for n in
                          range(0, n_targets, data_batch_size)]
            loss_epoch = 0.
            reg_loss_epoch = 0.
            diff_loss_epoch = 0.
            rec_loss_epoch = 0.
            for (start, stop) in data_batch:

                batch_idx = random_idx[start:stop]
                batch_users = user_array[batch_idx]
                batch_items = item_array[batch_idx]
                batch_targets = target_array[batch_idx]

                # dropout
                if dropout != 0:
                    n_to_drop = int(np.floor(dropout * len(batch_idx)))  # number of u-i pairs to be dropped
                    zero_index = np.random.choice(np.arange(len(batch_idx)), n_to_drop, replace=False)
                else:
                    zero_index = np.array([])

                user_content_batch = user_content[batch_users, :].todense()
                dropout_indicator = np.zeros_like(batch_targets).reshape((-1, 1))
                if len(zero_index) > 0:
                    dropout_indicator[zero_index] = 1

                _, _, loss_out, rec_loss_out, reg_loss_out, diff_loss_out = sess.run(
                    [heater.preds, heater.optimizer, heater.loss,
                     heater.rec_loss, heater.reg_loss, heater.diff_loss],
                    feed_dict={
                        heater.Uin: u_pref[batch_users, :],
                        heater.Vin: v_pref[batch_items, :],
                        heater.Ucontent: user_content_batch,
                        heater.dropout_user_indicator: dropout_indicator,
                        heater.target: batch_targets,
                        heater.lr_placeholder: _lr,
                        heater.is_training: True
                    }
                )
                loss_epoch += loss_out
                rec_loss_epoch += rec_loss_out
                reg_loss_epoch += reg_loss_out
                diff_loss_epoch += diff_loss_out
                if np.isnan(loss_epoch):
                    raise Exception('f is nan')

            if (epoch + 1) % _decay_lr_every == 0:
                _lr = _lr_decay * _lr
                print('decayed lr:' + str(_lr))

            val_auc = utils.batch_eval_auc(sess, heater.eval_preds_cold,
                                           eval_feed_dict=heater.get_eval_dict,
                                           eval_data=val_eval)

            # checkpoint
            if val_auc > best_val_auc:
                patience = 0
                best_val_auc = val_auc
                best_warm_test = utils.batch_eval(sess, heater.eval_preds_cold,
                                                  eval_feed_dict=heater.get_eval_dict,
                                                  eval_data=warm_test_eval,
                                                  metric=dat['metric']['warm_test'],
                                                  warm=True)
                best_cold_test = utils.batch_eval(sess, heater.eval_preds_cold,
                                                  eval_feed_dict=heater.get_eval_dict,
                                                  eval_data=test_eval,
                                                  metric=dat['metric']['cold_test'])
                best_epoch = epoch

            # print results at every epoch
            timer.toc('%d loss=%.4f reg_loss=%.4f diff_loss=%.4f rec_loss=%.4f' % (
                epoch, loss_epoch / len(data_batch), reg_loss_epoch / len(data_batch),
                diff_loss_epoch / len(data_batch), rec_loss_epoch / len(data_batch)
            )).tic()
            print('Current val auc:%.4f\tbest:%.4f' % (val_auc, best_val_auc))
            print('\t\t\t\t\t' + '\t '.join([str(i).ljust(6) for i in ['auc', 'hr', 'ndcg']]))  # padding to fixed len
            print('origin warm test:\t%s' % (' '.join(['%.6f' % i for i in org_warm_test])))
            print('best[%d] warm test:\t%s' % (best_epoch, ' '.join(['%.6f' % i for i in best_warm_test])))
            print('best[%d] cold test:\t%s' % (best_epoch, ' '.join(['%.6f' % i for i in best_cold_test])))

            # early stop
            patience += 1
            if patience > 10:
                print(f"Early stop at epoch {epoch}")
                break
示例#10
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    data_name = args.data
    data_batch_size = 1024  # train batch size
    dropout = args.dropout
    num_epoch = 1000
    neg = args.neg  # negative sampling rate
    _lr = args.lr
    _decay_lr_every = 10
    _lr_decay = 0.8

    dat = load_data(data_name)
    u_pref = dat['u_pref']  # all user pre embedding
    v_pref = dat['v_pref']  # all item pre embedding
    item_content = dat['item_content']  # all item context matrix
    item_fake_pref = dat['item_fake_pref']
    test_eval = dat['cold_eval']  # EvalData
    val_eval = dat['val_eval']  # EvalData
    warm_test_eval = dat['warm_eval']  # EvalData
    user_list = dat['user_list']  # users of train interactions
    item_list = dat['item_list']  # items of train interactions
    item_warm = np.unique(item_list)  # train item set

    timer = utils.timer(name='main').tic()
    # build model
    heater = model.Heater(latent_rank_in=u_pref.shape[-1],
                          user_content_rank=0,
                          item_content_rank=item_content.shape[1],
                          args=args)
    heater.build_model()
    heater.build_predictor()

    saver = tf.train.Saver()
    save_path = './model_save/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        timer.toc('initialized tf').tic()

        best_epoch = 0
        patience = 0
        val_auc, best_val_auc = 0., 0.
        for epoch in range(num_epoch):
            user_array, item_array, target_array = utils.negative_sampling(
                user_list, item_list, neg, item_warm)
            random_idx = np.random.permutation(
                user_array.shape[0])  # 生成一个打乱的 range 序列作为下标
            data_batch = [(n, min(n + data_batch_size, len(random_idx)))
                          for n in range(0, len(random_idx), data_batch_size)]
            loss_epoch = 0.
            reg_loss_epoch = 0.
            diff_loss_epoch = 0.
            rec_loss_epoch = 0.
            for (start, stop) in data_batch:

                batch_idx = random_idx[start:stop]
                batch_users = user_array[batch_idx]
                batch_items = item_array[batch_idx]
                batch_targets = target_array[batch_idx]

                # content
                item_content_batch = item_content[batch_items, :].todense()
                # dropout: used in randomized training
                # indicator's target is the CF pretrain rep
                # set the dropped rows' position in indicator to be 1
                if dropout != 0:
                    n_to_drop = int(np.floor(
                        dropout *
                        len(batch_idx)))  # number of u-i pairs to be dropped
                    zero_index = np.random.choice(np.arange(len(batch_idx)),
                                                  n_to_drop,
                                                  replace=False)
                else:
                    zero_index = np.array([])
                dropout_indicator = np.zeros_like(batch_targets).reshape(
                    (-1, 1))
                dropout_indicator[zero_index] = 1

                _, _, loss_out, rec_loss_out, reg_loss_out, diff_loss_out = sess.run(
                    [
                        heater.preds, heater.optimizer, heater.loss,
                        heater.rec_loss, heater.reg_loss, heater.diff_loss
                    ],
                    feed_dict={
                        heater.Uin: u_pref[batch_users, :],
                        heater.Vin: v_pref[batch_items, :],
                        heater.Vcontent: item_content_batch,
                        heater.fake_v_nei_1: item_fake_pref[batch_items, 1, :],
                        heater.fake_v_nei_2: item_fake_pref[batch_items, 2, :],
                        heater.dropout_item_indicator: dropout_indicator,
                        heater.target: batch_targets,
                        heater.lr_placeholder: _lr,
                        heater.is_training: True
                    })
                loss_epoch += loss_out
                rec_loss_epoch += rec_loss_out
                reg_loss_epoch += reg_loss_out
                diff_loss_epoch += diff_loss_out
                if np.isnan(loss_epoch):
                    raise Exception('f is nan')

            timer.toc(
                '%d loss=%.4f reg_loss=%.4f diff_loss=%.4f rec_loss=%.4f' %
                (epoch, loss_epoch / len(data_batch), reg_loss_epoch /
                 len(data_batch), diff_loss_epoch / len(data_batch),
                 rec_loss_epoch / len(data_batch))).tic()
            if (epoch + 1) % _decay_lr_every == 0:
                _lr = _lr_decay * _lr
                print('decayed lr:' + str(_lr))

            # eval on val
            val_auc = utils.batch_eval(sess,
                                       heater.eval_preds_cold,
                                       eval_feed_dict=heater.get_eval_dict,
                                       eval_data=val_eval,
                                       U_pref=u_pref,
                                       V_pref=v_pref,
                                       excluded_dict=dat['pos_nb'],
                                       V_content=item_content,
                                       v_fake_pref=item_fake_pref,
                                       val=True)
            # if get a better eval result on val, update test result
            # best_recall and best_test_recall are global variables while others are local ones
            if val_auc > best_val_auc:
                saver.save(sess, save_path + args.data + args.warm_model)
                patience = 0
                best_val_auc = val_auc
                best_epoch = epoch
            # print val results at every epoch
            timer.toc('[%d/10] Current val auc:%.4f\tbest:%.4f' %
                      (patience, val_auc, best_val_auc)).tic()

            # early stop
            patience += 1
            if patience > 10:
                print(f"Early stop at epoch {epoch}")
                break

        saver.restore(sess, save_path + args.data + args.warm_model)
        best_warm_test = utils.batch_eval(sess,
                                          heater.eval_preds_cold,
                                          eval_feed_dict=heater.get_eval_dict,
                                          eval_data=warm_test_eval,
                                          U_pref=u_pref,
                                          V_pref=v_pref,
                                          excluded_dict=dat['pos_nb'],
                                          V_content=item_content,
                                          v_fake_pref=item_fake_pref,
                                          metric=dat['metric']['warm_test'],
                                          warm=True)
        best_cold_test = utils.batch_eval(sess,
                                          heater.eval_preds_cold,
                                          eval_feed_dict=heater.get_eval_dict,
                                          eval_data=test_eval,
                                          U_pref=u_pref,
                                          V_pref=v_pref,
                                          excluded_dict=dat['pos_nb'],
                                          V_content=item_content,
                                          v_fake_pref=item_fake_pref,
                                          metric=dat['metric']['cold_test'])
        timer.toc('Test').tic()
        print('\t\t\t\t\t' + '\t '.join(
            [str(i).ljust(6)
             for i in ['auc', 'hr', 'ndcg']]))  # padding to fixed len
        print('best[%d] warm test:\t%s' %
              (best_epoch, ' '.join(['%.6f' % i for i in best_warm_test])))
        print('best[%d] cold test:\t%s' %
              (best_epoch, ' '.join(['%.6f' % i for i in best_cold_test])))