def predict(self, data, batch_size=32, feature_cols=None, profile=False): from zoo.orca.data import SparkXShards param = dict( batch_size=batch_size, profile=profile ) from pyspark.sql import DataFrame if isinstance(data, DataFrame): xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict") pred_shards = self._predict_spark_xshards(xshards, param) result = convert_predict_xshards_to_dataframe(data, pred_shards) elif isinstance(data, SparkXShards): pred_shards = self._predict_spark_xshards(data, param) result = update_predict_xshards(data, pred_shards) else: raise ValueError("Only xshards or Spark DataFrame is supported for predict") return result
def test_update_predict_xshard_multi_output(self): def get_data_xshards(key): rdd = self.sc.range(0, 110).map(lambda x: np.array([x] * 50)) shards = rdd.mapPartitions(lambda iter: chunks(iter, 5)).map( lambda x: {key: np.stack(x)}) shards = SparkXShards(shards) return shards def get_pred_xshards(key): rdd = self.sc.range(0, 110).map(lambda x: np.array([x] * 50)) shards = rdd.mapPartitions(lambda iter: chunks(iter, 5)).map( lambda x: { key: np.stack(x) }).map(lambda x: {key: [x[key][:, :24], x[key][:, 24:]]}) shards = SparkXShards(shards) return shards data_shards = get_data_xshards("x") pred_shards = get_pred_xshards("prediction") result_shards = update_predict_xshards(data_shards, pred_shards) result = np.concatenate([ np.concatenate(shard["prediction"], axis=1) for shard in result_shards.collect() ]) expected_result = np.concatenate( [shard["x"] for shard in result_shards.collect()]) assert np.array_equal(result, expected_result)
def predict(self, data, batch_size=None, verbose=1, steps=None, callbacks=None, data_config=None, feature_cols=None): """Evaluates the model on the validation data set.""" logger.info("Starting predict step.") params = dict( verbose=verbose, batch_size=batch_size, steps=steps, callbacks=callbacks, data_config=data_config, ) from zoo.orca.data import SparkXShards from pyspark.sql import DataFrame if isinstance(data, DataFrame): xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict") pred_shards = self._predict_spark_xshards(xshards, params) result = convert_predict_xshards_to_dataframe(data, pred_shards) elif isinstance(data, SparkXShards): pred_shards = self._predict_spark_xshards(data, params) result = update_predict_xshards(data, pred_shards) else: raise ValueError("Only xshards or Spark DataFrame is supported for predict") return result
def predict(self, data, batch_size=None, verbose=1, steps=None, callbacks=None, data_config=None, feature_cols=None): """ Predict the input data :param data: predict input data. It can be XShards or Spark DataFrame. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature}, where feature is a numpy array or a tuple of numpy arrays. :param batch_size: Batch size used for inference. Default: None. :param verbose: Prints output of one model if true. :param steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of None. :param callbacks: List of Keras compatible callbacks to apply during prediction. :param data_config: An optional dictionary that can be passed to data creator function. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :return: """ logger.info("Starting predict step.") params = dict( verbose=verbose, batch_size=batch_size, steps=steps, callbacks=callbacks, data_config=data_config, ) from zoo.orca.data import SparkXShards from pyspark.sql import DataFrame if isinstance(data, DataFrame): xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict", accept_str_col=True) pred_shards = self._predict_spark_xshards(xshards, params) result = convert_predict_xshards_to_dataframe(data, pred_shards) elif isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data = process_xshards_of_pandas_dataframe(data, feature_cols) pred_shards = self._predict_spark_xshards(data, params) result = update_predict_xshards(data, pred_shards) else: raise ValueError( "Only xshards or Spark DataFrame is supported for predict") return result