Exemplo n.º 1
0
class PyTorchRayEstimator(OrcaRayEstimator):
    def __init__(self,
                 *,
                 model_creator,
                 optimizer_creator,
                 loss_creator=None,
                 metrics=None,
                 scheduler_creator=None,
                 training_operator_cls=TrainingOperator,
                 initialization_hook=None,
                 config=None,
                 scheduler_step_freq="batch",
                 use_tqdm=False,
                 backend="torch_distributed",
                 workers_per_node=1):

        if config is not None and "batch_size" in config:
            raise Exception(
                "Please do not specify batch_size in config. Input batch_size in the"
                " fit/evaluate/predict function of the estimator instead.")

        from zoo.orca.learn.pytorch.pytorch_ray_estimator import PyTorchRayEstimator
        self.estimator = PyTorchRayEstimator(
            model_creator=model_creator,
            optimizer_creator=optimizer_creator,
            loss_creator=loss_creator,
            metrics=metrics,
            scheduler_creator=scheduler_creator,
            training_operator_cls=training_operator_cls,
            initialization_hook=initialization_hook,
            config=config,
            scheduler_step_freq=scheduler_step_freq,
            use_tqdm=use_tqdm,
            backend=backend,
            workers_per_node=workers_per_node)

    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            profile=False,
            reduce_results=True,
            info=None,
            feature_cols=None,
            label_cols=None):
        """
        Trains a PyTorch model given training data for several epochs.

        Calls `TrainingOperator.train_epoch()` on N parallel workers simultaneously
        underneath the hood.
        :param data: An instance of SparkXShards, a Spark DataFrame or a function that
               takes config and batch_size as argument and returns a PyTorch DataLoader for
               training.
        :param epochs: The number of epochs to train the model. Default is 1.
        :param batch_size: The number of samples per batch for each worker. Default is 32.
               The total batch size would be workers_per_node*num_nodes.
               If your training data is a function, you can set batch_size to be the input
               batch_size of the function for the PyTorch DataLoader.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param reduce_results: Boolean. Whether to average all metrics across all workers into
               one dict. If a metric is a non-numerical value (or nested dictionaries), one value
               will be randomly selected among the workers. If False, returns a list of dicts for
               all workers.
               Default is True.
        :param info: An optional dictionary that can be passed to the TrainingOperator for
               train_epoch and train_batch.
        :param feature_cols: feature column names if data is Spark DataFrame.
        :param label_cols: label column names if data is Spark DataFrame.

        :return A list of dictionary of metrics for every training epoch. If reduce_results is
                False, this will return a nested list of metric dictionaries whose length will be
                equal to the total number of workers.
                You can also provide custom metrics by passing in a custom training_operator_cls
                when creating the Estimator.
        """
        return self.estimator.train(data=data,
                                    epochs=epochs,
                                    batch_size=batch_size,
                                    profile=profile,
                                    reduce_results=reduce_results,
                                    info=info,
                                    feature_cols=feature_cols,
                                    label_cols=label_cols)

    def predict(self, data, batch_size=32, feature_cols=None, profile=False):
        """
        Using this PyTorch model to make predictions on the data.

        :param data: An instance of SparkXShards or a Spark DataFrame
        :param batch_size: The number of samples per batch for each worker. Default is 32.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param feature_cols: feature column names if data is a Spark DataFrame.
        :return A SparkXShards that contains the predictions with key "prediction" in each shard
        """
        return self.estimator.predict(data,
                                      batch_size=batch_size,
                                      feature_cols=feature_cols,
                                      profile=profile)

    def evaluate(self,
                 data,
                 batch_size=32,
                 num_steps=None,
                 profile=False,
                 info=None,
                 feature_cols=None,
                 label_cols=None):
        """
        Evaluates a PyTorch model given validation data.
        Note that only accuracy for classification with zero-based label is supported by
        default. You can override validate_batch in TrainingOperator for other metrics.

        Calls `TrainingOperator.validate()` on N parallel workers simultaneously
        underneath the hood.
        :param data: An instance of SparkXShards, a Spark DataFrame or a function that
               takes config and batch_size as argument and returns a PyTorch DataLoader for
               validation.
        :param batch_size: The number of samples per batch for each worker. Default is 32.
               The total batch size would be workers_per_node*num_nodes.
               If your validation data is a function, you can set batch_size to be the input
               batch_size of the function for the PyTorch DataLoader.
        :param num_steps: The number of batches to compute the validation results on. This
               corresponds to the number of times `TrainingOperator.validate_batch` is called.
        :param profile: Boolean. Whether to return time stats for the training procedure.
               Default is False.
        :param info: An optional dictionary that can be passed to the TrainingOperator
               for validate.
        :param feature_cols: feature column names if train data is Spark DataFrame.
        :param label_cols: label column names if train data is Spark DataFrame.

        :return A dictionary of metrics for the given data, including validation accuracy and loss.
                You can also provide custom metrics by passing in a custom training_operator_cls
                when creating the Estimator.
        """
        return self.estimator.validate(data=data,
                                       batch_size=batch_size,
                                       num_steps=num_steps,
                                       profile=profile,
                                       info=info,
                                       feature_cols=feature_cols,
                                       label_cols=label_cols)

    def get_model(self):
        """
        Returns the learned PyTorch model.

        :return: The learned PyTorch model.
        """
        return self.estimator.get_model()

    def save(self, checkpoint):
        """
        Saves the Estimator state (including model and optimizer) to the provided checkpoint path.

        :param checkpoint: (str) Path to target checkpoint file.
        :return:
        """
        return self.estimator.save(checkpoint=checkpoint)

    def load(self, checkpoint):
        """
        Loads the Estimator state (including model and optimizer) from the provided checkpoint.

        :param checkpoint: (str) Path to target checkpoint file.
        """
        return self.estimator.load(checkpoint=checkpoint)

    def shutdown(self, force=False):
        """
        Shuts down workers and releases resources.

        :return:
        """
        return self.estimator.shutdown(force=force)
