def test_renamed_kwargs():
    kwargs, c = {"a": 1}, 2
    val = renamed_kwargs("b", new_name="c", value=c, kwargs=kwargs)
    assert val == 2

    kwargs, c = {"a": 1, "c": 2}, 2
    val = renamed_kwargs("b", new_name="c", value=c, kwargs=kwargs)
    assert val == 2

    with pytest.warns(DeprecationWarning):
        kwargs, c = {"a": 1, "b": 3}, 2
        val = renamed_kwargs("b", new_name="c", value=c, kwargs=kwargs)
        assert val == 3
        assert kwargs == {"a": 1, "b": 3, "c": 3}
예제 #2
0
    def __init__(
            self,
            endpoint_name,
            sagemaker_session=None,
            serializer=IdentitySerializer(),
            deserializer=BytesDeserializer(),
            **kwargs,
    ):
        """Initialize a ``Predictor``.

        Behavior for serialization of input data and deserialization of
        result data can be configured through initializer arguments. If not
        specified, a sequence of bytes is expected and the API sends it in the
        request body without modifications. In response, the API returns the
        sequence of bytes from the prediction result without any modifications.

        Args:
            endpoint_name (str): Name of the Amazon SageMaker endpoint to which
                requests are sent.
            sagemaker_session (sagemaker.session.Session): A SageMaker Session
                object, used for SageMaker interactions (default: None). If not
                specified, one is created using the default AWS configuration
                chain.
            serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
                serializer object, used to encode data for an inference endpoint
                (default: :class:`~sagemaker.serializers.IdentitySerializer`).
            deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
                deserializer object, used to decode data from an inference
                endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
        """
        removed_kwargs("content_type", kwargs)
        removed_kwargs("accept", kwargs)
        endpoint_name = renamed_kwargs("endpoint", "endpoint_name",
                                       endpoint_name, kwargs)
        self.endpoint_name = endpoint_name
        self.sagemaker_session = sagemaker_session or Session()
        self.serializer = serializer
        self.deserializer = deserializer
        self._endpoint_config_name = self._get_endpoint_config_name()
        self._model_names = self._get_model_names()
        self._context = None
