def process_sqs_records(ctx, event, function=None, function_arg=None): if event is None: return False if "Records" not in event: return False processed = 0 for r in event["Records"]: if r["eventSource"] == "aws:sqs": misc.initialize_clients(["sqs"], ctx) if function is None or function(r, function_arg): log.debug("Deleting SQS record...") processed += 1 try: sqs_client = ctx["sqs.client"] queue_arn = r["eventSourceARN"] queue_name = queue_arn.split(':')[-1] account_id = queue_arn.split(':')[-2] response = sqs_client.get_queue_url( QueueName=queue_name, QueueOwnerAWSAccountId=account_id) response = sqs_client.delete_message( QueueUrl=response["QueueUrl"], ReceiptHandle=r["receiptHandle"]) except Exception as e: log.exception( "[WARNING] Failed to delete SQS message %s : %e" % (r, e)) return processed
def call_sns(self, arn, region, account_id, service_path, content, e): misc.initialize_clients(["sns"], self.context) client = self.context["sns.client"] log.info("Notifying to SNS Topic '%s' for event '%s'..." % (arn, e)) response = client.publish( TopicArn=arn, Message=content, Subject="CloneSquad/%s event notification" % (self.context["GroupName"]) )
def call_lambda(self, arn, region, account_id, service_path, content, e): misc.initialize_clients(["lambda"], self.context) client = self.context["lambda.client"] log.info("Notifying asynchronously UserLambda '%s' for event '%s'..." % (arn, e)) response = client.invoke( FunctionName=arn, InvocationType='Event', LogType='Tail', Payload=content )
def call_sqs(self, arn, region, account_id, service_path, content, e): misc.initialize_clients(["sqs"], self.context) client = self.context["sqs.client"] response = client.get_queue_url( QueueName=service_path, QueueOwnerAWSAccountId=account_id ) log.info("Notifying to SQS Queue '%s' for event '%s'..." % (arn, e)) args = { "QueueUrl": response["QueueUrl"], "MessageBody": content } if service_path.endswith(".fifo"): # create a message group id that is unique to this CS deployment to allow # SQS FIFO sharing between multiple deployment concurrently. args["MessageGroupId"] = f"CS-notif-channel-{account_id}-%s" % (self.context["GroupName"]) args["MessageDeduplicationId"] = misc.sha256(content) response = client.send_message(**args)
def sqs_interact_processing(self, event, dummy): """ This function always return True to discard the message in all case. """ event = json.loads(event["body"]) if "OpType" not in event: log.warning( "Can't understand SQS message! (Missing 'OpType' required member of body dict)" ) return True command = event["OpType"].lower() command = urllib.parse.unquote(command) cmd = self.find_command(command) if cmd is None: log.warning("Received unknown command '%s' through SQS message!" % command) else: cmd = self.commands[command] if "sqs" not in cmd["interface"]: log.warn("Command '%s' not available through SQS queue!" % command) return True cacheddata = None if "prepare" in cmd: log.log(log.NOTICE, "Loading cached data...") query_cache.load_cached_data() if command not in query_cache.interact_precomputed_data[ "data"]: response["statusCode"] = 500 response["body"] = "Missing data" return True cacheddata = query_cache.interact_precomputed_data["data"][ command] log.info("Loading prerequisites %s..." % cmd["prerequisites"]) misc.initialize_clients(cmd["clients"], self.context) misc.load_prerequisites(self.context, cmd["prerequisites"]) cmd["func"](self.context, event, {"headers": {}}, cacheddata) log.info("Processed command '%s' through an SQS message." % command) return True
def get_prerequisites(self): ctx = self.context self.table = kvtable.KVTable.create( self.context, self.context["StateTable"], cache_max_age=Cfg.get_duration_secs("statemanager.cache.max_age")) for a in self.table_aggregates: self.table.register_aggregates(a) self.table.reread_table() # Retrieve all CloneSquad resources misc.initialize_clients(["resourcegroupstaggingapi"], self.context) tagging_client = self.context["resourcegroupstaggingapi.client"] paginator = tagging_client.get_paginator('get_resources') tag_mappings = itertools.chain.from_iterable( page['ResourceTagMappingList'] for page in paginator.paginate( TagFilters=[{ 'Key': 'clonesquad:group-name', 'Values': [self.context["GroupName"]] }])) self.clonesquad_resources = list(tag_mappings)
def manage_rule_event(self, event): if Cfg.get_int("cron.disable"): return if "source" in event and event["source"] == "aws.events" and event[ "detail-type"] == "Scheduled Event": # Triggered by an AWS CloudWatch Scheduled event. We look for a ParameterSet # request based on the ARN misc.initialize_clients(["events"], self.context) misc.load_prerequisites(self.context, ["o_scheduler"]) for r in event["resources"]: log.debug("Processing Scheduled event '%s'..." % r) m = re.search( "^arn:aws:events:[a-z-0-9]+:[0-9]+:rule/CS-Cron-%s-(.*)" % self.context["GroupName"], r) if m is not None and len(m.groups()) == 1: rule_num = m.group(1) log.info("Got event rule '%s'" % rule_num) self.load_event_definitions() rule_def = self.get_ruledef_by_name( "CS-Cron-%s-%s" % (self.context["GroupName"], rule_num)) log.debug(rule_def) ttl = None try: ttl = misc.str2duration_seconds( rule_def["TTL"] ) if rule_def is not None and "TTL" in rule_def else None except Exception as e: log.exception( "[WARNING] Failed to read 'TTL' value '%s'!" % (TTL)) params = dict(rule_def["Data"][0]) for k in params: if k in ["TTL", "schedule"]: continue Cfg.set(k, params[k], ttl=ttl) return True return False
def reread_table(self, force_reread=False): if not force_reread and self.table_cache is not None: return now = self.context["now"] misc.initialize_clients(["dynamodb"], self.context) client = self.context["dynamodb.client"] self.table_last_read_date = now # Get table schema response = client.describe_table(TableName=self.table_name) self.table_schema = response["Table"] schema = self.table_schema["KeySchema"] self.is_kv_table = len( schema) == 1 and schema[0]["AttributeName"] == "Key" # Read all the table into memory table_content = None try: table_content = misc.dynamodb_table_scan(client, self.table_name) except Exception as e: log.exception("Failed to scan '%s' table: %s" % (self.table_name, e)) raise e table_cache = [] # Extract aggregates when encountering them for record in table_content: if "Key" not in record: table_cache.append(record) continue key = record["Key"] if "Value" not in record: log.warn( "Key '%s' specified but missing 'Value' column in configuration record: %s" % (key, record)) continue value = record["Value"] aggregate = next( filter(lambda a: key == a["Prefix"], self.aggregates), None) if aggregate is not None: agg = [] try: agg = misc.decode_json(value) except Exception as e: log.debug( "Failed to decode JSON aggregate for key '%s' : %s / %s " % (key, value, e)) continue agg.append(record) self._safe_key_import(table_cache, aggregate, agg, exclude_aggregate_key=False) else: aggregate = next( filter(lambda a: self.is_aggregated_key(key), self.aggregates), None) if aggregate: log.debug( "Found a record '%s' that should belong to an aggregate. Ignoring it!" % key) continue self._safe_key_import(table_cache, aggregate, [record]) # Clean the table of outdated record (TTL based) self.table_cache = [] for r in table_cache: if "ExpirationTime" not in r: self.table_cache.append(r) continue expiration_time = misc.seconds2utc(r["ExpirationTime"]) if expiration_time is None or expiration_time > now: self.table_cache.append(r) else: if self.is_kv_table: log.debug("Wiping outdated item '%s'..." % r["Key"]) client.delete_item(Key={ 'Key': { 'S': r["Key"] }, }, TableName=self.table_name) # Build an easier to manipulate dict of all the data self._build_dict()
def get_prerequisites(self): """ Gather instance status by calling SSM APIs. """ if not Cfg.get_int("ssm.enable"): log.log(log.NOTICE, "SSM support is currently disabled. Set ssm.enable to 1 to enabled it.") return now = self.context["now"] self.ttl = Cfg.get_duration_secs("ssm.state.default_ttl") GroupName = self.context["GroupName"] misc.initialize_clients(["ssm"], self.context) client = self.context["ssm.client"] # Retrive all SSM maintenace windows applicable to this CloneSquad deployment mw_names = { "__globaldefault__": {}, "__default__": {}, "__main__": {}, "__all__": {} } fmt = self.context.copy() mw_names["__globaldefault__"]["Names"] = Cfg.get_list("ssm.feature.maintenance_window.global_defaults", fmt=fmt) mw_names["__default__"]["Names"] = Cfg.get_list("ssm.feature.maintenance_window.defaults", fmt=fmt) mw_names["__main__"]["Names"] = Cfg.get_list("ssm.feature.maintenance_window.mainfleet.defaults", fmt=fmt) mw_names["__all__"]["Names"] = Cfg.get_list("ssm.feature.maintenance_window.subfleet.__all__.defaults", fmt=fmt) all_mw_names = mw_names["__globaldefault__"]["Names"] all_mw_names.extend([ n for n in mw_names["__default__"]["Names"] if n not in all_mw_names]) all_mw_names.extend([ n for n in mw_names["__main__"]["Names"] if n not in all_mw_names]) all_mw_names.extend([ n for n in mw_names["__all__"]["Names"] if n not in all_mw_names]) Cfg.register({ f"ssm.feature.maintenance_window.subfleet.__all__.force_running": Cfg.get("ssm.feature.maintenance_window.subfleet.{SubfleetName}.force_running"), f"ssm.feature.events.ec2.scaling_state_changes.draining.__main__.connection_refused_tcp_ports": Cfg.get("ssm.feature.events.ec2.scaling_state_changes.draining.connection_refused_tcp_ports") }) for SubfleetName in self.o_ec2.get_subfleet_names(): fmt["SubfleetName"] = SubfleetName mw_names[f"Subfleet.{SubfleetName}"] = {} Cfg.register({ f"ssm.feature.maintenance_window.subfleet.{SubfleetName}.defaults": Cfg.get("ssm.feature.maintenance_window.subfleet.{SubfleetName}.defaults"), f"ssm.feature.maintenance_window.subfleet.{SubfleetName}.ec2.schedule.min_instance_count": Cfg.get("ssm.feature.maintenance_window.subfleet.{SubfleetName}.ec2.schedule.min_instance_count"), f"ssm.feature.maintenance_window.subfleet.{SubfleetName}.force_running": Cfg.get("ssm.feature.maintenance_window.subfleet.{SubfleetName}.force_running"), f"ssm.feature.events.ec2.scaling_state_changes.draining.{SubfleetName}.connection_refused_tcp_ports": Cfg.get("ssm.feature.events.ec2.scaling_state_changes.draining.connection_refused_tcp_ports") }) mw_names[f"Subfleet.{SubfleetName}"]["Names"] = Cfg.get_list(f"ssm.feature.maintenance_window.subfleet.{SubfleetName}.defaults", fmt=fmt) all_mw_names.extend([ n for n in mw_names[f"Subfleet.{SubfleetName}"]["Names"] if n not in all_mw_names]) names = all_mw_names mws = [] while len(names): paginator = client.get_paginator('describe_maintenance_windows') response_iterator = paginator.paginate( Filters=[ { 'Key': 'Name', 'Values': names[:20] }, ]) for r in response_iterator: for wi in r["WindowIdentities"]: if not wi["Enabled"]: log.log(log.NOTICE, f"SSM Maintenance Window '%s' not enabled. Ignored..." % wi["Name"]) continue if "NextExecutionTime" not in wi: log.log(log.NOTICE, f"/!\ SSM Maintenance Window '%s' without 'NextExecutionTime'." % wi["Name"]) if wi not in mws: mws.append(wi) names = names[20:] # Make string dates as object dates for d in mws: if "NextExecutionTime" in d: d["NextExecutionTime"] = misc.str2utc(d["NextExecutionTime"]) # Retrieve Maintenace Window tags with the resourcegroup API tagged_mws = self.context["o_state"].get_resources(service="ssm", resource_name="maintenancewindow") for tmw in tagged_mws: mw_id = tmw["ResourceARN"].split("/")[1] mw = next(filter(lambda w: w["WindowId"] == mw_id, mws), None) if mw: mw["Tags"] = tmw["Tags"] valid_mws = [] for mw in mws: mw_id=mw["WindowId"] if "Tags" not in mw: try: response = client.list_tags_for_resource(ResourceType='MaintenanceWindow', ResourceId=mw_id) mw["Tags"] = response['TagList'] if 'TagList' in response else [] except Exception as e: log.error(f"Failed to fetch Tags for MaintenanceWindow '{mw_id}'") if ("Tags" not in mw or not len(mw["Tags"])) and mw["Name"] not in mw_names["__globaldefault__"]["Names"]: log.warning(f"Please tag SSM Maintenance Window '%s/%s' with 'clonesquad:group-name': '%s'!" % (mw["Name"], mw["WindowId"], self.context["GroupName"])) continue valid_mws.append(mw) self.maintenance_windows = { "Names": mw_names, "Windows": valid_mws } # Update asynchronous results from previously launched commands self.update_pending_command_statuses() # Perform maintenance window house keeping self.manage_maintenance_windows() if len(mws): log.log(log.NOTICE, f"Found matching SSM maintenance windows: %s" % self.maintenance_windows["Windows"]) # Hard dependency toward EC2 module. We update the SSM instance initializing states self.o_ec2.update_ssm_initializing_states()
ctx["ScaleUp_SNSTopicArn"] = "arn:aws:sns:%s:%s:CloneSquad-CloudWatchAlarm-ScaleUp-%s" % (ctx["AWS_DEFAULT_REGION"], account_id, ctx["GroupName"]) ctx["InteractLambdaArn"] = "arn:aws:lambda:%s:%s:function:CloneSquad-Interact-%s" % (ctx["AWS_DEFAULT_REGION"], account_id, ctx["GroupName"]) ctx["AWS_LAMBDA_LOG_GROUP_NAME"] = "/aws/lambda/CloneSquad-Main-%s" % ctx["GroupName"] ctx["SSMLogGroup"] = "/aws/lambda/CloneSquad-SSM-%s" % ctx["GroupName"] ctx["CloneSquadVersion"] = "--Development--" # Special treatment while started from SMA invoke loval if misc.is_sam_local() or __name__ == '__main__': fix_sam_bugs() print("SAM Local Environment:") for env in os.environ: print("%s=%s" % (env, os.environ[env])) # Avoid client initialization time during event processsing misc.initialize_clients(["ec2", "cloudwatch", "events", "sqs", "sns", "dynamodb", "ssm", "lambda", "elbv2", "rds", "resourcegroupstaggingapi", "transfer"], ctx) log.debug("End of preambule.") @xray_recorder.capture(name="app.init") def init(with_kvtable=True, with_predefined_configuration=True): config.init(ctx, with_kvtable=with_kvtable, with_predefined_configuration=with_predefined_configuration) Cfg.register({ "app.run_period,Stable" : { "DefaultValue": "seconds=20", "Format" : "Duration", "Description" : """Period when the Main scheduling Lambda function is run. The smaller, the more accurate and reactive is CloneSquad. The bigger, the cheaper is CloneSquad to run itself (Lambda executions, Cloudwatch GetMetricData, DynamoDB queries...) """ },
def handler(self, event, context, response): global query_cache query_cache.check_invalidation() if "httpMethod" in event and "path" in event: response.update({ "isBase64Encoded": False, "statusCode": 500, "headers": { "Content-Type": "application/json" }, "body": "" }) querystring = "" if "queryStringParameters" in event and event[ "queryStringParameters"] is not None: querystring = "&".join([ "%s:%s" % (q, event["queryStringParameters"][q]) for q in event["queryStringParameters"].keys() ]) event.update(event["queryStringParameters"]) log.log( log.NOTICE, "Received API Gateway message for path '%s'" % event["path"]) # Normalize command format arg = event["OpType"] if "OpType" in event else event["path"] path = arg.lower().split("/") path_list = list(filter(lambda x: x != "", path)) command = "/".join(path_list) command = urllib.parse.unquote(command) cmd = self.find_command(command) if cmd is None: response["statusCode"] = 404 response["body"] = "Unknown command '%s'" % (command) return True event["OpType"] = command #log.log(log.NOTICE, "Processing API Gateway command '%s'" % (command)) if "apigw" not in cmd["interface"]: response["statusCode"] = 404 response["body"] = "Command not available through API Gateway" return True is_cacheable = cmd["cache"] in ["global", "client"] if is_cacheable: cacheable_url = "%s?%s_%s" % ( command, querystring, "" if cmd["cache"] == "global" else event["requestContext"]["identity"]["userArn"]) #log.log(log.NOTICE, "Cacheable query '%s'" % cacheable_url) entry = query_cache.get(cacheable_url) if entry is not None: response.update(entry) log.log( log.NOTICE, "API Gateway query '%s' served from the cache..." % command) return True misc.initialize_clients(cmd["clients"], self.context) misc.load_prerequisites(self.context, cmd["prerequisites"]) cacheddata = None if "prepare" in cmd: #log.log(log.NOTICE, "Loading cached data...") query_cache.load_cached_data() if command not in query_cache.interact_precomputed_data[ "data"]: response["statusCode"] = 500 response["body"] = "Missing data" return True cacheddata = query_cache.interact_precomputed_data["data"][ command] if cmd["func"]( self.context, event, response, cacheddata ) and is_cacheable and response["statusCode"] == 200: query_cache.put(cacheable_url, response) return True elif self.context["o_scheduler"].manage_rule_event(event): log.log(log.NOTICE, "Processed Cloudwatch Scheduler event") elif sqs.process_sqs_records(self.context, event, function=self.sqs_interact_processing): log.info("Processed SQS records") else: log.warning("Failed to process the Interact message!") return False