Пример #1
0
    def evaluate(self,
                 data_creator,
                 verbose=1,
                 sample_weight=None,
                 steps=None,
                 callbacks=None,
                 data_config=None,
                 feature_cols=None,
                 label_cols=None):
        """Evaluates the model on the validation data set."""
        logger.info("Starting validation step.")
        params = dict(
            verbose=verbose,
            sample_weight=sample_weight,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards
        from pyspark.sql import DataFrame

        if isinstance(data_creator, DataFrame):
            assert feature_cols is not None,\
                "feature_col must be provided if data_creator is a spark dataframe"
            assert label_cols is not None,\
                "label_cols must be provided if data_creator is a spark dataframe"
            schema = data_creator.schema
            numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(
                row, schema, feature_cols, label_cols))
            shard_rdd = numpy_rdd.mapPartitions(
                lambda x: arrays2dict(x, feature_cols, label_cols))
            data_creator = SparkXShards(shard_rdd)

        if isinstance(data_creator, SparkXShards):
            data = data_creator
            if data.num_partitions() != self.num_workers:
                data = data.repartition(self.num_workers)

            ray_xshards = RayXShards.from_spark_xshards(data)

            def transform_func(worker, shards_ref):
                params["data_creator"] = shards_ref_to_creator(shards_ref)
                return worker.validate.remote(**params)

            stats_shards = ray_xshards.transform_shards_with_actors(
                self.remote_workers, transform_func, gang_scheduling=True)
            worker_stats = stats_shards.collect()

        else:  # data_creator functions; should return Iter or DataLoader
            params["data_creator"] = data_creator
            params_list = [params] * self.num_workers

            worker_stats = ray.get([
                w.validate.remote(**params_list[i])
                for i, w in enumerate(self.remote_workers)
            ])
            worker_stats = list(itertools.chain.from_iterable(worker_stats))
        stats = worker_stats[0].copy()
        return stats
Пример #2
0
    def predict(self,
                data_creator,
                batch_size=None,
                verbose=1,
                steps=None,
                callbacks=None,
                data_config=None,
                feature_cols=None):
        """Evaluates the model on the validation data set."""
        logger.info("Starting predict step.")
        params = dict(
            verbose=verbose,
            batch_size=batch_size,
            steps=steps,
            callbacks=callbacks,
            data_config=data_config,
        )
        from zoo.orca.data import SparkXShards
        from pyspark.sql import DataFrame
        if isinstance(data_creator, DataFrame):
            assert feature_cols is not None,\
                "feature_col must be provided if data_creator is a spark dataframe"
            schema = data_creator.schema
            numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(
                row, schema, feature_cols, None))
            shard_rdd = numpy_rdd.mapPartitions(
                lambda x: arrays2dict(x, feature_cols, None))
            data_creator = SparkXShards(shard_rdd)
        if isinstance(data_creator, SparkXShards):
            ray_xshards = RayXShards.from_spark_xshards(data_creator)

            def transform_func(worker, shards_ref):
                params["data_creator"] = shards_ref_to_creator(shards_ref)
                return worker.predict.remote(**params)

            stats_shards = ray_xshards.transform_shards_with_actors(
                self.remote_workers, transform_func, gang_scheduling=False)
            spark_xshards = stats_shards.to_spark_xshards()

        else:
            raise ValueError("Only xshards is supported for predict")

        return spark_xshards
Пример #3
0
    def fit(
        self,
        data_creator,
        epochs=1,
        verbose=1,
        callbacks=None,
        validation_data_creator=None,
        class_weight=None,
        steps_per_epoch=None,
        validation_steps=None,
        validation_freq=1,
        data_config=None,
        feature_cols=None,
        label_cols=None,
    ):
        """Runs a training epoch."""
        params = dict(epochs=epochs,
                      verbose=verbose,
                      callbacks=callbacks,
                      class_weight=class_weight,
                      steps_per_epoch=steps_per_epoch,
                      validation_steps=validation_steps,
                      validation_freq=validation_freq,
                      data_config=data_config)

        from zoo.orca.data import SparkXShards
        from pyspark.sql import DataFrame
        if isinstance(data_creator, DataFrame):
            assert feature_cols is not None,\
                "feature_col must be provided if data_creator is a spark dataframe"
            assert label_cols is not None,\
                "label_cols must be provided if data_creator is a spark dataframe"
            schema = data_creator.schema
            numpy_rdd = data_creator.rdd.map(lambda row: convert_row_to_numpy(
                row, schema, feature_cols, label_cols))
            shard_rdd = numpy_rdd.mapPartitions(
                lambda x: arrays2dict(x, feature_cols, label_cols))
            data_creator = SparkXShards(shard_rdd)

        if isinstance(data_creator, SparkXShards):
            max_length, ray_xshards = process_spark_xshards(
                data_creator, self.num_workers)

            if validation_data_creator is None:

                def transform_func(worker, shards_ref):
                    params["data_creator"] = shards_ref_to_creator(shards_ref)
                    return worker.step.remote(**params)

                stats_shards = ray_xshards.transform_shards_with_actors(
                    self.remote_workers, transform_func, gang_scheduling=True)
            else:
                val_max_length, val_ray_xshards = process_spark_xshards(
                    validation_data_creator, self.num_workers)

                def zip_func(worker, this_shards_ref, that_shards_ref):
                    params["data_creator"] = shards_ref_to_creator(
                        this_shards_ref)
                    params["validation_data_creator"] =\
                        shards_ref_to_creator(that_shards_ref)
                    return worker.step.remote(**params)

                stats_shards = ray_xshards.zip_shards_with_actors(
                    val_ray_xshards,
                    self.remote_workers,
                    zip_func,
                    gang_scheduling=True)
            worker_stats = stats_shards.collect()
        else:
            params["data_creator"] = data_creator
            params["validation_data_creator"] = validation_data_creator
            params_list = [params] * self.num_workers

            worker_stats = ray.get([
                self.remote_workers[i].step.remote(**params_list[i])
                for i in range(self.num_workers)
            ])
            worker_stats = list(itertools.chain.from_iterable(worker_stats))
        stats = worker_stats[0].copy()
        return stats