def test_estimator_keras_save_load(self):
        import zoo.orca.data.pandas

        tf.reset_default_graph()

        model = self.create_model()
        file_path = os.path.join(self.resource_path, "orca/learn/ncf.csv")
        data_shard = zoo.orca.data.pandas.read_csv(file_path)

        def transform(df):
            result = {
                "x": (df['user'].to_numpy().reshape([-1, 1]),
                      df['item'].to_numpy().reshape([-1, 1])),
                "y": df['label'].to_numpy()
            }
            return result

        data_shard = data_shard.transform_shard(transform)

        est = Estimator.from_keras(keras_model=model)
        est.fit(data=data_shard,
                batch_size=8,
                epochs=10,
                validation_data=data_shard)

        eval_result = est.evaluate(data_shard)
        print(eval_result)

        temp = tempfile.mkdtemp()
        model_path = os.path.join(temp, 'test.h5')
        est.save_keras_model(model_path)

        tf.reset_default_graph()

        from tensorflow.python.keras import models
        from zoo.common.utils import load_from_file

        def load_func(file_path):
            return models.load_model(file_path)

        model = load_from_file(load_func, model_path)
        est = Estimator.from_keras(keras_model=model)

        data_shard = zoo.orca.data.pandas.read_csv(file_path)

        def transform(df):
            result = {
                "x": (df['user'].to_numpy().reshape([-1, 1]),
                      df['item'].to_numpy().reshape([-1, 1])),
            }
            return result

        data_shard = data_shard.transform_shard(transform)
        predictions = est.predict(data_shard).collect()
        assert predictions[0]['prediction'].shape[1] == 2
        shutil.rmtree(temp)
Ejemplo n.º 2
0
    def load_keras_model(path):
        """
        Create Estimator by loading an existing keras model (with weights) from HDF5 file.

        :param path: String. The path to the pre-defined model.
        :return: Orca TF Estimator.
        """
        from tensorflow.python.keras import models

        def load_func(file_path):
            return models.load_model(file_path)

        model = load_from_file(load_func, path)
        return Estimator.from_keras(keras_model=model)