示例#1
0
def test_training_job_resource_config():
    rc = training_job.TrainingJobResourceConfig(
        instance_count=1,
        instance_type="random.instance",
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )

    rc2 = training_job.TrainingJobResourceConfig.from_flyte_idl(
        rc.to_flyte_idl())
    assert rc2 == rc
    assert rc2.distributed_protocol == training_job.DistributedProtocol.MPI
    assert rc != training_job.TrainingJobResourceConfig(
        instance_count=1,
        instance_type="random.instance",
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.UNSPECIFIED,
    )

    assert rc != training_job.TrainingJobResourceConfig(
        instance_count=1,
        instance_type="oops",
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )
示例#2
0
def test_training_job():
    rc = training_job.TrainingJobResourceConfig(
        instance_type="test_type",
        instance_count=10,
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )
    alg = training_job.AlgorithmSpecification(
        algorithm_name=training_job.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_mode=training_job.InputMode.FILE,
        input_content_type=training_job.InputContentType.TEXT_CSV,
    )
    tj = training_job.TrainingJob(
        training_job_resource_config=rc,
        algorithm_specification=alg,
    )

    tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl())
    # checking tj == tj2 would return false because we don't have the __eq__ magic method defined
    assert tj.training_job_resource_config.instance_type == tj2.training_job_resource_config.instance_type
    assert tj.training_job_resource_config.instance_count == tj2.training_job_resource_config.instance_count
    assert tj.training_job_resource_config.distributed_protocol == tj2.training_job_resource_config.distributed_protocol
    assert tj.training_job_resource_config.volume_size_in_gb == tj2.training_job_resource_config.volume_size_in_gb
    assert tj.algorithm_specification.algorithm_name == tj2.algorithm_specification.algorithm_name
    assert tj.algorithm_specification.algorithm_version == tj2.algorithm_specification.algorithm_version
    assert tj.algorithm_specification.input_mode == tj2.algorithm_specification.input_mode
    assert tj.algorithm_specification.input_content_type == tj2.algorithm_specification.input_content_type
示例#3
0
def test_hyperparameter_tuning_job():
    rc = training_job.TrainingJobResourceConfig(
        instance_type="test_type",
        instance_count=10,
        volume_size_in_gb=25,
        distributed_protocol=training_job.DistributedProtocol.MPI,
    )
    alg = training_job.AlgorithmSpecification(
        algorithm_name=training_job.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_mode=training_job.InputMode.FILE,
        input_content_type=training_job.InputContentType.TEXT_CSV,
    )
    tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,)
    hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj)

    hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl())

    assert hpo.max_number_of_training_jobs == hpo2.max_number_of_training_jobs
    assert hpo.max_parallel_training_jobs == hpo2.max_parallel_training_jobs
    assert (
        hpo2.training_job.training_job_resource_config.instance_type
        == hpo.training_job.training_job_resource_config.instance_type
    )
    assert (
        hpo2.training_job.training_job_resource_config.instance_count
        == hpo.training_job.training_job_resource_config.instance_count
    )
    assert (
        hpo2.training_job.training_job_resource_config.distributed_protocol
        == hpo.training_job.training_job_resource_config.distributed_protocol
    )
    assert (
        hpo2.training_job.training_job_resource_config.volume_size_in_gb
        == hpo.training_job.training_job_resource_config.volume_size_in_gb
    )
    assert (
        hpo2.training_job.algorithm_specification.algorithm_name
        == hpo.training_job.algorithm_specification.algorithm_name
    )
    assert (
        hpo2.training_job.algorithm_specification.algorithm_version
        == hpo.training_job.algorithm_specification.algorithm_version
    )
    assert hpo2.training_job.algorithm_specification.input_mode == hpo.training_job.algorithm_specification.input_mode
    assert (
        hpo2.training_job.algorithm_specification.input_content_type
        == hpo.training_job.algorithm_specification.input_content_type
    )
示例#4
0
@inputs(dummy_train_dataset=Types.Blob,
        dummy_validation_dataset=Types.Blob,
        my_input=Types.String)
@outputs(out_model=Types.Blob,
         out=Types.Integer,
         out_extra_output_file=Types.Blob)
@custom_training_job_task(
    algorithm_specification=training_job_models.AlgorithmSpecification(
        input_mode=training_job_models.InputMode.FILE,
        algorithm_name=training_job_models.AlgorithmName.CUSTOM,
        algorithm_version="",
        input_content_type=training_job_models.InputContentType.TEXT_CSV,
    ),
    training_job_resource_config=training_job_models.TrainingJobResourceConfig(
        instance_type="ml.m4.xlarge",
        instance_count=1,
        volume_size_in_gb=25,
    ))
def custom_training_task(wf_params, dummy_train_dataset,
                         dummy_validation_dataset, my_input, out_model, out,
                         out_extra_output_file):
    with _common_utils.AutoDeletingTempDir("output_dir") as output_dir:
        wf_params.logging.info("My printed value: {}".format(my_input))
        wf_params.logging.info(
            "My dummy train_dataset: {}".format(dummy_train_dataset))
        wf_params.logging.info(
            "My dummy train_dataset: {}".format(dummy_validation_dataset))

        mnist = tf.keras.datasets.mnist

        (x_train, y_train), (x_test, y_test) = mnist.load_data()