def test_training_job():
    rc = training_job.TrainingJobResourceConfig(
        instance_type="test_type",
        instance_count=10,
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )
    alg = training_job.AlgorithmSpecification(
        algorithm_name=training_job.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_mode=training_job.InputMode.FILE,
        input_content_type=training_job.InputContentType.TEXT_CSV,
    )
    tj = training_job.TrainingJob(
        training_job_resource_config=rc,
        algorithm_specification=alg,
    )

    tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl())
    # checking tj == tj2 would return false because we don't have the __eq__ magic method defined
    assert tj.training_job_resource_config.instance_type == tj2.training_job_resource_config.instance_type
    assert tj.training_job_resource_config.instance_count == tj2.training_job_resource_config.instance_count
    assert tj.training_job_resource_config.distributed_protocol == tj2.training_job_resource_config.distributed_protocol
    assert tj.training_job_resource_config.volume_size_in_gb == tj2.training_job_resource_config.volume_size_in_gb
    assert tj.algorithm_specification.algorithm_name == tj2.algorithm_specification.algorithm_name
    assert tj.algorithm_specification.algorithm_version == tj2.algorithm_specification.algorithm_version
    assert tj.algorithm_specification.input_mode == tj2.algorithm_specification.input_mode
    assert tj.algorithm_specification.input_content_type == tj2.algorithm_specification.input_content_type
Exemple #2
0
def test_hyperparameter_tuning_job():
    rc = training_job.TrainingJobResourceConfig(
        instance_type="test_type",
        instance_count=10,
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )
    alg = training_job.AlgorithmSpecification(
        algorithm_name=training_job.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_mode=training_job.InputMode.FILE,
        input_content_type=training_job.InputContentType.TEXT_CSV,
    )
    tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,)
    hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj)

    hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl())

    assert hpo.max_number_of_training_jobs == hpo2.max_number_of_training_jobs
    assert hpo.max_parallel_training_jobs == hpo2.max_parallel_training_jobs
    assert (
        hpo2.training_job.training_job_resource_config.instance_type
        == hpo.training_job.training_job_resource_config.instance_type
    )
    assert (
        hpo2.training_job.training_job_resource_config.instance_count
        == hpo.training_job.training_job_resource_config.instance_count
    )
    assert (
        hpo2.training_job.training_job_resource_config.distributed_protocol
        == hpo.training_job.training_job_resource_config.distributed_protocol
    )
    assert (
        hpo2.training_job.training_job_resource_config.volume_size_in_gb
        == hpo.training_job.training_job_resource_config.volume_size_in_gb
    )
    assert (
        hpo2.training_job.algorithm_specification.algorithm_name
        == hpo.training_job.algorithm_specification.algorithm_name
    )
    assert (
        hpo2.training_job.algorithm_specification.algorithm_version
        == hpo.training_job.algorithm_specification.algorithm_version
    )
    assert hpo2.training_job.algorithm_specification.input_mode == hpo.training_job.algorithm_specification.input_mode
    assert (
        hpo2.training_job.algorithm_specification.input_content_type
        == hpo.training_job.algorithm_specification.input_content_type
    )
def test_metric_definition():
    md = training_job.MetricDefinition(name="test-metric", regex="[a-zA-Z]*")

    md2 = training_job.MetricDefinition.from_flyte_idl(md.to_flyte_idl())
    assert md == md2
    assert md2.name == "test-metric"
    assert md2.regex == "[a-zA-Z]*"


def test_algorithm_specification():
    case = unittest.TestCase()
    alg_spec = training_job.AlgorithmSpecification(
        algorithm_name=training_job.AlgorithmName.CUSTOM,
        algorithm_version="v100",
        input_mode=training_job.InputMode.FILE,
        metric_definitions=[training_job.MetricDefinition(name="a", regex="b")],
        input_content_type=training_job.InputContentType.TEXT_CSV,
    )

    alg_spec2 = training_job.AlgorithmSpecification.from_flyte_idl(alg_spec.to_flyte_idl())

    assert alg_spec2.algorithm_name == training_job.AlgorithmName.CUSTOM
    assert alg_spec2.algorithm_version == "v100"
    assert alg_spec2.input_mode == training_job.InputMode.FILE
    case.assertCountEqual(alg_spec.metric_definitions, alg_spec2.metric_definitions)
    assert alg_spec == alg_spec2


def test_training_job():
    rc = training_job.TrainingJobResourceConfig(
Exemple #4
0
from flytekit.models.sagemaker import training_job as training_job_models

import tensorflow as tf
from flytekit.common import utils as _common_utils


@inputs(dummy_train_dataset=Types.Blob,
        dummy_validation_dataset=Types.Blob,
        my_input=Types.String)
@outputs(out_model=Types.Blob,
         out=Types.Integer,
         out_extra_output_file=Types.Blob)
@custom_training_job_task(
    algorithm_specification=training_job_models.AlgorithmSpecification(
        input_mode=training_job_models.InputMode.FILE,
        algorithm_name=training_job_models.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_content_type=training_job_models.InputContentType.TEXT_CSV,
    ),
    training_job_resource_config=training_job_models.TrainingJobResourceConfig(
        instance_type="ml.m4.xlarge",
        instance_count=1,
        volume_size_in_gb=25,
    ))
def custom_training_task(wf_params, dummy_train_dataset,
                         dummy_validation_dataset, my_input, out_model, out,
                         out_extra_output_file):
    with _common_utils.AutoDeletingTempDir("output_dir") as output_dir:
        wf_params.logging.info("My printed value: {}".format(my_input))
        wf_params.logging.info(
            "My dummy train_dataset: {}".format(dummy_train_dataset))
        wf_params.logging.info(