def construct_resource(self, name: str, input_type: st.Type, output_type: st.Type) -> st.Resource: """ Construct a new MachineResource for the given resource name and input/output types """ if name in self._resource_cache: return self._resource_cache[name] machine_cls = type( PulumiResourceMachine.__name__, (PulumiResourceMachine, ), {"UP": st.State("UP", input_type, output_type)}, ) resource = machine_cls(name, self) self._resource_cache[name] = resource return resource
class InstanceMachine(st.SimpleMachine, AWSMachine): """ Machine for an EC2 Instance """ service: str = "ec2" UP = st.State("UP", InstanceConfigType, InstanceType) @staticmethod async def convert_instance(instance: "Instance") -> Dict[str, Any]: out = {"id": instance.id} ( out["ami"], out["ebs_optimized"], out["instance_type"], placement, out["private_ip"], out["private_dns"], out["public_ip"], out["public_dns"], out["key_name"], security_groups, out["subnet_id"], tags, placement, cpu_options, out["ebs_optimized"], termination_protection, state_info, ) = await asyncio.gather( instance.image_id, instance.ebs_optimized, instance.instance_type, instance.placement, instance.private_ip_address, instance.private_dns_name, instance.public_ip_address, instance.public_dns_name, instance.key_name, instance.security_groups, instance.subnet_id, instance.tags, instance.placement, instance.cpu_options, instance.ebs_optimized, instance.describe_attribute(Attribute="disableApiTermination"), instance.state, ) out["vpc_security_group_ids"] = [ group["GroupId"] for group in security_groups ] out["tags"] = {tag["Key"]: tag["Value"] for tag in tags or []} out["placement_group"] = placement["GroupName"] out["availability_zone"] = placement["AvailabilityZone"] out["tenancy"] = placement["Tenancy"] out["host_id"] = placement.get("HostId") out["disable_api_termination"] = termination_protection[ "DisableApiTermination"]["Value"] out["state"] = state_info["Name"] if cpu_options: out["cpu_core_count"] = cpu_options.get("CoreCount") out["cpu_threads_per_core"] = cpu_options.get("ThreadsPerCore") else: out["cpu_core_count"] = None out["cpu_threads_per_core"] = None return out async def refresh_state(self, data: Any) -> Optional[Any]: async with self.resource_ctx() as ec2: instance = await ec2.Instance(data["id"]) try: await instance.load() except botocore.exceptions.ClientError: return None return await self.convert_instance(instance) def get_diff( self, current: st.StateSnapshot, config: st.StateConfig, session: st.TaskSession, ) -> st.Diff: differ = session.ns.registry.get_differ(config.state.input_type) diffconfig = differ.config() def ignore_if_new_none(old, new): if new is None: return True # Pass it along to the element-level logic return NotImplemented diffconfig.set_comparison("subnet_id", ignore_if_new_none) diffconfig.set_comparison("vpc_security_group_ids", ignore_if_new_none) diffconfig.set_comparison("availability_zone", ignore_if_new_none) diffconfig.set_comparison("cpu_core_count", ignore_if_new_none) diffconfig.set_comparison("cpu_threads_per_core", ignore_if_new_none) def none_as_empty_dict(old, new): new = {} if new is None else new old = {} if old is None else old if new == old: return True return NotImplemented diffconfig.set_comparison("tags", none_as_empty_dict) def none_as_empty_string(old, new): new = "" if new is None else new old = "" if old is None else old if new == old: return True return NotImplemented diffconfig.set_comparison("placement_group", none_as_empty_string) current_as_config = st.filter_struct(current.obj, config.type) out_diff = differ.diff(current_as_config, config.obj, session, diffconfig) return out_diff async def get_expected(self, current: st.StateSnapshot, config: st.StateConfig) -> Any: output = st.Unknown[config.state.output_type] if not current.state.null: output = current.obj replaced = st.struct_replace( config.obj, False, subnet_id=st.ifnull(config.obj.subnet_id, output.subnet_id), vpc_security_group_ids=st.ifnull(config.obj.vpc_security_group_ids, output.vpc_security_group_ids), tags=st.ifnull(config.obj.tags, {}), availability_zone=st.ifnull(config.obj.availability_zone, output.availability_zone), placement_group=st.ifnull(config.obj.placement_group, ""), ) return st.fill(replaced, config.state.output_type, output) async def create_task(self, config: InstanceConfigType) -> InstanceType: """ Create a new EC2 Instance """ async with self.resource_ctx() as ec2: kws = { "ImageId": config["ami"], "InstanceType": config["instance_type"], "KeyName": config["key_name"], "MinCount": 1, "MaxCount": 1, "DisableApiTermination": config["disable_api_termination"], "EbsOptimized": config["ebs_optimized"], } if config["vpc_security_group_ids"] is not None: kws["SecurityGroupIds"] = config["vpc_security_group_ids"] if config["subnet_id"] is not None: kws["SubnetId"] = config["subnet_id"] tags = config["tags"] or {} tags_list = [{ "Key": key, "Value": value } for key, value in tags.items()] specs = [] if tags_list: specs.append({"ResourceType": "instance", "Tags": tags_list}) kws["TagSpecifications"] = specs placement = kws["Placement"] = {"Tenancy": config["tenancy"]} if config["availability_zone"] is not None: placement["AvailabilityZone"] = config["availability_zone"] if config["placement_group"] is not None: placement["GroupName"] = config["placement_group"] if config["host_id"] is not None: placement["HostId"] = config["host_id"] if config["cpu_core_count"] is not None: opts = kws["CpuOptions"] = { "CoreCount": config["cpu_core_count"] } if config["cpu_threads_per_core"] is not None: opts["ThreadsPerCore"] = config["cpu_threads_per_core"] (instance, ) = await ec2.create_instances(**kws) # Checkpoint after creation yield await self.convert_instance(instance) await instance.wait_until_running() await instance.load() yield await self.convert_instance(instance) async def delete_task(self, current: InstanceType) -> st.EmptyType: """ Delete the EC2 Instance """ async with self.resource_ctx() as ec2: instance = await ec2.Instance(current["id"]) await instance.terminate() yield {} await instance.wait_until_terminated() def get_action(self, diff: st.Diff) -> st.ModificationAction: if not diff: return st.ModificationAction.NONE if ("ami" in diff or "instance_type" in diff or "key_name" in diff or "availability_zone" in diff or "placement_group" in diff or "tenancy" in diff or "host_id" in diff or "cpu_core_count" in diff or "cpu_threads_per_core" in diff): return st.ModificationAction.DELETE_AND_RECREATE return st.ModificationAction.MODIFY async def modify_task( self, diff: st.Diff, current: InstanceType, config: InstanceConfigType, ) -> InstanceType: """ Modify the EC2 Instance """ async with self.resource_ctx() as ec2: instance = await ec2.Instance(current["id"]) await instance.load() # This means the new value is not null if "subnet_id" in diff: kws = {"SubnetId": config["subnet_id"]} if config["vpc_security_group_ids"] is not None: kws["Groups"] = config["vpc_security_group_ids"] new_ni = await ec2.create_network_interface(**kws) current_ni_data = (await instance.network_interfaces_attribute)[0] current_ni = await ec2.NetworkInterface( current_ni_data["NetworkInterfaceId"]) await current_ni.detach() await new_ni.attach(DeviceIndex=0, InstanceId=current["id"]) await instance.load() yield await self.convert_instance(instance) elif "vpc_security_group_ids" in diff: current_ni_data = (await instance.network_interfaces_attribute)[0] current_ni = await ec2.NetworkInterface( current_ni_data["NetworkInterfaceId"]) group_ids = config["vpc_security_group_ids"] if not group_ids: await current_ni.detach() else: await current_ni.modify_attribute(Groups=group_ids) await instance.load() yield await self.convert_instance(instance) if "tags" in diff: new_tags = config["tags"] or {} remove_tags = [ key for key in current["tags"] if key not in new_tags ] if remove_tags: await instance.delete_tags(Tags=[{ "Key": key } for key in remove_tags]) set_tags = [{ "Key": key, "Value": val } for key, val in new_tags.items()] if set_tags: await instance.create_tags(Tags=set_tags) await instance.load() yield await self.convert_instance(instance) if "disable_api_termination" in diff: await instance.modify_attribute( Attribute="disableApiTermination", Value=str(config["disable_api_termination"]).lower(), ) await instance.load() yield await self.convert_instance(instance) yield await self.convert_instance(instance)
class SubnetMachine(st.SimpleMachine): """ Maching representing an AWS subnet """ UP = st.State("UP", SubnetConfigType, SubnetType) @contextlib.asynccontextmanager async def resource_ctx(self): async with aioboto3.resource("ec2") as ec2: yield ec2 @contextlib.asynccontextmanager async def client_ctx(self): async with aioboto3.client("ec2") as client: yield client @staticmethod async def convert_instance(subnet: "Subnet") -> Dict[str, Any]: out = {"id": subnet.id} ipv6_associations = [] ( out["owner_id"], out["cidr_block"], # ipv6_associations, out["map_public_ip_on_launch"], out["assign_ipv6_address_on_creation"], out["vpc_id"], ) = await asyncio.gather( subnet.owner_id, subnet.cidr_block, # subnet.ipv6_cidr_block_assocation_set, subnet.map_public_ip_on_launch, subnet.assign_ipv6_address_on_creation, subnet.vpc_id, ) if ipv6_associations: association = ipv6_associations[0] out["ipv6_association_id"] = association["AssociationId"] out["ipv6_cidr_block"] = association["Ipv6CidrBlock"] else: out["ipv6_association_id"] = None out["ipv6_cidr_block"] = None return out async def refresh_state(self, data: Any) -> Optional[Any]: async with self.resource_ctx() as ec2: instance = await ec2.Subnet(data["id"]) try: await instance.load() except botocore.exceptions.ClientError: return None return await self.convert_instance(instance) async def create_task(self, config: SubnetConfigType) -> SubnetType: """ Create a new subnet """ async with self.resource_ctx() as ec2, self.client_ctx() as client: kws = { "CidrBlock": config["cidr_block"], "VpcId": config["vpc_id"] } if config["ipv6_cidr_block"] is not None: kws["Ipv6CidrBlock"] = config["ipv6_cidr_block"] subnet = await ec2.create_subnet(**kws) yield await self.convert_instance(subnet) map_public_ip_on_launch = await subnet.map_public_ip_on_launch if map_public_ip_on_launch != config["map_public_ip_on_launch"]: await client.modify_subnet_attribute( MapPublicIpOnLaunch={ "Value": config["map_public_ip_on_launch"] }, SubnetId=subnet.id, ) await subnet.load() yield await self.convert_instance(subnet) assign_ipv6_address_on_creation = ( await subnet.assign_ipv6_address_on_creation) if (assign_ipv6_address_on_creation != config["assign_ipv6_address_on_creation"]): await client.modify_subnet_attribute( AssignIpv6AddressOnCreation={ "Value": config["assign_ipv6_address_on_creation"] }, SubnetId=subnet.id, ) await subnet.load() yield await self.convert_instance(subnet) async def delete_task(self, current: SubnetType) -> st.EmptyType: """ Delete the subnet """ async with self.resource_ctx() as ec2: subnet = await ec2.Subnet(current["id"]) await subnet.delete()
class FileMachine(Machine): """ Simple file state machine """ UP = st.State("UP", FileConfigType, FileType) DOWN = st.NullState("DOWN") async def refresh(self, current: StateSnapshot) -> StateSnapshot: state = current.state.state if state == self.null_state.state: return current if not os.path.isfile(current.data["location"]): return StateSnapshot({}, self.null_state) data = self.get_file_info(current.data["location"]) return StateSnapshot(data, current.state) async def finalize(self, current: StateSnapshot) -> StateSnapshot: return StateSnapshot(dict(current.data, data=""), current.state) @staticmethod def get_file_info(path: str) -> FileType: location = os.path.realpath(path) with open(location) as f: data = f.read() stat_info = os.stat(path) return { "location": location, "data": data, "stat": { "mode": stat_info.st_mode, "ino": stat_info.st_ino, "dev": stat_info.st_dev, "nlink": stat_info.st_nlink, "uid": stat_info.st_uid, "gid": stat_info.st_gid, "size": stat_info.st_size, "atime": stat_info.st_atime, "mtime": stat_info.st_mtime, "ctime": stat_info.st_ctime, }, } @staticmethod def get_file_expected(config: st.StateConfig) -> Dict[str, Any]: with_realpath = st.struct_replace(config.obj, location=realpath( config.obj.location)) return st.fill_unknowns(with_realpath, FileType) @task.new async def remove_file(self, path: str) -> types.EmptyType: """ Delete the file """ os.remove(path) return {} @task.new async def set_file(self, data: FileConfigType) -> FileType: """ Set the file's contents """ path = os.path.realpath(data["location"]) with open(path, "w+") as f: f.write(data["data"]) return self.get_file_info(path) @task.new async def rename_file(self, from_path: str, to_path: str) -> FileType: from_path = os.path.realpath(from_path) to_path = os.path.realpath(to_path) os.rename(from_path, to_path) return self.get_file_info(to_path) @transition("UP", "UP") async def modify(self, current: StateSnapshot, config: StateConfig, session: TaskSession) -> Object: differ = session.ns.registry.get_differ(current.state.input_type) diffconfig = differ.config() expected = self.get_file_expected(config) def compare_realpaths(x, y): return os.path.realpath(x) == os.path.realpath(y) diffconfig.set_comparison("location", compare_realpaths) current_as_config = st.filter_struct(current.obj, config.type) diff = differ.diff(current_as_config, config.obj, session, diffconfig) if not diff: return current.obj loc_changed = "location" in diff data_changed = "data" in diff # If location changes only we can just rename the file if loc_changed and not data_changed: return session["rename_file"] << (self.rename_file( current.obj.location, config.obj.location) >> expected) if loc_changed: raise st.exc.NullRequired return session["update_file"] << ( self.set_file(config.obj) >> expected) @transition("DOWN", "UP") async def create(self, current: StateSnapshot, config: StateConfig, session: TaskSession) -> Object: expected = self.get_file_expected(config) return session["create_file"] << ( self.set_file(config.obj) >> expected) @transition("UP", "DOWN") async def delete(self, current: StateSnapshot, config: StateConfig, session: TaskSession) -> Object: session["delete_file"] << self.remove_file(current.obj.location) return st.Object({})
class SecurityGroupMachine(st.SimpleMachine): """ Machine for a security group """ UP = st.State("UP", SecurityGroupConfigType, SecurityGroupType) @contextlib.asynccontextmanager async def resource_ctx(self): async with aioboto3.resource("ec2") as ec2: yield ec2 @contextlib.asynccontextmanager async def client_ctx(self): async with aioboto3.client("ec2") as client: yield client def get_diff( self, current: st.StateSnapshot, config: st.StateConfig, session: st.TaskSession, ) -> st.Diff: differ = session.ns.registry.get_differ(config.state.input_type) diffconfig = differ.config() def compare_unordered(arr1, arr2): if arr1 is None or arr2 is None: return arr1 == arr2 return all(el1 in arr2 for el1 in arr1) and all(el2 in arr1 for el2 in arr2) diffconfig.set_comparison("ingress", compare_unordered) diffconfig.set_comparison("egress", compare_unordered) current_as_config = st.filter_struct(current.obj, config.type) out_diff = differ.diff(current_as_config, config.obj, session, diffconfig) return out_diff def get_action(self, diff: st.Diff) -> st.ModificationAction: if not diff: return st.ModificationAction.NONE if "name" in diff or "vpc_id" in diff or "description" in diff: return st.ModificationAction.DELETE_AND_RECREATE return st.ModificationAction.MODIFY async def convert_instance(self, instance: "SecurityGroup") -> Dict[str, Any]: out = {"id": instance.id} ( out["name"], out["description"], ingress_perms, egress_perms, out["vpc_id"], out["owner_id"], ) = await asyncio.gather( instance.group_name, instance.description, instance.ip_permissions, instance.ip_permissions_egress, instance.vpc_id, instance.owner_id, ) out["ingress"] = list(map(self.convert_rule, ingress_perms)) out["egress"] = list(map(self.convert_rule, egress_perms)) return out @staticmethod def convert_rule(rule: Dict[str, Any]) -> Dict[str, Any]: return { "cidr_blocks": [rang["CidrIp"] for rang in rule["IpRanges"]], "ipv6_cidr_blocks": [rang["CidrIpv6"] for rang in rule["Ipv6Ranges"]], "from_port": rule.get("FromPort"), "to_port": rule.get("ToPort"), "protocol": rule["IpProtocol"], } async def refresh_state(self, data: Any) -> Optional[Any]: async with self.resource_ctx() as ec2: sg = await ec2.SecurityGroup(data["id"]) try: await sg.load() except botocore.exceptions.ClientError: return None return await self.convert_instance(sg) async def refresh_config(self, config: st.Object) -> st.Object: async with self.client_ctx() as client: default_vpc = await utils.get_default_vpc(client) return st.struct_replace(config, vpc_id=st.ifnull(config.vpc_id, default_vpc["VpcId"])) async def create_task( self, config: SecurityGroupConfigType) -> SecurityGroupType: """ Create a security group resource """ async with self.resource_ctx() as ec2: kws = { "Description": config["description"], "GroupName": config["name"] } if config["vpc_id"] is not None: kws["VpcId"] = config["vpc_id"] group = await ec2.create_security_group(**kws) current = await self.convert_instance(group) yield current yield await self.update_rules(current, config) async def modify_task( self, diff: st.Diff, current: SecurityGroupType, config: SecurityGroupConfigType) -> SecurityGroupType: """ Modify the security group """ async with self.resource_ctx() as ec2: sg = await ec2.SecurityGroup(current["id"]) if "ingress" in diff or "egress" in diff: await self.update_rules(current, config) await sg.load() yield await self.convert_instance(sg) async def update_rules( self, current: SecurityGroupType, config: SecurityGroupConfigType) -> SecurityGroupType: """ Update security group rules """ ingress_to_add = [] ingress_to_remove = [] egress_to_add = [] egress_to_remove = [] for name, to_add, to_remove, same in [ ("ingress", ingress_to_add, ingress_to_remove, []), ("egress", egress_to_add, egress_to_remove, []), ]: for rule in config[name]: if rule in current[name]: same.append(rule) else: to_add.append(rule) for rule in current[name]: if rule not in same: to_remove.append(rule) async with self.resource_ctx() as ec2: coros = [] instance = await ec2.SecurityGroup(current["id"]) def rule_params(rule): kws = { "IpRanges": [{ "CidrIp": block } for block in (rule["cidr_blocks"] or [])], "Ipv6Ranges": [{ "CidrIpv6": block } for block in (rule["ipv6_cidr_blocks"] or [])], "IpProtocol": rule["protocol"], } if rule["from_port"] is not None: kws["FromPort"] = rule["from_port"] if rule["to_port"] is not None: kws["ToPort"] = rule["to_port"] return {"IpPermissions": [kws]} for rule in ingress_to_add: coros.append(instance.authorize_ingress(**rule_params(rule))) for rule in ingress_to_remove: coros.append(instance.revoke_ingress(**rule_params(rule))) for rule in egress_to_add: coros.append(instance.authorize_egress(**rule_params(rule))) for rule in egress_to_remove: coros.append(instance.revoke_egress(**rule_params(rule))) await asyncio.gather(*coros) out = current.copy() out["ingress"] = config["ingress"] out["egress"] = config["egress"] return out async def delete_task(self, current: SecurityGroupType) -> st.EmptyType: """ Delete a security group """ async with self.resource_ctx() as ec2: group = await ec2.SecurityGroup(current["id"]) await group.delete() return {}
class VpcMachine(st.SimpleMachine): """ AWS VPC resource """ UP = st.State("UP", VpcConfigType, VpcType) @contextlib.asynccontextmanager async def resource_ctx(self): async with aioboto3.resource("ec2") as ec2: yield ec2 @contextlib.asynccontextmanager async def client_ctx(self): async with aioboto3.client("ec2") as client: yield client async def convert_instance(self, data: Dict[str, Any], vpc: "Vpc") -> Dict[str, Any]: out = { "id": vpc.id, "assign_generated_ipv6_cidr_block": data["assign_generated_ipv6_cidr_block"], } ( out["cidr_block"], out["instance_tenancy"], ipv6_associations, out["owner_id"], ) = await asyncio.gather( vpc.cidr_block, vpc.instance_tenancy, vpc.ipv6_cidr_block_association_set, vpc.owner_id, ) if ipv6_associations: association = ipv6_associations[0] out["ipv6_association_id"] = association["AssociationId"] out["ipv6_cidr_block"] = association["Ipv6CidrBlock"] else: out["ipv6_association_id"] = None out["ipv6_cidr_block"] = None async with self.client_ctx() as client: main_rt_resp = await client.describe_route_tables(Filters=[ { "Name": "vpc-id", "Values": [vpc.id] }, { "Name": "association.main", "Values": ["true"] }, ]) if main_rt_resp["RouteTables"]: out["main_route_table_id"] = main_rt_resp["RouteTables"][0][ "RouteTableId"] else: out["main_route_table_id"] = None default_acl_resp = await client.describe_network_acls(Filters=[ { "Name": "vpc-id", "Values": [vpc.id] }, { "Name": "default", "Values": ["true"] }, ]) if default_acl_resp["NetworkAcls"]: out["default_network_acl_id"] = default_acl_resp[ "NetworkAcls"][0]["NetworkAclId"] else: out["default_network_acl_id"] = None return out async def refresh_state(self, data: Any) -> Optional[Any]: async with self.resource_ctx() as ec2: instance = await ec2.Vpc(data["id"]) try: await instance.load() except botocore.exceptions.ClientError: return None return await self.convert_instance(data, instance) async def create_task(self, config: VpcConfigType) -> VpcType: """ Create a new VPC """ async with self.resource_ctx() as ec2: vpc = await ec2.create_vpc( CidrBlock=config["cidr_block"], InstanceTenancy=config["instance_tenancy"], AmazonProvidedIpv6CidrBlock=config[ "assign_generated_ipv6_cidr_block"], ) yield await self.convert_instance(config, vpc) await vpc.wait_until_available() await vpc.load() yield await self.convert_instance(config, vpc) async def delete_task(self, current: VpcType) -> st.EmptyType: """ Delete the VPC """ async with self.resource_ctx() as ec2: instance = await ec2.Vpc(current["id"]) await instance.delete() return {}