def __init__( self, model_data, role, entry_point, image=None, py_version="py2", framework_version=None, predictor_cls=TensorFlowPredictor, model_server_workers=None, **kwargs ): """Initialize an TensorFlowModel. Args: model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5. image (str): A Docker image URI (default: None). If not specified, a default image for TensorFlow will be used. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). framework_version (str): TensorFlow version you want to use for executing your model training code. predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. **kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer. .. tip:: You can find additional parameters for initializing this class at :class:`~sagemaker.model.FrameworkModel` and :class:`~sagemaker.model.Model`. """ super(TensorFlowModel, self).__init__( model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs ) if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) if framework_version is None: logger.warning(empty_framework_version_warning(TF_VERSION, LATEST_VERSION)) self.py_version = py_version self.framework_version = framework_version or TF_VERSION self.model_server_workers = model_server_workers
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2', framework_version=None, image_name=None, distributions=None, **kwargs): """ This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions defined in the supplied ``entry_point`` Python script. Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator. After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.mxnet.model.MXNetPredictor` instance that can be used to perform inference against the hosted model. Technical documentation on preparing MXNet scripts for SageMaker training and using the MXNet Estimator is available on the project home-page: https://github.com/aws/sagemaker-python-sdk Args: entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from tne entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). One of 'py2' or 'py3'. framework_version (str): MXNet version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 custom-image:latest. distributions (dict): A dictionary with information on how to run distributed training (default: None). **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. """ if framework_version is None: logger.warning( empty_framework_version_warning(MXNET_VERSION, self.LATEST_VERSION)) self.framework_version = framework_version or MXNET_VERSION super(MXNet, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs) self.py_version = py_version self._configure_distribution(distributions)
def __init__( self, sagemaker_session: sagemaker.Session, role: str, image_name: str, base_job_name: str, train_instance_type: str = "ml.c5.xlarge", train_instance_count: int = 1, dependencies: Optional[List[str]] = None, output_path: str = None, code_location: str = None, framework_version: str = GLUONTS_VERSION, hyperparameters: Dict = None, entry_point: str = str(ENTRY_POINTS_FOLDER / TRAIN_SCRIPT), **kwargs, ): # Framework_version currently serves no purpose, # except for compatibility with the sagemaker framework. if framework_version is None: logger.warning( empty_framework_version_warning( GLUONTS_VERSION, self.LATEST_VERSION ) ) self.framework_version = framework_version or GLUONTS_VERSION super().__init__( dependencies=dependencies, output_path=output_path, code_location=code_location, sagemaker_session=sagemaker_session, role=role, train_instance_type=train_instance_type, train_instance_count=train_instance_count, base_job_name=base_job_name, entry_point=entry_point, hyperparameters=hyperparameters, image_name=image_name, **kwargs, ) # must be set self.py_version = PYTHON_VERSION self._s3fs = s3fs.S3FileSystem( session=self.sagemaker_session.boto_session )
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', framework_version=None, requirements_file='', image_name=None, **kwargs): """Initialize an ``TensorFlow`` estimator. Args: training_steps (int): Perform this many steps of training. `None`, the default means train forever. evaluation_steps (int): Perform this many steps of evaluation. `None`, the default means that evaluation runs until input from eval_input_fn is exhausted (or another exception is raised). checkpoint_path (str): Identifies S3 location where checkpoint data during model training can be saved (default: None). For distributed model training, this parameter is required. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). framework_version (str): TensorFlow version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and relative to ``source_dir``. Details on the format can be found in the `Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_. image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 custom-image:latest. **kwargs: Additional kwargs passed to the Framework constructor. """ if framework_version is None: LOGGER.warning( empty_framework_version_warning(TF_VERSION, TF_VERSION)) self.framework_version = framework_version or TF_VERSION super(TensorFlow, self).__init__(image_name=image_name, **kwargs) self.checkpoint_path = checkpoint_path self.py_version = py_version self.training_steps = training_steps self.evaluation_steps = evaluation_steps self._validate_requirements_file(requirements_file) self.requirements_file = requirements_file
def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=None, hyperparameters=None, py_version='py3', image_name=None, **kwargs): """ This ``Estimator`` executes an Scikit-learn script in a managed Scikit-learn execution environment, within a SageMaker Training Job. The managed Scikit-learn environment is an Amazon-built Docker container that executes functions defined in the supplied ``entry_point`` Python script. Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator. After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.sklearn.model.SKLearnPredictor` instance that can be used to perform inference against the hosted model. Technical documentation on preparing Scikit-learn scripts for SageMaker training and using the Scikit-learn Estimator is available on the project home-page: https://github.com/aws/sagemaker-python-sdk Args: entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from tne entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). One of 'py2' or 'py3'. framework_version (str): Scikit-learn version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#sklearn-sagemaker-estimators image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 custom-image:latest. **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. """ # SciKit-Learn does not support distributed training or training on GPU instance types. Fail fast. train_instance_type = kwargs.get('train_instance_type') _validate_not_gpu_instance_type(train_instance_type) train_instance_count = kwargs.get('train_instance_count') if train_instance_count: if train_instance_count != 1: raise AttributeError( "Scikit-Learn does not support distributed training. " "Please remove the 'train_instance_count' argument or set " "'train_instance_count=1' when initializing SKLearn.") super(SKLearn, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **dict(kwargs, train_instance_count=1)) self.py_version = py_version if framework_version is None: logger.warning( empty_framework_version_warning(SKLEARN_VERSION, SKLEARN_VERSION)) self.framework_version = framework_version or SKLEARN_VERSION if image_name is None: image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version) self.image_name = default_framework_uri( SKLearn.__framework_name__, self.sagemaker_session.boto_region_name, image_tag)
def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_per_host=None, additional_mpi_options=None, source_dir=None, hyperparameters=None, py_version="py3", framework_version=None, image_name=None, **kwargs): """This ``Estimator`` executes an Chainer script in a managed Chainer execution environment, within a SageMaker Training Job. The managed Chainer environment is an Amazon-built Docker container that executes functions defined in the supplied ``entry_point`` Python script. Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator. After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.chainer.model.ChainerPredictor` instance that can be used to perform inference against the hosted model. Technical documentation on preparing Chainer scripts for SageMaker training and using the Chainer Estimator is available on the project home-page: https://github.com/aws/sagemaker-python-sdk Args: entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. use_mpi (bool): If true, entry point is run as an MPI script. By default, the Chainer Framework runs the entry point with 'mpirun' if more than one instance is used. num_processes (int): Total number of processes to run the entry point with. By default, the Chainer Framework runs one process per GPU (on GPU instances), or one process per host (on CPU instances). process_slots_per_host (int): The number of processes that can run on each instance. By default, this is set to the number of GPUs on the instance (on GPU instances), or one (on CPU instances). additional_mpi_options (str): String of options to the 'mpirun' command used to run the entry point. For example, '-X NCCL_DEBUG=WARN' will pass that option string to the mpirun command. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from tne entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). One of 'py2' or 'py3'. framework_version (str): Chainer version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators. If not specified, this will default to 4.1. image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0`` * ``custom-image:latest`` **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. .. tip:: You can find additional parameters for initializing this class at :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: logger.warning( empty_framework_version_warning(CHAINER_VERSION, self.LATEST_VERSION)) self.framework_version = framework_version or CHAINER_VERSION super(Chainer, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs) if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version self.use_mpi = use_mpi self.num_processes = num_processes self.process_slots_per_host = process_slots_per_host self.additional_mpi_options = additional_mpi_options
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', framework_version=None, model_dir=None, requirements_file='', image_name=None, script_mode=False, distributions=None, **kwargs): """Initialize an ``TensorFlow`` estimator. Args: training_steps (int): Perform this many steps of training. `None`, the default means train forever. evaluation_steps (int): Perform this many steps of evaluation. `None`, the default means that evaluation runs until input from eval_input_fn is exhausted (or another exception is raised). checkpoint_path (str): Identifies S3 location where checkpoint data during model training can be saved (default: None). For distributed model training, this parameter is required. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). framework_version (str): TensorFlow version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators model_dir (str): S3 location where the checkpoint data and models can be exported to during training (default: None). If not specified a default S3 URI will be generated. It will be passed in the training script as one of the command line arguments. requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and relative to ``source_dir``. Details on the format can be found in the `Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_. image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 custom-image:latest. script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False). This will be ignored if py_version is set to 'py3'. distributions (dict): A dictionary with information on how to run distributed training (default: None). Currently we only support distributed training with parameter servers. To enable it use the following setup: { 'parameter_server': { 'enabled': True } } **kwargs: Additional kwargs passed to the Framework constructor. """ if framework_version is None: LOGGER.warning( fw.empty_framework_version_warning(TF_VERSION, TF_VERSION)) self.framework_version = framework_version or TF_VERSION super(TensorFlow, self).__init__(image_name=image_name, **kwargs) self.checkpoint_path = checkpoint_path self.py_version = py_version self.training_steps = training_steps self.evaluation_steps = evaluation_steps self.model_dir = model_dir self.script_mode = script_mode self.distributions = distributions or {} self._validate_args(py_version=py_version, script_mode=script_mode, framework_version=framework_version, training_steps=training_steps, evaluation_steps=evaluation_steps, requirements_file=requirements_file, checkpoint_path=checkpoint_path) self._validate_requirements_file(requirements_file) self.requirements_file = requirements_file
def __init__( self, entry_point, source_dir=None, hyperparameters=None, py_version=defaults.PYTHON_VERSION, framework_version=None, image_name=None, **kwargs ): """This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker Training Job. The managed PyTorch environment is an Amazon-built Docker container that executes functions defined in the supplied ``entry_point`` Python script. Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator. After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.pytorch.model.PyTorchPredictor` instance that can be used to perform inference against the hosted model. Technical documentation on preparing PyTorch scripts for SageMaker training and using the PyTorch Estimator is available on the project home-page: https://github.com/aws/sagemaker-python-sdk Args: entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. py_version (str): Python version you want to use for executing your model training code (default: 'py3'). One of 'py2' or 'py3'. framework_version (str): PyTorch version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators. If not specified, this will default to 0.4. image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0`` * ``custom-image:latest`` **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. .. tip:: You can find additional parameters for initializing this class at :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: logger.warning( empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION) ) self.framework_version = framework_version or defaults.PYTORCH_VERSION if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for PT v1.3 or greater: if is_version_equal_or_higher([1, 3], self.framework_version): kwargs["enable_sagemaker_metrics"] = True super(PyTorch, self).__init__( entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs ) if py_version == "py2": logger.warning( python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION) ) self.py_version = py_version
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version=None, framework_version=None, model_dir=None, requirements_file="", image_name=None, script_mode=False, distributions=None, **kwargs): """Initialize a ``TensorFlow`` estimator. Args: training_steps (int): Perform this many steps of training. `None`, the default means train forever. evaluation_steps (int): Perform this many steps of evaluation. `None`, the default means that evaluation runs until input from eval_input_fn is exhausted (or another exception is raised). checkpoint_path (str): Identifies S3 location where checkpoint data during model training can be saved (default: None). For distributed model training, this parameter is required. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). framework_version (str): TensorFlow version you want to use for executing your model training code. If not specified, this will default to 1.11. model_dir (str): S3 location where the checkpoint data and models can be exported to during training (default: None). It will be passed in the training script as one of the command line arguments. If not specified, one is provided based on your training configuration: * *distributed training with MPI* - ``/opt/ml/model`` * *single-machine training or distributed training without MPI* - \ ``s3://{output_path}/model`` * *Local Mode with local sources (file:// instead of s3://)* - \ ``/opt/ml/shared/model`` requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and relative to ``source_dir``. Details on the format can be found in the Pip User Guide: <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format> image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0 custom-image:latest. script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False). This will be ignored if py_version is set to 'py3'. distributions (dict): A dictionary with information on how to run distributed training (default: None). Currently we support distributed training with parameter servers and MPI. To enable parameter server use the following setup: .. code:: python { 'parameter_server': { 'enabled': True } } To enable MPI: .. code:: python { 'mpi': { 'enabled': True } } **kwargs: Additional kwargs passed to the Framework constructor. .. tip:: You can find additional parameters for initializing this class at :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: logger.warning( fw.empty_framework_version_warning(defaults.TF_VERSION, self.LATEST_VERSION)) self.framework_version = framework_version or defaults.TF_VERSION if not py_version: py_version = "py3" if self._only_python_3_supported() else "py2" if py_version == "py2": logger.warning( fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)) if distributions is not None: logger.warning( fw.parameter_v2_rename_warning("distribution", distributions)) train_instance_type = kwargs.get("train_instance_type") fw.warn_if_parameter_server_with_multi_gpu( training_instance_type=train_instance_type, distributions=distributions) if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for TF v1.15 or greater: if fw.is_version_equal_or_higher([1, 15], self.framework_version): kwargs["enable_sagemaker_metrics"] = True super(TensorFlow, self).__init__(image_name=image_name, **kwargs) self.checkpoint_path = checkpoint_path self.py_version = py_version self.training_steps = training_steps self.evaluation_steps = evaluation_steps self.model_dir = model_dir self.script_mode = script_mode self.distributions = distributions or {} self._validate_args( py_version=py_version, script_mode=script_mode, framework_version=self.framework_version, training_steps=training_steps, evaluation_steps=evaluation_steps, requirements_file=requirements_file, checkpoint_path=checkpoint_path, ) self._validate_requirements_file(requirements_file) self.requirements_file = requirements_file
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version="py2", framework_version=None, image_name=None, distributions=None, **kwargs): """This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions defined in the supplied ``entry_point`` Python script. Training is started by calling :meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator. After training is complete, calling :meth:`~sagemaker.amazon.estimator.Framework.deploy` creates a hosted SageMaker endpoint and returns an :class:`~sagemaker.amazon.mxnet.model.MXNetPredictor` instance that can be used to perform inference against the hosted model. Technical documentation on preparing MXNet scripts for SageMaker training and using the MXNet Estimator is available on the project home-page: https://github.com/aws/sagemaker-python-sdk Args: entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must point to a tar.gz file. Structure within this directory are preserved when training on Amazon SageMaker. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. py_version (str): Python version you want to use for executing your model training code (default: 'py2'). One of 'py2' or 'py3'. framework_version (str): MXNet version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators. If not specified, this will default to 1.2.1. image_name (str): If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. It can be an ECR url or dockerhub image and tag. Examples: * ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0`` * ``custom-image:latest`` distributions (dict): A dictionary with information on how to run distributed training (default: None). To have parameter servers launched for training, set this value to be ``{'parameter_server': {'enabled': True}}``. **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. .. tip:: You can find additional parameters for initializing this class at :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ if framework_version is None: logger.warning( empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION)) self.framework_version = framework_version or defaults.MXNET_VERSION if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for MXNet v1.6 or greater: if is_version_equal_or_higher([1, 6], self.framework_version): kwargs["enable_sagemaker_metrics"] = True super(MXNet, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs) if py_version == "py2": logger.warning( python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)) if distributions is not None: logger.warning( parameter_v2_rename_warning("distributions", "distribution")) train_instance_type = kwargs.get("train_instance_type") warn_if_parameter_server_with_multi_gpu( training_instance_type=train_instance_type, distributions=distributions) self.py_version = py_version self._configure_distribution(distributions)