def test_create_model_step_with_model_pipeline(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) sparkml_model = SparkMLModel( model_data="s3://bucket/model_2.tar.gz", role=ROLE, sagemaker_session=sagemaker_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, ) model = PipelineModel(models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session) inputs = CreateModelInput( instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) step = CreateModelStep( name="MyCreateModelStep", depends_on=["TestStep"], display_name="MyCreateModelStep", description="TestDescription", model=model, inputs=inputs, ) step.add_depends_on(["SecondTestStep"]) assert step.to_request() == { "Name": "MyCreateModelStep", "Type": "Model", "Description": "TestDescription", "DisplayName": "MyCreateModelStep", "DependsOn": ["TestStep", "SecondTestStep"], "Arguments": { "Containers": [ { "Environment": { "SAGEMAKER_PROGRAM": "dummy_script.py", "SAGEMAKER_SUBMIT_DIRECTORY": "s3://my-bucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz", "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", "SAGEMAKER_REGION": "us-west-2", }, "Image": "mi-1", "ModelDataUrl": "s3://bucket/model_1.tar.gz", }, { "Environment": { "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv" }, "Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4", "ModelDataUrl": "s3://bucket/model_2.tar.gz", }, ], "ExecutionRoleArn": "DummyRole", }, } assert step.properties.ModelName.expr == { "Get": "Steps.MyCreateModelStep.ModelName" }
def test_create_model_step(sagemaker_session): model = Model( image_uri=IMAGE_URI, role=ROLE, sagemaker_session=sagemaker_session, ) inputs = CreateModelInput( instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) step = CreateModelStep( name="MyCreateModelStep", depends_on=["TestStep"], model=model, inputs=inputs, ) step.add_depends_on(["SecondTestStep"]) assert step.to_request() == { "Name": "MyCreateModelStep", "Type": "Model", "DependsOn": ["TestStep", "SecondTestStep"], "Arguments": { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { "Environment": {}, "Image": "fakeimage" }, }, } assert step.properties.ModelName.expr == { "Get": "Steps.MyCreateModelStep.ModelName" }
def _append_create_model_step(self): """Create and append a `CreateModelStep`""" create_model_step = CreateModelStep( name="{}-{}".format(self.name, _CREATE_MODEL_NAME_BASE), step_args=self._create_model_args, retry_policies=self._create_model_retry_policies, display_name=self.display_name, description=self.description, ) if not self._need_runtime_repack: create_model_step.add_depends_on(self.depends_on) self.steps.append(create_model_step)
def __init__( self, name: str, estimator: EstimatorBase, model_data, model_inputs, instance_count, instance_type, transform_inputs, # model arguments image_uri=None, predictor_cls=None, env=None, # transformer arguments strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None, volume_kms_key=None, depends_on: List[str] = None, **kwargs, ): """Construct steps required for a Transformer step collection: An estimator-centric step collection. It models what happens in workflows when invoking the `transform()` method on an estimator instance: First, if custom model artifacts are required, a `_RepackModelStep` is included. Second, a `CreateModelStep` with the model data passed in from a training step or other training job output. Finally, a `TransformerStep`. If repacking the model artifacts is not necessary, only the CreateModelStep and TransformerStep are in the step collection. Args: name (str): The name of the Transform Step. estimator: The estimator instance. instance_count (int): The number of EC2 instances to use. instance_type (str): The type of EC2 instance to use. strategy (str): The strategy used to decide how to batch records in a single request (default: None). Valid values: 'MultiRecord' and 'SingleRecord'. assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'. output_path (str): The S3 location for saving the transform result. If not specified, results are stored to a default bucket. output_kms_key (str): Optional. A KMS key ID for encrypting the transform output (default: None). accept (str): The accept header passed by the client to the inference endpoint. If it is supported by the endpoint, it will be the format of the batch transform output. env (dict): The Environment variables to be set for use during the transform job (default: None). depends_on (List[str]): The list of step names the first step in the collection depends on """ steps = [] if "entry_point" in kwargs: entry_point = kwargs["entry_point"] source_dir = kwargs.get("source_dir") dependencies = kwargs.get("dependencies") repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, estimator=estimator, model_data=model_data, entry_point=entry_point, source_dir=source_dir, dependencies=dependencies, ) steps.append(repack_model_step) model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts def predict_wrapper(endpoint, session): return Predictor(endpoint, session) predictor_cls = predictor_cls or predict_wrapper model = Model( image_uri=image_uri or estimator.training_image_uri(), model_data=model_data, predictor_cls=predictor_cls, vpc_config=None, sagemaker_session=estimator.sagemaker_session, role=estimator.role, **kwargs, ) model_step = CreateModelStep( name=f"{name}CreateModelStep", model=model, inputs=model_inputs, ) if "entry_point" not in kwargs and depends_on: # if the CreateModelStep is the first step in the collection model_step.add_depends_on(depends_on) steps.append(model_step) transformer = Transformer( model_name=model_step.properties.ModelName, instance_count=instance_count, instance_type=instance_type, strategy=strategy, assemble_with=assemble_with, output_path=output_path, output_kms_key=output_kms_key, accept=accept, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=env, tags=tags, base_transform_job_name=name, volume_kms_key=volume_kms_key, sagemaker_session=estimator.sagemaker_session, ) transform_step = TransformStep( name=f"{name}TransformStep", transformer=transformer, inputs=transform_inputs, ) steps.append(transform_step) self.steps = steps