Exemplo n.º 1
0
 async def _create_execution_role(self):
     async with self._client("iam") as iam:
         response = await iam.create_role(
             RoleName=self._execution_role_name,
             AssumeRolePolicyDocument="""{
                 "Version": "2012-10-17",
                 "Statement": [
                     {
                     "Effect": "Allow",
                     "Principal": {
                         "Service": "ecs-tasks.amazonaws.com"
                     },
                     "Action": "sts:AssumeRole"
                     }
                 ]
                 }""",
             Description="A role for ECS to use when executing",
             Tags=dict_to_aws(self.tags, upper=True),
         )
         await iam.attach_role_policy(
             RoleName=self._execution_role_name,
             PolicyArn=
             "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryReadOnly",
         )
         await iam.attach_role_policy(
             RoleName=self._execution_role_name,
             PolicyArn="arn:aws:iam::aws:policy/CloudWatchLogsFullAccess",
         )
         await iam.attach_role_policy(
             RoleName=self._execution_role_name,
             PolicyArn=
             "arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceRole",
         )
         weakref.finalize(self, self.sync, self._delete_role,
                          self._execution_role_name)
         return response["Role"]["Arn"]
Exemplo n.º 2
0
    async def start(self):
        timeout = Timeout(60, "Unable to start %s after 60 seconds" % self.task_type)
        while timeout.run():
            try:
                kwargs = (
                    {"tags": dict_to_aws(self.tags)}
                    if await self._is_long_arn_format_enabled()
                    else {}
                )  # Tags are only supported if you opt into long arn format so we need to check for that
                response = await self._clients["ecs"].run_task(
                    cluster=self.cluster_arn,
                    taskDefinition=self.task_definition_arn,
                    overrides={
                        "containerOverrides": [
                            {
                                "name": "dask-{}".format(self.task_type),
                                "environment": dict_to_aws(
                                    self.environment, key_string="name"
                                ),
                                **self._overrides,
                            }
                        ]
                    },
                    count=1,
                    launchType="FARGATE" if self.fargate else "EC2",
                    networkConfiguration={
                        "awsvpcConfiguration": {
                            "subnets": self._vpc_subnets,
                            "securityGroups": self._security_groups,
                            "assignPublicIp": "ENABLED"
                            if self._use_public_ip
                            else "DISABLED",
                        }
                    },
                    **kwargs
                )

                if not response.get("tasks"):
                    raise RuntimeError(response)  # print entire response

                [self.task] = response["tasks"]
                break
            except Exception as e:
                timeout.set_exception(e)
                await asyncio.sleep(1)

        self.task_arn = self.task["taskArn"]
        while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]:
            await asyncio.sleep(1)
            await self._update_task()
        if not await self._task_is_running():
            raise RuntimeError("%s failed to start" % type(self).__name__)
        [eni] = [
            attachment
            for attachment in self.task["attachments"]
            if attachment["type"] == "ElasticNetworkInterface"
        ]
        [network_interface_id] = [
            detail["value"]
            for detail in eni["details"]
            if detail["name"] == "networkInterfaceId"
        ]
        eni = await self._clients["ec2"].describe_network_interfaces(
            NetworkInterfaceIds=[network_interface_id]
        )
        [interface] = eni["NetworkInterfaces"]
        if self._use_public_ip:
            self.public_ip = interface["Association"]["PublicIp"]
        self.private_ip = interface["PrivateIpAddresses"][0]["PrivateIpAddress"]
        await self._set_address_from_logs()
        self.status = "running"
Exemplo n.º 3
0
 async def _create_worker_task_definition_arn(self):
     resource_requirements = []
     if self._worker_gpu:
         resource_requirements.append({
             "type": "GPU",
             "value": str(self._worker_gpu)
         })
     async with self._client("ecs") as ecs:
         response = await ecs.register_task_definition(
             family="{}-{}".format(self.cluster_name, "worker"),
             taskRoleArn=self._task_role_arn,
             executionRoleArn=self._execution_role_arn,
             networkMode="awsvpc",
             containerDefinitions=[{
                 "name":
                 "dask-worker",
                 "image":
                 self.image,
                 "cpu":
                 self._worker_cpu,
                 "memory":
                 self._worker_mem,
                 "memoryReservation":
                 self._worker_mem,
                 "resourceRequirements":
                 resource_requirements,
                 "essential":
                 True,
                 "command": [
                     "dask-cuda-worker"
                     if self._worker_gpu else "dask-worker",
                     "--nthreads",
                     "{}".format(max(int(self._worker_cpu / 1024), 1)),
                     "--memory-limit",
                     "{}MB".format(int(self._worker_mem)),
                     "--death-timeout",
                     "60",
                 ] + (list() if not self._worker_extra_args else
                      self._worker_extra_args),
                 "logConfiguration": {
                     "logDriver": "awslogs",
                     "options": {
                         "awslogs-region": ecs.meta.region_name,
                         "awslogs-group": self.cloudwatch_logs_group,
                         "awslogs-stream-prefix":
                         self._cloudwatch_logs_stream_prefix,
                         "awslogs-create-group": "true",
                     },
                 },
                 "mountPoints":
                 self._mount_points if self._mount_points else [],
             }],
             volumes=self._volumes if self._volumes else [],
             requiresCompatibilities=["FARGATE"]
             if self._fargate_workers else [],
             cpu=str(self._worker_cpu),
             memory=str(self._worker_mem),
             tags=dict_to_aws(self.tags),
         )
     weakref.finalize(self, self.sync,
                      self._delete_worker_task_definition_arn)
     return response["taskDefinition"]["taskDefinitionArn"]