Example #1
0
def test_arg_order_modify_node():
    model_config_calls = (
        ("model_config(instance_type, model)",
         "model_config(model, instance_type=instance_type)"),
        (
            "model_config('ml.m4.xlarge', 'my-model')",
            "model_config('my-model', instance_type='ml.m4.xlarge')",
        ),
        (
            "model_config('ml.m4.xlarge', model='my-model')",
            "model_config(instance_type='ml.m4.xlarge', model='my-model')",
        ),
        (
            "model_config_from_estimator(instance_type, estimator, task_id, task_type)",
            "model_config_from_estimator(estimator, task_id, task_type, instance_type=instance_type)",
        ),
        (
            "model_config_from_estimator(instance_type, estimator, task_id=task_id, task_type=task_type)",
            "model_config_from_estimator(estimator, instance_type=instance_type, task_id=task_id, task_type=task_type)",
        ),
    )

    modifier = airflow.ModelConfigArgModifier()

    for call, expected in model_config_calls:
        node = ast_call(call)
        modifier.modify_node(node)
        assert expected == pasta.dump(node)
Example #2
0
def test_arg_order_node_should_be_modified_random_function_call():
    node = ast_call(
        "sagemaker.workflow.airflow.prepare_framework_container_def()")
    modifier = airflow.ModelConfigArgModifier()
    assert modifier.node_should_be_modified(node) is False
Example #3
0
def test_arg_order_node_should_be_modified_model_config_without_args():
    modifier = airflow.ModelConfigArgModifier()

    for template in MODEL_CONFIG_CALL_TEMPLATES:
        node = ast_call(template.format(""))
        assert modifier.node_should_be_modified(node) is False