def test_estimator_keras_save_load(self): import zoo.orca.data.pandas tf.reset_default_graph() model = self.create_model() file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv") data_shard = zoo.orca.data.pandas.read_csv(file_path) def transform(df): result = { "x": (df['user'].to_numpy().reshape([-1, 1]), df['item'].to_numpy().reshape([-1, 1])), "y": df['label'].to_numpy() } return result data_shard = data_shard.transform_shard(transform) est = Estimator.from_keras(keras_model=model) est.fit(data=data_shard, batch_size=8, epochs=10, validation_data=data_shard) eval_result = est.evaluate(data_shard) print(eval_result) temp = tempfile.mkdtemp() model_path = os.path.join(temp, 'test.h5') est.save_keras_model(model_path) tf.reset_default_graph() est = Estimator.load_keras_model(model_path) data_shard = zoo.orca.data.pandas.read_csv(file_path) def transform(df): result = { "x": (df['user'].to_numpy().reshape([-1, 1]), df['item'].to_numpy().reshape([-1, 1])), } return result data_shard = data_shard.transform_shard(transform) predictions = est.predict(data_shard).collect() assert predictions[0]['prediction'].shape[1] == 2 shutil.rmtree(temp)
model = ncf_model.getKerasModel(u_limit, m_limit, u_output, m_output, args.log_dir) est = Estimator.from_keras(model, model_dir=args.log_dir) est.fit(data=trainingDF, batch_size=batch_size, epochs=max_epoch, feature_cols=['features'], label_cols=['labels'], validation_data=validationDF) # save the model est.save_keras_model(save_model_dir) # metrics ,result and save model print(model.metrics_names) #Orca the predict function supports native spark data frame ! Just need to tell batch_size and feature_cols # use a new Estimamtor to validate load model API pre_est = Estimator.load_keras_model(save_model_dir) prediction_df = pre_est.predict(inferenceDF, batch_size=batch_size, feature_cols=['features']) prediction_df.show(5) score_udf = udf(lambda pred: 0.0 if pred[0] > pred[1] else 1.0, FloatType()) prediction_df = prediction_df.withColumn('prediction2', score_udf('prediction')) prediction_df.show(10) # Save Table #prediction_final_df.write.mode('overwrite').parquet(predict_output_path) prediction_df.select( 'uid', 'mid', 'prediction2').write.mode('overwrite').parquet(predict_output_path) #prediction_df.select('uid','mid','prediction2').write.mode('overwrite').format("csv").save(predict_output_path)