예제 #3
0
    def __init__(
        self,
        py_version,
        entry_point,
        transformers_version=None,
        tensorflow_version=None,
        pytorch_version=None,
        source_dir=None,
        hyperparameters=None,
        image_uri=None,
        distribution=None,
        **kwargs
    ):
        """This ``Estimator`` executes a HuggingFace script in a managed execution environment.

        The managed HuggingFace environment is an Amazon-built Docker container that executes
        functions defined in the supplied ``entry_point`` Python script within a SageMaker
        Training Job.

        Training is started by calling
        :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.

        Args:
            py_version (str): Python version you want to use for executing your model training
                code. Defaults to ``None``. Required unless ``image_uri`` is provided.  List
                of supported versions:
                https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators
            entry_point (str): Path (absolute or relative) to the Python source
                file which should be executed as the entry point to training.
                If ``source_dir`` is specified, then ``entry_point``
                must point to a file located at the root of ``source_dir``.
            transformers_version (str): Transformers version you want to use for
                executing your model training code. Defaults to ``None``. Required unless
                ``image_uri`` is provided. List of supported versions:
                https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
            tensorflow_version (str): TensorFlow version you want to use for
                executing your model training code. Defaults to ``None``. Required unless
                ``pytorch_version`` is provided. List of supported versions:
                https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
            pytorch_version (str): PyTorch version you want to use for
                executing your model training code. Defaults to ``None``. Required unless
                ``tensorflow_version`` is provided. List of supported versions:
                https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
            source_dir (str): Path (absolute, relative or an S3 URI) to a directory
                with any other training source code dependencies aside from the entry
                point file (default: None). If ``source_dir`` is an S3 URI, it must
                point to a tar.gz file. Structure within this directory are preserved
                when training on Amazon SageMaker.
            hyperparameters (dict): Hyperparameters that will be used for
                training (default: None). The hyperparameters are made
                accessible as a dict[str, str] to the training code on
                SageMaker. For convenience, this accepts other types for keys
                and values, but ``str()`` will be called to convert them before
                training.
            image_uri (str): If specified, the estimator will use this image
                for training and hosting, instead of selecting the appropriate
                SageMaker official image based on framework_version and
                py_version. It can be an ECR url or dockerhub image and tag.
                Examples:
                    * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
                    * ``custom-image:latest``

                If ``framework_version`` or ``py_version`` are ``None``, then
                ``image_uri`` is required. If also ``None``, then a ``ValueError``
                will be raised.
            distribution (dict): A dictionary with information on how to run distributed training
                (default: None).  Currently, the following are supported:
                distributed training with parameter servers, SageMaker Distributed (SMD) Data
                and Model Parallelism, and MPI. SMD Model Parallelism can only be used with MPI.
                To enable parameter server use the following setup:

                .. code:: python

                    {
                        "parameter_server": {
                            "enabled": True
                        }
                    }

                To enable MPI:

                .. code:: python

                    {
                        "mpi": {
                            "enabled": True
                        }
                    }

                To enable SMDistributed Data Parallel or Model Parallel:

                .. code:: python

                    {
                        "smdistributed": {
                            "dataparallel": {
                                "enabled": True
                            },
                            "modelparallel": {
                                "enabled": True,
                                "parameters": {}
                            }
                        }
                    }

            **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
                constructor.

        .. tip::

            You can find additional parameters for initializing this class at
            :class:`~sagemaker.estimator.Framework` and
            :class:`~sagemaker.estimator.EstimatorBase`.
        """
        self.framework_version = transformers_version
        self.py_version = py_version
        self.tensorflow_version = tensorflow_version
        self.pytorch_version = pytorch_version

        self._validate_args(image_uri=image_uri)

        if distribution is not None:
            instance_type = renamed_kwargs(
                "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
            )

            base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
            base_framework_version = (
                tensorflow_version if tensorflow_version is not None else pytorch_version
            )

            validate_smdistributed(
                instance_type=instance_type,
                framework_name=base_framework_name,
                framework_version=base_framework_version,
                py_version=self.py_version,
                distribution=distribution,
                image_uri=image_uri,
            )

            warn_if_parameter_server_with_multi_gpu(
                training_instance_type=instance_type, distribution=distribution
            )

        if "enable_sagemaker_metrics" not in kwargs:
            kwargs["enable_sagemaker_metrics"] = True

        super(HuggingFace, self).__init__(
            entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
        )
        self.distribution = distribution or {}
