Ejemplo n.º 1
0
    def predict(self,
                data,
                batch_size=32,
                feature_cols=None,
                profile=False):
        from zoo.orca.data import SparkXShards
        param = dict(
            batch_size=batch_size,
            profile=profile
        )
        from pyspark.sql import DataFrame
        if isinstance(data, DataFrame):
            xshards, _ = dataframe_to_xshards(data,
                                              validation_data=None,
                                              feature_cols=feature_cols,
                                              label_cols=None,
                                              mode="predict")
            pred_shards = self._predict_spark_xshards(xshards, param)
            result = convert_predict_xshards_to_dataframe(data, pred_shards)
        elif isinstance(data, SparkXShards):
            pred_shards = self._predict_spark_xshards(data, param)
            result = update_predict_xshards(data, pred_shards)
        else:
            raise ValueError("Only xshards or Spark DataFrame is supported for predict")

        return result
Ejemplo n.º 2
0
    def test_update_predict_xshard_multi_output(self):
        def get_data_xshards(key):
            rdd = self.sc.range(0, 110).map(lambda x: np.array([x] * 50))
            shards = rdd.mapPartitions(lambda iter: chunks(iter, 5)).map(
                lambda x: {key: np.stack(x)})
            shards = SparkXShards(shards)
            return shards

        def get_pred_xshards(key):
            rdd = self.sc.range(0, 110).map(lambda x: np.array([x] * 50))
            shards = rdd.mapPartitions(lambda iter: chunks(iter, 5)).map(
                lambda x: {
                    key: np.stack(x)
                }).map(lambda x: {key: [x[key][:, :24], x[key][:, 24:]]})
            shards = SparkXShards(shards)
            return shards

        data_shards = get_data_xshards("x")
        pred_shards = get_pred_xshards("prediction")

        result_shards = update_predict_xshards(data_shards, pred_shards)
        result = np.concatenate([
            np.concatenate(shard["prediction"], axis=1)
            for shard in result_shards.collect()
        ])
        expected_result = np.concatenate(
            [shard["x"] for shard in result_shards.collect()])

        assert np.array_equal(result, expected_result)
Ejemplo n.º 3
0
    def predict(self, data, batch_size=None, verbose=1,
                steps=None, callbacks=None, data_config=None,
                feature_cols=None):
        """Evaluates the model on the validation data set."""
        logger.info("Starting predict step.")
        params = dict(
            verbose=verbose,
            batch_size=batch_size,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards
        from pyspark.sql import DataFrame

        if isinstance(data, DataFrame):
            xshards, _ = dataframe_to_xshards(data,
                                              validation_data=None,
                                              feature_cols=feature_cols,
                                              label_cols=None,
                                              mode="predict")
            pred_shards = self._predict_spark_xshards(xshards, params)
            result = convert_predict_xshards_to_dataframe(data, pred_shards)
        elif isinstance(data, SparkXShards):
            pred_shards = self._predict_spark_xshards(data, params)
            result = update_predict_xshards(data, pred_shards)
        else:
            raise ValueError("Only xshards or Spark DataFrame is supported for predict")

        return result
Ejemplo n.º 4
0
    def predict(self,
                data,
                batch_size=None,
                verbose=1,
                steps=None,
                callbacks=None,
                data_config=None,
                feature_cols=None):
        """
        Predict the input data

        :param data: predict input data.  It can be XShards or Spark DataFrame.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature}, where feature is a numpy array or a tuple of numpy arrays.
        :param batch_size: Batch size used for inference. Default: None.
        :param verbose: Prints output of one model if true.
        :param steps: Total number of steps (batches of samples) before declaring the prediction
               round finished. Ignored with the default value of None.
        :param callbacks: List of Keras compatible callbacks to apply during prediction.
        :param data_config: An optional dictionary that can be passed to data creator function.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame or an XShards of Pandas DataFrame. Default: None.
        :return:
        """
        logger.info("Starting predict step.")
        params = dict(
            verbose=verbose,
            batch_size=batch_size,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards
        from pyspark.sql import DataFrame

        if isinstance(data, DataFrame):
            xshards, _ = dataframe_to_xshards(data,
                                              validation_data=None,
                                              feature_cols=feature_cols,
                                              label_cols=None,
                                              mode="predict",
                                              accept_str_col=True)
            pred_shards = self._predict_spark_xshards(xshards, params)
            result = convert_predict_xshards_to_dataframe(data, pred_shards)
        elif isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(data, feature_cols)
            pred_shards = self._predict_spark_xshards(data, params)
            result = update_predict_xshards(data, pred_shards)
        else:
            raise ValueError(
                "Only xshards or Spark DataFrame is supported for predict")

        return result