def test_estimator_graph_fit_dataset(self):
        import zoo.orca.data.pandas
        tf.reset_default_graph()
        model = SimpleModel()

        file_path = os.path.join(resource_path, "orca/learn/ncf.csv")
        data_shard = zoo.orca.data.pandas.read_csv(file_path)

        def transform(df):
            result = {
                "x": (df['user'].to_numpy(), df['item'].to_numpy()),
                "y": df['label'].to_numpy()
            }
            return result

        data_shard = data_shard.transform_shard(transform)
        dataset = Dataset.from_tensor_slices(data_shard)

        est = Estimator.from_graph(
            inputs=[model.user, model.item],
            labels=[model.label],
            loss=model.loss,
            optimizer=tf.train.AdamOptimizer(),
            metrics={"loss": model.loss})
        est.fit(data=dataset,
                batch_size=8,
                epochs=10,
                validation_data=dataset)

        result = est.evaluate(dataset, batch_size=4)
        assert 'loss' in result
 def read_as_tf(path):
     """
     return a orca.data.tf.data.Dataset
     :param path:
     :return:
     """
     from zoo.orca.data.tf.data import Dataset
     xshards = ParquetDataset._read_as_xshards(path)
     return Dataset.from_tensor_slices(xshards)
Beispiel #3
0
    def test_estimator_graph_predict_dataset(self):
        tf.reset_default_graph()

        model = SimpleModel()
        file_path = os.path.join(resource_path, "orca/learn/ncf.csv")
        data_shard = zoo.orca.data.pandas.read_csv(file_path)

        est = Estimator.from_graph(inputs=[model.user, model.item],
                                   outputs=[model.logits])

        def transform(df):
            result = {
                "x": (df['user'].to_numpy(), df['item'].to_numpy()),
            }
            return result

        data_shard = data_shard.transform_shard(transform)
        dataset = Dataset.from_tensor_slices(data_shard)
        predictions = est.predict(dataset).collect()
        assert len(predictions) == 10