예제 #4
0
    def __init__(self,
                 py_version=None,
                 framework_version=None,
                 model_dir=None,
                 image_uri=None,
                 distribution=None,
                 **kwargs):
        """Initialize a ``TensorFlow`` estimator.

        Args:
            py_version (str): Python version you want to use for executing your model training
                code. Defaults to ``None``. Required unless ``image_uri`` is provided.
            framework_version (str): TensorFlow version you want to use for executing your model
                training code. Defaults to ``None``. Required unless ``image_uri`` is provided.
                List of supported versions:
                https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators.
            model_dir (str): S3 location where the checkpoint data and models can be exported to
                during training (default: None). It will be passed in the training script as one of
                the command line arguments. If not specified, one is provided based on
                your training configuration:

                * *distributed training with SMDistributed or MPI with Horovod* - ``/opt/ml/model``
                * *single-machine training or distributed training without MPI* - \
                    ``s3://{output_path}/model``
                * *Local Mode with local sources (file:// instead of s3://)* - \
                    ``/opt/ml/shared/model``

                To disable having ``model_dir`` passed to your training script,
                set ``model_dir=False``.
            image_uri (str): If specified, the estimator will use this image for training and
                hosting, instead of selecting the appropriate SageMaker official image based on
                framework_version and py_version. It can be an ECR url or dockerhub image and tag.

                Examples:
                    123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
                    custom-image:latest.

                If ``framework_version`` or ``py_version`` are ``None``, then
                ``image_uri`` is required. If also ``None``, then a ``ValueError``
                will be raised.
            distribution (dict): A dictionary with information on how to run distributed training
                (default: None). Currently, the following are supported:
                distributed training with parameter servers, SageMaker Distributed (SMD) Data
                and Model Parallelism, and MPI. SMD Model Parallelism can only be used with MPI.
                To enable parameter server use the following setup:

                .. code:: python

                    {
                        "parameter_server": {
                            "enabled": True
                        }
                    }

                To enable MPI:

                .. code:: python

                    {
                        "mpi": {
                            "enabled": True
                        }
                    }

                To enable SMDistributed Data Parallel or Model Parallel:

                .. code:: python

                    {
                        "smdistributed": {
                            "dataparallel": {
                                "enabled": True
                            },
                            "modelparallel": {
                                "enabled": True,
                                "parameters": {}
                            }
                        }
                    }

            **kwargs: Additional kwargs passed to the Framework constructor.

        .. tip::

            You can find additional parameters for initializing this class at
            :class:`~sagemaker.estimator.Framework` and
            :class:`~sagemaker.estimator.EstimatorBase`.
        """
        distribution = renamed_kwargs("distributions", "distribution",
                                      distribution, kwargs)
        instance_type = renamed_kwargs("train_instance_type", "instance_type",
                                       kwargs.get("instance_type"), kwargs)
        fw.validate_version_or_image_args(framework_version, py_version,
                                          image_uri)
        if py_version == "py2":
            logger.warning(
                fw.python_deprecation_warning(self._framework_name,
                                              defaults.LATEST_PY2_VERSION))
        self.framework_version = framework_version
        self.py_version = py_version
        self.instance_type = instance_type

        if distribution is not None:
            fw.warn_if_parameter_server_with_multi_gpu(
                training_instance_type=instance_type,
                distribution=distribution)
            fw.validate_smdistributed(
                instance_type=instance_type,
                framework_name=self._framework_name,
                framework_version=framework_version,
                py_version=py_version,
                distribution=distribution,
                image_uri=image_uri,
            )

        if "enable_sagemaker_metrics" not in kwargs:
            # enable sagemaker metrics for TF v1.15 or greater:
            if framework_version and version.Version(
                    framework_version) >= version.Version("1.15"):
                kwargs["enable_sagemaker_metrics"] = True

        super(TensorFlow, self).__init__(image_uri=image_uri, **kwargs)
        self.model_dir = model_dir
        self.distribution = distribution or {}

        self._validate_args(py_version=py_version)
