def test_model_naml(mind_resource_path):
    train_news_file = os.path.join(mind_resource_path, "train", r"news.tsv")
    train_behaviors_file = os.path.join(mind_resource_path, "train",
                                        r"behaviors.tsv")
    valid_news_file = os.path.join(mind_resource_path, "valid", r"news.tsv")
    valid_behaviors_file = os.path.join(mind_resource_path, "valid",
                                        r"behaviors.tsv")
    wordEmb_file = os.path.join(mind_resource_path, "utils",
                                "embedding_all.npy")
    userDict_file = os.path.join(mind_resource_path, "utils", "uid2index.pkl")
    wordDict_file = os.path.join(mind_resource_path, "utils",
                                 "word_dict_all.pkl")
    vertDict_file = os.path.join(mind_resource_path, "utils", "vert_dict.pkl")
    subvertDict_file = os.path.join(mind_resource_path, "utils",
                                    "subvert_dict.pkl")
    yaml_file = os.path.join(mind_resource_path, "utils", r"naml.yaml")

    if not os.path.exists(train_news_file):
        download_deeprec_resources(
            r"https://recodatasets.z20.web.core.windows.net/newsrec/",
            os.path.join(mind_resource_path, "train"),
            "MINDdemo_train.zip",
        )
    if not os.path.exists(valid_news_file):
        download_deeprec_resources(
            r"https://recodatasets.z20.web.core.windows.net/newsrec/",
            os.path.join(mind_resource_path, "valid"),
            "MINDdemo_dev.zip",
        )
    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            r"https://recodatasets.z20.web.core.windows.net/newsrec/",
            os.path.join(mind_resource_path, "utils"),
            "MINDdemo_utils.zip",
        )

    hparams = prepare_hparams(
        yaml_file,
        wordEmb_file=wordEmb_file,
        wordDict_file=wordDict_file,
        userDict_file=userDict_file,
        vertDict_file=vertDict_file,
        subvertDict_file=subvertDict_file,
        epochs=1,
    )

    iterator = MINDAllIterator
    model = NAMLModel(hparams, iterator)
    assert model.run_eval(valid_news_file, valid_behaviors_file) is not None
    assert isinstance(
        model.fit(train_news_file, train_behaviors_file, valid_news_file,
                  valid_behaviors_file),
        BaseModel,
    )
Exemple #2
0
def test_model_naml(tmp):
    yaml_file = os.path.join(tmp, "naml.yaml")
    train_file = os.path.join(tmp, "train.txt")
    valid_file = os.path.join(tmp, "test.txt")
    wordEmb_file = os.path.join(tmp, "embedding.npy")

    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            "https://recodatasets.blob.core.windows.net/newsrec/", tmp,
            "naml.zip")

    hparams = prepare_hparams(yaml_file, wordEmb_file=wordEmb_file, epochs=1)
    assert hparams is not None

    iterator = NAMLIterator
    model = NAMLModel(hparams, iterator)

    assert model.run_eval(valid_file) is not None
    assert isinstance(model.fit(train_file, valid_file), BaseModel)