def test_NFM(): name = "NFM" sample_size = 64 feature_dim_dict = {'sparse': {'sparse_1': 2, 'sparse_2': 5, 'sparse_3': 10}, 'dense': ['dense_1', 'dense_2', 'dense_3']} sparse_input = [np.random.randint(0, dim, sample_size) for dim in feature_dim_dict['sparse'].values()] dense_input = [np.random.random(sample_size) for name in feature_dim_dict['dense']] y = np.random.randint(0, 2, sample_size) x = sparse_input + dense_input model = NFM(feature_dim_dict, embedding_size=8, hidden_size=[32, 32], keep_prob=0.5, ) model.compile('adam', 'binary_crossentropy', metrics=['binary_crossentropy']) model.fit(x, y, batch_size=100, epochs=1, validation_split=0.5) print(name+" test train valid pass!") model.save_weights(name + '_weights.h5') model.load_weights(name + '_weights.h5') print(name+" test save load weight pass!") save_model(model, name + '.h5') model = load_model(name + '.h5', custom_objects) print(name + " test save load model pass!") print(name + " test pass!")
# 将数据集切分成训练集和测试集 train, test = train_test_split(data, test_size=0.2) train_model_input = {name: train[name].values for name in feature_names} test_model_input = {name: test[name].values for name in feature_names} # 使用NFM进行训练 model = NFM(linear_feature_columns, dnn_feature_columns, task='regression') model.compile( "adam", "mse", metrics=['mse'], ) history = model.fit( train_model_input, train[target].values, batch_size=256, epochs=1, verbose=True, validation_split=0.2, ) # 使用NFM进行预测 pred_ans = model.predict(test_model_input, batch_size=256) # 输出RMSE或MSE mse = round(mean_squared_error(test[target].values, pred_ans), 4) rmse = mse**0.5 print("test RMSE", rmse) ''' Train on 128 samples, validate on 32 samples 2019-11-10 01:25:40.830752: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA 128/128 [==============================] - 1s 6ms/sample - loss: 14.5472 - mean_squared_error: 14.5472 - val_loss: 13.8782 - val_mean_squared_error: 13.8782 test RMSE 3.5901532000737797 '''