def test_hyperparameter_optimization_happy_case(): with patch( 'boto3.Session' ): with patch( 'sagemaker.Session' ) as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch( 'sagemaker.get_execution_role', return_value='arn_role' ): with patch( 'sagemaker.estimator.Estimator' ) as mocked_sagemaker_estimator: with patch( 'sagify.sagemaker.sagemaker.SageMakerClient._construct_image_location', return_value='image-full-name' ): with patch( 'sagemaker.tuner.HyperparameterTuner' ) as mocked_sagemaker_tuner: sage_maker_client = sagemaker.SageMakerClient('sagemaker', 'us-east-1') sage_maker_client.hyperparameter_optimization( image_name='image', input_s3_data_location='s3://bucket/input', instance_count=1, instance_type='m1.xlarge', volume_size=30, max_run=60, max_jobs=3, max_parallel_jobs=2, output_path='s3://bucket/output', objective_type='Maximize', objective_metric_name='Precision', hyperparams_ranges_dict={ 'lr': ContinuousParameter(0.001, 0.1), 'batch-size': CategoricalParameter([32, 64, 128, 256, 512]) }, base_job_name="Some-job-name-prefix", job_name="some job name" ) mocked_sagemaker_estimator.assert_called_with( image_name='image-full-name', role='arn_role', train_instance_count=1, train_instance_type='m1.xlarge', train_volume_size=30, train_max_run=60, input_mode='File', output_path='s3://bucket/output', sagemaker_session=sagemaker_session_instance ) mocked_sagemaker_tuner_instance = mocked_sagemaker_tuner.return_value assert mocked_sagemaker_tuner_instance.fit.call_count == 1 mocked_sagemaker_tuner_instance.fit.assert_called_with( 's3://bucket/input', job_name='some job name' )
def test_deploy_xgboost_happy_case(): with patch('boto3.Session'): with patch('sagemaker.Session') as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch('sagemaker.get_execution_role', return_value='arn_role'): with patch('sagemaker.xgboost.model.XGBoostModel' ) as mocked_sagemaker_xgboost_model: sage_maker_client = sagemaker.SageMakerClient( 'sagemaker', 'us-east-1') sage_maker_client.deploy_xgboost( s3_model_location= 's3://bucket/model_input/model.tar.gz', instance_count=1, instance_type='m1.xlarge', framework_version='0.90-2') mocked_sagemaker_xgboost_model.assert_called_with( role='arn_role', model_data='s3://bucket/model_input/model.tar.gz', framework_version='0.90-2', py_version='py3', entry_point='xgboost_inference.py', source_dir=os.path.join(sagemaker._FILE_DIR_PATH, 'xgboost_code'), model_server_workers=None, sagemaker_session=sagemaker_session_instance) sagemaker_model_instance = mocked_sagemaker_xgboost_model.return_value assert sagemaker_model_instance.deploy.call_count == 1 sagemaker_model_instance.deploy.assert_called_with( initial_instance_count=1, instance_type='m1.xlarge', tags=None, endpoint_name=None)
def train(dir, input_s3_dir, output_s3_dir, hyperparams_file, ec2_type, volume_size, time_out): """ Command to train ML model(s) on SageMaker """ logger.info(ASCII_LOGO) logger.info("Started training on SageMaker...\n") config = _read_config(dir) hyperparams_dict = _read_hyperparams_config( hyperparams_file) if hyperparams_file else None sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region) s3_model_location = sage_maker_client.train( image_name=config.image_name, input_s3_data_location=input_s3_dir, train_instance_count=1, train_instance_type=ec2_type, train_volume_size=volume_size, train_max_run=time_out, output_path=output_s3_dir, hyperparameters=hyperparams_dict) logger.info("Training on SageMaker succeeded") logger.info("Model S3 location: {}".format(s3_model_location))
def test_deploy_happy_case(): with patch('boto3.Session'): with patch('sagemaker.Session') as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch('sagemaker.get_execution_role', return_value='arn_role'): with patch('sagemaker.Model') as mocked_sagemaker_model: with patch( 'sagify.sagemaker.sagemaker.SageMakerClient._construct_image_location', return_value='image-full-name'): sage_maker_client = sagemaker.SageMakerClient( 'sagemaker', 'us-east-1') sage_maker_client.deploy( image_name='image', s3_model_location= 's3://bucket/model_input/model.tar.gz', train_instance_count=1, train_instance_type='m1.xlarge') mocked_sagemaker_model.assert_called_with( model_data='s3://bucket/model_input/model.tar.gz', image='image-full-name', role='arn_role', sagemaker_session=sagemaker_session_instance) sagemaker_model_instance = mocked_sagemaker_model.return_value assert sagemaker_model_instance.deploy.call_count == 1 sagemaker_model_instance.deploy.assert_called_with( initial_instance_count=1, instance_type='m1.xlarge', tags=None, endpoint_name=None, update_endpoint=True)
def test_batch_transform_with_tags(): with patch('boto3.Session'): with patch('sagemaker.Session') as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch('sagemaker.get_execution_role', return_value='arn_role'): with patch('sagemaker.Model') as mocked_sagemaker_model: with patch( 'sagify.sagemaker.sagemaker.SageMakerClient._construct_image_location', return_value='image-full-name'): sage_maker_client = sagemaker.SageMakerClient( 'sagemaker', 'us-east-1') tags = [ { 'Key': 'key_name_1', 'Value': 1, }, { 'Key': 'key_name_2', 'Value': '2', }, ] sage_maker_client.batch_transform( image_name='image', s3_model_location= 's3://bucket/model_input/model.tar.gz', s3_input_location='s3://bucket/input_data', s3_output_location='s3://bucket/output_data', transform_instance_count=1, transform_instance_type='m1.xlarge', tags=tags) mocked_sagemaker_model.assert_called_with( model_data='s3://bucket/model_input/model.tar.gz', image='image-full-name', role='arn_role', sagemaker_session=sagemaker_session_instance) sagemaker_model_instance = mocked_sagemaker_model.return_value assert sagemaker_model_instance.transformer.call_count == 1 sagemaker_model_instance.transformer.assert_called_with( instance_type='m1.xlarge', instance_count=1, assemble_with='Line', output_path='s3://bucket/output_data', tags=tags, accept='application/json', strategy="SingleRecord") transformer = sagemaker_model_instance.transformer.return_value assert transformer.transform.call_count == 1 transformer.transform.assert_called_with( data='s3://bucket/input_data', split_type='Line', content_type='application/json', job_name=None)
def train(dir, job_name, input_s3_dir, output_s3_dir, hyperparams_file, ec2_type, volume_size, time_out, docker_tag, tags=None): """ Trains ML model(s) on SageMaker :param dir: [str], source root directory :param job_name [str], training job name :param input_s3_dir: [str], S3 location to input data :param output_s3_dir: [str], S3 location to save output (models, etc) :param hyperparams_file: [str], path to hyperparams json file :param ec2_type: [str], ec2 instance type. Refere to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param volume_size: [int], size in GB of the EBS volume :param time_out: [int], time-out in seconds :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: [ { 'Key': 'key_name_1', 'Value': key_value_1, }, { 'Key': 'key_name_2', 'Value': key_value_2, }, ... ] :return: [str], S3 model location """ config = _read_config(dir) hyperparams_dict = _read_hyperparams_config( hyperparams_file) if hyperparams_file else None sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region) image_name = config.image_name + ':' + docker_tag return sage_maker_client.train(image_name=image_name, job_name=job_name, input_s3_data_location=input_s3_dir, train_instance_count=1, train_instance_type=ec2_type, train_volume_size=volume_size, train_max_run=time_out, output_path=output_s3_dir, hyperparameters=hyperparams_dict, tags=tags)
def test_deploy_with_tags(): with patch( 'boto3.Session' ): with patch( 'sagemaker.Session' ) as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch( 'sagemaker.get_execution_role', return_value='arn_role' ): with patch( 'sagemaker.Model' ) as mocked_sagemaker_model: with patch( 'sagify.sagemaker.sagemaker.SageMakerClient._construct_image_location', return_value='image-full-name' ): sage_maker_client = sagemaker.SageMakerClient('sagemaker', 'us-east-1') tags = [ { 'Key': 'key_name_1', 'Value': 1, }, { 'Key': 'key_name_2', 'Value': '2', }, ] sage_maker_client.deploy( image_name='image', s3_model_location='s3://bucket/model_input/model.tar.gz', train_instance_count=1, train_instance_type='m1.xlarge', tags=tags ) mocked_sagemaker_model.assert_called_with( model_data='s3://bucket/model_input/model.tar.gz', image='image-full-name', role='arn_role', sagemaker_session=sagemaker_session_instance ) sagemaker_model_instance = mocked_sagemaker_model.return_value assert sagemaker_model_instance.deploy.call_count == 1 sagemaker_model_instance.deploy.assert_called_with( initial_instance_count=1, instance_type='m1.xlarge', tags=tags )
def deploy( dir, s3_model_location, num_instances, ec2_type, docker_tag, aws_role=None, external_id=None, tags=None, endpoint_name=None ): """ Deploys ML model(s) on SageMaker :param dir: [str], source root directory :param s3_model_location: [str], S3 model location :param num_instances: [int], number of ec2 instances :param ec2_type: [str], ec2 instance type. Refer to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param docker_tag: [str], the Docker tag for the image :param aws_role: [str], the AWS role assumed by SageMaker while deploying :param external_id: [str], Optional external id used when using an IAM role :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: [ { 'Key': 'key_name_1', 'Value': key_value_1, }, { 'Key': 'key_name_2', 'Value': key_value_2, }, ... ] :param endpoint_name: [optional[str]], Optional name for the SageMaker endpoint :return: [str], endpoint name """ config = _read_config(dir) image_name = config.image_name+':'+docker_tag sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region, aws_role, external_id) return sage_maker_client.deploy( image_name=image_name, s3_model_location=s3_model_location, train_instance_count=num_instances, train_instance_type=ec2_type, tags=tags, endpoint_name=endpoint_name )
def batch_transform(dir, s3_model_location, s3_input_location, s3_output_location, num_instances, ec2_type, docker_tag, aws_role=None, external_id=None, tags=None): """ Executes a batch transform job given a trained ML model on SageMaker :param dir: [str], source root directory :param s3_model_location: [str], S3 model location :param s3_input_location: [str], S3 input data location :param s3_output_location: [str], S3 location to save predictions :param num_instances: [int], number of ec2 instances :param ec2_type: [str], ec2 instance type. Refer to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param docker_tag: [str], the Docker tag for the image :param aws_role: [str], the AWS role assumed by SageMaker while deploying :param external_id: [str], Optional external id used when using an IAM role :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: [ { 'Key': 'key_name_1', 'Value': key_value_1, }, { 'Key': 'key_name_2', 'Value': key_value_2, }, ... ] """ config = _read_config(dir) image_name = config.image_name + ':' + docker_tag sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region, aws_role, external_id) sage_maker_client.batch_transform(image_name=image_name, s3_model_location=s3_model_location, s3_input_location=s3_input_location, s3_output_location=s3_output_location, transform_instance_count=num_instances, transform_instance_type=ec2_type, tags=tags)
def upload_data(dir, input_dir, s3_dir): """ Uploads data to S3 :param dir: [str], source root directory :param input_dir: [str], path to local data input directory :param s3_dir: [str], S3 location to upload data :return: [str], S3 location to upload data """ config = _read_config(dir) sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region) return sage_maker_client.upload_data(input_dir, s3_dir)
def test_train_happy_case(): with patch( 'boto3.Session' ): with patch( 'sagemaker.Session' ) as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch( 'sagemaker.get_execution_role', return_value='arn_role' ): with patch( 'sagemaker.estimator.Estimator' ) as mocked_sagemaker_estimator: with patch( 'sagify.sagemaker.sagemaker.SageMakerClient._construct_image_location', return_value='image-full-name' ): sage_maker_client = sagemaker.SageMakerClient('sagemaker', 'us-east-1') sage_maker_client.train( image_name='image', input_s3_data_location='s3://bucket/input', train_instance_count=1, train_instance_type='m1.xlarge', train_volume_size=30, train_max_run=60, output_path='s3://bucket/output', hyperparameters={'n_estimator': 3}, base_job_name="Some-job-name-prefix", job_name="some job name" ) mocked_sagemaker_estimator.assert_called_with( image_name='image-full-name', role='arn_role', train_instance_count=1, train_instance_type='m1.xlarge', train_volume_size=30, train_max_run=60, input_mode='File', base_job_name="Some-job-name-prefix", output_path='s3://bucket/output', hyperparameters={'n_estimator': 3}, sagemaker_session=sagemaker_session_instance, metric_definitions=None ) sagemaker_estimator_instance = mocked_sagemaker_estimator.return_value assert sagemaker_estimator_instance.fit.call_count == 1 sagemaker_estimator_instance.fit.assert_called_with('s3://bucket/input', job_name='some job name')
def upload_data(dir, input_dir, s3_dir): """ Command to upload data to S3 """ logger.info(ASCII_LOGO) logger.info("Started uploading data to S3...\n") config = _read_config(dir) sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region) s3_path = sage_maker_client.upload_data(input_dir, s3_dir) logger.info("Data uploaded to {} successfully".format(s3_path))
def test_upload_data_with_s3_path_that_contains_only_bucket_name(): with patch('boto3.Session'): with patch('sagemaker.Session') as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch('sagemaker.get_execution_role', return_value='arn_role'): sage_maker_client = sagemaker.SageMakerClient( 'sagemaker', 'us-east-1') sage_maker_client.upload_data(input_dir='/input/data', s3_dir='s3://bucket/') assert sagemaker_session_instance.upload_data.call_count == 1 sagemaker_session_instance.upload_data.assert_called_with( path='/input/data', bucket='bucket', key_prefix='data')
def deploy(dir, s3_model_location, model_name, vpc_configs, num_instances, ec2_type, docker_tag, tags=None): """ Deploys ML model(s) on SageMaker :param dir: [str], source root directory :param s3_model_location: [str], S3 model location :param model_name: [str], SageMaker model name :param num_instances: [int], number of ec2 instances :param ec2_type: [str], ec2 instance type. Refere to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: [ { 'Key': 'key_name_1', 'Value': key_value_1, }, { 'Key': 'key_name_2', 'Value': key_value_2, }, ... ] :return: [str], endpoint name """ config = _read_config(dir) image_name = config.image_name sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region, vpc_configs) return sage_maker_client.deploy(image_name=image_name, s3_model_location=s3_model_location, model_name=model_name, train_instance_count=num_instances, train_instance_type=ec2_type, tags=tags)
def deploy(dir, s3_model_location, num_instances, ec2_type): """ Command to deploy ML model(s) on SageMaker """ logger.info(ASCII_LOGO) logger.info("Started deployment on SageMaker ...\n") config = _read_config(dir) sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region) endpoint_name = sage_maker_client.deploy( image_name=config.image_name, s3_model_location=s3_model_location, train_instance_count=num_instances, train_instance_type=ec2_type) logger.info("Model deployed to SageMaker successfully") logger.info("Endpoint name: {}".format(endpoint_name))
def test_deploy_hugging_face_with_hub(): with patch('boto3.Session'): with patch('sagemaker.Session') as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch('sagemaker.get_execution_role', return_value='arn_role'): with patch('sagemaker.huggingface.HuggingFaceModel' ) as mocked_sagemaker_hf_model: sage_maker_client = sagemaker.SageMakerClient( 'sagemaker', 'us-east-1') hub = {'HF_MODEL_ID': 'gpt2', 'HF_TASK': 'text-generation'} sage_maker_client.deploy_hugging_face( instance_count=1, instance_type='m1.xlarge', transformers_version='4.6.1', pytorch_version='1.7.1', hub=hub) mocked_sagemaker_hf_model.assert_called_with( role='arn_role', model_data=None, transformers_version='4.6.1', pytorch_version='1.7.1', tensorflow_version=None, model_server_workers=None, py_version='py36', env=hub, sagemaker_session=sagemaker_session_instance) sagemaker_model_instance = mocked_sagemaker_hf_model.return_value assert sagemaker_model_instance.deploy.call_count == 1 sagemaker_model_instance.deploy.assert_called_with( initial_instance_count=1, instance_type='m1.xlarge', tags=None, endpoint_name=None)
def test_deploy_hugging_face_happy_case(): with patch('boto3.Session'): with patch('sagemaker.Session') as mocked_sagemaker_session: sagemaker_session_instance = mocked_sagemaker_session.return_value with patch('sagemaker.get_execution_role', return_value='arn_role'): with patch('sagemaker.huggingface.HuggingFaceModel' ) as mocked_sagemaker_hf_model: sage_maker_client = sagemaker.SageMakerClient( 'sagemaker', 'us-east-1') sage_maker_client.deploy_hugging_face( s3_model_location= 's3://bucket/model_input/model.tar.gz', instance_count=1, instance_type='m1.xlarge', transformers_version='4.6.1', pytorch_version='1.7.1') mocked_sagemaker_hf_model.assert_called_with( role='arn_role', model_data='s3://bucket/model_input/model.tar.gz', transformers_version='4.6.1', pytorch_version='1.7.1', tensorflow_version=None, model_server_workers=None, py_version='py36', env=None, sagemaker_session=sagemaker_session_instance) sagemaker_model_instance = mocked_sagemaker_hf_model.return_value assert sagemaker_model_instance.deploy.call_count == 1 sagemaker_model_instance.deploy.assert_called_with( initial_instance_count=1, instance_type='m1.xlarge', tags=None, endpoint_name=None)
def lightning_deploy(framework, num_instances, ec2_type, aws_region, s3_model_location=None, model_server_workers=None, aws_profile=None, aws_role=None, external_id=None, tags=None, endpoint_name=None, extra_config_file=None): """ Deploys ML model(s) on SageMaker without code :param framework: [str], The name of the ML framework. Valid values: sklearn, huggingface, xgboost :param num_instances: [int], number of ec2 instances :param ec2_type: [str], ec2 instance type. Refer to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param aws_region: [str], the AWS region :param s3_model_location: [str], S3 model location :param model_server_workers: [int], Optional number of worker processes used by the inference server. If None, server will use one worker per vCPU. :param aws_profile: [optional[str]], Optional AWS profile :param aws_role: [optional[str]], Optional AWS role assumed by SageMaker while deploying :param external_id: [optional[str]], Optional external id used when using an IAM role :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: [ { 'Key': 'key_name_1', 'Value': key_value_1, }, { 'Key': 'key_name_2', 'Value': key_value_2, }, ... ] :param endpoint_name: [optional[str]], Optional name for the SageMaker endpoint :param extra_config_file: [optional[str]], Optional Json file with ML framework specific arguments :return: [str], endpoint name """ sage_maker_client = sagemaker.SageMakerClient(aws_profile, aws_region, aws_role, external_id) if not os.path.isfile(extra_config_file): raise ValueError("The given extra config file {} doesn't exist".format( extra_config_file)) with open(extra_config_file) as _in_file: extra_config_dict = json.load(_in_file) if framework == 'sklearn': return sage_maker_client.deploy_sklearn( s3_model_location=s3_model_location, instance_count=num_instances, instance_type=ec2_type, model_server_workers=model_server_workers, tags=tags, endpoint_name=endpoint_name, **extra_config_dict) elif framework == 'huggingface': return sage_maker_client.deploy_hugging_face( s3_model_location=s3_model_location, instance_count=num_instances, instance_type=ec2_type, model_server_workers=model_server_workers, tags=tags, endpoint_name=endpoint_name, **extra_config_dict) elif framework == 'xgboost': return sage_maker_client.deploy_xgboost( s3_model_location=s3_model_location, instance_count=num_instances, instance_type=ec2_type, model_server_workers=model_server_workers, tags=tags, endpoint_name=endpoint_name, **extra_config_dict) raise ValueError("Invalid framework value")
def hyperparameter_optimization( dir, input_s3_dir, output_s3_dir, hyperparams_config_file, ec2_type, max_jobs, max_parallel_jobs, volume_size, time_out, docker_tag, aws_role, external_id, base_job_name, job_name, wait, tags=None ): """ Hyperparameter Optimization on SageMaker :param dir: [str], source root directory :param input_s3_dir: [str], S3 location to input data :param output_s3_dir: [str], S3 location to save the multiple trained models :param hyperparams_config_file: [str], path to hyperparameters config json file :param ec2_type: [str], ec2 instance type. Refer to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param max_jobs: [int], Maximum total number of training jobs to start for the hyperparameter tuning job :param max_parallel_jobs: [int], Maximum number of parallel training jobs to start :param volume_size: [int], size in GB of the EBS volume :param time_out: [int], time-out in seconds :param docker_tag: [str], the Docker tag for the image :param aws_role: [str], the AWS role assumed by SageMaker while training :param external_id: [str], Optional external id used when using an IAM role :param base_job_name: [str], Optional prefix for the SageMaker training job :param job_name: [str], Optional name for the SageMaker tuning job. Overrides `base_job_name` :param wait: [bool, default=False], Wait until hyperparameter tuning is done :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: [ { 'Key': 'key_name_1', 'Value': key_value_1, }, { 'Key': 'key_name_2', 'Value': key_value_2, }, ... ] :return: [str], S3 model location """ config = _read_config(dir) objective_metric_name, objective_type, hyperparams_ranges_dict = _read_hyperparams_ranges_config( hyperparams_config_file ) sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region, aws_role, external_id) image_name = config.image_name+':'+docker_tag return sage_maker_client.hyperparameter_optimization( image_name=image_name, input_s3_data_location=input_s3_dir, instance_count=1, instance_type=ec2_type, volume_size=volume_size, objective_type=objective_type, objective_metric_name=objective_metric_name, max_jobs=max_jobs, max_parallel_jobs=max_parallel_jobs, max_run=time_out, output_path=output_s3_dir, hyperparams_ranges_dict=hyperparams_ranges_dict, base_job_name=base_job_name, job_name=job_name, tags=tags, wait=wait )
def train( dir, input_s3_dir, output_s3_dir, hyperparams_file, ec2_type, volume_size, time_out, docker_tag, aws_role, external_id, base_job_name, job_name, metric_names=None, tags=None ): """ Trains ML model(s) on SageMaker :param dir: [str], source root directory :param input_s3_dir: [str], S3 location to input data :param output_s3_dir: [str], S3 location to save output (models, etc) :param hyperparams_file: [str], path to hyperparams json file :param ec2_type: [str], ec2 instance type. Refer to: https://aws.amazon.com/sagemaker/pricing/instance-types/ :param volume_size: [int], size in GB of the EBS volume :param time_out: [int], time-out in seconds :param docker_tag: [str], the Docker tag for the image :param aws_role: [str], the AWS role assumed by SageMaker while training :param external_id: [str], Optional external id used when using an IAM role :param base_job_name: [str], Optional prefix for the SageMaker training job :param job_name: [str], Optional name for the SageMaker training job. Overrides `base_job_name` :param metric_names: [list[str], default=None], Optional list of string metric names :param tags: [optional[list[dict]], default: None], List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example: [ { 'Key': 'key_name_1', 'Value': key_value_1, }, { 'Key': 'key_name_2', 'Value': key_value_2, }, ... ] :return: [str], S3 model location """ config = _read_config(dir) hyperparams_dict = _read_hyperparams_config(hyperparams_file) if hyperparams_file else None sage_maker_client = sagemaker.SageMakerClient(config.aws_profile, config.aws_region, aws_role, external_id) image_name = config.image_name+':'+docker_tag return sage_maker_client.train( image_name=image_name, input_s3_data_location=input_s3_dir, train_instance_count=1, train_instance_type=ec2_type, train_volume_size=volume_size, train_max_run=time_out, output_path=output_s3_dir, hyperparameters=hyperparams_dict, base_job_name=base_job_name, job_name=job_name, tags=tags, metric_names=metric_names )