コード例 #1
0
 async def _create_worker_task_definition_arn(self):
     response = await self._clients["ecs"].register_task_definition(
         family="{}-{}".format(self.cluster_name, "worker"),
         taskRoleArn=self._task_role_arn,
         executionRoleArn=self._execution_role_arn,
         networkMode="awsvpc",
         containerDefinitions=[{
             "name":
             "dask-worker",
             "image":
             self.image,
             "cpu":
             self._worker_cpu,
             "memory":
             self._worker_mem,
             "memoryReservation":
             self._worker_mem,
             "essential":
             True,
             "environment":
             dict_to_aws(self.environment, key_string="name"),
             "command": [
                 "dask-worker",
                 "--nthreads",
                 "{}".format(int(self._worker_cpu / 1024)),
                 "--memory-limit",
                 "{}GB".format(int(self._worker_mem / 1024)),
                 "--death-timeout",
                 "60",
             ],
             "logConfiguration": {
                 "logDriver": "awslogs",
                 "options": {
                     "awslogs-region":
                     self._clients["ecs"].meta.region_name,
                     "awslogs-group": self.cloudwatch_logs_group,
                     "awslogs-stream-prefix":
                     self._cloudwatch_logs_stream_prefix,
                     "awslogs-create-group": "true",
                 },
             },
         }],
         volumes=[],
         requiresCompatibilities=["FARGATE"] if self._fargate else [],
         cpu=str(self._worker_cpu),
         memory=str(self._worker_mem),
         tags=dict_to_aws(self.tags),
     )
     weakref.finalize(self, self.sync,
                      self._delete_worker_task_definition_arn)
     return response["taskDefinition"]["taskDefinitionArn"]
コード例 #2
0
    async def _create_task_role(self):
        response = await self._clients["iam"].create_role(
            RoleName=self._task_role_name,
            AssumeRolePolicyDocument="""{
            "Version": "2012-10-17",
            "Statement": [
                {
                "Effect": "Allow",
                "Principal": {
                    "Service": "ecs-tasks.amazonaws.com"
                },
                "Action": "sts:AssumeRole"
                }
            ]
            }""",
            Description="A role for dask tasks to use when executing",
            Tags=dict_to_aws(self.tags, upper=True),
        )

        for policy in self._task_role_policies:
            await self._clients["iam"].attach_role_policy(
                RoleName=self._task_role_name, PolicyArn=policy)

        weakref.finalize(self, self.sync, self._delete_role,
                         self._task_role_name)
        return response["Role"]["Arn"]
コード例 #3
0
 async def _create_execution_role(self):
     response = await self._clients["iam"].create_role(
         RoleName=self._execution_role_name,
         AssumeRolePolicyDocument="""{
             "Version": "2012-10-17",
             "Statement": [
                 {
                 "Effect": "Allow",
                 "Principal": {
                     "Service": "ecs-tasks.amazonaws.com"
                 },
                 "Action": "sts:AssumeRole"
                 }
             ]
             }""",
         Description="A role for ECS to use when executing",
         Tags=dict_to_aws(self.tags, upper=True),
     )
     await self._clients["iam"].attach_role_policy(
         RoleName=self._execution_role_name,
         PolicyArn=
         "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryReadOnly",
     )
     await self._clients["iam"].attach_role_policy(
         RoleName=self._execution_role_name,
         PolicyArn="arn:aws:iam::aws:policy/CloudWatchLogsFullAccess",
     )
     await self._clients["iam"].attach_role_policy(
         RoleName=self._execution_role_name,
         PolicyArn=
         "arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceRole",
     )
     weakref.finalize(self, self.sync, self._delete_role,
                      self._execution_role_name)
     return response["Role"]["Arn"]
