def test_simple_hpo_job_task():
    assert isinstance(simple_xgboost_hpo_job_task,
                      SdkSimpleHyperparameterTuningJobTask)
    assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask)
    # Checking if the input of the underlying SdkTrainingJobTask has been embedded
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "train"].description == ""
    assert (simple_xgboost_hpo_job_task.interface.inputs["train"].type ==
            _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "validation"].description == ""
    assert (simple_xgboost_hpo_job_task.interface.inputs["validation"].type ==
            _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "validation"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "static_hyperparameters"].description == ""
    assert (
        simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].
        type == _sdk_types.Types.Generic.to_flyte_literal_type())

    # Checking if the hpo-specific input is defined
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "hyperparameter_tuning_job_config"].description == ""
    assert (simple_xgboost_hpo_job_task.interface.
            inputs["hyperparameter_tuning_job_config"].type ==
            _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type())
    assert simple_xgboost_hpo_job_task.interface.outputs[
        "model"].description == ""
    assert simple_xgboost_hpo_job_task.interface.outputs[
        "model"].type == _sdk_types.Types.Blob.to_flyte_literal_type()
    assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK

    # Checking if the spec of the TrainingJob is embedded into the custom field
    # of this SdkSimpleHyperparameterTuningJobTask
    assert simple_xgboost_hpo_job_task.to_flyte_idl(
    ).custom["trainingJob"] == (
        builtin_algorithm_training_job_task2.to_flyte_idl().custom)

    assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(
        seconds=0)
    assert simple_xgboost_hpo_job_task.metadata.discoverable is True
    assert simple_xgboost_hpo_job_task.metadata.discovery_version == "1"
    assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2

    assert simple_xgboost_hpo_job_task.metadata.deprecated_error_message == ""
    assert "metricDefinitions" in simple_xgboost_hpo_job_task.custom[
        "trainingJob"]["algorithmSpecification"].keys()
    assert len(simple_xgboost_hpo_job_task.custom["trainingJob"]
               ["algorithmSpecification"]["metricDefinitions"]) == 1
    """
Beispiel #2
0
 def to_flyte_literal_type(cls):
     """
     :rtype: flytekit.models.types.LiteralType
     """
     return _idl_types.LiteralType(blob=_core_types.BlobType(
         format="",
         dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE))
def test_engine_file_output():
    basic_blob_type = _core_types.BlobType(
        format="",
        dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
    )

    fs = FileAccessProvider(local_sandbox_dir="/tmp/flytetesting")
    with context_manager.FlyteContext.current_context(
    ).new_file_access_context(file_access_provider=fs) as ctx:
        # Write some text to a file not in that directory above
        test_file_location = "/tmp/sample.txt"
        with open(test_file_location, "w") as fh:
            fh.write("Hello World\n")

        lit = TypeEngine.to_literal(ctx, test_file_location, os.PathLike,
                                    LiteralType(blob=basic_blob_type))

        # Since we're using local as remote, we should be able to just read the file from the 'remote' location.
        with open(lit.scalar.blob.uri, "r") as fh:
            assert fh.readline() == "Hello World\n"

        # We should also be able to turn the thing back into regular python native thing.
        redownloaded_local_file_location = TypeEngine.to_python_value(
            ctx, lit, os.PathLike)
        with open(redownloaded_local_file_location, "r") as fh:
            assert fh.readline() == "Hello World\n"
Beispiel #4
0
def test_blob_promote_from_model():
    m = _literal_models.Literal(scalar=_literal_models.Scalar(
        blob=_literal_models.Blob(
            _literal_models.BlobMetadata(
                _core_types.BlobType(format="f",
                                     dimensionality=_core_types.BlobType.
                                     BlobDimensionality.SINGLE)),
            "some/path")))
    b = blobs.Blob.promote_from_model(m)
    assert b.value.blob.uri == "some/path"
    assert b.value.blob.metadata.type.format == "f"
    assert b.value.blob.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
Beispiel #5
0
def test_blob_type():
    o = _types.BlobType(
        format="csv",
        dimensionality=_types.BlobType.BlobDimensionality.SINGLE,
    )
    assert o.format == "csv"
    assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE

    o2 = _types.BlobType.from_flyte_idl(o.to_flyte_idl())
    assert o == o2
    assert o2.format == "csv"
    assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
Beispiel #6
0
 def __init__(self, remote_path, mode='rb', format=None):
     """
     :param Text remote_path: Path to location where the Blob should be synced to.
     :param Text mode: File access mode.  'a' and '+' are forbidden.  A blob can only be written or read at a time.
     :param Text format: Format
     """
     if '+' in mode or 'a' in mode or ('w' in mode and 'r' in mode):
         raise _user_exceptions.FlyteAssertion(
             "A blob cannot be read and written at the same time")
     self._mode = mode
     self._local_path = None
     self._file = None
     super(Blob, self).__init__(
         _literal_models.BlobMetadata(type=_core_types.BlobType(
             format or "", _core_types.BlobType.BlobDimensionality.SINGLE)),
         remote_path)
Beispiel #7
0
 def __init__(self, remote_path, mode='rb', format=None):
     """
     :param Text remote_path: Path to location where the Blob should be synced to.
     :param Text mode: File access mode.  'a' and '+' are forbidden.  A blob can only be written or read at a time.
     :param Text format: Format of underlying blob pieces.
     """
     remote_path = remote_path.strip().rstrip('/') + '/'
     super(MultiPartBlob, self).__init__(
         _literal_models.BlobMetadata(type=_core_types.BlobType(
             format or "",
             _core_types.BlobType.BlobDimensionality.MULTIPART)),
         remote_path)
     self._is_managed = False
     self._blobs = []
     self._directory = None
     self._mode = mode
Beispiel #8
0
    def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
        meta = BlobMetadata(
            type=_core_types.BlobType(
                format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
            )
        )
        # Dump the task output into pickle
        local_dir = ctx.file_access.get_random_local_directory()
        os.makedirs(local_dir, exist_ok=True)
        local_path = ctx.file_access.get_random_local_path()
        uri = os.path.join(local_dir, local_path)
        with open(uri, "w+b") as outfile:
            cloudpickle.dump(python_val, outfile)

        remote_path = ctx.file_access.get_random_remote_path(uri)
        ctx.file_access.put_data(uri, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
Beispiel #9
0
    def to_literal(self, ctx: FlyteContext, python_val: np.ndarray,
                   python_type: Type[np.ndarray],
                   expected: LiteralType) -> Literal:
        meta = BlobMetadata(type=_core_types.BlobType(
            format=self.NUMPY_ARRAY_FORMAT,
            dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE))

        local_path = ctx.file_access.get_random_local_path() + ".npy"
        pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

        # save numpy array to a file
        # allow_pickle=False prevents numpy from trying to save object arrays (dtype=object) using pickle
        np.save(file=local_path, arr=python_val, allow_pickle=False)

        remote_path = ctx.file_access.get_random_remote_path(local_path)
        ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(
            blob=Blob(metadata=meta, uri=remote_path)))
Beispiel #10
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: PyTorchCheckpoint,
        python_type: Type[PyTorchCheckpoint],
        expected: LiteralType,
    ) -> Literal:
        meta = BlobMetadata(
            type=_core_types.BlobType(
                format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
            )
        )

        local_path = ctx.file_access.get_random_local_path() + ".pt"
        pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

        to_save = {}
        for field in fields(python_val):
            value = getattr(python_val, field.name)

            if value and field.name in ["module", "optimizer"]:
                to_save[field.name + "_state_dict"] = getattr(value, "state_dict")()
            elif value and field.name == "hyperparameters":
                if isinstance(value, dict):
                    to_save.update(value)
                elif isinstance(value, tuple):
                    to_save.update(value._asdict())
                elif is_dataclass(value):
                    to_save.update(asdict(value))

        if not to_save:
            raise TypeTransformerFailedError(f"Cannot save empty {python_val}")

        # save checkpoint to a file
        torch.save(to_save, local_path)

        remote_path = ctx.file_access.get_random_remote_path(local_path)
        ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
Beispiel #11
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
        tunable_parameters: typing.List[str] = None,
    ):
        """
        :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition
        :param int retries: Number of retries to attempt
        :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param str cache_version: String describing the caching version for task discovery purposes
        :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int
                algorithm, refer to the algorithm's documentation to understand the possible values for the tunable
                parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the
                list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom
                training job, the list of tunable parameters must be a strict subset of the list of inputs defined on
                that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html
                for the list of supported hyperparameter types.
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {}
        inputs.update(training_job.interface.inputs)
        inputs.update({
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                HyperparameterTuningJobConfig.to_flyte_literal_type(),
                "",
            ),
        })

        if tunable_parameters:
            inputs.update({
                param: _interface_model.Variable(
                    ParameterRange.to_flyte_literal_type(), "")
                for param in tunable_parameters
            })

        super().__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
