Beispiel #1
0
def test_get_task_definition_arn_provided_task_definition_arn():
    run_config = ECSRun(task_definition_arn="my-taskdef-arn")
    flow_run = GraphQLResult({"flow": GraphQLResult({"id": "flow-id", "version": 1})})
    agent = ECSAgent()

    res = agent.get_task_definition_arn(flow_run, run_config)
    assert res == "my-taskdef-arn"
Beispiel #2
0
def test_boto_kwargs():
    # Defaults to loaded from environment
    agent = ECSAgent()
    keys = [
        "aws_access_key_id",
        "aws_secret_access_key",
        "aws_session_token",
        "region_name",
    ]
    for k in keys:
        assert agent.boto_kwargs[k] is None
    assert agent.boto_kwargs["config"].retries == {"mode": "standard"}

    # Explicit parametes are passed on
    kwargs = dict(zip(keys, "abcd"))
    agent = ECSAgent(
        botocore_config={"retries": {
            "mode": "adaptive",
            "max_attempts": 2
        }},
        **kwargs)
    for k, v in kwargs.items():
        assert agent.boto_kwargs[k] == v
    assert agent.boto_kwargs["config"].retries == {
        "mode": "adaptive",
        "max_attempts": 2,
    }
Beispiel #3
0
def test_get_task_definition_arn(aws, kind):
    if kind == "exists":
        aws.resourcegroupstaggingapi.get_resources.return_value = {
            "ResourceTagMappingList": [{"ResourceARN": "my-taskdef-arn"}]
        }
        expected = "my-taskdef-arn"
    elif kind == "missing":
        aws.resourcegroupstaggingapi.get_resources.return_value = {
            "ResourceTagMappingList": []
        }
        expected = None
    else:
        from botocore.exceptions import ClientError

        aws.resourcegroupstaggingapi.get_resources.side_effect = ClientError(
            {}, "GetResources"
        )
        expected = None

    run_config = ECSRun()
    flow_run = GraphQLResult({"flow": GraphQLResult({"id": "flow-id", "version": 1})})
    agent = ECSAgent()

    res = agent.get_task_definition_arn(flow_run, run_config)
    assert res == expected
    kwargs = aws.resourcegroupstaggingapi.get_resources.call_args[1]
    assert sorted(kwargs["TagFilters"], key=lambda x: x["Key"]) == [
        {"Key": "prefect:flow-id", "Values": ["flow-id"]},
        {"Key": "prefect:flow-version", "Values": ["1"]},
    ]
    assert kwargs["ResourceTypeFilters"] == ["ecs:task-definition"]
Beispiel #4
0
def test_boto_kwargs(monkeypatch):
    # Defaults to loaded from environment
    agent = ECSAgent()
    keys = [
        "aws_access_key_id",
        "aws_secret_access_key",
        "aws_session_token",
        "region_name",
    ]
    for k in keys:
        assert agent.boto_kwargs[k] is None
    assert agent.boto_kwargs["config"].retries == {"mode": "standard"}

    # Explicit parametes are passed on
    kwargs = dict(zip(keys, "abcd"))
    agent = ECSAgent(
        botocore_config={"retries": {
            "mode": "adaptive",
            "max_attempts": 2
        }},
        **kwargs)
    for k, v in kwargs.items():
        assert agent.boto_kwargs[k] == v
    assert agent.boto_kwargs["config"].retries == {
        "mode": "adaptive",
        "max_attempts": 2,
    }

    # Does not set 'standard' if env variable is set
    monkeypatch.setenv("AWS_RETRY_MODE", "adaptive")
    agent = ECSAgent()
    assert (agent.boto_kwargs["config"].retries or {}).get("mode") is None
Beispiel #5
0
 def get_run_task_kwargs(self, run_config, **kwargs):
     agent = ECSAgent(**kwargs)
     flow_run = GraphQLResult({
         "flow":
         GraphQLResult({
             "storage": Local().serialize(),
             "run_config": run_config.serialize(),
             "id": "flow-id",
             "version": 1,
             "name": "Test Flow",
             "core_version": "0.13.0",
         }),
         "id":
         "flow-run-id",
     })
     return agent.get_run_task_kwargs(flow_run, run_config)
Beispiel #6
0
    def test_task_definition_path_local(self, tmpdir):
        task_definition = {"networkMode": "awsvpc", "cpu": 2048, "memory": 4096}
        path = str(tmpdir.join("task.yaml"))
        with open(path, "w") as f:
            yaml.safe_dump(task_definition, f)

        agent = ECSAgent(task_definition_path=path)
        assert agent.task_definition == task_definition
Beispiel #7
0
def test_agent_defaults(default_task_definition):
    agent = ECSAgent()
    assert agent.agent_config_id is None
    assert set(agent.labels) == set()
    assert agent.name == "agent"
    assert agent.cluster is None
    assert agent.launch_type == "FARGATE"
    assert agent.task_role_arn is None
