コード例 #1
0
def test_model_xdeepfm(resource_path):
    data_path = os.path.join(resource_path, '../resources/deeprec/xdeepfm')
    yaml_file = os.path.join(data_path, r'xDeepFM.yaml')
    data_file = os.path.join(data_path, r'sample_FFM_data.txt')
    output_file = os.path.join(data_path, r'output.txt')

    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            r'https://recodatasets.blob.core.windows.net/deeprec/', data_path,
            'xdeepfmresources.zip')

    hparams = prepare_hparams(yaml_file, learning_rate=0.01)
    assert hparams is not None

    input_creator = FFMTextIterator
    model = XDeepFMModel(hparams, input_creator)

    assert model.run_eval(data_file) is not None
    assert isinstance(model.fit(data_file, data_file), BaseModel)
    assert model.predict(data_file, output_file) is not None
コード例 #2
0
def test_model_xdeepfm(deeprec_resource_path):
    data_path = os.path.join(deeprec_resource_path, "xdeepfm")
    yaml_file = os.path.join(data_path, "xDeepFM.yaml")
    data_file = os.path.join(data_path, "sample_FFM_data.txt")
    output_file = os.path.join(data_path, "output.txt")

    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            "https://recodatasets.z20.web.core.windows.net/deeprec/",
            data_path,
            "xdeepfmresources.zip",
        )

    hparams = prepare_hparams(yaml_file, learning_rate=0.01)
    assert hparams is not None

    input_creator = FFMTextIterator
    model = XDeepFMModel(hparams, input_creator)

    assert model.run_eval(data_file) is not None
    assert isinstance(model.fit(data_file, data_file), BaseModel)
    assert model.predict(data_file, output_file) is not None
コード例 #3
0
def test_model_xdeepfm(resource_path):
    data_path = os.path.join(resource_path, "..", "resources", "deeprec", "xdeepfm")
    yaml_file = os.path.join(data_path, "xDeepFM.yaml")
    data_file = os.path.join(data_path, "sample_FFM_data.txt")
    output_file = os.path.join(data_path, "output.txt")

    if not os.path.exists(yaml_file):
        download_deeprec_resources(
            "https://recodatasets.blob.core.windows.net/deeprec/",
            data_path,
            "xdeepfmresources.zip",
        )

    hparams = prepare_hparams(yaml_file, learning_rate=0.01)
    assert hparams is not None

    input_creator = FFMTextIterator
    model = XDeepFMModel(hparams, input_creator)

    assert model.run_eval(data_file) is not None
    assert isinstance(model.fit(data_file, data_file), BaseModel)
    assert model.predict(data_file, output_file) is not None
コード例 #4
0
                          FIELD_COUNT=10, 
                          cross_l2=0.0001, 
                          embed_l2=0.0001, 
                          learning_rate=0.001, 
                          epochs=EPOCHS,
                          batch_size=BATCH_SIZE)
print("Hyper-parameters: ")
print(hparams)
# 2. create data loader
# designate a data iterator for xDeepFM model (FFMTextIterator)
input_creator = FFMTextIterator

# 3. create model
model = XDeepFMModel(hparams, input_creator, seed=RANDOM_SEED)
# we can also load a pre-trained model with model.load_model(r'model_path')

# untrained model's performance
print("Untrained model's performance: {}".format(model.run_eval(test_file)))

# 4. train model
print("Begin model training...")
model.fit(train_file, valid_file)
print("End model training...")

# 5. evaluate model
#res_syn = model.run_eval(test_file)
#print("Trained model's performance: {}".format(res_syn))
#pm.record("rest_syn", res_syn)
# we can also get full prediction scores rather than evaluation metrics with model.predict(test_file, output_file)