Beispiel #1
0
 def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
     training_job = _training_job_models.TrainingJob(
         algorithm_specification=self._task_config.algorithm_specification,
         training_job_resource_config=self._task_config.
         training_job_resource_config,
     )
     return MessageToDict(training_job.to_flyte_idl())
Beispiel #2
0
def test_training_job():
    rc = training_job.TrainingJobResourceConfig(
        instance_type="test_type",
        instance_count=10,
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )
    alg = training_job.AlgorithmSpecification(
        algorithm_name=training_job.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_mode=training_job.InputMode.FILE,
        input_content_type=training_job.InputContentType.TEXT_CSV,
    )
    tj = training_job.TrainingJob(
        training_job_resource_config=rc,
        algorithm_specification=alg,
    )

    tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl())
    # checking tj == tj2 would return false because we don't have the __eq__ magic method defined
    assert tj.training_job_resource_config.instance_type == tj2.training_job_resource_config.instance_type
    assert tj.training_job_resource_config.instance_count == tj2.training_job_resource_config.instance_count
    assert tj.training_job_resource_config.distributed_protocol == tj2.training_job_resource_config.distributed_protocol
    assert tj.training_job_resource_config.volume_size_in_gb == tj2.training_job_resource_config.volume_size_in_gb
    assert tj.algorithm_specification.algorithm_name == tj2.algorithm_specification.algorithm_name
    assert tj.algorithm_specification.algorithm_version == tj2.algorithm_specification.algorithm_version
    assert tj.algorithm_specification.input_mode == tj2.algorithm_specification.input_mode
    assert tj.algorithm_specification.input_content_type == tj2.algorithm_specification.input_content_type
Beispiel #3
0
def test_hyperparameter_tuning_job():
    rc = training_job.TrainingJobResourceConfig(
        instance_type="test_type",
        instance_count=10,
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )
    alg = training_job.AlgorithmSpecification(
        algorithm_name=training_job.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_mode=training_job.InputMode.FILE,
        input_content_type=training_job.InputContentType.TEXT_CSV,
    )
    tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,)
    hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj)

    hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl())

    assert hpo.max_number_of_training_jobs == hpo2.max_number_of_training_jobs
    assert hpo.max_parallel_training_jobs == hpo2.max_parallel_training_jobs
    assert (
        hpo2.training_job.training_job_resource_config.instance_type
        == hpo.training_job.training_job_resource_config.instance_type
    )
    assert (
        hpo2.training_job.training_job_resource_config.instance_count
        == hpo.training_job.training_job_resource_config.instance_count
    )
    assert (
        hpo2.training_job.training_job_resource_config.distributed_protocol
        == hpo.training_job.training_job_resource_config.distributed_protocol
    )
    assert (
        hpo2.training_job.training_job_resource_config.volume_size_in_gb
        == hpo.training_job.training_job_resource_config.volume_size_in_gb
    )
    assert (
        hpo2.training_job.algorithm_specification.algorithm_name
        == hpo.training_job.algorithm_specification.algorithm_name
    )
    assert (
        hpo2.training_job.algorithm_specification.algorithm_version
        == hpo.training_job.algorithm_specification.algorithm_version
    )
    assert hpo2.training_job.algorithm_specification.input_mode == hpo.training_job.algorithm_specification.input_mode
    assert (
        hpo2.training_job.algorithm_specification.input_content_type
        == hpo.training_job.algorithm_specification.input_content_type
    )
Beispiel #4
0
 def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
     training_job = _training_job_model.TrainingJob(
         algorithm_specification=self._training_task.task_config.
         algorithm_specification,
         training_job_resource_config=self._training_task.task_config.
         training_job_resource_config,
     )
     return MessageToDict(
         _hpo_job_model.HyperparameterTuningJob(
             max_number_of_training_jobs=self.task_config.
             max_number_of_training_jobs,
             max_parallel_training_jobs=self.task_config.
             max_parallel_training_jobs,
             training_job=training_job,
         ).to_flyte_idl())
Beispiel #5
0
    def __init__(
        self,
        task_function,
        cache_version,
        retries,
        deprecated,
        storage_request,
        cpu_request,
        gpu_request,
        memory_request,
        storage_limit,
        cpu_limit,
        gpu_limit,
        memory_limit,
        cache,
        timeout,
        environment,
        algorithm_specification: _training_job_models.AlgorithmSpecification,
        training_job_resource_config: _training_job_models.
        TrainingJobResourceConfig,
        output_persist_predicate: _typing.
        Callable = DefaultOutputPersistPredicate(),
    ):
        """
        :param task_function: Function container user code.  This will be executed via the SDK's engine.
        :param Text cache_version: string describing the version for task discovery purposes
        :param int retries: Number of retries to attempt
        :param Text deprecated:
        :param Text storage_request:
        :param Text cpu_request:
        :param Text gpu_request:
        :param Text memory_request:
        :param Text storage_limit:
        :param Text cpu_limit:
        :param Text gpu_limit:
        :param Text memory_limit:
        :param bool cache:
        :param datetime.timedelta timeout:
        :param dict[Text, Text] environment:
        :param _training_job_models.AlgorithmSpecification algorithm_specification:
        :param _training_job_models.TrainingJobResourceConfig training_job_resource_config:
        :param _typing.Callable output_persist_predicate:
        """

        self._output_persist_predicate = output_persist_predicate

        # Use the training job model as a measure of type checking
        self._training_job_model = _training_job_models.TrainingJob(
            algorithm_specification=algorithm_specification,
            training_job_resource_config=training_job_resource_config)

        super().__init__(
            task_function=task_function,
            task_type=SdkTaskType.SAGEMAKER_CUSTOM_TRAINING_JOB_TASK,
            discovery_version=cache_version,
            retries=retries,
            interruptible=False,
            deprecated=deprecated,
            storage_request=storage_request,
            cpu_request=cpu_request,
            gpu_request=gpu_request,
            memory_request=memory_request,
            storage_limit=storage_limit,
            cpu_limit=cpu_limit,
            gpu_limit=gpu_limit,
            memory_limit=memory_limit,
            discoverable=cache,
            timeout=timeout,
            environment=environment,
            custom=MessageToDict(self._training_job_model.to_flyte_idl()),
        )
    def __init__(
        self,
        training_job_resource_config: _training_job_models.
        TrainingJobResourceConfig,
        algorithm_specification: _training_job_models.AlgorithmSpecification,
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param training_job_resource_config: The options to configure the training job
        :param algorithm_specification: The options to configure the target algorithm of the training
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        self._training_job_model = _training_job_models.TrainingJob(
            algorithm_specification=algorithm_specification,
            training_job_resource_config=training_job_resource_config,
        )

        # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
        # job gracefully
        timeout = _datetime.timedelta(seconds=0)

        super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs={
                    "static_hyperparameters":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            simple=_idl_types.SimpleType.STRUCT),
                        description="",
                    ),
                    "train":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                    "validation":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                },
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format="",
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.SINGLE,
                        )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(self._training_job_model.to_flyte_idl()),
        )