def test_mk_training_job_with_vpc_config(self, prj): prj.train["vpc_config"] = { "security_groups": ["sg-1"], "subnets": ["net-2"] } training_job_cfg = cli_utils.mk_training_job(prj, "training-job-1", "dataset-1") assert training_job_cfg["VpcConfig"] == { "SecurityGroupIds": ["sg-1"], "Subnets": ["net-2"], }
def test_mk_training_job(self, prj): training_job_cfg = cli_utils.mk_training_job(prj, "training-job-1", "dataset-1") assert training_job_cfg == { "TrainingJobName": "modelling-project-training-job-1", "AlgorithmSpecification": { "TrainingImage": ("123456789012.dkr.ecr.eu-west-1" ".amazonaws.com/modelling-project-sagemaker:latest"), "TrainingInputMode": "File", }, "EnableNetworkIsolation": True, "HyperParameters": { "ML2P_ENV.ML2P_PROJECT": '"modelling-project"', "ML2P_ENV.ML2P_S3_URL": ('"s3://prodigyfinance-modelling-project-sagemaker-production/"' ), }, "InputDataConfig": [{ "ChannelName": "training", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": "s3://prodigyfinance-modelling-project-" "sagemaker-production/datasets/dataset-1", } }, }], "OutputDataConfig": { "S3OutputPath": "s3://prodigyfinance-modelling-project" "-sagemaker-production/models/" }, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.2xlarge", "VolumeSizeInGB": 20, }, "RoleArn": "arn:aws:iam::111111111111:role/modelling-project", "StoppingCondition": { "MaxRuntimeInSeconds": 60 * 60 }, "Tags": [{ "Key": "ml2p-project", "Value": "modelling-project" }], }
def test_mk_training_job_with_model_type(self, prj): prj.models["model-type-1"] = "my.pkg.model" training_job_cfg = cli_utils.mk_training_job(prj, "training-job-1", "dataset-1", "model-type-1") assert training_job_cfg["HyperParameters"] == { "ML2P_ENV.ML2P_MODEL_CLS": '"my.pkg.model"', "ML2P_ENV.ML2P_PROJECT": '"modelling-project"', "ML2P_ENV.ML2P_S3_URL": ('"s3://prodigyfinance-modelling-project-sagemaker-production/"'), }
def test_mk_training_job_with_missing_model_type(self, prj): with pytest.raises(KeyError) as err: cli_utils.mk_training_job(prj, "training-job-1", "dataset-1", "model-type-1") assert str(err.value) == "'model-type-1'"