def test_model_naml(mind_resource_path): train_news_file = os.path.join(mind_resource_path, "train", r"news.tsv") train_behaviors_file = os.path.join(mind_resource_path, "train", r"behaviors.tsv") valid_news_file = os.path.join(mind_resource_path, "valid", r"news.tsv") valid_behaviors_file = os.path.join(mind_resource_path, "valid", r"behaviors.tsv") wordEmb_file = os.path.join(mind_resource_path, "utils", "embedding_all.npy") userDict_file = os.path.join(mind_resource_path, "utils", "uid2index.pkl") wordDict_file = os.path.join(mind_resource_path, "utils", "word_dict_all.pkl") vertDict_file = os.path.join(mind_resource_path, "utils", "vert_dict.pkl") subvertDict_file = os.path.join(mind_resource_path, "utils", "subvert_dict.pkl") yaml_file = os.path.join(mind_resource_path, "utils", r"naml.yaml") if not os.path.exists(train_news_file): download_deeprec_resources( r"https://recodatasets.z20.web.core.windows.net/newsrec/", os.path.join(mind_resource_path, "train"), "MINDdemo_train.zip", ) if not os.path.exists(valid_news_file): download_deeprec_resources( r"https://recodatasets.z20.web.core.windows.net/newsrec/", os.path.join(mind_resource_path, "valid"), "MINDdemo_dev.zip", ) if not os.path.exists(yaml_file): download_deeprec_resources( r"https://recodatasets.z20.web.core.windows.net/newsrec/", os.path.join(mind_resource_path, "utils"), "MINDdemo_utils.zip", ) hparams = prepare_hparams( yaml_file, wordEmb_file=wordEmb_file, wordDict_file=wordDict_file, userDict_file=userDict_file, vertDict_file=vertDict_file, subvertDict_file=subvertDict_file, epochs=1, ) iterator = MINDAllIterator model = NAMLModel(hparams, iterator) assert model.run_eval(valid_news_file, valid_behaviors_file) is not None assert isinstance( model.fit(train_news_file, train_behaviors_file, valid_news_file, valid_behaviors_file), BaseModel, )
def test_model_naml(tmp): yaml_file = os.path.join(tmp, "naml.yaml") train_file = os.path.join(tmp, "train.txt") valid_file = os.path.join(tmp, "test.txt") wordEmb_file = os.path.join(tmp, "embedding.npy") if not os.path.exists(yaml_file): download_deeprec_resources( "https://recodatasets.blob.core.windows.net/newsrec/", tmp, "naml.zip") hparams = prepare_hparams(yaml_file, wordEmb_file=wordEmb_file, epochs=1) assert hparams is not None iterator = NAMLIterator model = NAMLModel(hparams, iterator) assert model.run_eval(valid_file) is not None assert isinstance(model.fit(train_file, valid_file), BaseModel)