Example #1
0
def test_train(ray_start_2_cpus, num_replicas):  # noqa: F811
    trainer = TFTrainer(model_creator=simple_model,
                        data_creator=simple_dataset,
                        num_replicas=num_replicas,
                        config=SIMPLE_CONFIG)

    train_stats1 = trainer.train()
    train_stats1.update(trainer.validate())

    train_stats2 = trainer.train()
    train_stats2.update(trainer.validate())
Example #2
0
class TFEstimator(EstimatorInterface, SparkEstimatorInterface):
    def __init__(self,
                 num_workers: int = 1,
                 model: keras.Model = None,
                 optimizer: Union[keras.optimizers.Optimizer, str] = None,
                 loss: Union[keras.losses.Loss, str] = None,
                 metrics: Union[List[keras.metrics.Metric], List[str]] = None,
                 feature_columns: Union[str, List[str]] = None,
                 feature_types: Optional[Union[DType, List[DType]]] = None,
                 feature_shapes: Optional[Union[TensorShape,
                                                List[TensorShape]]] = None,
                 label_column: str = None,
                 label_type: Optional[tf.DType] = None,
                 label_shape: Optional[tf.TensorShape] = None,
                 batch_size: int = 128,
                 num_epochs: int = 1,
                 shuffle: bool = True,
                 **extra_config):
        """A scikit-learn like API to distributed training Tensorflow Keras model.

        In the backend it leverage the ray.sgd.TorchTrainer.
        :param num_workers: the number of workers for distributed model training
        :param model: the model, it should be instance of tensorflow.keras.Model. We do not support
                      multiple output models.
        :param optimizer: the optimizer, it should be keras.optimizers.Optimizer instance or str.
                          We do not support multiple optimizers currently.
        :param loss: the loss, it should be keras.losses.Loss instance or str. We do not support
                     multiple losses.
        :param metrics: the metrics list. It could be None, a list of keras.metrics.Metric instance
                        or a list of str.
        :param feature_columns: the feature columns name.
               The inputs of the model will be match the feature columns.
               .. code-block:: python
                   feature_columns = ["x", "y", "z"]
                   # the input to the model will be (x_batch_tensor, y_batch_tensor, z_batch_tensor)
        :param feature_types: the type for each feature input. It must match the length of the
                              feature_columns if provided. It will be tf.float32 by default.
        :param feature_shapes: the shape for each feature input. It must match the length of the
                               feature_columns
        :param label_column: the label column name.
        :param label_type: the label type, it will be tf.float32 by default.
        :param label_shape: the label shape.
        :param batch_size: the batch size
        :param num_epochs: the number of epochs
        :param shuffle: whether input dataset should be shuffle, True by default.
        :param extra_config: extra config will fit into TFTrainer. You can also set
               the get_shard config with
               {"get_shard": {batch_ms=0, num_async=5, shuffle_buffer_size=2, seed=0}}.
               You can refer to the MLDataset.get_repeatable_shard for the parameters.
        """
        self._num_workers: int = num_workers

        # model
        assert model is not None, "model must be not be None"
        if isinstance(model, keras.Model):
            self._serialized_model = model.to_json()
        else:
            raise Exception(
                "Unsupported parameter, we only support tensorflow.keras.Model"
            )

        # optimizer
        # TODO: we should support multiple optimizers for multiple outputs model
        assert optimizer is not None, "optimizer must not be None"
        if isinstance(optimizer, str):
            # it is a str represents the optimizer
            _optimizer = optimizer
        elif isinstance(optimizer, keras.optimizers.Optimizer):
            _optimizer = keras.optimizers.serialize(optimizer)
        else:
            raise Exception(
                "Unsupported parameter, we only support keras.optimizers.Optimizer subclass "
                "instance or a str to represent the optimizer")
        self._serialized_optimizer = _optimizer

        # loss
        # TODO: we should support multiple losses for multiple outputs model
        assert loss is not None, "loss must not be None"
        if isinstance(loss, str):
            _loss = loss
        elif isinstance(loss, keras.losses.Loss):
            _loss = keras.losses.serialize(loss)
        else:
            raise Exception(
                "Unsupported parameter, we only support keras.losses.Loss subclass "
                "instance or a str to represents the loss)")
        self._serialized_loss = _loss

        # metrics
        if metrics is None:
            _metrics = None
        else:
            assert isinstance(metrics, list), "metrics must be a list"
            if isinstance(metrics[0], str):
                _metrics = metrics
            elif isinstance(metrics[0], keras.metrics.Metric):
                _metrics = [keras.metrics.serialize(m) for m in metrics]
            else:
                raise Exception(
                    "Unsupported parameter, we only support list of "
                    "keras.metrics.Metrics instances or list of str to")
        self._serialized_metrics = _metrics

        self._feature_columns = feature_columns
        self._feature_types = feature_types
        self._feature_shapes = feature_shapes
        self._label_column = label_column
        self._label_type = label_type
        self._label_shape = label_shape
        self._batch_size = batch_size
        self._extra_config = extra_config

        config = {"batch_size": self._batch_size, "shuffle": shuffle}
        if self._extra_config:
            if "config" in self._extra_config:
                self._extra_config["config"].update(config)
            else:
                self._extra_config["config"] = config
        else:
            self._extra_config = {"config": config}

        self._num_epochs: int = num_epochs

        self._trainer: TFTrainer = None

    def _create_tf_ds(self, ds: MLDataset) -> TFMLDataset:
        return ds.to_tf(self._feature_columns, self._feature_shapes,
                        self._feature_types, self._label_column,
                        self._label_shape, self._label_type)

    def fit(self,
            train_ds: MLDataset,
            evaluate_ds: Optional[MLDataset] = None) -> NoReturn:
        super(TFEstimator, self).fit(train_ds, evaluate_ds)

        def model_creator(config):
            # https://github.com/ray-project/ray/issues/5914
            import tensorflow.keras as keras

            model: keras.Model = keras.models.model_from_json(
                self._serialized_model)
            optimizer = keras.optimizers.get(self._serialized_optimizer)
            loss = keras.losses.get(self._serialized_loss)
            metrics = [keras.metrics.get(m) for m in self._serialized_metrics]
            model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
            return model

        train_ds = train_ds.batch(self._batch_size)
        train_tf_ds = self._create_tf_ds(train_ds)

        if evaluate_ds is not None:
            evaluate_ds = evaluate_ds.batch(self._batch_size)
            evaluate_tf_ds = self._create_tf_ds(evaluate_ds)
        else:
            evaluate_tf_ds = None

        def data_creator(config):
            if "TF_CONFIG" in os.environ:
                tf_config = json.loads(os.environ["TF_CONFIG"])
                world_rank = tf_config["task"]["index"]
            else:
                world_rank = -1
            batch_size = config["batch_size"]
            get_shard_config = config.get("get_shard", {})
            if "shuffle" in config:
                get_shard_config["shuffle"] = config["shuffle"]
            train_data = train_tf_ds.get_shard(
                world_rank, **get_shard_config).repeat().batch(batch_size)
            evaluate_data = None
            if evaluate_tf_ds is not None:
                evaluate_data = evaluate_tf_ds.get_shard(
                    world_rank, **get_shard_config).batch(batch_size)
            return train_data, evaluate_data

        self._trainer = TFTrainer(model_creator=model_creator,
                                  data_creator=data_creator,
                                  num_replicas=self._num_workers,
                                  **self._extra_config)
        for i in range(self._num_epochs):
            stats = self._trainer.train()
            print(f"Epoch-{i}: {stats}")

        if evaluate_tf_ds is not None:
            print(self._trainer.validate())

    def fit_on_spark(self,
                     train_df: DF,
                     evaluate_df: OPTIONAL_DF = None,
                     fs_directory: Optional[str] = None,
                     compression: Optional[str] = None) -> NoReturn:
        super(TFEstimator, self).fit_on_spark(train_df, evaluate_df)
        train_df = self._check_and_convert(train_df)
        if evaluate_df is not None:
            evaluate_df = self._check_and_convert(evaluate_df)
        train_ds = create_ml_dataset_from_spark(train_df, self._num_workers,
                                                self._batch_size, fs_directory,
                                                compression)
        evaluate_ds = None
        if evaluate_df is not None:
            evaluate_ds = create_ml_dataset_from_spark(evaluate_df,
                                                       self._num_workers,
                                                       self._batch_size,
                                                       fs_directory,
                                                       compression)
        return self.fit(train_ds, evaluate_ds)

    def get_model(self) -> Any:
        assert self._trainer, "Trainer has not been created"
        return self._trainer.get_model()

    def save(self, file_path) -> NoReturn:
        assert self._trainer, "Trainer has not been created"
        self._trainer.save(file_path)

    def restore(self, file_path) -> NoReturn:
        assert self._trainer, "Trainer has not been created"
        self._trainer.restore(file_path)

    def shutdown(self) -> NoReturn:
        if self._trainer is not None:
            self._trainer.shutdown()
            del self._trainer