async def _create_execution_role(self): async with self._client("iam") as iam: response = await iam.create_role( RoleName=self._execution_role_name, AssumeRolePolicyDocument="""{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Principal": { "Service": "ecs-tasks.amazonaws.com" }, "Action": "sts:AssumeRole" } ] }""", Description="A role for ECS to use when executing", Tags=dict_to_aws(self.tags, upper=True), ) await iam.attach_role_policy( RoleName=self._execution_role_name, PolicyArn= "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryReadOnly", ) await iam.attach_role_policy( RoleName=self._execution_role_name, PolicyArn="arn:aws:iam::aws:policy/CloudWatchLogsFullAccess", ) await iam.attach_role_policy( RoleName=self._execution_role_name, PolicyArn= "arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceRole", ) weakref.finalize(self, self.sync, self._delete_role, self._execution_role_name) return response["Role"]["Arn"]
async def start(self): timeout = Timeout(60, "Unable to start %s after 60 seconds" % self.task_type) while timeout.run(): try: kwargs = ( {"tags": dict_to_aws(self.tags)} if await self._is_long_arn_format_enabled() else {} ) # Tags are only supported if you opt into long arn format so we need to check for that response = await self._clients["ecs"].run_task( cluster=self.cluster_arn, taskDefinition=self.task_definition_arn, overrides={ "containerOverrides": [ { "name": "dask-{}".format(self.task_type), "environment": dict_to_aws( self.environment, key_string="name" ), **self._overrides, } ] }, count=1, launchType="FARGATE" if self.fargate else "EC2", networkConfiguration={ "awsvpcConfiguration": { "subnets": self._vpc_subnets, "securityGroups": self._security_groups, "assignPublicIp": "ENABLED" if self._use_public_ip else "DISABLED", } }, **kwargs ) if not response.get("tasks"): raise RuntimeError(response) # print entire response [self.task] = response["tasks"] break except Exception as e: timeout.set_exception(e) await asyncio.sleep(1) self.task_arn = self.task["taskArn"] while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]: await asyncio.sleep(1) await self._update_task() if not await self._task_is_running(): raise RuntimeError("%s failed to start" % type(self).__name__) [eni] = [ attachment for attachment in self.task["attachments"] if attachment["type"] == "ElasticNetworkInterface" ] [network_interface_id] = [ detail["value"] for detail in eni["details"] if detail["name"] == "networkInterfaceId" ] eni = await self._clients["ec2"].describe_network_interfaces( NetworkInterfaceIds=[network_interface_id] ) [interface] = eni["NetworkInterfaces"] if self._use_public_ip: self.public_ip = interface["Association"]["PublicIp"] self.private_ip = interface["PrivateIpAddresses"][0]["PrivateIpAddress"] await self._set_address_from_logs() self.status = "running"
async def _create_worker_task_definition_arn(self): resource_requirements = [] if self._worker_gpu: resource_requirements.append({ "type": "GPU", "value": str(self._worker_gpu) }) async with self._client("ecs") as ecs: response = await ecs.register_task_definition( family="{}-{}".format(self.cluster_name, "worker"), taskRoleArn=self._task_role_arn, executionRoleArn=self._execution_role_arn, networkMode="awsvpc", containerDefinitions=[{ "name": "dask-worker", "image": self.image, "cpu": self._worker_cpu, "memory": self._worker_mem, "memoryReservation": self._worker_mem, "resourceRequirements": resource_requirements, "essential": True, "command": [ "dask-cuda-worker" if self._worker_gpu else "dask-worker", "--nthreads", "{}".format(max(int(self._worker_cpu / 1024), 1)), "--memory-limit", "{}MB".format(int(self._worker_mem)), "--death-timeout", "60", ] + (list() if not self._worker_extra_args else self._worker_extra_args), "logConfiguration": { "logDriver": "awslogs", "options": { "awslogs-region": ecs.meta.region_name, "awslogs-group": self.cloudwatch_logs_group, "awslogs-stream-prefix": self._cloudwatch_logs_stream_prefix, "awslogs-create-group": "true", }, }, "mountPoints": self._mount_points if self._mount_points else [], }], volumes=self._volumes if self._volumes else [], requiresCompatibilities=["FARGATE"] if self._fargate_workers else [], cpu=str(self._worker_cpu), memory=str(self._worker_mem), tags=dict_to_aws(self.tags), ) weakref.finalize(self, self.sync, self._delete_worker_task_definition_arn) return response["taskDefinition"]["taskDefinitionArn"]