Beispiel #8
0
    def test_run_task_kwargs_path_local(self, tmpdir):
        run_task_kwargs = {"overrides": {"taskRoleArn": "my-task-role"}}
        path = str(tmpdir.join("kwargs.yaml"))
        with open(path, "w") as f:
            yaml.safe_dump(run_task_kwargs, f)

        agent = ECSAgent(launch_type="EC2", run_task_kwargs_path=path)
        assert agent.run_task_kwargs == run_task_kwargs
Beispiel #9
0
 def generate_task_definition(self, run_config, storage=None, **kwargs):
     if storage is None:
         storage = Local()
     agent = ECSAgent(**kwargs)
     flow_run = GraphQLResult({
         "flow":
         GraphQLResult({
             "storage": storage.serialize(),
             "run_config": run_config.serialize(),
             "id": "flow-id",
             "version": 1,
             "name": "Test Flow",
             "core_version": "0.13.0",
         }),
         "id":
         "flow-run-id",
     })
     return agent.generate_task_definition(flow_run, run_config)
Beispiel #10
0
def start(
    token,
    api,
    agent_config_id,
    name,
    label,
    env,
    max_polls,
    agent_address,
    no_cloud_logs,
    log_level,
    cluster,
    launch_type,
    task_role_arn,
    task_definition,
    run_task_kwargs,
):
    """Start an ECS agent"""
    from prefect.agent.ecs.agent import ECSAgent

    labels = sorted(set(label))
    env_vars = dict(e.split("=", 2) for e in env)

    tmp_config = {
        "cloud.agent.auth_token": token or config.cloud.agent.auth_token,
        "cloud.agent.level": log_level or config.cloud.agent.level,
        "cloud.api": api or config.cloud.api,
    }
    with set_temporary_config(tmp_config):
        agent = ECSAgent(
            agent_config_id=agent_config_id,
            name=name,
            labels=labels,
            env_vars=env_vars,
            max_polls=max_polls,
            agent_address=agent_address,
            no_cloud_logs=no_cloud_logs,
            cluster=cluster,
            launch_type=launch_type,
            task_role_arn=task_role_arn,
            task_definition_path=task_definition,
            run_task_kwargs_path=run_task_kwargs,
        )
        agent.start()
Beispiel #11
0
    def test_task_definition_path_remote(self, monkeypatch):
        task_definition = {"networkMode": "awsvpc", "cpu": 2048, "memory": 4096}
        data = yaml.safe_dump(task_definition)

        mock = MagicMock(wraps=read_bytes_from_path, return_value=data)
        monkeypatch.setattr("prefect.agent.ecs.agent.read_bytes_from_path", mock)

        agent = ECSAgent(task_definition_path="s3://bucket/test.yaml")
        assert agent.task_definition == task_definition
        assert mock.call_args[0] == ("s3://bucket/test.yaml",)
Beispiel #12
0
    def test_run_task_kwargs_path_remote(self, monkeypatch):
        run_task_kwargs = {"overrides": {"taskRoleArn": "my-task-role"}}
        data = yaml.safe_dump(run_task_kwargs)
        s3_path = "s3://bucket/kwargs.yaml"

        def mock(path):
            return data if path == s3_path else read_bytes_from_path(path)

        monkeypatch.setattr("prefect.agent.ecs.agent.read_bytes_from_path", mock)
        agent = ECSAgent(launch_type="EC2", run_task_kwargs_path=s3_path)
        assert agent.run_task_kwargs == run_task_kwargs
Beispiel #13
0
    def test_infer_network_configuration_not_called_if_configured(self, aws, tmpdir):
        run_task_kwargs = {
            "networkConfiguration": {"awsvpcConfiguration": {"subnets": ["one", "two"]}}
        }
        path = str(tmpdir.join("kwargs.yaml"))
        with open(path, "w") as f:
            yaml.safe_dump(run_task_kwargs, f)

        agent = ECSAgent(run_task_kwargs_path=path)
        assert agent.run_task_kwargs == run_task_kwargs
        assert not aws.ec2.mock_calls
Beispiel #14
0
def ecs_fargate_agent_definition(mocker):
    mocker.patch(
        "prefect_setup.prefect_agent.start_agent._get_ecs_fargate_agent_definition",
        return_value=ECSAgent(
            region_name="ap-southeast-2",
            cluster="unit-test-cluster",
            labels=["test_dataflow_automation"],
            run_task_kwargs_path=
            "tests/unit/prefect_setup/prefect_agent/agent_conf_mock.yaml",
        ),
    )
    return start_agent.get_agent_definition("ecs_fargate")
Beispiel #15
0
 def test_run_task_kwargs_path_read_errors(self, tmpdir):
     with pytest.raises(Exception):
         ECSAgent(run_task_kwargs_path=str(tmpdir.join("missing.yaml")))
Beispiel #16
0
 def test_task_definition_path_read_errors(self, tmpdir):
     with pytest.raises(Exception):
         ECSAgent(task_definition_path=str(tmpdir.join("missing.yaml")))
Beispiel #17
0
 def test_run_task_kwargs_path_default(self):
     agent = ECSAgent(launch_type="EC2")
     assert agent.run_task_kwargs == {}
Beispiel #18
0
 def test_task_definition_path_default(self, default_task_definition):
     agent = ECSAgent()
     assert agent.task_definition == default_task_definition