def test_model_naml(mind_resource_path):
    train_news_file = os.path.join(mind_resource_path, "train", r"news.tsv")
    train_behaviors_file = os.path.join(mind_resource_path, "train",
                                        r"behaviors.tsv")
    valid_news_file = os.path.join(mind_resource_path, "valid", r"news.tsv")
    valid_behaviors_file = os.path.join(mind_resource_path, "valid",
                                        r"behaviors.tsv")
    wordEmb_file = os.path.join(mind_resource_path, "utils",
                                "embedding_all.npy")
    userDict_file = os.path.join(mind_resource_path, "utils", "uid2index.pkl")
    wordDict_file = os.path.join(mind_resource_path, "utils",
                                 "word_dict_all.pkl")
    vertDict_file = os.path.join(mind_resource_path, "utils", "vert_dict.pkl")
    subvertDict_file = os.path.join(mind_resource_path, "utils",
                                    "subvert_dict.pkl")
    yaml_file = os.path.join(mind_resource_path, "utils", r"naml.yaml")

    if not os.path.exists(train_news_file):
        download_deeprec_resources(
            r"https://recodatasets.z20.web.core.windows.net/newsrec/",
            os.path.join(mind_resource_path, "train"),
            "MINDdemo_train.zip",
        )
    if not os.path.exists(valid_news_file):
        download_deeprec_resources(
            r"https://recodatasets.z20.web.core.windows.net/newsrec/",
            os.path.join(mind_resource_path, "valid"),
            "MINDdemo_dev.zip",
        )
    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            r"https://recodatasets.z20.web.core.windows.net/newsrec/",
            os.path.join(mind_resource_path, "utils"),
            "MINDdemo_utils.zip",
        )

    hparams = prepare_hparams(
        yaml_file,
        wordEmb_file=wordEmb_file,
        wordDict_file=wordDict_file,
        userDict_file=userDict_file,
        vertDict_file=vertDict_file,
        subvertDict_file=subvertDict_file,
        epochs=1,
    )

    iterator = MINDAllIterator
    model = NAMLModel(hparams, iterator)
    assert model.run_eval(valid_news_file, valid_behaviors_file) is not None
    assert isinstance(
        model.fit(train_news_file, train_behaviors_file, valid_news_file,
                  valid_behaviors_file),
        BaseModel,
    )
def test_naml_component_definition(mind_resource_path):
    wordEmb_file = os.path.join(mind_resource_path, "utils",
                                "embedding_all.npy")
    userDict_file = os.path.join(mind_resource_path, "utils", "uid2index.pkl")
    wordDict_file = os.path.join(mind_resource_path, "utils",
                                 "word_dict_all.pkl")
    vertDict_file = os.path.join(mind_resource_path, "utils", "vert_dict.pkl")
    subvertDict_file = os.path.join(mind_resource_path, "utils",
                                    "subvert_dict.pkl")
    yaml_file = os.path.join(mind_resource_path, "utils", r"naml.yaml")

    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            r"https://recodatasets.z20.web.core.windows.net/newsrec/",
            os.path.join(mind_resource_path, "utils"),
            "MINDdemo_utils.zip",
        )

    hparams = prepare_hparams(
        yaml_file,
        wordEmb_file=wordEmb_file,
        wordDict_file=wordDict_file,
        userDict_file=userDict_file,
        vertDict_file=vertDict_file,
        subvertDict_file=subvertDict_file,
        epochs=1,
    )
    iterator = MINDAllIterator
    model = NAMLModel(hparams, iterator)

    assert model.model is not None
    assert model.scorer is not None
    assert model.loss is not None
    assert model.train_optimizer is not None
Exemple #3
0
def test_model_naml(tmp):
    yaml_file = os.path.join(tmp, "naml.yaml")
    train_file = os.path.join(tmp, "train.txt")
    valid_file = os.path.join(tmp, "test.txt")
    wordEmb_file = os.path.join(tmp, "embedding.npy")

    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            "https://recodatasets.blob.core.windows.net/newsrec/", tmp,
            "naml.zip")

    hparams = prepare_hparams(yaml_file, wordEmb_file=wordEmb_file, epochs=1)
    assert hparams is not None

    iterator = NAMLIterator
    model = NAMLModel(hparams, iterator)

    assert model.run_eval(valid_file) is not None
    assert isinstance(model.fit(train_file, valid_file), BaseModel)
Exemple #4
0
def test_naml_component_definition(tmp):
    yaml_file = os.path.join(tmp, "naml.yaml")
    wordEmb_file = os.path.join(tmp, "embedding.npy")

    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            "https://recodatasets.blob.core.windows.net/newsrec/", tmp,
            "naml.zip")

    hparams = prepare_hparams(yaml_file, wordEmb_file=wordEmb_file, epochs=1)
    iterator = NAMLIterator
    model = NAMLModel(hparams, iterator)

    assert model.model is not None
    assert model.scorer is not None
    assert model.loss is not None
    assert model.train_optimizer is not None
Exemple #5
0
wordDict_file = os.path.join(data_path, "utils", "word_dict_all.pkl")
subvertDict_file = os.path.join(data_path, "utils", "subvert_dict.pkl")
vertDict_file = os.path.join(data_path, "utils", "vert_dict.pkl")
yaml_file = os.path.join(data_path, "utils", '{}.yaml'.format(opt.model_name))

hparams = prepare_hparams(yaml_file,
                          wordEmb_file=wordEmb_file,
                          wordDict_file=wordDict_file,
                          userDict_file=userDict_file,
                          vertDict_file=vertDict_file,
                          subvertDict_file=subvertDict_file,
                          batch_size=batch_size,
                          epochs=epochs)
print(hparams)

iterator = iterator = MINDAllIterator
model = NAMLModel(hparams, iterator, seed=seed)

# print(model.run_slow_eval(news_file, valid_behaviors_file))

model.fit(news_file, train_behaviors_file, news_file, valid_behaviors_file)

# model_path = os.path.join(model_path, "model")
# os.makedirs(model_path, exist_ok=True)

# model.model.save_weights(os.path.join(model_path, "nrms_ckpt"))

# group_impr_indexes, group_labels, group_preds = model.run_slow_eval(test_news_file, test_behaviors_file)

# res = cal_metric(group_labels, group_preds, hparams.metrics)
Exemple #6
0
                          subvertDict_file=subvertDict_file,
                          batch_size=batch_size,
                          epochs=epochs,
                          show_step=10)
logging.info(hparams)


# ## Train the NRMS model


if model_type == 'nrms':
    iterator = MINDIterator
    model = NRMSModel(hparams, iterator, seed=seed)
elif model_type == 'naml':
    iterator = MINDAllIterator
    model = NAMLModel(hparams, iterator, seed=seed)
elif model_type == 'npa':
    iterator = MINDIterator
    model = NPAModel(hparams, iterator, seed=seed)
elif model_type == 'nrmma':
    iterator = MINDAllIterator
    model = NRMMAModel(hparams, iterator, seed=seed)

else:
    raise NotImplementedError(f"{exp_name} is not implemented")

# In[8]:
model_path = os.path.join(exp_path, model_type)
model_name = model_type + '_ckpt'
model.fit(train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file,
          model_path=model_path, model_name=model_name)