Example #1
0
def test_hpo_for_builtin():
    trainer = SagemakerBuiltinAlgorithmsTask(
        name="builtin-trainer",
        task_config=SagemakerTrainingJobConfig(
            training_job_resource_config=TrainingJobResourceConfig(
                instance_count=1,
                instance_type="ml-xlarge",
                volume_size_in_gb=1,
            ),
            algorithm_specification=AlgorithmSpecification(
                algorithm_name=AlgorithmName.XGBOOST, ),
        ),
    )

    hpo = SagemakerHPOTask(
        name="test",
        task_config=HPOJob(10, 10, ["x"]),
        training_task=trainer,
    )

    assert hpo.python_interface.inputs.keys() == {
        "static_hyperparameters",
        "train",
        "validation",
        "hyperparameter_tuning_job_config",
        "x",
    }
    assert hpo.python_interface.outputs.keys() == {"model"}

    assert hpo.get_custom(_get_reg_settings()) == {
        "maxNumberOfTrainingJobs": "10",
        "maxParallelTrainingJobs": "10",
        "trainingJob": {
            "algorithmSpecification": {
                "algorithmName": "XGBOOST"
            },
            "trainingJobResourceConfig": {
                "instanceCount": "1",
                "instanceType": "ml-xlarge",
                "volumeSizeInGb": "1"
            },
        },
    }

    with pytest.raises(NotImplementedError):
        hpo(
            static_hyperparameters={},
            train="",
            validation="",
            hyperparameter_tuning_job_config=HyperparameterTuningJobConfig(
                tuning_strategy=1,
                tuning_objective=HyperparameterTuningObjective(
                    objective_type=HyperparameterTuningObjectiveType.MINIMIZE,
                    metric_name="x",
                ),
                training_job_early_stopping_type=TrainingJobEarlyStoppingType.
                OFF,
            ),
            x=ParameterRangeOneOf(param=IntegerParameterRange(10, 1, 1)),
        )
def test_custom_training():
    @task(task_config=SagemakerTrainingJobConfig(
        training_job_resource_config=TrainingJobResourceConfig(
            instance_type="ml-xlarge",
            volume_size_in_gb=1,
        ),
        algorithm_specification=AlgorithmSpecification(
            algorithm_name=AlgorithmName.CUSTOM, ),
    ))
    def my_custom_trainer(x: int) -> int:
        return x

    assert my_custom_trainer.python_interface.inputs == {"x": int}
    assert my_custom_trainer.python_interface.outputs == {"o0": int}

    assert my_custom_trainer(x=10) == 10

    assert my_custom_trainer.get_custom(_get_reg_settings()) == {
        "algorithmSpecification": {},
        "trainingJobResourceConfig": {
            "instanceCount": "1",
            "instanceType": "ml-xlarge",
            "volumeSizeInGb": "1"
        },
    }
    def test_if_wf_param_has_dist_context(self):
        with mock.patch.dict(
                os.environ,
            {
                _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1",
                _sm_distribution.SM_ENV_VAR_HOSTS:
                '["algo-0", "algo-1", "algo-2"]',
                _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0",
            },
                clear=True,
        ):

            # This test is making sure that the distributed_training_context is successfully passed into the
            # task_function.
            # Specifically, we want to make sure the _execute_user_code() of the CustomTrainingJobTask class does the
            # thing that it is supposed to do

            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task_with_valid_dist_training_context(
                    wf_params, input_1, model):
                if not wf_params.distributed_training_context:
                    raise ValueError

            try:
                my_distributed_task_with_valid_dist_training_context.execute(
                    self._context, self._task_input)
            except ValueError:
                self.fail(
                    "The distributed_training_context is not passed into task function successfully"
                )
    def setUp(self):
        with _utils.AutoDeletingTempDir("input_dir") as input_dir:

            self._task_input = _literals.LiteralMap({
                "input_1":
                _literals.Literal(scalar=_literals.Scalar(
                    primitive=_literals.Primitive(integer=1)))
            })

            self._context = _common_engine.EngineContext(
                execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                         domain="unit_test",
                                                         name="unit_test"),
                execution_date=_datetime.datetime.utcnow(),
                stats=MockStats(),
                logging=None,
                tmp_dir=input_dir.name,
            )

            # Defining the distributed training task without specifying an output-persist
            # predicate (so it will use the default)
            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task(wf_params, input_1, model):
                pass

            self._my_distributed_task = my_distributed_task
            assert type(self._my_distributed_task) == CustomTrainingJobTask
def test_distributed_custom_training():
    setup_envars_for_testing()

    @task(task_config=SagemakerTrainingJobConfig(
        training_job_resource_config=TrainingJobResourceConfig(
            instance_type="ml-xlarge",
            volume_size_in_gb=1,
            instance_count=2,  # Indicates distributed training
            distributed_protocol=DistributedProtocol.MPI,
        ),
        algorithm_specification=AlgorithmSpecification(
            algorithm_name=AlgorithmName.CUSTOM, ),
    ))
    def my_custom_trainer(x: int) -> int:
        assert flytekit.current_context(
        ).distributed_training_context is not None
        return x

    assert my_custom_trainer.python_interface.inputs == {"x": int}
    assert my_custom_trainer.python_interface.outputs == {"o0": int}

    assert my_custom_trainer(x=10) == 10

    assert my_custom_trainer._is_distributed() is True

    pb = ExecutionParameters.new_builder()
    pb.working_dir = "/tmp"
    p = pb.build()
    new_p = my_custom_trainer.pre_execute(p)
    assert new_p is not None
    assert new_p.has_attr("distributed_training_context")

    assert my_custom_trainer.get_custom(_get_reg_settings()) == {
        "algorithmSpecification": {},
        "trainingJobResourceConfig": {
            "distributedProtocol": "MPI",
            "instanceCount": "2",
            "instanceType": "ml-xlarge",
            "volumeSizeInGb": "1",
        },
    }