예제 #5
0
    def __init__(self,
                 entry_point,
                 framework_version=None,
                 py_version=None,
                 source_dir=None,
                 hyperparameters=None,
                 image_uri=None,
                 distribution=None,
                 **kwargs):
        """This ``Estimator`` executes an PyTorch script in a managed PyTorch
        execution environment, within a SageMaker Training Job. The managed
        PyTorch environment is an Amazon-built Docker container that executes
        functions defined in the supplied ``entry_point`` Python script.

        Training is started by calling
        :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
        After training is complete, calling
        :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted
        SageMaker endpoint and returns an
        :class:`~sagemaker.amazon.pytorch.model.PyTorchPredictor` instance that
        can be used to perform inference against the hosted model.

        Technical documentation on preparing PyTorch scripts for SageMaker
        training and using the PyTorch Estimator is available on the project
        home-page: https://github.com/aws/sagemaker-python-sdk

        Args:
            entry_point (str): Path (absolute or relative) to the Python source
                file which should be executed as the entry point to training.
                If ``source_dir`` is specified, then ``entry_point``
                must point to a file located at the root of ``source_dir``.
            framework_version (str): PyTorch version you want to use for
                executing your model training code. Defaults to ``None``. Required unless
                ``image_uri`` is provided. List of supported versions:
                https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
            py_version (str): Python version you want to use for executing your
                model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
                unless ``image_uri`` is provided.
            source_dir (str): Path (absolute, relative or an S3 URI) to a directory
                with any other training source code dependencies aside from the entry
                point file (default: None). If ``source_dir`` is an S3 URI, it must
                point to a tar.gz file. Structure within this directory are preserved
                when training on Amazon SageMaker.
            hyperparameters (dict): Hyperparameters that will be used for
                training (default: None). The hyperparameters are made
                accessible as a dict[str, str] to the training code on
                SageMaker. For convenience, this accepts other types for keys
                and values, but ``str()`` will be called to convert them before
                training.
            image_uri (str): If specified, the estimator will use this image
                for training and hosting, instead of selecting the appropriate
                SageMaker official image based on framework_version and
                py_version. It can be an ECR url or dockerhub image and tag.
                Examples:
                    * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
                    * ``custom-image:latest``

                If ``framework_version`` or ``py_version`` are ``None``, then
                ``image_uri`` is required. If also ``None``, then a ``ValueError``
                will be raised.
            distribution (dict): A dictionary with information on how to run distributed training
                (default: None).  Currently, the following are supported:
                distributed training with parameter servers, SageMaker Distributed (SMD) Data
                and Model Parallelism, and MPI. SMD Model Parallelism can only be used with MPI.
                To enable parameter server use the following setup:

                .. code:: python

                    {
                        "parameter_server": {
                            "enabled": True
                        }
                    }

                To enable MPI:

                .. code:: python

                    {
                        "mpi": {
                            "enabled": True
                        }
                    }

                To enable SMDistributed Data Parallel or Model Parallel:

                .. code:: python

                    {
                        "smdistributed": {
                            "dataparallel": {
                                "enabled": True
                            },
                            "modelparallel": {
                                "enabled": True,
                                "parameters": {}
                            }
                        }
                    }

            **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
                constructor.

        .. tip::

            You can find additional parameters for initializing this class at
            :class:`~sagemaker.estimator.Framework` and
            :class:`~sagemaker.estimator.EstimatorBase`.
        """
        validate_version_or_image_args(framework_version, py_version,
                                       image_uri)
        if py_version == "py2":
            logger.warning(
                python_deprecation_warning(self._framework_name,
                                           defaults.LATEST_PY2_VERSION))
        self.framework_version = framework_version
        self.py_version = py_version

        if distribution is not None:
            instance_type = renamed_kwargs("train_instance_type",
                                           "instance_type",
                                           kwargs.get("instance_type"), kwargs)

            validate_smdistributed(
                instance_type=instance_type,
                framework_name=self._framework_name,
                framework_version=framework_version,
                py_version=py_version,
                distribution=distribution,
                image_uri=image_uri,
            )

            warn_if_parameter_server_with_multi_gpu(
                training_instance_type=instance_type,
                distribution=distribution)

        if "enable_sagemaker_metrics" not in kwargs:
            # enable sagemaker metrics for PT v1.3 or greater:
            if self.framework_version and Version(
                    self.framework_version) >= Version("1.3"):
                kwargs["enable_sagemaker_metrics"] = True

        super(PyTorch, self).__init__(entry_point,
                                      source_dir,
                                      hyperparameters,
                                      image_uri=image_uri,
                                      **kwargs)
        self.distribution = distribution or {}
