def test_arg_order_modify_node(): model_config_calls = ( ("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"), ( "model_config('ml.m4.xlarge', 'my-model')", "model_config('my-model', instance_type='ml.m4.xlarge')", ), ( "model_config('ml.m4.xlarge', model='my-model')", "model_config(instance_type='ml.m4.xlarge', model='my-model')", ), ( "model_config_from_estimator(instance_type, estimator, task_id, task_type)", "model_config_from_estimator(estimator, task_id, task_type, instance_type=instance_type)", ), ( "model_config_from_estimator(instance_type, estimator, task_id=task_id, task_type=task_type)", "model_config_from_estimator(estimator, instance_type=instance_type, task_id=task_id, task_type=task_type)", ), ) modifier = airflow.ModelConfigArgModifier() for call, expected in model_config_calls: node = ast_call(call) modifier.modify_node(node) assert expected == pasta.dump(node)
def test_arg_order_node_should_be_modified_random_function_call(): node = ast_call( "sagemaker.workflow.airflow.prepare_framework_container_def()") modifier = airflow.ModelConfigArgModifier() assert modifier.node_should_be_modified(node) is False
def test_arg_order_node_should_be_modified_model_config_without_args(): modifier = airflow.ModelConfigArgModifier() for template in MODEL_CONFIG_CALL_TEMPLATES: node = ast_call(template.format("")) assert modifier.node_should_be_modified(node) is False