コード例 #1
0
def do_validate(epoch):
    logging.info("[VALIDATE]")
    model_search.eval()

    meter = SRSMetric(k_list=args.aux_eval_ks)
    meter.setup_and_clean()

    with torch.no_grad():
        for batch in iter(val_loader):
            batch = [x.to(device_student) for x in batch]
            seqs, candidates = batch

            scores = model_search.predict(seqs)  # [B, L, num_item]
            scores = scores[:, -1, :]  # [B, num_item]
            scores = scores.gather(1, candidates)  # [B, 1+num_sample]

            meter.submit(scores.cpu(), [[0]] * len(scores))

    meter.calc()
    meter.output_to_logger()

    if epoch is not None:
        save_data = {"state_dict": create_state_dict(), "epoch": epoch}
        torch.save(save_data, folder.joinpath("checkpoint.pth"))

    return round(meter.mrr[5], 4), round(meter.hit[5], 4)
コード例 #2
0
ファイル: teacher_nin.py プロジェクト: Anonymous-2611/AdaRec
def do_validate(epoch):
    logging.info("[VALIDATE]")
    model.eval()

    meter = SRSMetric(k_list=args.aux_eval_ks)
    meter.setup_and_clean()

    def do_metric(logits, other):
        if EVAL_NEG_SAMPLE:
            chosen_from = other
            logits = logits.gather(1, chosen_from)  # [B, 1+num_neg_sample]
            meter.submit(logits.cpu(), [[0]] * len(logits))
        else:  # EVAL_ALL_SAMPLE
            ground_truth = np.array(other[:, 0:1].cpu())  # [B, 1]
            logits = np.array(logits.cpu())
            meter.submit(logits, ground_truth)

    with torch.no_grad():
        for batch in iter(val_loader):
            batch = [x.cuda() for x in batch]
            seqs, candidates = batch

            scores = model(seqs)  # [B, L, #item]
            scores = scores[:, -1, :]  # [B, #item]

            do_metric(scores, candidates)

    meter.calc()
    meter.output_to_logger()

    save_data = {"state_dict": create_state_dict(), "epoch": epoch}
    torch.save(save_data, folder.joinpath("checkpoint.pth"))

    return round(meter.mrr[5], 4), round(meter.hit[5], 4)
コード例 #3
0
ファイル: policy_rl_rec.py プロジェクト: Anonymous-1561/UAF
def evaluate():
    batch_size = configs["batch_size"]
    n_neg = configs["n_neg"]

    total_steps = int(test_set.shape[0] / batch_size)
    action_nums = len(configs["dilations"])

    meter = SRSMetric(k_list=[5, 20])
    meter.setup_and_clean()

    test_usage_sample = []
    for batch_step in range(total_steps):
        f, t = batch_step * batch_size, (batch_step + 1) * batch_size
        batch = test_set[f:t, :]  # [B, L+1]

        context = batch[:, :-1]
        pos_target = batch[:, -1:]
        neg_target = [
            random_negs(l=1, r=data_loader.target_nums, size=n_neg, pos=s[0])
            for s in pos_target
        ]
        target = np.concatenate([neg_target, pos_target], 1)  # [n_neg*neg+pos]

        test_probs, action = sess.run(
            [target_model.test_probs, policy_model.test_action],
            feed_dict={
                source_model.input_source_test: context,
                policy_model.input: context,
                policy_model.method: np.array(1),
                policy_model.sample_action: np.ones((batch_size, action_nums)),
                target_model.input_test: target,
            },
        )
        ground_truth = [[n_neg]] * batch_size
        meter.submit(test_probs, ground_truth)

        test_usage_sample.extend(np.array(action).tolist())

    summary_block(test_usage_sample, len(configs["dilations"]), "Test")

    meter.calc()
    meter.output_to_logger()

    return meter.mrr[5]
コード例 #4
0
def evaluate():
    batch_size = configs["batch_size"]
    total_steps = int(test_set.shape[0] / batch_size)

    meter = SRSMetric(k_list=[5, 20])
    meter.setup_and_clean()

    for batch_step in range(total_steps):
        f, t = batch_step * batch_size, (batch_step + 1) * batch_size
        test_batch = test_set[f:t, :]

        pred_probs = sess.run(model.probs_test,
                              feed_dict={model.input_test: test_batch})

        meter.submit(pred_probs, test_batch[:, -1:])

    meter.calc()
    meter.output_to_logger()

    return meter.mrr[5]