예제 #6
0
    def __init__(self,
                 entry_point,
                 framework_version=None,
                 py_version="py3",
                 source_dir=None,
                 hyperparameters=None,
                 image_uri=None,
                 **kwargs):
        """This ``Estimator`` executes an Scikit-learn script in a managed
        Scikit-learn execution environment, within a SageMaker Training Job. The
        managed Scikit-learn environment is an Amazon-built Docker container
        that executes functions defined in the supplied ``entry_point`` Python
        script.

        Training is started by calling
        :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
        After training is complete, calling
        :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted
        SageMaker endpoint and returns an
        :class:`~sagemaker.amazon.sklearn.model.SKLearnPredictor` instance that
        can be used to perform inference against the hosted model.

        Technical documentation on preparing Scikit-learn scripts for
        SageMaker training and using the Scikit-learn Estimator is available on
        the project home-page: https://github.com/aws/sagemaker-python-sdk

        Args:
            entry_point (str): Path (absolute or relative) to the Python source
                file which should be executed as the entry point to training.
                If ``source_dir`` is specified, then ``entry_point``
                must point to a file located at the root of ``source_dir``.
            framework_version (str): Scikit-learn version you want to use for
                executing your model training code. Defaults to ``None``. Required
                unless ``image_uri`` is provided. List of supported versions:
                https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators
            py_version (str): Python version you want to use for executing your
                model training code (default: 'py3'). Currently, 'py3' is the only
                supported version. If ``None`` is passed in, ``image_uri`` must be
                provided.
            source_dir (str): Path (absolute, relative or an S3 URI) to a directory
                with any other training source code dependencies aside from the entry
                point file (default: None). If ``source_dir`` is an S3 URI, it must
                point to a tar.gz file. Structure within this directory are preserved
                when training on Amazon SageMaker.
            hyperparameters (dict): Hyperparameters that will be used for
                training (default: None). The hyperparameters are made
                accessible as a dict[str, str] to the training code on
                SageMaker. For convenience, this accepts other types for keys
                and values, but ``str()`` will be called to convert them before
                training.
            image_uri (str): If specified, the estimator will use this image
                for training and hosting, instead of selecting the appropriate
                SageMaker official image based on framework_version and
                py_version. It can be an ECR url or dockerhub image and tag.

                Examples:
                    123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
                    custom-image:latest.

                If ``framework_version`` or ``py_version`` are ``None``, then
                ``image_uri`` is required. If also ``None``, then a ``ValueError``
                will be raised.
            **kwargs: Additional kwargs passed to the
                :class:`~sagemaker.estimator.Framework` constructor.

        .. tip::

            You can find additional parameters for initializing this class at
            :class:`~sagemaker.estimator.Framework` and
            :class:`~sagemaker.estimator.EstimatorBase`.
        """
        instance_type = renamed_kwargs("train_instance_type", "instance_type",
                                       kwargs.get("instance_type"), kwargs)
        instance_count = renamed_kwargs("train_instance_count",
                                        "instance_count",
                                        kwargs.get("instance_count"), kwargs)
        validate_version_or_image_args(framework_version, py_version,
                                       image_uri)
        if py_version and py_version != "py3":
            raise AttributeError(
                "Scikit-learn image only supports Python 3. Please use 'py3' for py_version."
            )
        self.framework_version = framework_version
        self.py_version = py_version

        # SciKit-Learn does not support distributed training or training on GPU instance types.
        # Fail fast.
        _validate_not_gpu_instance_type(instance_type)

        if instance_count:
            if instance_count != 1:
                raise AttributeError(
                    "Scikit-Learn does not support distributed training. Please remove the "
                    "'instance_count' argument or set 'instance_count=1' when initializing SKLearn."
                )

        super(SKLearn, self).__init__(entry_point,
                                      source_dir,
                                      hyperparameters,
                                      image_uri=image_uri,
                                      **dict(kwargs, instance_count=1))

        if image_uri is None:
            self.image_uri = image_uris.retrieve(
                SKLearn._framework_name,
                self.sagemaker_session.boto_region_name,
                version=self.framework_version,
                py_version=self.py_version,
                instance_type=instance_type,
            )
