def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: training_job = _training_job_models.TrainingJob( algorithm_specification=self._task_config.algorithm_specification, training_job_resource_config=self._task_config. training_job_resource_config, ) return MessageToDict(training_job.to_flyte_idl())
def test_training_job(): rc = training_job.TrainingJobResourceConfig( instance_type="test_type", instance_count=10, volume_size_in_gb=25, distributed_protocol=training_job.DistributedProtocol.MPI, ) alg = training_job.AlgorithmSpecification( algorithm_name=training_job.AlgorithmName.CUSTOM, algorithm_version="", input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) tj = training_job.TrainingJob( training_job_resource_config=rc, algorithm_specification=alg, ) tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl()) # checking tj == tj2 would return false because we don't have the __eq__ magic method defined assert tj.training_job_resource_config.instance_type == tj2.training_job_resource_config.instance_type assert tj.training_job_resource_config.instance_count == tj2.training_job_resource_config.instance_count assert tj.training_job_resource_config.distributed_protocol == tj2.training_job_resource_config.distributed_protocol assert tj.training_job_resource_config.volume_size_in_gb == tj2.training_job_resource_config.volume_size_in_gb assert tj.algorithm_specification.algorithm_name == tj2.algorithm_specification.algorithm_name assert tj.algorithm_specification.algorithm_version == tj2.algorithm_specification.algorithm_version assert tj.algorithm_specification.input_mode == tj2.algorithm_specification.input_mode assert tj.algorithm_specification.input_content_type == tj2.algorithm_specification.input_content_type
def test_hyperparameter_tuning_job(): rc = training_job.TrainingJobResourceConfig( instance_type="test_type", instance_count=10, volume_size_in_gb=25, distributed_protocol=training_job.DistributedProtocol.MPI, ) alg = training_job.AlgorithmSpecification( algorithm_name=training_job.AlgorithmName.CUSTOM, algorithm_version="", input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,) hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj) hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl()) assert hpo.max_number_of_training_jobs == hpo2.max_number_of_training_jobs assert hpo.max_parallel_training_jobs == hpo2.max_parallel_training_jobs assert ( hpo2.training_job.training_job_resource_config.instance_type == hpo.training_job.training_job_resource_config.instance_type ) assert ( hpo2.training_job.training_job_resource_config.instance_count == hpo.training_job.training_job_resource_config.instance_count ) assert ( hpo2.training_job.training_job_resource_config.distributed_protocol == hpo.training_job.training_job_resource_config.distributed_protocol ) assert ( hpo2.training_job.training_job_resource_config.volume_size_in_gb == hpo.training_job.training_job_resource_config.volume_size_in_gb ) assert ( hpo2.training_job.algorithm_specification.algorithm_name == hpo.training_job.algorithm_specification.algorithm_name ) assert ( hpo2.training_job.algorithm_specification.algorithm_version == hpo.training_job.algorithm_specification.algorithm_version ) assert hpo2.training_job.algorithm_specification.input_mode == hpo.training_job.algorithm_specification.input_mode assert ( hpo2.training_job.algorithm_specification.input_content_type == hpo.training_job.algorithm_specification.input_content_type )
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: training_job = _training_job_model.TrainingJob( algorithm_specification=self._training_task.task_config. algorithm_specification, training_job_resource_config=self._training_task.task_config. training_job_resource_config, ) return MessageToDict( _hpo_job_model.HyperparameterTuningJob( max_number_of_training_jobs=self.task_config. max_number_of_training_jobs, max_parallel_training_jobs=self.task_config. max_parallel_training_jobs, training_job=training_job, ).to_flyte_idl())
def __init__( self, task_function, cache_version, retries, deprecated, storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit, cache, timeout, environment, algorithm_specification: _training_job_models.AlgorithmSpecification, training_job_resource_config: _training_job_models. TrainingJobResourceConfig, output_persist_predicate: _typing. Callable = DefaultOutputPersistPredicate(), ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text cache_version: string describing the version for task discovery purposes :param int retries: Number of retries to attempt :param Text deprecated: :param Text storage_request: :param Text cpu_request: :param Text gpu_request: :param Text memory_request: :param Text storage_limit: :param Text cpu_limit: :param Text gpu_limit: :param Text memory_limit: :param bool cache: :param datetime.timedelta timeout: :param dict[Text, Text] environment: :param _training_job_models.AlgorithmSpecification algorithm_specification: :param _training_job_models.TrainingJobResourceConfig training_job_resource_config: :param _typing.Callable output_persist_predicate: """ self._output_persist_predicate = output_persist_predicate # Use the training job model as a measure of type checking self._training_job_model = _training_job_models.TrainingJob( algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config) super().__init__( task_function=task_function, task_type=SdkTaskType.SAGEMAKER_CUSTOM_TRAINING_JOB_TASK, discovery_version=cache_version, retries=retries, interruptible=False, deprecated=deprecated, storage_request=storage_request, cpu_request=cpu_request, gpu_request=gpu_request, memory_request=memory_request, storage_limit=storage_limit, cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, discoverable=cache, timeout=timeout, environment=environment, custom=MessageToDict(self._training_job_model.to_flyte_idl()), )
def __init__( self, training_job_resource_config: _training_job_models. TrainingJobResourceConfig, algorithm_specification: _training_job_models.AlgorithmSpecification, retries: int = 0, cacheable: bool = False, cache_version: str = "", ): """ :param training_job_resource_config: The options to configure the training job :param algorithm_specification: The options to configure the target algorithm of the training :param retries: Number of retries to attempt :param cacheable: The flag to set if the user wants the output of the task execution to be cached :param cache_version: String describing the caching version for task discovery purposes """ # Use the training job model as a measure of type checking self._training_job_model = _training_job_models.TrainingJob( algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config, ) # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training # job gracefully timeout = _datetime.timedelta(seconds=0) super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, retries=_literal_models.RetryStrategy(retries=retries), interruptible=False, discovery_version=cache_version, deprecated_error_message="", ), interface=_interface.TypedInterface( inputs={ "static_hyperparameters": _interface_model.Variable( type=_idl_types.LiteralType( simple=_idl_types.SimpleType.STRUCT), description="", ), "train": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format=_content_type_to_blob_format( algorithm_specification.input_content_type), dimensionality=_core_types.BlobType. BlobDimensionality.MULTIPART, ), ), description="", ), "validation": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format=_content_type_to_blob_format( algorithm_specification.input_content_type), dimensionality=_core_types.BlobType. BlobDimensionality.MULTIPART, ), ), description="", ), }, outputs={ "model": _interface_model.Variable( type=_idl_types.LiteralType(blob=_core_types.BlobType( format="", dimensionality=_core_types.BlobType. BlobDimensionality.SINGLE, )), description="", ) }, ), custom=MessageToDict(self._training_job_model.to_flyte_idl()), )