Esempio n. 1
0
def inference():
    dataset = DKNDataset('data/test/behaviors_cleaned.tsv',
                         'data/test/news_with_entity.tsv')
    print(f"Load inference dataset with size {len(dataset)}.")
    dataloader = DataLoader(dataset,
                            batch_size=Config.batch_size,
                            shuffle=False,
                            num_workers=Config.num_workers,
                            drop_last=False)

    # Load trained embedding file
    # num_entity_tokens, entity_embedding_dim
    entity_embedding = np.load('data/train/entity_embedding.npy')
    context_embedding = np.load('data/train/entity_embedding.npy')  # TODO

    dkn = DKN(Config, entity_embedding, context_embedding).to(device)
    checkpoint_path = latest_checkpoint('./checkpoint')
    print(f"Load saved parameters in {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    dkn.load_state_dict(checkpoint['model_state_dict'])
    dkn.eval()

    y_pred = []
    y = []

    count = 0

    with tqdm(total=len(dataloader), desc="Inferering") as pbar:
        for minibatch in dataloader:
            y_pred.extend(
                dkn(minibatch["candidate_news"],
                    minibatch["clicked_news"]).tolist())
            y.extend(minibatch["clicked"].float().tolist())
            pbar.update(1)

            count += 1
            if count == 500:
                break

    y_pred = iter(y_pred)
    y = iter(y)

    # For locating and order validating
    truth_file = open('./data/test/truth.json', 'r')
    # For writing inference results
    submission_answer_file = open('./data/test/answer.json', 'w')
    try:
        for line in truth_file.readlines():
            user_truth = json.loads(line)
            user_inference = copy.deepcopy(user_truth)
            for k in user_truth['impression'].keys():
                assert next(y) == user_truth['impression'][k]
                user_inference['impression'][k] = next(y_pred)
            submission_answer_file.write(json.dumps(user_inference) + '\n')
    except StopIteration:
        print(
            'Warning: Behaviors not fully inferenced. You can still run evaluate.py, but the evaluation result would be inaccurate.'
        )

    submission_answer_file.close()
Esempio n. 2
0
def get_models_and_configs(args):
    ensemble_model_names = args.models
    assert type(ensemble_model_names) == type([])
    models_configs = {model_name: [] for model_name in ensemble_model_names}

    for idx, (model_name, model_list) in enumerate(models_configs.items()):
        Config_cls = getattr(importlib.import_module('config'), f"{model_name}Config")
        config = Config_cls()
        config.datasize = args.datasize
        config.model_name = model_name
        config.configure_datasize()
        config.device_str = args.device

        checkpoint_path = latest_checkpoint(os.path.join(config.checkpoint_dir,
                                                         config.datasize,
                                                         config.model_name))
        if checkpoint_path is None:
            print('No checkpoint file found!')
            exit()
        print(f"Load saved parameters in {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)

        Model_cls = getattr(importlib.import_module(f"model.{model_name}"), model_name)
        model = Model_cls(config).to(config.device_str)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()

        model_list.append(model)
        model_list.append(config)

    return models_configs
Esempio n. 3
0
def evaluation():
    print('Using device:', device)
    print(f'Evaluating model {model_name}')
    # Don't need to load pretrained word/entity/context embedding
    # since it will be loaded from checkpoint later
    model = Model(config).to(device)
    from train import latest_checkpoint  # Avoid circular imports
    checkpoint_path = latest_checkpoint(path.join('./checkpoint', model_name))
    if checkpoint_path is None:
        print('No checkpoint file found!')
        exit()
    print(f"Load saved parameters in {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    auc, mrr, ndcg5, ndcg10 = evaluate(model, './data/test')
    print(
        f'AUC: {auc:.4f}\nMRR: {mrr:.4f}\nnDCG@5: {ndcg5:.4f}\nnDCG@10: {ndcg10:.4f}'
    )
Esempio n. 4
0
    print(f'saved {txt_path}')
    return


if __name__ == '__main__':
    print('Using device:', device)
    print(f'Evaluating model {model_name}')

    # Don't need to load pretrained word/entity/context embedding
    # since it will be loaded from checkpoint later
    model = Model(config).to(device)
    from train import latest_checkpoint  # Avoid circular imports
    checkpoint_dir = os.path.join(
        '../checkpoint', model_name, 'batch_size' + str(config.batch_size) +
        '_num' + str(config.num_clicked_news_a_user))
    checkpoint_path = latest_checkpoint(checkpoint_dir)
    if checkpoint_path is None:
        print('No checkpoint file found!')
        exit()
    print(f"Load saved parameters in {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    evaluate(model, '../data/real_test', True,
             '../data/real_test/prediction.txt')

    print('zippppp')
    f = zipfile.ZipFile('../data/real_test/prediction.zip', 'w',
                        zipfile.ZIP_DEFLATED)
Esempio n. 5
0
                    f"{minibatch['impression_id'][0]} {str(list(value2rank(impression).values())).replace(' ','')}\n"
                )
            pbar.update(1)

    if generate_txt:
        answer_file.close()

    return np.mean(aucs), np.mean(mrrs), np.mean(ndcg5s), np.mean(ndcg10s)


if __name__ == '__main__':
    print('Using device:', device)
    print(f'Evaluating model {model_name}')
    # Don't need to load pretrained word/entity/context embedding
    # since it will be loaded from checkpoint later
    model = Model(config).to(device)
    from train import latest_checkpoint  # Avoid circular imports
    checkpoint_path = latest_checkpoint(
        os.path.join('../checkpoint', model_name))
    if checkpoint_path is None:
        print('No checkpoint file found!')
        exit()
    print(f"Load saved parameters in {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    auc, mrr, ndcg5, ndcg10 = evaluate(model, '../data/test', True,
                                       '../data/test/prediction.txt')
    print(
        f'AUC: {auc:.4f}\nMRR: {mrr:.4f}\nnDCG@5: {ndcg5:.4f}\nnDCG@10: {ndcg10:.4f}'
    )
Esempio n. 6
0
            # batch_size
            y = minibatch['label'].float()
            loss = criterion(y_pred, y.to(device))
            loss_full.append(loss.item())
            y_pred_full.extend((torch.sigmoid(y_pred) > 0.5).int().tolist())
            y_full.extend(y.tolist())

            pbar.update(1)

    return np.mean(loss_full), classification_report(y_full,
                                                     y_pred_full,
                                                     output_dict=True,
                                                     zero_division=0)


if __name__ == '__main__':
    test_dataset = MyDataset('data/test/news_parsed.tsv')
    # Don't need to load pretrained word embedding
    # since it will be loaded from checkpoint later
    model = Model(Config).to(device)
    from train import latest_checkpoint  # Avoid circular imports
    checkpoint_path = latest_checkpoint('./checkpoint')
    if checkpoint_path is None:
        print('No checkpoint file found!')
        exit()
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    loss, report = evaluate(model, test_dataset)
    print(report['weighted avg'])