コード例 #4
0
    async def start(self):
        timeout = Timeout(
            60, "Unable to start %s after 60 seconds" % self.task_type)
        while timeout.run():
            try:
                kwargs = (
                    {
                        "tags": dict_to_aws(self.tags)
                    } if await self._is_long_arn_format_enabled() else {}
                )  # Tags are only supported if you opt into long arn format so we need to check for that
                [self.task] = (await self._clients["ecs"].run_task(
                    cluster=self.cluster_arn,
                    taskDefinition=self.task_definition_arn,
                    overrides=self.overrides,
                    count=1,
                    launchType="FARGATE" if self.fargate else "EC2",
                    networkConfiguration={
                        "awsvpcConfiguration": {
                            "subnets":
                            self._vpc_subnets,
                            "securityGroups":
                            self._security_groups,
                            "assignPublicIp":
                            "ENABLED" if self._use_public_ip else "DISABLED",
                        }
                    },
                    **kwargs))["tasks"]
                break
            except Exception as e:
                timeout.set_exception(e)
                await asyncio.sleep(1)

        self.task_arn = self.task["taskArn"]
        while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]:
            await asyncio.sleep(1)
            await self._update_task()
        if not await self._task_is_running():
            raise RuntimeError("%s failed to start" % type(self).__name__)
        [eni] = [
            attachment for attachment in self.task["attachments"]
            if attachment["type"] == "ElasticNetworkInterface"
        ]
        [network_interface_id] = [
            detail["value"] for detail in eni["details"]
            if detail["name"] == "networkInterfaceId"
        ]
        eni = await self._clients["ec2"].describe_network_interfaces(
            NetworkInterfaceIds=[network_interface_id])
        [interface] = eni["NetworkInterfaces"]
        self.public_ip = interface["Association"]["PublicIp"]
        self.private_ip = interface["PrivateIpAddresses"][0][
            "PrivateIpAddress"]
        self.address = await self._get_address_from_logs()
        self.status = "running"
コード例 #5
0
 async def _create_cluster(self):
     if not self._fargate:
         raise RuntimeError(
             "You must specify a cluster when not using Fargate.")
     self.cluster_name = dask.config.expand_environment_variables(
         self._cluster_name_template)
     self.cluster_name = self.cluster_name.format(
         uuid=str(uuid.uuid4())[:10])
     response = await self._clients["ecs"].create_cluster(
         clusterName=self.cluster_name, tags=dict_to_aws(self.tags))
     weakref.finalize(self, self.sync, self._delete_cluster)
     return response["cluster"]["clusterArn"]
コード例 #6
0
 async def _create_security_groups(self):
     response = await self._clients["ec2"].create_security_group(
         Description="A security group for dask-ecs",
         GroupName=self.cluster_name,
         VpcId=self._vpc,
         DryRun=False,
     )
     await self._clients["ec2"].authorize_security_group_ingress(
         GroupId=response["GroupId"],
         IpPermissions=[
             {
                 "IpProtocol":
                 "TCP",
                 "FromPort":
                 8786,
                 "ToPort":
                 8787,
                 "IpRanges": [{
                     "CidrIp": "0.0.0.0/0",
                     "Description": "Anywhere"
                 }],
                 "Ipv6Ranges": [{
                     "CidrIpv6": "::/0",
                     "Description": "Anywhere"
                 }],
             },
             {
                 "IpProtocol": "TCP",
                 "FromPort": 0,
                 "ToPort": 65535,
                 "UserIdGroupPairs": [{
                     "GroupName": self.cluster_name
                 }],
             },
         ],
         DryRun=False,
     )
     await self._clients["ec2"].create_tags(Resources=[response["GroupId"]],
                                            Tags=dict_to_aws(self.tags,
                                                             upper=True))
     weakref.finalize(self, self.sync, self._delete_security_groups)
     return [response["GroupId"]]
コード例 #7
0
ファイル: test_helper.py プロジェクト: RPrudden/dask-cloud
def test_aws_to_dict_and_back():
    from dask_cloud.providers.aws.helper import aws_to_dict, dict_to_aws

    aws_dict = [{"key": "hello", "value": "world"}]
    aws_upper_dict = [{"Key": "hello", "Value": "world"}]
    py_dict = {"hello": "world"}

    assert dict_to_aws(py_dict) == aws_dict
    assert dict_to_aws(py_dict, upper=True) == aws_upper_dict
    assert aws_to_dict(aws_dict) == py_dict

    assert aws_to_dict(dict_to_aws(py_dict, upper=True)) == py_dict
    assert aws_to_dict(dict_to_aws(py_dict)) == py_dict
    assert dict_to_aws(aws_to_dict(aws_dict)) == aws_dict
    assert dict_to_aws(aws_to_dict(aws_upper_dict),
                       upper=True) == aws_upper_dict