def test_get_task_definition_arn_provided_task_definition_arn(): run_config = ECSRun(task_definition_arn="my-taskdef-arn") flow_run = GraphQLResult({"flow": GraphQLResult({"id": "flow-id", "version": 1})}) agent = ECSAgent() res = agent.get_task_definition_arn(flow_run, run_config) assert res == "my-taskdef-arn"
def test_boto_kwargs(): # Defaults to loaded from environment agent = ECSAgent() keys = [ "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "region_name", ] for k in keys: assert agent.boto_kwargs[k] is None assert agent.boto_kwargs["config"].retries == {"mode": "standard"} # Explicit parametes are passed on kwargs = dict(zip(keys, "abcd")) agent = ECSAgent( botocore_config={"retries": { "mode": "adaptive", "max_attempts": 2 }}, **kwargs) for k, v in kwargs.items(): assert agent.boto_kwargs[k] == v assert agent.boto_kwargs["config"].retries == { "mode": "adaptive", "max_attempts": 2, }
def test_get_task_definition_arn(aws, kind): if kind == "exists": aws.resourcegroupstaggingapi.get_resources.return_value = { "ResourceTagMappingList": [{"ResourceARN": "my-taskdef-arn"}] } expected = "my-taskdef-arn" elif kind == "missing": aws.resourcegroupstaggingapi.get_resources.return_value = { "ResourceTagMappingList": [] } expected = None else: from botocore.exceptions import ClientError aws.resourcegroupstaggingapi.get_resources.side_effect = ClientError( {}, "GetResources" ) expected = None run_config = ECSRun() flow_run = GraphQLResult({"flow": GraphQLResult({"id": "flow-id", "version": 1})}) agent = ECSAgent() res = agent.get_task_definition_arn(flow_run, run_config) assert res == expected kwargs = aws.resourcegroupstaggingapi.get_resources.call_args[1] assert sorted(kwargs["TagFilters"], key=lambda x: x["Key"]) == [ {"Key": "prefect:flow-id", "Values": ["flow-id"]}, {"Key": "prefect:flow-version", "Values": ["1"]}, ] assert kwargs["ResourceTypeFilters"] == ["ecs:task-definition"]
def test_boto_kwargs(monkeypatch): # Defaults to loaded from environment agent = ECSAgent() keys = [ "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "region_name", ] for k in keys: assert agent.boto_kwargs[k] is None assert agent.boto_kwargs["config"].retries == {"mode": "standard"} # Explicit parametes are passed on kwargs = dict(zip(keys, "abcd")) agent = ECSAgent( botocore_config={"retries": { "mode": "adaptive", "max_attempts": 2 }}, **kwargs) for k, v in kwargs.items(): assert agent.boto_kwargs[k] == v assert agent.boto_kwargs["config"].retries == { "mode": "adaptive", "max_attempts": 2, } # Does not set 'standard' if env variable is set monkeypatch.setenv("AWS_RETRY_MODE", "adaptive") agent = ECSAgent() assert (agent.boto_kwargs["config"].retries or {}).get("mode") is None
def get_run_task_kwargs(self, run_config, **kwargs): agent = ECSAgent(**kwargs) flow_run = GraphQLResult({ "flow": GraphQLResult({ "storage": Local().serialize(), "run_config": run_config.serialize(), "id": "flow-id", "version": 1, "name": "Test Flow", "core_version": "0.13.0", }), "id": "flow-run-id", }) return agent.get_run_task_kwargs(flow_run, run_config)
def test_task_definition_path_local(self, tmpdir): task_definition = {"networkMode": "awsvpc", "cpu": 2048, "memory": 4096} path = str(tmpdir.join("task.yaml")) with open(path, "w") as f: yaml.safe_dump(task_definition, f) agent = ECSAgent(task_definition_path=path) assert agent.task_definition == task_definition
def test_agent_defaults(default_task_definition): agent = ECSAgent() assert agent.agent_config_id is None assert set(agent.labels) == set() assert agent.name == "agent" assert agent.cluster is None assert agent.launch_type == "FARGATE" assert agent.task_role_arn is None
def test_run_task_kwargs_path_local(self, tmpdir): run_task_kwargs = {"overrides": {"taskRoleArn": "my-task-role"}} path = str(tmpdir.join("kwargs.yaml")) with open(path, "w") as f: yaml.safe_dump(run_task_kwargs, f) agent = ECSAgent(launch_type="EC2", run_task_kwargs_path=path) assert agent.run_task_kwargs == run_task_kwargs
def generate_task_definition(self, run_config, storage=None, **kwargs): if storage is None: storage = Local() agent = ECSAgent(**kwargs) flow_run = GraphQLResult({ "flow": GraphQLResult({ "storage": storage.serialize(), "run_config": run_config.serialize(), "id": "flow-id", "version": 1, "name": "Test Flow", "core_version": "0.13.0", }), "id": "flow-run-id", }) return agent.generate_task_definition(flow_run, run_config)
def start( token, api, agent_config_id, name, label, env, max_polls, agent_address, no_cloud_logs, log_level, cluster, launch_type, task_role_arn, task_definition, run_task_kwargs, ): """Start an ECS agent""" from prefect.agent.ecs.agent import ECSAgent labels = sorted(set(label)) env_vars = dict(e.split("=", 2) for e in env) tmp_config = { "cloud.agent.auth_token": token or config.cloud.agent.auth_token, "cloud.agent.level": log_level or config.cloud.agent.level, "cloud.api": api or config.cloud.api, } with set_temporary_config(tmp_config): agent = ECSAgent( agent_config_id=agent_config_id, name=name, labels=labels, env_vars=env_vars, max_polls=max_polls, agent_address=agent_address, no_cloud_logs=no_cloud_logs, cluster=cluster, launch_type=launch_type, task_role_arn=task_role_arn, task_definition_path=task_definition, run_task_kwargs_path=run_task_kwargs, ) agent.start()
def test_task_definition_path_remote(self, monkeypatch): task_definition = {"networkMode": "awsvpc", "cpu": 2048, "memory": 4096} data = yaml.safe_dump(task_definition) mock = MagicMock(wraps=read_bytes_from_path, return_value=data) monkeypatch.setattr("prefect.agent.ecs.agent.read_bytes_from_path", mock) agent = ECSAgent(task_definition_path="s3://bucket/test.yaml") assert agent.task_definition == task_definition assert mock.call_args[0] == ("s3://bucket/test.yaml",)
def test_run_task_kwargs_path_remote(self, monkeypatch): run_task_kwargs = {"overrides": {"taskRoleArn": "my-task-role"}} data = yaml.safe_dump(run_task_kwargs) s3_path = "s3://bucket/kwargs.yaml" def mock(path): return data if path == s3_path else read_bytes_from_path(path) monkeypatch.setattr("prefect.agent.ecs.agent.read_bytes_from_path", mock) agent = ECSAgent(launch_type="EC2", run_task_kwargs_path=s3_path) assert agent.run_task_kwargs == run_task_kwargs
def test_infer_network_configuration_not_called_if_configured(self, aws, tmpdir): run_task_kwargs = { "networkConfiguration": {"awsvpcConfiguration": {"subnets": ["one", "two"]}} } path = str(tmpdir.join("kwargs.yaml")) with open(path, "w") as f: yaml.safe_dump(run_task_kwargs, f) agent = ECSAgent(run_task_kwargs_path=path) assert agent.run_task_kwargs == run_task_kwargs assert not aws.ec2.mock_calls
def ecs_fargate_agent_definition(mocker): mocker.patch( "prefect_setup.prefect_agent.start_agent._get_ecs_fargate_agent_definition", return_value=ECSAgent( region_name="ap-southeast-2", cluster="unit-test-cluster", labels=["test_dataflow_automation"], run_task_kwargs_path= "tests/unit/prefect_setup/prefect_agent/agent_conf_mock.yaml", ), ) return start_agent.get_agent_definition("ecs_fargate")
def test_run_task_kwargs_path_read_errors(self, tmpdir): with pytest.raises(Exception): ECSAgent(run_task_kwargs_path=str(tmpdir.join("missing.yaml")))
def test_task_definition_path_read_errors(self, tmpdir): with pytest.raises(Exception): ECSAgent(task_definition_path=str(tmpdir.join("missing.yaml")))
def test_run_task_kwargs_path_default(self): agent = ECSAgent(launch_type="EC2") assert agent.run_task_kwargs == {}
def test_task_definition_path_default(self, default_task_definition): agent = ECSAgent() assert agent.task_definition == default_task_definition