예제 #7
0
    def __init__(self,
                 entry_point,
                 framework_version,
                 source_dir=None,
                 hyperparameters=None,
                 py_version="py3",
                 image_uri=None,
                 **kwargs):
        """An estimator that executes an XGBoost-based SageMaker Training Job.

        The managed XGBoost environment is an Amazon-built Docker container thatexecutes functions
        defined in the supplied ``entry_point`` Python script.

        Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this
        Estimator. After training is complete, calling
        :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint
        and returns an :class:`~sagemaker.amazon.xgboost.model.XGBoostPredictor` instance that
        can be used to perform inference against the hosted model.

        Technical documentation on preparing XGBoost scripts for SageMaker training and using the
        XGBoost Estimator is available on the project home-page:
        https://github.com/aws/sagemaker-python-sdk

        Args:
            entry_point (str): Path (absolute or relative) to the Python source file which should
                be executed as the entry point to training.  If ``source_dir`` is specified,
                then ``entry_point`` must point to a file located at the root of ``source_dir``.
            framework_version (str): XGBoost version you want to use for executing your model
                training code.
            source_dir (str): Path (absolute, relative or an S3 URI) to a directory
                with any other training source code dependencies aside from the entry
                point file (default: None). If ``source_dir`` is an S3 URI, it must
                point to a tar.gz file. Structure within this directory are preserved
                when training on Amazon SageMaker.
            hyperparameters (dict): Hyperparameters that will be used for training (default: None).
                The hyperparameters are made accessible as a dict[str, str] to the training code
                on SageMaker. For convenience, this accepts other types for keys and values, but
                ``str()`` will be called to convert them before training.
            py_version (str): Python version you want to use for executing your model
                training code (default: 'py3').
            image_uri (str): If specified, the estimator will use this image for training and
                hosting, instead of selecting the appropriate SageMaker official image
                based on framework_version and py_version. It can be an ECR url or
                dockerhub image and tag.
                Examples:
                    123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
                    custom-image:latest.
            **kwargs: Additional kwargs passed to the
                :class:`~sagemaker.estimator.Framework` constructor.

        .. tip::

            You can find additional parameters for initializing this class at
            :class:`~sagemaker.estimator.Framework` and
            :class:`~sagemaker.estimator.EstimatorBase`.
        """
        instance_type = renamed_kwargs("train_instance_type", "instance_type",
                                       kwargs.get("instance_type"), kwargs)
        super(XGBoost, self).__init__(entry_point,
                                      source_dir,
                                      hyperparameters,
                                      image_uri=image_uri,
                                      **kwargs)

        self.py_version = py_version
        self.framework_version = framework_version

        validate_py_version(py_version)
        validate_framework_version(framework_version)

        if image_uri is None:
            self.image_uri = image_uris.retrieve(
                self._framework_name,
                self.sagemaker_session.boto_region_name,
                version=framework_version,
                py_version=self.py_version,
                instance_type=instance_type,
                image_scope="training",
            )