Exemplo n.º 2
0
class PyTorchRayEstimatorWrapper(Estimator):
    def __init__(self,
                 *,
                 model_creator,
                 optimizer_creator,
                 loss_creator=None,
                 scheduler_creator=None,
                 training_operator_cls=TrainingOperator,
                 initialization_hook=None,
                 config=None,
                 scheduler_step_freq="batch",
                 use_tqdm=False,
                 backend="pytorch",
                 workers_per_node=1):
        from zoo.orca.learn.pytorch.pytorch_ray_estimator import PyTorchRayEstimator
        self.estimator = PyTorchRayEstimator(
            model_creator=model_creator,
            optimizer_creator=optimizer_creator,
            loss_creator=loss_creator,
            scheduler_creator=scheduler_creator,
            training_operator_cls=training_operator_cls,
            initialization_hook=initialization_hook,
            config=config,
            scheduler_step_freq=scheduler_step_freq,
            use_tqdm=use_tqdm,
            backend=backend,
            workers_per_node=workers_per_node)

    def fit(self,
            data,
            epochs=1,
            profile=False,
            reduce_results=True,
            info=None):
        """

        :param data: (callable) a funtion that takes a config dict as input and return a data
            loader containing the training data.
        :param epochs: (int) Number of epochs to train the model
        :param profile: (bool) Returns time stats for the training procedure.
        :param reduce_results: (bool) Whether to average all metrics across all workers into one
            dict. If a metric is a non-numerical value (or nested dictionaries), one value will be
            randomly selected among the workers. If False, returns a list of dicts.
        :param info: (dict) Optional dictionary passed to the training operator for ``train_epoch``
            and ``train_batch``.
        :return: (list) A list of stats whose length will be equal to ``epochs``.
                stats is a dictionary of metrics for training.
                    You can provide custom metrics by passing in a custom
                    ``training_operator_cls``. If ``reduce_results=False``,
                    this will return a list of metric dictionaries whose
                    length will be equal to ``num_workers``.
        """
        return self.estimator.train(data_creator=data,
                                    epochs=epochs,
                                    profile=profile,
                                    reduce_results=reduce_results,
                                    info=info)

    def predict(self, data, **kwargs):
        pass

    def evaluate(self, data, num_steps=None, profile=False, info=None):
        """

        :param data: (callable) a funtion that takes a config dict as input and return
            a data loader containing the validation data.
        :param num_steps: (int) Number of batches to compute update steps on.
               This corresponds also to the number of times ``TrainingOperator.validate_batch``
               is called.
        :param profile: (bool) Returns time stats for the evaluation procedure.
        :param info: (dict) Optional dictionary passed to the training operator for `validate`
            and `validate_batch`.
        :return: A dictionary of metrics for validation.
            You can provide custom metrics by passing in a custom ``training_operator_cls``.
        """
        return self.estimator.validate(data_creator=data,
                                       num_steps=num_steps,
                                       profile=profile,
                                       info=info)

    def get_model(self):
        """Returns the learned model(s)."""
        return self.estimator.get_model()

    def save(self, checkpoint):
        """Saves the Estimator state to the provided checkpoint path.

        :param checkpoint: (str) Path to target checkpoint file.
        """
        return self.estimator.save(checkpoint=checkpoint)

    def load(self, checkpoint):
        """Loads the Estimator and all workers from the provided checkpoint.

        :param checkpoint: (str) Path to target checkpoint file.
        """
        return self.estimator.load(checkpoint=checkpoint)

    def shutdown(self, force=False):
        """Shuts down workers and releases resources."""
        return self.estimator.shutdown(force=force)