Beispiel #12
0
 def _blob_type(self, format: str) -> _core_types.BlobType:
     return _core_types.BlobType(
         format=format,
         dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE)
Beispiel #13
0
 def get_literal_type(self, t: Type[np.ndarray]) -> LiteralType:
     return LiteralType(blob=_core_types.BlobType(
         format=self.NUMPY_ARRAY_FORMAT,
         dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE))
 def _blob_type(self) -> _core_types.BlobType:
     return _core_types.BlobType(
         format=mimetypes.types_map[".bin"],
         dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
     )
Beispiel #15
0
    types.LiteralType(schema=types.SchemaType([
        types.SchemaType.SchemaColumn(
            "a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER),
        types.SchemaType.SchemaColumn(
            "b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN),
        types.SchemaType.SchemaColumn(
            "c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME),
        types.SchemaType.SchemaColumn(
            "d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION),
        types.SchemaType.SchemaColumn(
            "e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT),
        types.SchemaType.SchemaColumn(
            "f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING),
    ])),
    types.LiteralType(blob=_core_types.BlobType(
        format="",
        dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
    )),
    types.LiteralType(blob=_core_types.BlobType(
        format="csv",
        dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
    )),
    types.LiteralType(blob=_core_types.BlobType(
        format="",
        dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
    )),
    types.LiteralType(blob=_core_types.BlobType(
        format="csv",
        dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
    ))
]
    def __init__(
        self,
        training_job_resource_config: _training_job_models.
        TrainingJobResourceConfig,
        algorithm_specification: _training_job_models.AlgorithmSpecification,
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param training_job_resource_config: The options to configure the training job
        :param algorithm_specification: The options to configure the target algorithm of the training
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        self._training_job_model = _training_job_models.TrainingJob(
            algorithm_specification=algorithm_specification,
            training_job_resource_config=training_job_resource_config,
        )

        # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
        # job gracefully
        timeout = _datetime.timedelta(seconds=0)

        super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs={
                    "static_hyperparameters":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(
                            simple=_idl_types.SimpleType.STRUCT),
                        description="",
                    ),
                    "train":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                    "validation":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format=_content_type_to_blob_format(
                                algorithm_specification.input_content_type),
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.MULTIPART,
                        ), ),
                        description="",
                    ),
                },
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_idl_types.LiteralType(blob=_core_types.BlobType(
                            format="",
                            dimensionality=_core_types.BlobType.
                            BlobDimensionality.SINGLE,
                        )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(self._training_job_model.to_flyte_idl()),
        )
Beispiel #17
0
 def _blob_type(format: str) -> _core_types.BlobType:
     return _core_types.BlobType(
         format=format,
         dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART)
Beispiel #18
0
 def get_literal_type(self, t: Type[T]) -> LiteralType:
     return LiteralType(
         blob=_core_types.BlobType(
             format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
         )
     )
Beispiel #19
0
    def __init__(
        self,
        max_number_of_training_jobs: int,
        max_parallel_training_jobs: int,
        training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask,
                                   CustomTrainingJobTask],
        retries: int = 0,
        cacheable: bool = False,
        cache_version: str = "",
    ):
        """

        :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
        hyperparameter tuning job
        :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
        tuning job in parallel
        :param training_job: The reference to the training job definition
        :param retries: Number of retries to attempt
        :param cacheable: The flag to set if the user wants the output of the task execution to be cached
        :param cache_version: String describing the caching version for task discovery purposes
        """
        # Use the training job model as a measure of type checking
        hpo_job = _hpo_job_model.HyperparameterTuningJob(
            max_number_of_training_jobs=max_number_of_training_jobs,
            max_parallel_training_jobs=max_parallel_training_jobs,
            training_job=training_job.training_job_model,
        ).to_flyte_idl()

        # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
        #   the underlying training job
        # TODO: Discuss whether this is a viable interface or contract
        timeout = _datetime.timedelta(seconds=0)

        inputs = {
            "hyperparameter_tuning_job_config":
            _interface_model.Variable(
                _sdk_types.Types.Proto(
                    _pb2_hpo_job.HyperparameterTuningJobConfig).
                to_flyte_literal_type(),
                "",
            ),
        }
        inputs.update(training_job.interface.inputs)

        super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
            type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
            metadata=_task_models.TaskMetadata(
                runtime=_task_models.RuntimeMetadata(
                    type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                    version=__version__,
                    flavor="sagemaker",
                ),
                discoverable=cacheable,
                timeout=timeout,
                retries=_literal_models.RetryStrategy(retries=retries),
                interruptible=False,
                discovery_version=cache_version,
                deprecated_error_message="",
            ),
            interface=_interface.TypedInterface(
                inputs=inputs,
                outputs={
                    "model":
                    _interface_model.Variable(
                        type=_types_models.LiteralType(
                            blob=_core_types.BlobType(
                                format="",
                                dimensionality=_core_types.BlobType.
                                BlobDimensionality.SINGLE,
                            )),
                        description="",
                    )
                },
            ),
            custom=MessageToDict(hpo_job),
        )
