Пример #1
0
def process_spark_xshards(spark_xshards, num_workers):
    from zoo.orca.data.shard import RayXShards
    data = spark_xshards
    if data.num_partitions() != num_workers:
        data = data.repartition(num_workers)
    ray_xshards = RayXShards.from_spark_xshards(data)
    return ray_xshards
Пример #2
0
    def predict(self, data, batch_size=None, verbose=1,
                steps=None, callbacks=None, data_config=None,
                feature_cols=None):
        """Evaluates the model on the validation data set."""
        logger.info("Starting predict step.")
        params = dict(
            verbose=verbose,
            batch_size=batch_size,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards
        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             labels_cols=None,
                                             mode="predict")

        if isinstance(data, SparkXShards):
            ray_xshards = RayXShards.from_spark_xshards(data)

            def transform_func(worker, shards_ref):
                params["data_creator"] = shards_ref_to_creator(shards_ref)
                return worker.predict.remote(**params)

            pred_shards = ray_xshards.transform_shards_with_actors(self.remote_workers,
                                                                   transform_func,
                                                                   gang_scheduling=False)
            spark_xshards = pred_shards.to_spark_xshards()

        else:
            raise ValueError("Only xshards or Spark DataFrame is supported for predict")

        return spark_xshards
Пример #3
0
    def predict(self,
                data,
                batch_size=32,
                feature_cols=None,
                profile=False):
        from zoo.orca.data import SparkXShards
        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             labels_cols=None,
                                             mode="predict")
        if isinstance(data, SparkXShards):
            ray_xshards = RayXShards.from_spark_xshards(data)

            def transform_func(worker, shards_ref):
                data_creator = lambda config: shards_ref
                return worker.predict.remote(
                    data_creator, batch_size, profile)

            pred_shards = ray_xshards.transform_shards_with_actors(self.remote_workers,
                                                                   transform_func,
                                                                   gang_scheduling=False)
            spark_xshards = pred_shards.to_spark_xshards()
        else:
            raise ValueError("Only xshards or Spark DataFrame is supported for predict")

        return spark_xshards
Пример #4
0
    def evaluate(self,
                 data_creator,
                 verbose=1,
                 sample_weight=None,
                 steps=None,
                 callbacks=None,
                 data_config=None,
                 feature_cols=None,
                 label_cols=None):
        """Evaluates the model on the validation data set."""
        logger.info("Starting validation step.")
        params = dict(
            verbose=verbose,
            sample_weight=sample_weight,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards
        from pyspark.sql import DataFrame

        if isinstance(data_creator, DataFrame):
            assert feature_cols is not None,\
                "feature_col must be provided if data_creator is a spark dataframe"
            assert label_cols is not None,\
                "label_cols must be provided if data_creator is a spark dataframe"
            schema = data_creator.schema
            numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(
                row, schema, feature_cols, label_cols))
            shard_rdd = numpy_rdd.mapPartitions(
                lambda x: arrays2dict(x, feature_cols, label_cols))
            data_creator = SparkXShards(shard_rdd)

        if isinstance(data_creator, SparkXShards):
            data = data_creator
            if data.num_partitions() != self.num_workers:
                data = data.repartition(self.num_workers)

            ray_xshards = RayXShards.from_spark_xshards(data)

            def transform_func(worker, shards_ref):
                params["data_creator"] = shards_ref_to_creator(shards_ref)
                return worker.validate.remote(**params)

            stats_shards = ray_xshards.transform_shards_with_actors(
                self.remote_workers, transform_func, gang_scheduling=True)
            worker_stats = stats_shards.collect()

        else:  # data_creator functions; should return Iter or DataLoader
            params["data_creator"] = data_creator
            params_list = [params] * self.num_workers

            worker_stats = ray.get([
                w.validate.remote(**params_list[i])
                for i, w in enumerate(self.remote_workers)
            ])
            worker_stats = list(itertools.chain.from_iterable(worker_stats))
        stats = worker_stats[0].copy()
        return stats
Пример #5
0
def process_spark_xshards(spark_xshards, num_workers):
    data = spark_xshards
    if data.num_partitions() != num_workers:
        data = data.repartition(num_workers)

    # todo currently we need this information to pad the short partitions
    # so that every model run exactly the same number of steps in one epoch
    max_length = data.rdd.map(data_length) \
        .mapPartitions(lambda iterator: [sum(iterator)]).max()
    ray_xshards = RayXShards.from_spark_xshards(data)
    return max_length, ray_xshards
Пример #6
0
    def evaluate(self, data, batch_size=32, num_steps=None, verbose=1,
                 sample_weight=None, callbacks=None, data_config=None,
                 feature_cols=None, labels_cols=None):
        """Evaluates the model on the validation data set."""
        logger.info("Starting validation step.")
        params = dict(
            batch_size=batch_size,
            verbose=verbose,
            sample_weight=sample_weight,
            steps=num_steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards

        data, _ = maybe_dataframe_to_xshards(data,
                                             validation_data=None,
                                             feature_cols=feature_cols,
                                             labels_cols=labels_cols,
                                             mode="evaluate")

        if isinstance(data, SparkXShards):
            data = data
            if data.num_partitions() != self.num_workers:
                data = data.repartition(self.num_workers)

            ray_xshards = RayXShards.from_spark_xshards(data)

            def transform_func(worker, shards_ref):
                params["data_creator"] = shards_ref_to_creator(shards_ref)
                return worker.validate.remote(**params)

            stats_shards = ray_xshards.transform_shards_with_actors(self.remote_workers,
                                                                    transform_func,
                                                                    gang_scheduling=True)
            worker_stats = stats_shards.collect()

        else:  # data_creator functions; should return Iter or DataLoader
            params["data_creator"] = data
            params_list = [params] * self.num_workers

            worker_stats = ray.get([w.validate.remote(**params_list[i])
                                    for i, w in enumerate(self.remote_workers)])
            worker_stats = list(itertools.chain.from_iterable(worker_stats))
        stats = worker_stats[0].copy()
        return stats
Пример #7
0
    def predict(self,
                data_creator,
                batch_size=None,
                verbose=1,
                steps=None,
                callbacks=None,
                data_config=None,
                feature_cols=None):
        """Evaluates the model on the validation data set."""
        logger.info("Starting predict step.")
        params = dict(
            verbose=verbose,
            batch_size=batch_size,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards
        from pyspark.sql import DataFrame
        if isinstance(data_creator, DataFrame):
            assert feature_cols is not None,\
                "feature_col must be provided if data_creator is a spark dataframe"
            schema = data_creator.schema
            numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(
                row, schema, feature_cols, None))
            shard_rdd = numpy_rdd.mapPartitions(
                lambda x: arrays2dict(x, feature_cols, None))
            data_creator = SparkXShards(shard_rdd)
        if isinstance(data_creator, SparkXShards):
            ray_xshards = RayXShards.from_spark_xshards(data_creator)

            def transform_func(worker, shards_ref):
                params["data_creator"] = shards_ref_to_creator(shards_ref)
                return worker.predict.remote(**params)

            stats_shards = ray_xshards.transform_shards_with_actors(
                self.remote_workers, transform_func, gang_scheduling=False)
            spark_xshards = stats_shards.to_spark_xshards()

        else:
            raise ValueError("Only xshards is supported for predict")

        return spark_xshards