Beispiel #1
0
    def predict(self, data, batch_size=32, feature_cols=None, profile=False):
        """
        Using this PyTorch model to make predictions on the data.

        :param data: An instance of SparkXShards or a Spark DataFrame
        :param batch_size: The number of samples per batch for each worker. Default is 32.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param feature_cols: feature column names if data is a Spark DataFrame.
        :return: A SparkXShards that contains the predictions with key "prediction" in each shard
        """
        from bigdl.orca.data import SparkXShards
        param = dict(batch_size=batch_size, profile=profile)
        from pyspark.sql import DataFrame
        if isinstance(data, DataFrame):
            xshards, _ = dataframe_to_xshards(data,
                                              validation_data=None,
                                              feature_cols=feature_cols,
                                              label_cols=None,
                                              mode="predict")
            pred_shards = self._predict_spark_xshards(xshards, param)
            result = convert_predict_xshards_to_dataframe(data, pred_shards)
        elif isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(data, feature_cols)
            pred_shards = self._predict_spark_xshards(data, param)
            result = update_predict_xshards(data, pred_shards)
        else:
            raise ValueError(
                "Only xshards or Spark DataFrame is supported for predict")

        return result
Beispiel #2
0
 def test_convert_predict_xshards_to_dataframe_multi_output(self):
     rdd = self.sc.range(0, 100)
     df = rdd.map(lambda x: ([float(x)] * 50,
                             [int(np.random.randint(0, 2, size=()))])).toDF(
                                 ["feature", "label"])
     pred_shards = _dataframe_to_xshards(
         df, feature_cols=["feature"]).transform_shard(
             lambda x: {"prediction": [x["x"][:, :25], x["x"][:, 25:]]})
     result_df = convert_predict_xshards_to_dataframe(df, pred_shards)
     expr = "sum(cast(feature <> flatten(prediction) as int)) as error"
     assert result_df.selectExpr(expr).first()["error"] == 0
Beispiel #3
0
    def predict(self,
                data,
                batch_size=None,
                verbose=1,
                steps=None,
                callbacks=None,
                data_config=None,
                feature_cols=None):
        """
        Predict the input data

        :param data: predict input data.  It can be XShards or Spark DataFrame.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature}, where feature is a numpy array or a tuple of numpy arrays.
        :param batch_size: Batch size used for inference. Default: None.
        :param verbose: Prints output of one model if true.
        :param steps: Total number of steps (batches of samples) before declaring the prediction
               round finished. Ignored with the default value of None.
        :param callbacks: List of Keras compatible callbacks to apply during prediction.
        :param data_config: An optional dictionary that can be passed to data creator function.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame or an XShards of Pandas DataFrame. Default: None.
        :return:
        """
        logger.info("Starting predict step.")
        params = dict(
            verbose=verbose,
            batch_size=batch_size,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from bigdl.orca.data import SparkXShards
        from pyspark.sql import DataFrame

        if isinstance(data, DataFrame):
            xshards, _ = dataframe_to_xshards(data,
                                              validation_data=None,
                                              feature_cols=feature_cols,
                                              label_cols=None,
                                              mode="predict",
                                              accept_str_col=True)
            pred_shards = self._predict_spark_xshards(xshards, params)
            result = convert_predict_xshards_to_dataframe(data, pred_shards)
        elif isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(data, feature_cols)
            pred_shards = self._predict_spark_xshards(data, params)
            result = update_predict_xshards(data, pred_shards)
        else:
            raise ValueError(
                "Only xshards or Spark DataFrame is supported for predict")

        return result
    def predict(self,
                data,
                batch_size=None,
                verbose=1,
                steps=None,
                callbacks=None,
                data_config=None,
                feature_cols=None):
        """
        Predict the input data
        :param data: predict input data.  It can be XShards or Spark DataFrame.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature}, where feature is a numpy array or a tuple of numpy arrays.
        :param batch_size: Batch size used for inference. Default: None.
        :param verbose: Prints output of one model if true.
        :param steps: Total number of steps (batches of samples) before declaring the prediction
               round finished. Ignored with the default value of None.
        :param callbacks: List of Keras compatible callbacks to apply during prediction.
        :param data_config: An optional dictionary that can be passed to data creator function.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame or an XShards of Pandas DataFrame. Default: None.
        :return:
        """
        logger.info("Starting predict step.")
        sc = OrcaContext.get_spark_context()
        if self.model_weights:
            weights = sc.broadcast(self.model_weights)
        else:
            weights = None

        init_params = dict(model_creator=self.model_creator,
                           compile_args_creator=self.compile_args_creator,
                           config=self.config,
                           verbose=self.verbose,
                           size=self.num_workers,
                           model_weights=weights,
                           mode="predict",
                           cluster_info=None)

        params = dict(verbose=verbose,
                      batch_size=batch_size,
                      steps=steps,
                      callbacks=callbacks,
                      data_config=data_config)

        if isinstance(data, DataFrame):
            data = data.repartition(self.num_workers)
            xshards, _ = dataframe_to_xshards(data,
                                              validation_data=None,
                                              feature_cols=feature_cols,
                                              label_cols=None,
                                              mode="predict",
                                              accept_str_col=True)

            def transform_func(iter, init_param, param):
                partition_data = list(iter)
                # res = combine_in_partition(partition_data)
                param["data_creator"] = make_data_creator(partition_data)
                return SparkRunner(**init_param).predict(**param)

            pred_shards = SparkXShards(xshards.rdd.repartition(self.num_workers) \
                                       .mapPartitions(
                                           lambda iter: transform_func(iter, init_params, params)))
            result = convert_predict_xshards_to_dataframe(data, pred_shards)
        else:
            raise ValueError(
                "Only xshards or Spark DataFrame is supported for predict")

        return result