def test_training_step(sagemaker_session): estimator = Estimator( image_uri=IMAGE_URI, role=ROLE, instance_count=1, instance_type="c4.4xlarge", profiler_config=ProfilerConfig(system_monitor_interval_millis=500), rules=[], sagemaker_session=sagemaker_session, ) inputs = TrainingInput(f"s3://{BUCKET}/train_manifest") cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config) assert step.to_request() == { "Name": "MyTrainingStep", "Type": "Training", "Arguments": { "AlgorithmSpecification": { "TrainingImage": IMAGE_URI, "TrainingInputMode": "File" }, "InputDataConfig": [{ "ChannelName": "training", "DataSource": { "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3Uri": f"s3://{BUCKET}/train_manifest", } }, }], "OutputDataConfig": { "S3OutputPath": f"s3://{BUCKET}/" }, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "c4.4xlarge", "VolumeSizeInGB": 30, }, "RoleArn": ROLE, "StoppingCondition": { "MaxRuntimeInSeconds": 86400 }, "ProfilerConfig": { "ProfilingIntervalInMilliseconds": 500, "S3OutputPath": f"s3://{BUCKET}/", }, }, "CacheConfig": { "Enabled": True, "ExpireAfter": "PT1H" }, } assert step.properties.TrainingJobName.expr == { "Get": "Steps.MyTrainingStep.TrainingJobName" }
def test_training_step(sagemaker_session): estimator = Estimator( image_uri=IMAGE_URI, role=ROLE, instance_count=1, instance_type="c4.4xlarge", sagemaker_session=sagemaker_session, ) inputs = TrainingInput(f"s3://{BUCKET}/train_manifest") step = TrainingStep( name="MyTrainingStep", estimator=estimator, inputs=inputs, ) assert step.to_request() == { "Name": "MyTrainingStep", "Type": "Training", "Arguments": { "AlgorithmSpecification": { "TrainingImage": IMAGE_URI, "TrainingInputMode": "File" }, "InputDataConfig": [{ "ChannelName": "training", "DataSource": { "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3Uri": f"s3://{BUCKET}/train_manifest", } }, }], "OutputDataConfig": { "S3OutputPath": f"s3://{BUCKET}/" }, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "c4.4xlarge", "VolumeSizeInGB": 30, }, "RoleArn": ROLE, "StoppingCondition": { "MaxRuntimeInSeconds": 86400 }, }, } assert step.properties.TrainingJobName.expr == { "Get": "Steps.MyTrainingStep.TrainingJobName" }
def test_training_step_tensorflow(sagemaker_session): instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.p3.16xlarge") instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1) data_source_uri_parameter = ParameterString( name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest") training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5) training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500) estimator = TensorFlow( entry_point=os.path.join(DATA_DIR, SCRIPT_FILE), role=ROLE, model_dir=False, image_uri=IMAGE_URI, source_dir="s3://mybucket/source", framework_version="2.4.1", py_version="py37", instance_count=instance_count_parameter, instance_type=instance_type_parameter, sagemaker_session=sagemaker_session, # subnets=subnets, hyperparameters={ "batch-size": training_batch_size_parameter, "epochs": training_epochs_parameter, }, # security_group_ids=security_group_ids, debugger_hook_config=False, # Training using SMDataParallel Distributed Training Framework distribution={"smdistributed": { "dataparallel": { "enabled": True } }}, ) inputs = TrainingInput(s3_data=data_source_uri_parameter) cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config) step_request = step.to_request() step_request["Arguments"]["HyperParameters"].pop("sagemaker_job_name", None) step_request["Arguments"]["HyperParameters"].pop("sagemaker_program", None) step_request["Arguments"].pop("ProfilerRuleConfigurations", None) assert step_request == { "Name": "MyTrainingStep", "Type": "Training", "Arguments": { "AlgorithmSpecification": { "TrainingInputMode": "File", "TrainingImage": "fakeimage", "EnableSageMakerMetricsTimeSeries": True, }, "OutputDataConfig": { "S3OutputPath": "s3://my-bucket/" }, "StoppingCondition": { "MaxRuntimeInSeconds": 86400 }, "ResourceConfig": { "InstanceCount": instance_count_parameter, "InstanceType": instance_type_parameter, "VolumeSizeInGB": 30, }, "RoleArn": "DummyRole", "InputDataConfig": [{ "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": data_source_uri_parameter, "S3DataDistributionType": "FullyReplicated", } }, "ChannelName": "training", }], "HyperParameters": { "batch-size": training_batch_size_parameter, "epochs": training_epochs_parameter, "sagemaker_submit_directory": '"s3://mybucket/source"', "sagemaker_container_log_level": "20", "sagemaker_region": '"us-west-2"', "sagemaker_distributed_dataparallel_enabled": "true", "sagemaker_instance_type": instance_type_parameter, "sagemaker_distributed_dataparallel_custom_mpi_options": '""', }, "ProfilerConfig": { "S3OutputPath": "s3://my-bucket/" }, }, "CacheConfig": { "Enabled": True, "ExpireAfter": "PT1H" }, } assert step.properties.TrainingJobName.expr == { "Get": "Steps.MyTrainingStep.TrainingJobName" }
def test_training_step_base_estimator(sagemaker_session): instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge") instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1) data_source_uri_parameter = ParameterString( name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest") training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5) training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500) estimator = Estimator( image_uri=IMAGE_URI, role=ROLE, instance_count=instance_count_parameter, instance_type=instance_type_parameter, profiler_config=ProfilerConfig(system_monitor_interval_millis=500), hyperparameters={ "batch-size": training_batch_size_parameter, "epochs": training_epochs_parameter, }, rules=[], sagemaker_session=sagemaker_session, ) inputs = TrainingInput(s3_data=data_source_uri_parameter) cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TrainingStep( name="MyTrainingStep", depends_on=["TestStep"], estimator=estimator, inputs=inputs, cache_config=cache_config, ) step.add_depends_on(["AnotherTestStep"]) assert step.to_request() == { "Name": "MyTrainingStep", "Type": "Training", "DependsOn": ["TestStep", "AnotherTestStep"], "Arguments": { "AlgorithmSpecification": { "TrainingImage": IMAGE_URI, "TrainingInputMode": "File" }, "HyperParameters": { "batch-size": training_batch_size_parameter, "epochs": training_epochs_parameter, }, "InputDataConfig": [{ "ChannelName": "training", "DataSource": { "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3Uri": data_source_uri_parameter, } }, }], "OutputDataConfig": { "S3OutputPath": f"s3://{BUCKET}/" }, "ResourceConfig": { "InstanceCount": instance_count_parameter, "InstanceType": instance_type_parameter, "VolumeSizeInGB": 30, }, "RoleArn": ROLE, "StoppingCondition": { "MaxRuntimeInSeconds": 86400 }, "ProfilerConfig": { "ProfilingIntervalInMilliseconds": 500, "S3OutputPath": f"s3://{BUCKET}/", }, }, "CacheConfig": { "Enabled": True, "ExpireAfter": "PT1H" }, } assert step.properties.TrainingJobName.expr == { "Get": "Steps.MyTrainingStep.TrainingJobName" }