def test_custom_training_job():
    @inputs(input_1=Types.Integer)
    @outputs(model=Types.Blob)
    @custom_training_job_task(
        training_job_resource_config=TrainingJobResourceConfig(
            instance_type="ml.m4.xlarge",
            instance_count=1,
            volume_size_in_gb=25,
        ),
        algorithm_specification=AlgorithmSpecification(
            input_mode=InputMode.FILE,
            input_content_type=InputContentType.TEXT_CSV,
            metric_definitions=[
                MetricDefinition(name="Validation error",
                                 regex="validation:error")
            ],
        ),
    )
    def my_task(wf_params, input_1, model):
        pass

    assert type(my_task) == CustomTrainingJobTask
def test_builtin_training():
    trainer = SagemakerBuiltinAlgorithmsTask(
        name="builtin-trainer",
        task_config=SagemakerTrainingJobConfig(
            training_job_resource_config=TrainingJobResourceConfig(
                instance_count=1,
                instance_type="ml-xlarge",
                volume_size_in_gb=1,
            ),
            algorithm_specification=AlgorithmSpecification(
                algorithm_name=AlgorithmName.XGBOOST, ),
        ),
    )

    assert trainer.python_interface.inputs.keys() == {
        "static_hyperparameters", "train", "validation"
    }
    assert trainer.python_interface.outputs.keys() == {"model"}

    with tempfile.TemporaryDirectory() as tmp:
        x = os.path.join(tmp, "x")
        y = os.path.join(tmp, "y")
        with open(x, "w") as f:
            f.write("test")
        with open(y, "w") as f:
            f.write("test")
        with pytest.raises(NotImplementedError):
            trainer(static_hyperparameters={}, train=x, validation=y)

    assert trainer.get_custom(_get_reg_settings()) == {
        "algorithmSpecification": {
            "algorithmName": "XGBOOST"
        },
        "trainingJobResourceConfig": {
            "instanceCount": "1",
            "instanceType": "ml-xlarge",
            "volumeSizeInGb": "1"
        },
    }
def test_builtin_algorithm_training_job_task():
    builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask(
        training_job_resource_config=TrainingJobResourceConfig(
            instance_type="ml.m4.xlarge",
            instance_count=1,
            volume_size_in_gb=25,
        ),
        algorithm_specification=AlgorithmSpecification(
            input_mode=InputMode.FILE,
            input_content_type=InputContentType.TEXT_CSV,
            algorithm_name=AlgorithmName.XGBOOST,
            algorithm_version="0.72",
        ),
    )

    builtin_algorithm_training_job_task._id = _identifier.Identifier(
        _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name",
        "my_version")
    assert isinstance(builtin_algorithm_training_job_task,
                      SdkBuiltinAlgorithmTrainingJobTask)
    assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask)
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].description == ""
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert (builtin_algorithm_training_job_task.interface.inputs["train"].type
            == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.inputs[
        "validation"].description == ""
    assert (builtin_algorithm_training_job_task.interface.inputs["validation"].
            type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert builtin_algorithm_training_job_task.interface.inputs[
        "static_hyperparameters"].description == ""
    assert (builtin_algorithm_training_job_task.interface.
            inputs["static_hyperparameters"].type ==
            _sdk_types.Types.Generic.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.outputs[
        "model"].description == ""
    assert (builtin_algorithm_training_job_task.interface.outputs["model"].type
            == _sdk_types.Types.Blob.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK
    assert builtin_algorithm_training_job_task.metadata.timeout == _datetime.timedelta(
        seconds=0)
    assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == ""
    assert builtin_algorithm_training_job_task.metadata.discoverable is False
    assert builtin_algorithm_training_job_task.metadata.discovery_version == ""
    assert builtin_algorithm_training_job_task.metadata.retries.retries == 0
    assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom[
        "algorithmSpecification"].keys()

    ParseDict(
        builtin_algorithm_training_job_task.
        custom["trainingJobResourceConfig"],
        _pb2_TrainingJobResourceConfig(),
    )  # fails the test if it cannot be parsed
    assert builtin_algorithm_training_job_task.metadata.discovery_version == ""
    assert builtin_algorithm_training_job_task.metadata.retries.retries == 0
    assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom[
        "algorithmSpecification"].keys()

    ParseDict(
        builtin_algorithm_training_job_task.
        custom["trainingJobResourceConfig"],
        _pb2_TrainingJobResourceConfig(),
    )  # fails the test if it cannot be parsed


builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask(
    training_job_resource_config=TrainingJobResourceConfig(
        instance_type="ml.m4.xlarge",
        instance_count=1,
        volume_size_in_gb=25,
    ),
    algorithm_specification=AlgorithmSpecification(
        input_mode=InputMode.FILE,
        input_content_type=InputContentType.TEXT_CSV,
        algorithm_name=AlgorithmName.XGBOOST,
        algorithm_version="0.72",
        metric_definitions=[
            MetricDefinition(name="Validation error", regex="validation:error")
        ],
    ),
)

simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHyperparameterTuningJobTask(
    training_job=builtin_algorithm_training_job_task2,