def test_training_job_resource_config(): rc = training_job.TrainingJobResourceConfig( instance_count=1, instance_type="random.instance", volume_size_in_gb=25, distributed_protocol=training_job.DistributedProtocol.MPI, ) rc2 = training_job.TrainingJobResourceConfig.from_flyte_idl( rc.to_flyte_idl()) assert rc2 == rc assert rc2.distributed_protocol == training_job.DistributedProtocol.MPI assert rc != training_job.TrainingJobResourceConfig( instance_count=1, instance_type="random.instance", volume_size_in_gb=25, distributed_protocol=training_job.DistributedProtocol.UNSPECIFIED, ) assert rc != training_job.TrainingJobResourceConfig( instance_count=1, instance_type="oops", volume_size_in_gb=25, distributed_protocol=training_job.DistributedProtocol.MPI, )
def test_training_job(): rc = training_job.TrainingJobResourceConfig( instance_type="test_type", instance_count=10, volume_size_in_gb=25, distributed_protocol=training_job.DistributedProtocol.MPI, ) alg = training_job.AlgorithmSpecification( algorithm_name=training_job.AlgorithmName.CUSTOM, algorithm_version="", input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) tj = training_job.TrainingJob( training_job_resource_config=rc, algorithm_specification=alg, ) tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl()) # checking tj == tj2 would return false because we don't have the __eq__ magic method defined assert tj.training_job_resource_config.instance_type == tj2.training_job_resource_config.instance_type assert tj.training_job_resource_config.instance_count == tj2.training_job_resource_config.instance_count assert tj.training_job_resource_config.distributed_protocol == tj2.training_job_resource_config.distributed_protocol assert tj.training_job_resource_config.volume_size_in_gb == tj2.training_job_resource_config.volume_size_in_gb assert tj.algorithm_specification.algorithm_name == tj2.algorithm_specification.algorithm_name assert tj.algorithm_specification.algorithm_version == tj2.algorithm_specification.algorithm_version assert tj.algorithm_specification.input_mode == tj2.algorithm_specification.input_mode assert tj.algorithm_specification.input_content_type == tj2.algorithm_specification.input_content_type
def test_hyperparameter_tuning_job(): rc = training_job.TrainingJobResourceConfig( instance_type="test_type", instance_count=10, volume_size_in_gb=25, distributed_protocol=training_job.DistributedProtocol.MPI, ) alg = training_job.AlgorithmSpecification( algorithm_name=training_job.AlgorithmName.CUSTOM, algorithm_version="", input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,) hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj) hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl()) assert hpo.max_number_of_training_jobs == hpo2.max_number_of_training_jobs assert hpo.max_parallel_training_jobs == hpo2.max_parallel_training_jobs assert ( hpo2.training_job.training_job_resource_config.instance_type == hpo.training_job.training_job_resource_config.instance_type ) assert ( hpo2.training_job.training_job_resource_config.instance_count == hpo.training_job.training_job_resource_config.instance_count ) assert ( hpo2.training_job.training_job_resource_config.distributed_protocol == hpo.training_job.training_job_resource_config.distributed_protocol ) assert ( hpo2.training_job.training_job_resource_config.volume_size_in_gb == hpo.training_job.training_job_resource_config.volume_size_in_gb ) assert ( hpo2.training_job.algorithm_specification.algorithm_name == hpo.training_job.algorithm_specification.algorithm_name ) assert ( hpo2.training_job.algorithm_specification.algorithm_version == hpo.training_job.algorithm_specification.algorithm_version ) assert hpo2.training_job.algorithm_specification.input_mode == hpo.training_job.algorithm_specification.input_mode assert ( hpo2.training_job.algorithm_specification.input_content_type == hpo.training_job.algorithm_specification.input_content_type )
@inputs(dummy_train_dataset=Types.Blob, dummy_validation_dataset=Types.Blob, my_input=Types.String) @outputs(out_model=Types.Blob, out=Types.Integer, out_extra_output_file=Types.Blob) @custom_training_job_task( algorithm_specification=training_job_models.AlgorithmSpecification( input_mode=training_job_models.InputMode.FILE, algorithm_name=training_job_models.AlgorithmName.CUSTOM, algorithm_version="", input_content_type=training_job_models.InputContentType.TEXT_CSV, ), training_job_resource_config=training_job_models.TrainingJobResourceConfig( instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, )) def custom_training_task(wf_params, dummy_train_dataset, dummy_validation_dataset, my_input, out_model, out, out_extra_output_file): with _common_utils.AutoDeletingTempDir("output_dir") as output_dir: wf_params.logging.info("My printed value: {}".format(my_input)) wf_params.logging.info( "My dummy train_dataset: {}".format(dummy_train_dataset)) wf_params.logging.info( "My dummy train_dataset: {}".format(dummy_validation_dataset)) mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data()