def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version): data_path = os.path.join(DATA_DIR, "mxnet_mnist") script_path = os.path.join(data_path, "mnist.py") tags = [{"Key": "some-tag", "Value": "value-for-tag"}] mx = MXNet( entry_point=script_path, role="SageMakerRole", train_instance_count=1, train_instance_type="ml.c4.xlarge", sagemaker_session=sagemaker_session, framework_version=mxnet_full_version, ) train_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" ) test_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" ) job_name = unique_name_from_base("test-mxnet-transform") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): mx.fit({"train": train_input, "test": test_input}, job_name=job_name) transform_input_path = os.path.join(data_path, "transform", "data.csv") transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" transform_input = mx.sagemaker_session.upload_data( path=transform_input_path, key_prefix=transform_input_key_prefix ) transformer = mx.transformer(1, "ml.m4.xlarge", tags=tags) transformer.transform(transform_input, content_type="text/csv") with timeout_and_delete_model_with_transformer( transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES ): transformer.wait() model_desc = sagemaker_session.sagemaker_client.describe_model( ModelName=transformer.model_name ) model_tags = sagemaker_session.sagemaker_client.list_tags( ResourceArn=model_desc["ModelArn"] )["Tags"] assert tags == model_tags
def test_stop_transform_job(sagemaker_session, mxnet_full_version, cpu_instance_type): data_path = os.path.join(DATA_DIR, "mxnet_mnist") script_path = os.path.join(data_path, "mnist.py") tags = [{"Key": "some-tag", "Value": "value-for-tag"}] mx = MXNet( entry_point=script_path, role="SageMakerRole", train_instance_count=1, train_instance_type=cpu_instance_type, sagemaker_session=sagemaker_session, framework_version=mxnet_full_version, ) train_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train") test_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test") job_name = unique_name_from_base("test-mxnet-transform") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): mx.fit({"train": train_input, "test": test_input}, job_name=job_name) transform_input_path = os.path.join(data_path, "transform", "data.csv") transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" transform_input = mx.sagemaker_session.upload_data( path=transform_input_path, key_prefix=transform_input_key_prefix) transformer = mx.transformer(1, cpu_instance_type, tags=tags) transformer.transform(transform_input, content_type="text/csv") time.sleep(15) latest_transform_job_name = transformer.latest_transform_job.name print("Attempting to stop {}".format(latest_transform_job_name)) transformer.stop_transform_job() desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job( TransformJobName=latest_transform_job_name) assert desc["TransformJobStatus"] == "Stopped"
def test_local_transform_mxnet(sagemaker_local_session, tmpdir, mxnet_full_version): data_path = os.path.join(DATA_DIR, 'mxnet_mnist') script_path = os.path.join(data_path, 'mnist.py') mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, train_instance_type='ml.c4.xlarge', framework_version=mxnet_full_version, sagemaker_session=sagemaker_local_session) train_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, 'train'), key_prefix='integ-test-data/mxnet_mnist/train') test_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, 'test'), key_prefix='integ-test-data/mxnet_mnist/test') with timeout(minutes=15): mx.fit({'train': train_input, 'test': test_input}) transform_input_path = os.path.join(data_path, 'transform') transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' transform_input = mx.sagemaker_session.upload_data( path=transform_input_path, key_prefix=transform_input_key_prefix) output_path = 'file://%s' % (str(tmpdir)) transformer = mx.transformer(1, 'local', assemble_with='Line', max_payload=1, strategy='SingleRecord', output_path=output_path) with local_mode_utils.lock(): transformer.transform(transform_input, content_type='text/csv', split_type='Line') transformer.wait() assert os.path.exists(os.path.join(str(tmpdir), 'data.csv.out'))
def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version): data_path = os.path.join(DATA_DIR, 'mxnet_mnist') script_path = os.path.join(data_path, 'mnist.py') tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, framework_version=mxnet_full_version) train_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, 'train'), key_prefix='integ-test-data/mxnet_mnist/train') test_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, 'test'), key_prefix='integ-test-data/mxnet_mnist/test') job_name = unique_name_from_base('test-mxnet-transform') with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) transform_input_path = os.path.join(data_path, 'transform', 'data.csv') transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' transform_input = mx.sagemaker_session.upload_data( path=transform_input_path, key_prefix=transform_input_key_prefix) transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) transformer.transform(transform_input, content_type='text/csv') with timeout_and_delete_model_with_transformer( transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): transformer.wait() model_desc = sagemaker_session.sagemaker_client.describe_model( ModelName=transformer.model_name) model_tags = sagemaker_session.sagemaker_client.list_tags( ResourceArn=model_desc['ModelArn'])['Tags'] assert tags == model_tags