def test_image_arg_modify_node():
    model_config_calls = (
        ("model_config(image='image:latest')", "model_config(image_uri='image:latest')"),
        (
            "model_config_from_estimator(image=my_image)",
            "model_config_from_estimator(image_uri=my_image)",
        ),
    )

    modifier = airflow.ModelConfigImageURIRenamer()

    for call, expected in model_config_calls:
        node = ast_call(call)
        modifier.modify_node(node)
        assert expected == pasta.dump(node)
Exemplo n.º 2
0
def test_image_arg_node_should_be_modified_model_config_without_arg():
    modifier = airflow.ModelConfigImageURIRenamer()

    for template in MODEL_CONFIG_CALL_TEMPLATES:
        node = ast_call(template.format(""))
        assert modifier.node_should_be_modified(node) is False
Exemplo n.º 3
0
def test_image_arg_node_should_be_modified_random_function_call():
    node = ast_call(
        "sagemaker.workflow.airflow.prepare_framework_container_def()")
    modifier = airflow.ModelConfigImageURIRenamer()
    assert modifier.node_should_be_modified(node) is False