def test_builtin_algorithm_training_job_task():
    builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask(
        training_job_resource_config=TrainingJobResourceConfig(
            instance_type="ml.m4.xlarge",
            instance_count=1,
            volume_size_in_gb=25,
        ),
        algorithm_specification=AlgorithmSpecification(
            input_mode=InputMode.FILE,
            input_content_type=InputContentType.TEXT_CSV,
            algorithm_name=AlgorithmName.XGBOOST,
            algorithm_version="0.72",
        ),
    )

    builtin_algorithm_training_job_task._id = _identifier.Identifier(
        _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name",
        "my_version")
    assert isinstance(builtin_algorithm_training_job_task,
                      SdkBuiltinAlgorithmTrainingJobTask)
    assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask)
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].description == ""
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert (builtin_algorithm_training_job_task.interface.inputs["train"].type
            == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.inputs[
        "validation"].description == ""
    assert (builtin_algorithm_training_job_task.interface.inputs["validation"].
            type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert builtin_algorithm_training_job_task.interface.inputs[
        "static_hyperparameters"].description == ""
    assert (builtin_algorithm_training_job_task.interface.
            inputs["static_hyperparameters"].type ==
            _sdk_types.Types.Generic.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.outputs[
        "model"].description == ""
    assert (builtin_algorithm_training_job_task.interface.outputs["model"].type
            == _sdk_types.Types.Blob.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK
    assert builtin_algorithm_training_job_task.metadata.timeout == _datetime.timedelta(
        seconds=0)
    assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == ""
    assert builtin_algorithm_training_job_task.metadata.discoverable is False
    assert builtin_algorithm_training_job_task.metadata.discovery_version == ""
    assert builtin_algorithm_training_job_task.metadata.retries.retries == 0
    assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom[
        "algorithmSpecification"].keys()

    ParseDict(
        builtin_algorithm_training_job_task.
        custom["trainingJobResourceConfig"],
        _pb2_TrainingJobResourceConfig(),
    )  # fails the test if it cannot be parsed
Beispiel #21
0
 def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType:
     return LiteralType(
         blob=_core_types.BlobType(
             format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
         )
     )