예제 #8
0
    def __init__(
        self,
        py_version,
        entry_point,
        transformers_version=None,
        tensorflow_version=None,
        pytorch_version=None,
        source_dir=None,
        hyperparameters=None,
        image_uri=None,
        distribution=None,
        compiler_config=None,
        **kwargs,
    ):
        """This estimator runs a Hugging Face training script in a SageMaker training environment.

        The estimator initiates the SageMaker-managed Hugging Face environment
        by using the pre-built Hugging Face Docker container and runs
        the Hugging Face training script that user provides through
        the ``entry_point`` argument.

        After configuring the estimator class, use the class method
        :meth:`~sagemaker.amazon.estimator.Framework.fit()` to start a training job.

        Args:
            py_version (str): Python version you want to use for executing your model training
                code. Defaults to ``None``. Required unless ``image_uri`` is provided.  If
                using PyTorch, the current supported version is ``py36``. If using TensorFlow,
                the current supported version is ``py37``.
            entry_point (str): Path (absolute or relative) to the Python source
                file which should be executed as the entry point to training.
                If ``source_dir`` is specified, then ``entry_point``
                must point to a file located at the root of ``source_dir``.
            transformers_version (str): Transformers version you want to use for
                executing your model training code. Defaults to ``None``. Required unless
                ``image_uri`` is provided. The current supported version is ``4.6.1``.
            tensorflow_version (str): TensorFlow version you want to use for
                executing your model training code. Defaults to ``None``. Required unless
                ``pytorch_version`` is provided. The current supported version is ``2.4.1``.
            pytorch_version (str): PyTorch version you want to use for
                executing your model training code. Defaults to ``None``. Required unless
                ``tensorflow_version`` is provided. The current supported versions are ``1.7.1`` and ``1.6.0``.
            source_dir (str): Path (absolute, relative or an S3 URI) to a directory
                with any other training source code dependencies aside from the entry
                point file (default: None). If ``source_dir`` is an S3 URI, it must
                point to a tar.gz file. Structure within this directory are preserved
                when training on Amazon SageMaker.
            hyperparameters (dict): Hyperparameters that will be used for
                training (default: None). The hyperparameters are made
                accessible as a dict[str, str] to the training code on
                SageMaker. For convenience, this accepts other types for keys
                and values, but ``str()`` will be called to convert them before
                training.
            image_uri (str): If specified, the estimator will use this image
                for training and hosting, instead of selecting the appropriate
                SageMaker official image based on framework_version and
                py_version. It can be an ECR url or dockerhub image and tag.
                Examples:
                    * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
                    * ``custom-image:latest``

                If ``framework_version`` or ``py_version`` are ``None``, then
                ``image_uri`` is required. If also ``None``, then a ``ValueError``
                will be raised.
            distribution (dict): A dictionary with information on how to run distributed training
                (default: None).  Currently, the following are supported:
                distributed training with parameter servers, SageMaker Distributed (SMD) Data
                and Model Parallelism, and MPI. SMD Model Parallelism can only be used with MPI.
                To enable parameter server use the following setup:

                .. code:: python

                    {
                        "parameter_server": {
                            "enabled": True
                        }
                    }

                To enable MPI:

                .. code:: python

                    {
                        "mpi": {
                            "enabled": True
                        }
                    }

                To enable SMDistributed Data Parallel or Model Parallel:

                .. code:: python

                    {
                        "smdistributed": {
                            "dataparallel": {
                                "enabled": True
                            },
                            "modelparallel": {
                                "enabled": True,
                                "parameters": {}
                            }
                        }
                    }
            compiler_config (:class:`~sagemaker.huggingface.TrainingCompilerConfig`):
                Configures SageMaker Training Compiler to accelerate training.

            **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
                constructor.

        .. tip::

            You can find additional parameters for initializing this class at
            :class:`~sagemaker.estimator.Framework` and
            :class:`~sagemaker.estimator.EstimatorBase`.
        """
        self.framework_version = transformers_version
        self.py_version = py_version
        self.tensorflow_version = tensorflow_version
        self.pytorch_version = pytorch_version

        self._validate_args(image_uri=image_uri)

        instance_type = renamed_kwargs(
            "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
        )

        base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
        base_framework_version = (
            tensorflow_version if tensorflow_version is not None else pytorch_version
        )

        if distribution is not None:
            validate_smdistributed(
                instance_type=instance_type,
                framework_name=base_framework_name,
                framework_version=base_framework_version,
                py_version=self.py_version,
                distribution=distribution,
                image_uri=image_uri,
            )

            warn_if_parameter_server_with_multi_gpu(
                training_instance_type=instance_type, distribution=distribution
            )

        if "enable_sagemaker_metrics" not in kwargs:
            kwargs["enable_sagemaker_metrics"] = True

        kwargs["py_version"] = self.py_version

        super(HuggingFace, self).__init__(
            entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
        )

        if compiler_config is not None:
            if not isinstance(compiler_config, TrainingCompilerConfig):
                error_string = (
                    f"Expected instance of type {TrainingCompilerConfig}"
                    f"for argument compiler_config. "
                    f"Instead got {type(compiler_config)}"
                )
                raise ValueError(error_string)
            if compiler_config:
                compiler_config.validate(
                    image_uri=image_uri,
                    instance_type=instance_type,
                    distribution=distribution,
                )

        self.distribution = distribution or {}
        self.compiler_config = compiler_config