def test_estimator_keras_save_load(self): import zoo.orca.data.pandas tf.reset_default_graph() model = self.create_model() file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv") data_shard = zoo.orca.data.pandas.read_csv(file_path) def transform(df): result = { "x": (df['user'].to_numpy().reshape([-1, 1]), df['item'].to_numpy().reshape([-1, 1])), "y": df['label'].to_numpy() } return result data_shard = data_shard.transform_shard(transform) est = Estimator.from_keras(keras_model=model) est.fit(data=data_shard, batch_size=8, epochs=10, validation_data=data_shard) eval_result = est.evaluate(data_shard) print(eval_result) temp = tempfile.mkdtemp() model_path = os.path.join(temp, 'test.h5') est.save_keras_model(model_path) tf.reset_default_graph() from tensorflow.python.keras import models from zoo.common.utils import load_from_file def load_func(file_path): return models.load_model(file_path) model = load_from_file(load_func, model_path) est = Estimator.from_keras(keras_model=model) data_shard = zoo.orca.data.pandas.read_csv(file_path) def transform(df): result = { "x": (df['user'].to_numpy().reshape([-1, 1]), df['item'].to_numpy().reshape([-1, 1])), } return result data_shard = data_shard.transform_shard(transform) predictions = est.predict(data_shard).collect() assert predictions[0]['prediction'].shape[1] == 2 shutil.rmtree(temp)
def load_keras_model(path): """ Create Estimator by loading an existing keras model (with weights) from HDF5 file. :param path: String. The path to the pre-defined model. :return: Orca TF Estimator. """ from tensorflow.python.keras import models def load_func(file_path): return models.load_model(file_path) model = load_from_file(load_func, path) return Estimator.from_keras(keras_model=model)