def test_method_modify_node(methods, caplog): modifier = image_uris.ImageURIRetrieveRefactor() method = "get_image_uri('us-west-2', 'xgboost')" node = ast_call(method) modifier.modify_node(node) assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) method = "amazon_estimator.get_image_uri('us-west-2', 'xgboost')" node = ast_call(method) modifier.modify_node(node) assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='xgboost')" node = ast_call(method) modifier.modify_node(node) assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) method = "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='xgboost')" node = ast_call(method) modifier.modify_node(node) assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node) method = ( "sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'xgboost', repo_version='1')" ) node = ast_call(method) modifier.modify_node(node) assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2', '1')" == pasta.dump(node)
def test_model_node_should_be_modified_no_distribution(): modifier = renamed_params.ModelImageURIRenamer() for model, namespaces in MODELS.items(): call = "{}()".format(model) assert not modifier.node_should_be_modified(ast_call(call)) for namespace in namespaces: call = "{}.{}()".format(namespace, model) assert not modifier.node_should_be_modified(ast_call(call))
def test_model_node_should_be_modified(): modifier = renamed_params.ModelImageURIRenamer() for model, namespaces in MODELS.items(): call = "{}(image='my-image:latest')".format(model) assert modifier.node_should_be_modified(ast_call(call)) for namespace in namespaces: call = "{}.{}(image='my-image:latest')".format(namespace, model) assert modifier.node_should_be_modified(ast_call(call))
def test_estimator_node_should_be_modified_no_distribution(): modifier = renamed_params.EstimatorImageURIRenamer() for estimator, namespaces in ESTIMATORS.items(): call = "{}()".format(estimator) assert not modifier.node_should_be_modified(ast_call(call)) for namespace in namespaces: call = "{}.{}()".format(namespace, estimator) assert not modifier.node_should_be_modified(ast_call(call))
def test_matches_name_or_namespaces(): name = "KMeans" namespaces = ("sagemaker", "sagemaker.amazon.kmeans") matches = ("KMeans()", "sagemaker.KMeans()") for call in matches: assert matching.matches_name_or_namespaces(ast_call(call), name, namespaces) non_matches = ("MXNet()", "sagemaker.mxnet.MXNet()") for call in non_matches: assert not matching.matches_name_or_namespaces(ast_call(call), name, namespaces)
def test_estimator_node_should_be_modified(): modifier = renamed_params.EstimatorImageURIRenamer() for estimator, namespaces in ESTIMATORS.items(): call = "{}(image_name='my-image:latest')".format(estimator) assert modifier.node_should_be_modified(ast_call(call)) for namespace in namespaces: call = "{}.{}(image_name='my-image:latest')".format( namespace, estimator) assert modifier.node_should_be_modified(ast_call(call))
def test_constructor_modify_node(): modifier = tfs.TensorFlowServingConstructorRenamer() node = ast_call("sagemaker.tensorflow.serving.Model()") modifier.modify_node(node) assert "sagemaker.tensorflow.TensorFlowModel()" == pasta.dump(node) node = ast_call("sagemaker.tensorflow.serving.Predictor()") modifier.modify_node(node) assert "sagemaker.tensorflow.TensorFlowPredictor()" == pasta.dump(node) node = ast_call("Predictor()") modifier.modify_node(node) assert "TensorFlowPredictor()" == pasta.dump(node)
def test_arg_value(): call = ast_call("MXNet(framework_version='1.6.0')") assert "1.6.0" == parsing.arg_value(call, "framework_version") call = ast_call("MXNet(framework_version=mxnet_version)") assert "mxnet_version" == parsing.arg_value(call, "framework_version") call = ast_call("MXNet(instance_count=1)") assert 1 == parsing.arg_value(call, "instance_count") call = ast_call("MXNet(enable_network_isolation=True)") assert parsing.arg_value(call, "enable_network_isolation") is True call = ast_call("MXNet(source_dir=None)") assert parsing.arg_value(call, "source_dir") is None
def test_modify_node(): node = ast_call("TensorFlow(distributions={'parameter_server': {'enabled': True}})") modifier = renamed_params.DistributionParameterRenamer() modifier.modify_node(node) expected = "TensorFlow(distribution={'parameter_server': {'enabled': True}})" assert expected == pasta.dump(node)
def test_arg_value_absent_keyword(): code = "MXNet(entry_point='run')" with pytest.raises(KeyError) as e: parsing.arg_value(ast_call(code), "framework_version") assert "arg 'framework_version' not found in call: {}".format(code) in str(e.value)
def test_estimator_modify_node(): node = ast_call("TensorFlow(image_name=my_image)") modifier = renamed_params.EstimatorImageURIRenamer() modifier.modify_node(node) expected = "TensorFlow(image_uri=my_image)" assert expected == pasta.dump(node)
def test_model_modify_node(): node = ast_call("TensorFlowModel(image=my_image)") modifier = renamed_params.ModelImageURIRenamer() modifier.modify_node(node) expected = "TensorFlowModel(image_uri=my_image)" assert expected == pasta.dump(node)
def test_modify_node(): node = ast_call("estimator.create_model(image=my_image)") modifier = renamed_params.EstimatorCreateModelImageURIRenamer() modifier.modify_node(node) expected = "estimator.create_model(image_uri=my_image)" assert expected == pasta.dump(node)
def test_modify_node_set_model_dir_and_image_name(retrieve_image_uri, boto_session): boto_session.return_value.region_name = REGION_NAME tf_constructors = ( "TensorFlow()", "TensorFlow(script_mode=False)", "TensorFlow(model_dir='s3//bucket/model')", ) modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() for constructor in tf_constructors: node = ast_call(constructor) modifier.modify_node(node) assert "TensorFlow(image_uri='{}', model_dir=False)".format( IMAGE_URI) == pasta.dump(node) retrieve_image_uri.assert_called_with( "tensorflow", REGION_NAME, instance_type="ml.m4.xlarge", version="1.11.0", py_version="py2", image_scope="training", )
def test_node_should_be_modified(): modifier = training_params.TrainPrefixRemover() for estimator in _estimators(): for param in PARAMS_WITH_VALUES: call = ast_call("{}({})".format(estimator, param)) assert modifier.node_should_be_modified(call)
def test_node_should_be_modified_tf_constructor_script_mode(): tf_script_mode_constructors = ( "TensorFlow(script_mode=True)", "TensorFlow(py_version='py3')", "TensorFlow(py_version='py37')", "TensorFlow(py_version='py3', script_mode=False)", "TensorFlow(py_version=py_version, script_mode=False)", "TensorFlow(py_version='py3', script_mode=script_mode)", "sagemaker.tensorflow.TensorFlow(script_mode=True)", "sagemaker.tensorflow.TensorFlow(py_version='py3')", "sagemaker.tensorflow.TensorFlow(py_version='py37')", "sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)", "sagemaker.tensorflow.TensorFlow(py_version=py_version, script_mode=False)", "sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=script_mode)", "sagemaker.tensorflow.estimator.TensorFlow(script_mode=True)", "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3')", "sagemaker.tensorflow.estimator.TensorFlow(py_version='py37')", "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=False)", "sagemaker.tensorflow.estimator.TensorFlow(py_version=py_version, script_mode=False)", "sagemaker.tensorflow.estimator.TensorFlow(py_version='py3', script_mode=script_mode)", ) modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader() for constructor in tf_script_mode_constructors: node = ast_call(constructor) assert modifier.node_should_be_modified(node) is False
def test_node_should_be_modified_no_session(): modifier = renamed_params.S3SessionRenamer() for func in FUNCTIONS: for namespace in NAMESPACES: call = ast_call("{}{}()".format(namespace, func)) assert not modifier.node_should_be_modified(call)
def test_modify_node(): node = ast_call("S3Downloader.download(session=sess)") modifier = renamed_params.S3SessionRenamer() modifier.modify_node(node) expected = "S3Downloader.download(sagemaker_session=sess)" assert expected == pasta.dump(node)
def test_constructor_modify_node(constructors, modified_constructors): modifier = training_input.ShuffleConfigModuleRenamer() for before, expected in zip(constructors, modified_constructors): node = ast_call(before) modifier.modify_node(node) assert expected == pasta.dump(node)
def test_arg_order_modify_node(): model_config_calls = ( ("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"), ( "model_config('ml.m4.xlarge', 'my-model')", "model_config('my-model', instance_type='ml.m4.xlarge')", ), ( "model_config('ml.m4.xlarge', model='my-model')", "model_config(instance_type='ml.m4.xlarge', model='my-model')", ), ( "model_config_from_estimator(instance_type, estimator, task_id, task_type)", "model_config_from_estimator(estimator, task_id, task_type, instance_type=instance_type)", ), ( "model_config_from_estimator(instance_type, estimator, task_id=task_id, task_type=task_type)", "model_config_from_estimator(estimator, instance_type=instance_type, task_id=task_id, task_type=task_type)", ), ) modifier = airflow.ModelConfigArgModifier() for call, expected in model_config_calls: node = ast_call(call) modifier.modify_node(node) assert expected == pasta.dump(node)
def test_node_should_be_modified_fit_without_tensorboard(): fit_calls = ("estimator.fit()", "tensorflow.fit()") modifier = tf_legacy_mode.TensorBoardParameterRemover() for call in fit_calls: node = ast_call(call) assert modifier.node_should_be_modified(node) is False
def _test_node_should_be_modified(ctrs, should_modify=True): modifier = framework_version.FrameworkVersionEnforcer() for ctr in ctrs: node = ast_call(ctr) if should_modify: assert modifier.node_should_be_modified(node), "{} wasn't modified.".format(ctr) else: assert not modifier.node_should_be_modified(node), "{} was modified.".format(ctr)
def test_create_endpoint_modify_node(): modifier = renamed_params.SessionCreateEndpointImageURIRenamer() for template in CREATE_ENDPOINT_TEMPLATES: call = ast_call(template.format("deployment_image=my_image")) modifier.modify_node(call) expected = template.format("image_uri=my_image") assert expected == pasta.dump(call)
def test_create_model_modify_node(): modifier = renamed_params.SessionCreateModelImageURIRenamer() for template in CREATE_MODEL_TEMPLATES: call = ast_call(template.format("primary_container_image=my_image")) modifier.modify_node(call) expected = template.format("image_uri=my_image") assert expected == pasta.dump(call)
def test_arg_from_keywords(): kw_name = "framework_version" kw_value = "1.6.0" call = ast_call("MXNet({}='{}', py_version='py3', entry_point='run')".format(kw_name, kw_value)) returned_kw = parsing.arg_from_keywords(call, kw_name) assert kw_name == returned_kw.arg assert kw_value == returned_kw.value.s
def test_modify_node(): modifier = training_params.TrainPrefixRemover() for params in _parameter_combinations(): node = ast_call("Estimator({})".format(params)) modifier.modify_node(node) expected = "Estimator({})".format(params).replace("train_", "") assert expected == pasta.dump(node)
def test_constructor_modify_node(): modifier = predictors.PredictorConstructorRefactor() node = ast_call("sagemaker.RealTimePredictor(endpoint='a')") modifier.modify_node(node) assert "sagemaker.Predictor(endpoint_name='a')" == pasta.dump(node) node = ast_call("RealTimePredictor(endpoint='a')") modifier.modify_node(node) assert "Predictor(endpoint_name='a')" == pasta.dump(node) node = ast_call("sagemaker.amazon.kmeans.KMeansPredictor(endpoint='a')") modifier.modify_node(node) assert "sagemaker.amazon.kmeans.KMeansPredictor(endpoint_name='a')" == pasta.dump( node) node = ast_call("KMeansPredictor(endpoint='a')") modifier.modify_node(node) assert "KMeansPredictor(endpoint_name='a')" == pasta.dump(node)
def test_modify_node(): fit_calls = ( "estimator.fit(run_tensorboard_locally=True)", "estimator.fit(run_tensorboard_locally=False)", ) modifier = tf_legacy_mode.TensorBoardParameterRemover() for call in fit_calls: node = ast_call(call) modifier.modify_node(node) assert "estimator.fit()" == pasta.dump(node)
def test_node_should_be_modified_tf_constructor_script_mode(): tf_script_mode_constructors = ( "TensorFlow(script_mode=True)", "sagemaker.tensorflow.TensorFlow(script_mode=True)", ) modifier = deprecated_params.TensorFlowScriptModeParameterRemover() for constructor in tf_script_mode_constructors: node = ast_call(constructor) assert modifier.node_should_be_modified(node) is True
def test_matches_any(): name_to_namespaces_dict = { "KMeansPredictor": ("sagemaker", "sagemaker.amazon.kmeans"), "Predictor": ("sagemaker.tensorflow.serving",), } matches = ( "KMeansPredictor()", "sagemaker.KMeansPredictor()", "sagemaker.amazon.kmeans.KMeansPredictor()", "Predictor()", "sagemaker.tensorflow.serving.Predictor()", ) for call in matches: assert matching.matches_any(ast_call(call), name_to_namespaces_dict) non_matches = ("MXNet()", "sagemaker.mxnet.MXNet()") for call in non_matches: assert not matching.matches_any(ast_call(call), name_to_namespaces_dict)