예제 #1
0
 def test_Account(self):
     json_blob = {u'id': 111111111111, u'name': u'prod'}
     account = Account(None, json_blob)
     assert_equal(111111111111, account.local_id)
     assert_equal("prod", account.name)
     assert_equal("account", account.node_type)
     assert_equal("arn:aws:::111111111111:", account.arn)
     assert_false(account.isLeaf)
     assert_equal("prod", get_name(json_blob, "name"))
     assert_false(account.has_leaves)
     assert_equal([], account.leaves)
     assert_equal({'data': {'node_data': {u'id': 111111111111, u'name': u'prod'}, 'local_id': 111111111111,
                            'type': 'account', 'id': 'arn:aws:::111111111111:', 'name': u'prod'}}, account.cytoscape_data())
예제 #2
0
 def test_get_ec2s(self):
     # This actually uses the demo data files provided
     json_blob = {u'id': 111111111111, u'name': u'demo'}
     account = Account(None, json_blob)
     region = Region(account, {"Endpoint": "ec2.us-east-1.amazonaws.com", "RegionName": "us-east-1"})
     vpc = Vpc(region, get_vpcs(region, {})[0])
     subnet = Subnet(vpc, {"SubnetId": "subnet-00000001", "CidrBlock": "10.0.0.0/24", "Tags": [{"Value": "Public a1", "Key": "Name"}]})
예제 #3
0
def get_cidrs_for_account(account, cidrs):
    account = Account(None, account)

    # TODO Need to use CloudMapper's prepare to identify trusted IPs that are actually in use.
    for region_json in get_regions(account):
        region = Region(account, region_json)
        sg_json = query_aws(account, "ec2-describe-security-groups", region)
        sgs = pyjq.all('.SecurityGroups[]', sg_json)
        for sg in sgs:
            cidrs_seen = set()
            cidr_and_name_list = pyjq.all('.IpPermissions[].IpRanges[]|[.CidrIp,.Description]', sg)
            for cidr, name in cidr_and_name_list:
                if not is_external_cidr(cidr):
                    continue

                if is_unneeded_cidr(cidr):
                    print('WARNING: Unneeded cidr used {} in {}'.format(cidr, sg['GroupId']))
                    continue

                for cidr_seen in cidrs_seen:
                    if (IPNetwork(cidr_seen) in IPNetwork(cidr) or
                                IPNetwork(cidr) in IPNetwork(cidr_seen)):
                        print('WARNING: Overlapping CIDRs in {}, {} and {}'.format(sg['GroupId'], cidr, cidr_seen))
                cidrs_seen.add(cidr)

                if cidr.startswith('0.0.0.0') and not cidr.endswith('/0'):
                    print('WARNING: Unexpected CIDR for attempted public access {} in {}'.format(cidr, sg['GroupId']))
                    continue

                if cidr == '0.0.0.0/0':
                    continue

                cidrs[cidr] = cidrs.get(cidr, set())
                if name is not None:
                    cidrs[cidr].add(name)
예제 #4
0
def get_account_stats(account):
    """Returns stats for an account"""
    account = Account(None, account)
    log_debug('Collecting stats in account {} ({})'.format(account.name, account.local_id))

    # Init stats to {}
    stats = OrderedDict()
    for k in resources:
        stats[k] = {}

    for region_json in get_regions(account):
        region = Region(account, region_json)

        for key, resource in resources.items():
            # Skip global services (just CloudFront)
            if ('region' in resource) and (resource['region'] != region.name):
                continue

            # Check exceptions that require special code to perform the count
            if key == 'route53_record':
                path = 'account-data/{}/{}/{}'.format(
                    account.name,
                    region.name,
                    'route53-list-resource-record-sets')
                if os.path.isdir(path):
                    stats[key][region.name] = 0
                    for f in listdir(path):
                        json_data = json.load(open(os.path.join(path, urllib.parse.quote_plus(f))))
                        stats[key][region.name] += sum(pyjq.all('.ResourceRecordSets|length', json_data))
            else:
                # Normal path
                stats[key][region.name] = sum(pyjq.all(resource['query'], 
                    query_aws(region.account, resource['source'], region)))

    return stats
예제 #5
0
def api_endpoints(accounts, config):
    for account in accounts:
        account = Account(None, account)
        for region_json in get_regions(account):
            region = Region(account, region_json)

            # Look for API Gateway
            json_blob = query_aws(region.account, 'apigateway-get-rest-apis',
                                  region)
            if json_blob is None:
                continue
            for api in json_blob.get('items', []):
                rest_id = api['id']
                deployments = get_parameter_file(region, 'apigateway',
                                                 'get-deployments', rest_id)
                if deployments is None:
                    continue
                for deployment in deployments['items']:
                    deployment_id = deployment['id']
                    stages = get_parameter_file(region, 'apigateway',
                                                'get-stages', rest_id)
                    if stages is None:
                        continue
                    for stage in stages['item']:
                        if stage['deploymentId'] == deployment_id:
                            resources = get_parameter_file(
                                region, 'apigateway', 'get-resources', rest_id)
                            if resources is None:
                                continue
                            for resource in resources['items']:
                                print('{}.execute-api.{}.amazonaws.com/{}{}'.
                                      format(api['id'], region.name,
                                             stage['stageName'],
                                             resource['path']))
예제 #6
0
def audit(accounts, config):
    """Audit the accounts"""

    for account in accounts:
        account = Account(None, account)
        print('Finding resources in account {} ({})'.format(
            account.name, account.local_id))

        for region_json in get_regions(account):
            region = Region(account, region_json)
            if region.name == 'us-east-1':
                audit_s3_buckets(region)
                audit_cloudtrail(region)
                audit_password_policy(region)
                audit_root_user(region)
                audit_users(region)
                audit_route53(region)
                audit_cloudfront(region)
            audit_ebs_snapshots(region)
            audit_rds_snapshots(region)
            audit_rds(region)
            audit_amis(region)
            audit_ecr_repos(region)
            audit_redshift(region)
            audit_es(region)
            audit_ec2(region)
            audit_elb(region)
            audit_sg(region)
            audit_lambda(region)
            audit_glacier(region)
            audit_kms(region)
            audit_sqs(region)
            audit_sns(region)
            audit_lightsail(region)
예제 #7
0
 def test_get_vpcs(self):
     # This actually uses the demo data files provided
     json_blob = {u'id': 111111111111, u'name': u'demo'}
     account = Account(None, json_blob)
     region = Region(account, {"Endpoint": "ec2.us-east-1.amazonaws.com", "RegionName": "us-east-1"})
     assert_equal([{"VpcId": "vpc-12345678", "Tags": [{"Value": "Prod", "Key": "Name"}], "InstanceTenancy": "default", "CidrBlockAssociationSet": [{"AssociationId": "vpc-cidr-assoc-12345678",
                                                                                                                                                    "CidrBlock": "10.0.0.0/16", "CidrBlockState": {"State": "associated"}}], "State": "available", "DhcpOptionsId": "dopt-12345678", "CidrBlock": "10.0.0.0/16", "IsDefault": True}], get_vpcs(region, {}))
예제 #8
0
def get_account_stats(account, all_resources=False):
    """Returns stats for an account"""

    with open("stats_config.yaml", 'r') as f:
        resources = yaml.safe_load(f)

    account = Account(None, account)
    log_debug('Collecting stats in account {} ({})'.format(account.name, account.local_id))

    stats = {}
    stats['keys'] = []
    for resource in resources:
        # If the resource is marked as verbose, and we're not showing all resources, skip it.
        if resource.get('verbose',False) and not all_resources:
            continue
        stats['keys'].append(resource['name'])
        stats[resource['name']] = {}

    for region_json in get_regions(account):
        region = Region(account, region_json)

        for resource in resources:
            if resource.get('verbose',False) and not all_resources:
                continue

            # Skip global services (just CloudFront)
            if ('region' in resource) and (resource['region'] != region.name):
                continue

            # Normal path
            stats[resource['name']][region.name] = sum(pyjq.all(resource['query'],
                                                                query_aws(region.account, resource['source'], region)))

    return stats
예제 #9
0
def find_admins(accounts, findings):
    admins = []
    for account in accounts:
        account = Account(None, account)
        region = get_us_east_1(account)
        admins.extend(find_admins_in_account(region, findings))

    return admins
def find_unused_resources(accounts):
    unused_resources = []
    for account in accounts:
        unused_resources_for_account = []
        for region_json in get_regions(Account(None, account)):
            region = Region(Account(None, account), region_json)

            unused_resources_for_region = {}

            add_if_exists(
                unused_resources_for_region,
                "security_groups",
                find_unused_security_groups(region),
            )
            add_if_exists(unused_resources_for_region, "volumes",
                          find_unused_volumes(region))
            add_if_exists(
                unused_resources_for_region,
                "elastic_ips",
                find_unused_elastic_ips(region),
            )
            add_if_exists(
                unused_resources_for_region,
                "network_interfaces",
                find_unused_network_interfaces(region),
            )
            add_if_exists(
                unused_resources_for_region,
                "elastic_load_balancers",
                find_unused_elastic_load_balancers(region),
            )

            unused_resources_for_account.append({
                "region":
                region_json["RegionName"],
                "unused_resources":
                unused_resources_for_region,
            })
        unused_resources.append({
            "account": {
                "id": account["id"],
                "name": account["name"]
            },
            "regions": unused_resources_for_account,
        })
    return unused_resources
예제 #11
0
def amis(args, accounts, config):
    # Loading the list of public images from disk takes a while, so we'll iterate by region

    regions_file = 'data/aws/us-east-1/ec2-describe-images.json'
    if not os.path.isfile(regions_file):
        raise Exception(
            "You need to download the set of public AMI images.  Run:\n"
            "  mkdir -p data/aws\n"
            "  cd data/aws\n"
            "  aws ec2 describe-regions | jq -r '.Regions[].RegionName' | xargs -I{} mkdir {}\n"
            "  aws ec2 describe-regions | jq -r '.Regions[].RegionName' | xargs -I{} sh -c 'aws --region {} ec2 describe-images --executable-users all > {}/ec2-describe-images.json'\n"
        )

    print("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format(
        'Account Name', 'Region Name', 'Instance Id', 'Instance Name',
        'AMI ID', 'Is Public', 'AMI Description', 'AMI Owner'))

    for region in listdir('data/aws/'):
        # Get public images
        public_images_file = 'data/aws/{}/ec2-describe-images.json'.format(
            region)
        public_images = json.load(open(public_images_file))
        resource_filter = '.Images[]'
        public_images = pyjq.all(resource_filter, public_images)

        for account in accounts:
            account = Account(None, account)
            region = Region(account, {'RegionName': region})

            instances = query_aws(account, "ec2-describe-instances", region)
            resource_filter = '.Reservations[].Instances[] | select(.State.Name == "running")'
            if args.instance_filter != '':
                resource_filter += '|{}'.format(args.instance_filter)
            instances = pyjq.all(resource_filter, instances)

            account_images = query_aws(account, "ec2-describe-images", region)
            resource_filter = '.Images[]'
            account_images = pyjq.all(resource_filter, account_images)

            for instance in instances:
                image_id = instance['ImageId']
                image_description = ''
                owner = ''
                image, is_public_image = find_image(image_id, public_images,
                                                    account_images)
                if image:
                    # Many images don't have all fields, so try the Name, then Description, then ImageLocation
                    image_description = image.get('Name', '')
                    if image_description == '':
                        image_description = image.get('Description', '')
                        if image_description == '':
                            image_description = image.get('ImageLocation', '')
                    owner = image.get('OwnerId', '')

                print("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format(
                    account.name, region.name, instance['InstanceId'],
                    get_instance_name(instance), image_id, is_public_image,
                    image_description, owner))
예제 #12
0
def get_cidrs_for_account(account, cidrs):
    account = Account(None, account)

    for region_json in get_regions(account):
        region = Region(account, region_json)
        sg_json = query_aws(account, "ec2-describe-security-groups", region)
        sgs = pyjq.all(".SecurityGroups[]", sg_json)
        for sg in sgs:
            cidr_and_name_list = pyjq.all(
                ".IpPermissions[].IpRanges[]|[.CidrIp,.Description]", sg
            )
            for cidr, name in cidr_and_name_list:
                if not is_external_cidr(cidr):
                    continue

                if is_unblockable_cidr(cidr):
                    print(
                        "WARNING: Unneeded cidr used {} in {}".format(
                            cidr, sg["GroupId"]
                        )
                    )
                    continue

                if cidr.startswith("0.0.0.0") and not cidr.endswith("/0"):
                    print(
                        "WARNING: Unexpected CIDR for attempted public access {} in {}".format(
                            cidr, sg["GroupId"]
                        )
                    )
                    continue

                if cidr == "0.0.0.0/0":
                    continue

                cidrs[cidr] = cidrs.get(cidr, set())
                if name is not None:
                    cidrs[cidr].add(name)

            for ip_permissions in sg["IpPermissions"]:
                cidrs_seen = set()
                for ip_ranges in ip_permissions["IpRanges"]:
                    if "CidrIp" not in ip_ranges:
                        continue
                    cidr = ip_ranges["CidrIp"]
                    for cidr_seen in cidrs_seen:
                        if IPNetwork(cidr_seen) in IPNetwork(cidr) or IPNetwork(
                            cidr
                        ) in IPNetwork(cidr_seen):
                            print(
                                "WARNING: Overlapping CIDRs in {}, {} and {}".format(
                                    sg["GroupId"], cidr, cidr_seen
                                )
                            )
                    cidrs_seen.add(cidr)
예제 #13
0
def get_collection_date(account):
    account_struct = Account(None, account)
    json_blob = query_aws(account_struct, "iam-get-credential-report",
                          get_us_east_1(account_struct))
    if not json_blob:
        raise Exception(
            "File iam-get-credential-report.json does not exist or is not well-formed. Likely cause is you did not run the collect command for this account."
        )

    # GeneratedTime looks like "2019-01-30T15:43:24+00:00"
    # so extract the data part "2019-01-30"
    return json_blob["GeneratedTime"][:10]
예제 #14
0
def get_account_stats(account, all_resources=False):
    """Returns stats for an account"""

    with open("stats_config.yaml", 'r') as f:
        resources = yaml.safe_load(f)

    account = Account(None, account)
    log_debug('Collecting stats in account {} ({})'.format(account.name, account.local_id))

    stats = {}
    stats['keys'] = []
    for resource in resources:
        # If the resource is marked as verbose, and we're not showing all resources, skip it.
        if resource.get('verbose',False) and not all_resources:
            continue
        stats['keys'].append(resource['name'])
        stats[resource['name']] = {}

    for region_json in get_regions(account):
        region = Region(account, region_json)

        for resource in resources:
            if resource.get('verbose',False) and not all_resources:
                continue

            # Skip global services (just CloudFront)
            if ('region' in resource) and (resource['region'] != region.name):
                continue

            # S3 buckets require special code to identify their location
            if resource['name'] == 'S3 buckets':
                if region.name == 'us-east-1':
                    buckets = pyjq.all('.Buckets[].Name', query_aws(region.account, 's3-list-buckets', region))
                    for bucket in buckets:
                        # Get the bucket's location
                        bucket_region = get_parameter_file(region, 's3', 'get-bucket-location', bucket)['LocationConstraint']

                        # Convert the value to a name.
                        # See https://docs.aws.amazon.com/general/latest/gr/rande.html#s3_region
                        if bucket_region is None:
                            bucket_region = 'us-east-1'
                        elif bucket_region == 'EU':
                            bucket_region = 'eu-west-1'

                        # Increment the count
                        tmp = stats[resource['name']].get(bucket_region, 0)
                        stats[resource['name']][bucket_region] = tmp + 1
            else:
                # Normal path
                stats[resource['name']][region.name] = sum(pyjq.all(resource['query'],
                                                                    query_aws(region.account, resource['source'], region)))

    return stats
예제 #15
0
 def test_Account(self):
     json_blob = {u"id": 111111111111, u"name": u"prod"}
     account = Account(None, json_blob)
     assert_equal(111111111111, account.local_id)
     assert_equal("prod", account.name)
     assert_equal("account", account.node_type)
     assert_equal("arn:aws:::111111111111:", account.arn)
     assert_false(account.isLeaf)
     assert_equal("prod", get_name(json_blob, "name"))
     assert_false(account.has_leaves)
     assert_equal([], account.leaves)
     assert_equal(
         {
             "data": {
                 "node_data": {"id": 111111111111, "name": "prod"},
                 "local_id": 111111111111,
                 "type": "account",
                 "id": "arn:aws:::111111111111:",
                 "name": u"prod",
             }
         },
         account.cytoscape_data(),
     )
예제 #16
0
def get_collection_date(account):
    if type(account) is not Account:
        account = Account(None, account)
    account_struct = account
    json_blob = query_aws(
        account_struct, "iam-get-credential-report", get_us_east_1(account_struct)
    )
    if not json_blob:
        raise InvalidAccountData(
            "File iam-get-credential-report.json does not exist or is not well-formed. Likely cause is you did not run the collect command for account {}".format(
                account.name
            )
        )

    # GeneratedTime looks like "2019-01-30T15:43:24+00:00"
    return json_blob["GeneratedTime"]
예제 #17
0
def audit(accounts):
    findings = Findings()

    for account in accounts:
        account = Account(None, account)

        for region_json in get_regions(account):
            region = Region(account, region_json)
            try:
                if region.name == "us-east-1":
                    audit_s3_buckets(findings, region)
                    audit_cloudtrail(findings, region)
                    audit_iam(findings, region)
                    audit_password_policy(findings, region)
                    audit_root_user(findings, region)
                    audit_users(findings, region)
                    audit_route53(findings, region)
                    audit_cloudfront(findings, region)
                    audit_s3_block_policy(findings, region)
                    audit_guardduty(findings, region)
                audit_ebs_snapshots(findings, region)
                audit_rds_snapshots(findings, region)
                audit_rds(findings, region)
                audit_amis(findings, region)
                audit_ecr_repos(findings, region)
                audit_redshift(findings, region)
                audit_es(findings, region)
                audit_ec2(findings, region)
                audit_sg(findings, region)
                audit_lambda(findings, region)
                audit_glacier(findings, region)
                audit_kms(findings, region)
                audit_sqs(findings, region)
                audit_sns(findings, region)
                audit_lightsail(findings, region)
            except Exception as e:
                findings.add(
                    Finding(
                        region,
                        "EXCEPTION",
                        str(e),
                        resource_details={
                            "exception": str(e),
                            "traceback": str(traceback.format_exc()),
                        },
                    ))
    return findings
예제 #18
0
def find_admins(accounts, args, findings):
    privs_to_look_for = None
    if "privs" in args and args.privs:
        privs_to_look_for = args.privs.split(",")
    include_restricted = False
    if "include_restricted" in args:
        include_restricted = args.include_restricted

    admins = []
    for account in accounts:
        account = Account(None, account)
        region = get_us_east_1(account)
        admins.extend(
            find_admins_in_account(region, findings, privs_to_look_for,
                                   include_restricted))

    return admins
예제 #19
0
def audit(accounts):
    findings = Findings()

    for account in accounts:
        account = Account(None, account)

        for region_json in get_regions(account):
            region = Region(account, region_json)
            try:
                if region.name == 'us-east-1':
                    audit_s3_buckets(findings, region)
                    audit_cloudtrail(findings, region)
                    audit_iam_policies(findings, region)
                    audit_password_policy(findings, region)
                    audit_root_user(findings, region)
                    audit_users(findings, region)
                    audit_route53(findings, region)
                    audit_cloudfront(findings, region)
                    audit_s3_block_policy(findings, region)
                    audit_guardduty(findings, region)
                audit_ebs_snapshots(findings, region)
                audit_rds_snapshots(findings, region)
                audit_rds(findings, region)
                audit_amis(findings, region)
                audit_ecr_repos(findings, region)
                audit_redshift(findings, region)
                audit_es(findings, region)
                audit_ec2(findings, region)
                audit_sg(findings, region)
                audit_lambda(findings, region)
                audit_glacier(findings, region)
                audit_kms(findings, region)
                audit_sqs(findings, region)
                audit_sns(findings, region)
                audit_lightsail(findings, region)
            except Exception as e:
                findings.add(
                    Finding(region,
                            'EXCEPTION',
                            e,
                            resource_details={
                                'exception': e,
                                'traceback': sys.exc_info()
                            }))
    return findings
예제 #20
0
def audit(accounts, config):
    """Audit the accounts"""

    for account in accounts:
        account = Account(None, account)
        print('Finding resources in account {} ({})'.format(
            account.name, account.local_id))

        for region_json in get_regions(account):
            region = Region(account, region_json)
            try:
                if region.name == 'us-east-1':
                    audit_s3_buckets(region)
                    audit_cloudtrail(region)
                    audit_password_policy(region)
                    audit_root_user(region)
                    audit_users(region)
                    audit_route53(region)
                    audit_cloudfront(region)
                    audit_s3_block_policy(region)
                    audit_guardduty(region)
                audit_ebs_snapshots(region)
                audit_rds_snapshots(region)
                audit_rds(region)
                audit_amis(region)
                audit_ecr_repos(region)
                audit_redshift(region)
                audit_es(region)
                audit_ec2(region)
                audit_elb(region)
                audit_sg(region)
                audit_lambda(region)
                audit_glacier(region)
                audit_kms(region)
                audit_sqs(region)
                audit_sns(region)
                audit_lightsail(region)
            except Exception as e:
                print('Exception in {} in {}'.format(region.account.name,
                                                     region.name),
                      file=sys.stderr)
                traceback.print_exc()
예제 #21
0
def get_cidrs_for_account(account, cidrs):
    account = Account(None, account)

    # TODO Need to use CloudMapper's prepare to identify trusted IPs that are actually in use.
    for region_json in get_regions(account):
        region = Region(account, region_json)
        sg_json = query_aws(account, "ec2-describe-security-groups", region)
        sgs = pyjq.all('.SecurityGroups[]', sg_json)
        for sg in sgs:
            cidr_and_name_list = pyjq.all(
                '.IpPermissions[].IpRanges[]|[.CidrIp,.Description]', sg)
            for cidr, name in cidr_and_name_list:
                if not is_external_cidr(cidr):
                    continue
                if is_unneeded_cidr(cidr):
                    print('WARNING: Unneeded cidr used {}'.format(cidr))
                    continue
                if cidr == '0.0.0.0/0':
                    continue
                cidrs[cidr] = cidrs.get(cidr, set())
                if name is not None:
                    cidrs[cidr].add(name)
def api_endpoints(accounts, config):
    for account in accounts:
        account = Account(None, account)
        for region_json in get_regions(account):
            region = Region(account, region_json)

            # Look for API Gateway
            json_blob = query_aws(region.account, "apigateway-get-rest-apis",
                                  region)
            if json_blob is None:
                continue
            for api in json_blob.get("items", []):
                rest_id = api["id"]
                deployments = get_parameter_file(region, "apigateway",
                                                 "get-deployments", rest_id)
                if deployments is None:
                    continue
                for deployment in deployments["items"]:
                    deployment_id = deployment["id"]
                    stages = get_parameter_file(region, "apigateway",
                                                "get-stages", rest_id)
                    if stages is None:
                        continue
                    for stage in stages["item"]:
                        if stage["deploymentId"] == deployment_id:
                            resources = get_parameter_file(
                                region, "apigateway", "get-resources", rest_id)
                            if resources is None:
                                continue
                            for resource in resources["items"]:
                                print("{}.execute-api.{}.amazonaws.com/{}{}".
                                      format(
                                          api["id"],
                                          region.name,
                                          stage["stageName"],
                                          resource["path"],
                                      ))
예제 #23
0
    def test_get_ec2s(self):
        # This actually uses the demo data files provided
        json_blob = {u'id': 111111111111, u'name': u'demo'}
        account = Account(None, json_blob)
        region = Region(account, {
            "Endpoint": "ec2.us-east-1.amazonaws.com",
            "RegionName": "us-east-1"
        })
        vpc = Vpc(region, get_vpcs(region, {})[0])
        subnet = Subnet(
            vpc, {
                "SubnetId": "subnet-00000001",
                "CidrBlock": "10.0.0.0/24",
                "Tags": [{
                    "Value": "Public a1",
                    "Key": "Name"
                }]
            })

        instances_passed = get_ec2s(subnet, {"tags": ["Name=Bastion"]})
        assert_equal(len(instances_passed), 1)
        instances_filtered = get_ec2s(
            subnet, {"tags": ["NonexistentTagName=NonexistentTagValue"]})
        assert_equal(len(instances_filtered), 0)
예제 #24
0
def get_collection_date(account):
    account_struct = Account(None, account)
    json_blob = query_aws(account_struct, "iam-get-credential-report", get_us_east_1(account_struct))
    # GeneratedTime looks like "2019-01-30T15:43:24+00:00"
    return json_blob['GeneratedTime'][:10]
예제 #25
0
def get_public_nodes(account, config, use_cache=False):
    # TODO Look for IPv6 also
    # TODO Look at more services from https://github.com/arkadiyt/aws_public_ips
    # TODO Integrate into something to more easily port scan and screenshot web services

    # Try reading from cache
    cache_file_path = "account-data/{}/public_nodes.json".format(account["name"])
    if use_cache and os.path.isfile(cache_file_path):
        with open(cache_file_path) as f:
            return json.load(f), []

    # Get the data from the `prepare` command
    outputfilter = {
        "internal_edges": False,
        "read_replicas": False,
        "inter_rds_edges": False,
        "azs": False,
        "collapse_by_tag": None,
        "collapse_asgs": True,
        "mute": True,
    }
    network = build_data_structure(account, config, outputfilter)

    public_nodes = []
    warnings = []

    # Look at all the edges for ones connected to the public Internet (0.0.0.0/0)
    for edge in pyjq.all(
        '.[].data|select(.type=="edge")|select(.source=="0.0.0.0/0")', network
    ):

        # Find the node at the other end of this edge
        target = {"arn": edge["target"], "account": account["name"]}
        target_node = pyjq.first(
            '.[].data|select(.id=="{}")'.format(target["arn"]), network, {}
        )

        # Depending on the type of node, identify what the IP or hostname is
        if target_node["type"] == "elb":
            target["type"] = "elb"
            target["hostname"] = target_node["node_data"]["DNSName"]
        elif target_node["type"] == "elbv2":
            target["type"] = "elbv2"
            target["hostname"] = target_node["node_data"]["DNSName"]
        elif target_node["type"] == "autoscaling":
            target["type"] = "autoscaling"
            target["hostname"] = target_node["node_data"].get("PublicIpAddress", "")
            if target["hostname"] == "":
                target["hostname"] = target_node["node_data"]["PublicDnsName"]
        elif target_node["type"] == "rds":
            target["type"] = "rds"
            target["hostname"] = target_node["node_data"]["Endpoint"]["Address"]
        elif target_node["type"] == "ec2":
            target["type"] = "ec2"
            dns_name = target_node["node_data"].get("PublicDnsName", "")
            target["hostname"] = target_node["node_data"].get(
                "PublicIpAddress", dns_name
            )
            target["tags"] = target_node["node_data"].get("Tags", [])
        elif target_node["type"] == "ecs":
            target["type"] = "ecs"
            target["hostname"] = ""
            for ip in target_node["node_data"]["ips"]:
                if is_public_ip(ip):
                    target["hostname"] = ip
        elif target_node["type"] == "redshift":
            target["type"] = "redshift"
            target["hostname"] = (
                target_node["node_data"].get("Endpoint", {}).get("Address", "")
            )
        else:
            # Unknown node
            raise Exception("Unknown type: {}".format(target_node["type"]))

        # Check if any protocol is allowed (indicated by IpProtocol == -1)
        ingress = pyjq.all(".[]", edge.get("node_data", {}))

        sg_group_allowing_all_protocols = pyjq.first(
            '.[]|select(.IpPermissions[]?|.IpProtocol=="-1")|.GroupId', ingress, None
        )
        public_sgs = {}
        if sg_group_allowing_all_protocols is not None:
            warnings.append(
                "All protocols allowed access to {} due to {}".format(
                    target, sg_group_allowing_all_protocols
                )
            )
            # I would need to redo this code in order to get the name of the security group
            public_sgs[sg_group_allowing_all_protocols] = {"public_ports": "0-65535"}

        # from_port and to_port mean the beginning and end of a port range
        # We only care about TCP (6) and UDP (17)
        # For more info see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/security-group-rules-reference.html
        port_ranges = []
        for sg in ingress:
            sg_port_ranges = []
            for ip_permission in sg.get("IpPermissions", []):
                selection = 'select((.IpProtocol=="tcp") or (.IpProtocol=="udp")) | select(.IpRanges[].CidrIp=="0.0.0.0/0")'
                sg_port_ranges.extend(
                    pyjq.all("{}| [.FromPort,.ToPort]".format(selection), ip_permission)
                )
                selection = 'select(.IpProtocol=="-1") | select(.IpRanges[].CidrIp=="0.0.0.0/0")'
                sg_port_ranges.extend(
                    pyjq.all("{}| [0,65535]".format(selection), ip_permission)
                )
            public_sgs[sg["GroupId"]] = {
                "GroupId": sg["GroupId"],
                "GroupName": sg["GroupName"],
                "public_ports": port_ranges_string(regroup_ranges(sg_port_ranges)),
            }
            port_ranges.extend(sg_port_ranges)
        range_string = port_ranges_string(regroup_ranges(port_ranges))

        target["ports"] = range_string
        target["public_sgs"] = public_sgs
        if target["ports"] == "":
            issue_msg = "No ports open for tcp or udp (probably can only be pinged). Rules that are not tcp or udp: {} -- {}"
            warnings.append(
                issue_msg.format(
                    json.dumps(
                        pyjq.all(
                            '.[]|select((.IpProtocol!="tcp") and (.IpProtocol!="udp"))'.format(
                                selection
                            ),
                            ingress,
                        )
                    ),
                    account,
                )
            )
        public_nodes.append(target)

    # For the network diagram, if an ELB has availability across 3 subnets, I put one node in each subnet.
    # We don't care about that when we want to know what is public and it makes it confusing when you
    # see 3 resources with the same hostname, when you view your environment as only having one ELB.
    # This same issue exists for RDS.
    # Reduce these to single nodes.

    reduced_nodes = {}

    for node in public_nodes:
        reduced_nodes[node["hostname"]] = node

    public_nodes = []
    for _, node in reduced_nodes.items():
        public_nodes.append(node)

    account = Account(None, account)
    for region_json in get_regions(account):
        region = Region(account, region_json)
        # Look for CloudFront
        if region.name == "us-east-1":
            json_blob = query_aws(
                region.account, "cloudfront-list-distributions", region
            )

            for distribution in json_blob.get("DistributionList", {}).get("Items", []):
                if not distribution["Enabled"]:
                    continue

                target = {"arn": distribution["ARN"], "account": account.name}
                target["type"] = "cloudfront"
                target["hostname"] = distribution["DomainName"]
                target["ports"] = "80,443"

                public_nodes.append(target)

        # Look for API Gateway
        json_blob = query_aws(region.account, "apigateway-get-rest-apis", region)
        if json_blob is not None:
            for api in json_blob.get("items", []):
                target = {"arn": api["id"], "account": account.name}
                target["type"] = "apigateway"
                target["hostname"] = "{}.execute-api.{}.amazonaws.com".format(
                    api["id"], region.name
                )
                target["ports"] = "80,443"

                public_nodes.append(target)

    # Write cache file
    with open(cache_file_path, "w") as f:
        f.write(json.dumps(public_nodes, indent=4, sort_keys=True))

    return public_nodes, warnings
예제 #26
0
def audit(accounts):
    findings = Findings()

    custom_auditor = None
    commands_path = "private_commands"
    for importer, command_name, _ in pkgutil.iter_modules([commands_path]):
        if "custom_auditor" != command_name:
            continue

        full_package_name = "%s.%s" % (commands_path, command_name)
        custom_auditor = importlib.import_module(full_package_name)

        for name, method in inspect.getmembers(custom_auditor,
                                               inspect.isfunction):
            if name.startswith("custom_filter"):
                global custom_filter
                custom_filter = method

    for account in accounts:
        account = Account(None, account)

        for region_json in get_regions(account):
            region = Region(account, region_json)

            try:
                if region.name == "us-east-1":
                    audit_s3_buckets(findings, region)
                    audit_cloudtrail(findings, region)
                    audit_iam(findings, region)
                    audit_password_policy(findings, region)
                    audit_root_user(findings, region)
                    audit_users(findings, region)
                    audit_route53(findings, region)
                    audit_cloudfront(findings, region)
                    audit_s3_block_policy(findings, region)
                audit_guardduty(findings, region)
                audit_ebs_snapshots(findings, region)
                audit_rds_snapshots(findings, region)
                audit_rds(findings, region)
                audit_amis(findings, region)
                audit_ecr_repos(findings, region)
                audit_redshift(findings, region)
                audit_es(findings, region)
                audit_ec2(findings, region)
                audit_sg(findings, region)
                audit_lambda(findings, region)
                audit_glacier(findings, region)
                audit_kms(findings, region)
                audit_sqs(findings, region)
                audit_sns(findings, region)
                audit_lightsail(findings, region)
            except Exception as e:
                findings.add(
                    Finding(
                        region,
                        "EXCEPTION",
                        str(e),
                        resource_details={
                            "exception": str(e),
                            "traceback": str(traceback.format_exc()),
                        },
                    ))

            # Run custom auditor if it exists
            try:
                if custom_auditor is not None:
                    for name, method in inspect.getmembers(
                            custom_auditor, inspect.isfunction):
                        if name.startswith("custom_audit_"):
                            method(findings, region)
            except Exception as e:
                findings.add(
                    Finding(
                        region,
                        "EXCEPTION",
                        str(e),
                        resource_details={
                            "exception": str(e),
                            "traceback": str(traceback.format_exc()),
                        },
                    ))

    return findings
예제 #27
0
def public(accounts, config):
    for account in accounts:
        # Get the data from the `prepare` command
        outputfilter = {
            'internal_edges': False,
            'read_replicas': False,
            'inter_rds_edges': False,
            'azs': False,
            'collapse_by_tag': None,
            'collapse_asgs': True,
            'mute': True
        }
        network = build_data_structure(account, config, outputfilter)

        # Look at all the edges for ones connected to the public Internet (0.0.0.0/0)
        for edge in pyjq.all(
                '.[].data|select(.type=="edge")|select(.source=="0.0.0.0/0")',
                network):

            # Find the node at the other end of this edge
            target = {'arn': edge['target'], 'account': account['name']}
            target_node = pyjq.first(
                '.[].data|select(.id=="{}")'.format(target['arn']), network,
                {})

            # Depending on the type of node, identify what the IP or hostname is
            if target_node['type'] == 'elb':
                target['type'] = 'elb'
                target['hostname'] = target_node['node_data']['DNSName']
            elif target_node['type'] == 'autoscaling':
                target['type'] = 'autoscaling'
                target['hostname'] = target_node['node_data'].get(
                    'PublicIpAddress', '')
                if target['hostname'] == '':
                    target['hostname'] = target_node['node_data'][
                        'PublicDnsName']
            elif target_node['type'] == 'rds':
                target['type'] = 'rds'
                target['hostname'] = target_node['node_data']['Endpoint'][
                    'Address']
            elif target_node['type'] == 'ec2':
                target['type'] = 'ec2'
                dns_name = target_node['node_data'].get('PublicDnsName', '')
                target['hostname'] = target_node['node_data'].get(
                    'PublicIpAddress', dns_name)
            else:
                print(
                    pyjq.first(
                        '.[].data|select(.id=="{}")|[.type, (.node_data|keys)]'
                        .format(target['arn']), network, {}))

            # Check if any protocol is allowed (indicated by IpProtocol == -1)
            ingress = pyjq.all('.[]', edge.get('node_data', {}))

            sg_group_allowing_all_protocols = pyjq.first(
                'select(.IpPermissions[]|.IpProtocol=="-1")|.GroupId', ingress,
                None)
            public_sgs = set()
            if sg_group_allowing_all_protocols is not None:
                log_warning(
                    'All protocols allowed access to {} due to {}'.format(
                        target, sg_group_allowing_all_protocols))
                range_string = '0-65535'
                public_sgs.add(sg_group_allowing_all_protocols)
            else:
                # from_port and to_port mean the beginning and end of a port range
                # We only care about TCP (6) and UDP (17)
                # For more info see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/security-group-rules-reference.html
                port_ranges = []
                for sg in ingress:
                    for ip_permission in sg['IpPermissions']:
                        selection = 'select((.IpProtocol=="tcp") or (.IpProtocol=="udp")) | select(.IpRanges[].CidrIp=="0.0.0.0/0")'
                        port_ranges.extend(
                            pyjq.all(
                                '{}| [.FromPort,.ToPort]'.format(selection),
                                ip_permission))
                        public_sgs.add(sg['GroupId'])
                range_string = port_ranges_string(regroup_ranges(port_ranges))

            target['ports'] = range_string
            target['public_sgs'] = list(public_sgs)
            if target['ports'] == "":
                issue_msg = 'No ports open for tcp or udp (probably can only be pinged). Rules that are not tcp or udp: {} -- {}'
                log_warning(
                    issue_msg.format(
                        json.dumps(
                            pyjq.all(
                                '.[]|select((.IpProtocol!="tcp") and (.IpProtocol!="udp"))'
                                .format(selection), ingress)), account))
            print(json.dumps(target, indent=4, sort_keys=True))

        account = Account(None, account)
        for region_json in get_regions(account):
            region = Region(account, region_json)
            # Look for CloudFront
            if region.name == 'us-east-1':
                json_blob = query_aws(region.account,
                                      'cloudfront-list-distributions', region)

                for distribution in json_blob.get('DistributionList',
                                                  {}).get('Items', []):
                    if not distribution['Enabled']:
                        continue

                    target = {
                        'arn': distribution['ARN'],
                        'account': account.name
                    }
                    target['type'] = 'cloudfront'
                    target['hostname'] = distribution['DomainName']
                    target['ports'] = '80,443'

                    print(json.dumps(target, indent=4, sort_keys=True))

            # Look for API Gateway
            json_blob = query_aws(region.account, 'apigateway-get-rest-apis',
                                  region)
            for api in json_blob.get('items', []):
                target = {'arn': api['id'], 'account': account.name}
                target['type'] = 'apigateway'
                target['hostname'] = '{}.execute-api.{}.amazonaws.com'.format(
                    api['id'], region.name)
                target['ports'] = '80,443'

                print(json.dumps(target, indent=4, sort_keys=True))
예제 #28
0
def report(accounts, config, args):
    """Create report"""

    # Create directory for output file if it doesn't already exists
    try:
        os.mkdir(os.path.dirname(REPORT_OUTPUT_FILE))
    except OSError:
        # Already exists
        pass

    # Read template
    with open(os.path.join("templates", "report.html"), "r") as report_template:
        template = Template(report_template.read())

    # Data to be passed to the template
    t = {}

    t["version"] = __version__

    # Get account names and id's
    t["accounts"] = []
    for account in accounts:
        t["accounts"].append(
            {
                "name": account["name"],
                "id": account["id"],
                "collection_date": get_collection_date(account)[:10],
            }
        )

    # Get resource count info
    # Collect counts
    account_stats = {}
    print("* Getting resource counts")
    for account in accounts:
        account_stats[account["name"]] = get_account_stats(
            account, args.stats_all_resources
        )
        print("  - {}".format(account["name"]))

    # Get names of resources
    # TODO: Change the structure passed through here to be a dict of dict's like I do for the regions
    t["resource_names"] = [""]
    # Just look at the resource names of the first account as they are all the same
    first_account = list(account_stats.keys())[0]
    for name in account_stats[first_account]["keys"]:
        t["resource_names"].append(name)

    # Create jinja data for the resource stats per account
    t["resource_stats"] = []
    for account in accounts:
        for resource_name in t["resource_names"]:
            if resource_name == "":
                resource_row = [account["name"]]
            else:
                count = sum(account_stats[account["name"]][resource_name].values())
                resource_row.append(count)

        t["resource_stats"].append(resource_row)

    t["resource_names"].pop(0)

    # Get region names
    t["region_names"] = []
    account = accounts[0]
    account = Account(None, account)
    for region in get_regions(account):
        region = Region(account, region)
        t["region_names"].append(region.name)

    # Get stats for the regions
    region_stats = {}
    region_stats_tooltip = {}
    for account in accounts:
        account = Account(None, account)
        region_stats[account.name] = {}
        region_stats_tooltip[account.name] = {}
        for region in get_regions(account):
            region = Region(account, region)
            count = 0
            for resource_name in t["resource_names"]:
                n = account_stats[account.name][resource_name].get(region.name, 0)
                count += n

                if n > 0:
                    if region.name not in region_stats_tooltip[account.name]:
                        region_stats_tooltip[account.name][region.name] = ""
                    region_stats_tooltip[account.name][
                        region.name
                    ] += "{}:{}<br>".format(resource_name, n)

            if count > 0:
                has_resources = "Y"
            else:
                has_resources = "N"
            region_stats[account.name][region.name] = has_resources

    t["region_stats"] = region_stats
    t["region_stats_tooltip"] = region_stats_tooltip

    # Pass the account names
    t["account_names"] = []
    for a in accounts:
        t["account_names"].append(a["name"])

    t["resource_data_set"] = []

    # Pass data for the resource chart
    color_index = 0
    for resource_name in t["resource_names"]:
        resource_counts = []
        for account_name in t["account_names"]:
            resource_counts.append(
                sum(account_stats[account_name][resource_name].values())
            )

        resource_data = {
            "label": resource_name,
            "data": resource_counts,
            "backgroundColor": COLOR_PALETTE[color_index],
            "borderWidth": 1,
        }
        t["resource_data_set"].append(resource_data)

        color_index = (color_index + 1) % len(COLOR_PALETTE)

    # Get IAM access dat
    print("* Getting IAM data")
    t["iam_active_data_set"] = [
        {
            "label": "Active users",
            "stack": "users",
            "data": [],
            "backgroundColor": "rgb(162, 203, 249)",
            "borderWidth": 1,
        },
        {
            "label": "Inactive users",
            "stack": "users",
            "data": [],
            "backgroundColor": INACTIVE_COLOR,
            "borderWidth": 1,
        },
        {
            "label": "Active roles",
            "stack": "roles",
            "data": [],
            "backgroundColor": ACTIVE_COLOR,
            "borderWidth": 1,
        },
        {
            "label": "Inactive roles",
            "stack": "roles",
            "data": [],
            "backgroundColor": INACTIVE_COLOR,
            "borderWidth": 1,
        },
    ]

    for account in accounts:
        account = Account(None, account)
        print("  - {}".format(account.name))

        account_stats = get_access_advisor_active_counts(account, args.max_age)

        # Add to dataset
        t["iam_active_data_set"][0]["data"].append(account_stats["users"]["active"])
        t["iam_active_data_set"][1]["data"].append(account_stats["users"]["inactive"])
        t["iam_active_data_set"][2]["data"].append(account_stats["roles"]["active"])
        t["iam_active_data_set"][3]["data"].append(account_stats["roles"]["inactive"])

    print("* Getting public resource data")
    # TODO Need to cache this data as this can take a long time
    t["public_network_resource_type_names"] = [
        "ec2",
        "elb",
        "elbv2",
        "rds",
        "redshift",
        "ecs",
        "autoscaling",
        "cloudfront",
        "apigateway",
    ]
    t["public_network_resource_types"] = {}

    t["public_ports"] = []
    t["account_public_ports"] = {}

    for account in accounts:
        print("  - {}".format(account["name"]))

        t["public_network_resource_types"][account["name"]] = {}
        t["account_public_ports"][account["name"]] = {}

        for type_name in t["public_network_resource_type_names"]:
            t["public_network_resource_types"][account["name"]][type_name] = 0

        public_nodes, _ = get_public_nodes(account, config, use_cache=True)

        for public_node in public_nodes:
            if public_node["type"] in t["public_network_resource_type_names"]:
                t["public_network_resource_types"][account["name"]][
                    public_node["type"]
                ] += 1
            else:
                raise Exception(
                    "Unknown type {} of public node".format(public_node["type"])
                )

            if public_node["ports"] not in t["public_ports"]:
                t["public_ports"].append(public_node["ports"])

            t["account_public_ports"][account["name"]][public_node["ports"]] = (
                t["account_public_ports"][account["name"]].get(public_node["ports"], 0)
                + 1
            )

    # Pass data for the public port chart
    t["public_ports_data_set"] = []
    color_index = 0
    for ports in t["public_ports"]:
        port_counts = []
        for account_name in t["account_names"]:
            port_counts.append(t["account_public_ports"][account_name].get(ports, 0))

        # Fix the port range name for '' when ICMP is being allowed
        if ports == "":
            ports = "ICMP only"

        port_data = {
            "label": ports,
            "data": port_counts,
            "backgroundColor": COLOR_PALETTE[color_index],
            "borderWidth": 1,
        }
        t["public_ports_data_set"].append(port_data)

        color_index = (color_index + 1) % len(COLOR_PALETTE)

    print("* Auditing accounts")
    findings = audit(accounts)
    audit_config = load_audit_config()

    # Filter findings
    tmp_findings = []
    for finding in findings:
        conf = audit_config[finding.issue_id]
        if finding_is_filtered(finding, conf, minimum_severity=args.minimum_severity):
            continue
        tmp_findings.append(finding)
    findings = tmp_findings

    t["findings_severity_by_account_chart"] = []

    # Figure out the counts of findings for each account

    # Create chart for finding type counts
    findings_severity_by_account = {}
    for account in accounts:
        findings_severity_by_account[account["name"]] = {}
        for severity in SEVERITIES:
            findings_severity_by_account[account["name"]][severity["name"]] = {}

        # Filtering the list of findings down to the ones specific to the current account.
        for finding in [f for f in findings if f.account_name == account["name"]]:
            conf = audit_config[finding.issue_id]

            count = findings_severity_by_account[finding.account_name][
                conf["severity"]
            ].get(finding.issue_id, 0)
            findings_severity_by_account[finding.account_name][conf["severity"]][
                finding.issue_id
            ] = (count + 1)

    t["findings_severity_by_account_chart"] = []
    for severity in SEVERITIES:
        severity_counts_by_account = []
        for _ in accounts:
            severity_counts_by_account.append(
                len(
                    findings_severity_by_account[finding.account_name][severity["name"]]
                )
            )

        t["findings_severity_by_account_chart"].append(
            {
                "label": severity["name"],
                "data": severity_counts_by_account,
                "backgroundColor": severity["color"],
                "borderWidth": 1,
            }
        )

    # Create list by severity
    t["severities"] = {}
    for severity in SEVERITIES:
        t["severities"][severity["name"]] = {}
    for finding in findings:
        conf = audit_config[finding.issue_id]

        t["severities"][conf["severity"]][finding.issue_id] = {
            "title": conf["title"],
            "id": finding.issue_id,
        }

    # Create chart for finding counts
    finding_type_set = {}

    for f in findings:
        finding_type_set[f.issue_id] = 1

    t["finding_counts_by_account_chart"] = []
    for finding_type in finding_type_set:
        finding_counts = []
        for account in accounts:
            count = 0
            for severity in findings_severity_by_account[account["name"]]:
                count += findings_severity_by_account[account["name"]][severity].get(
                    finding_type, 0
                )
            finding_counts.append(count)

        t["finding_counts_by_account_chart"].append(
            {
                "label": finding_type,
                "data": finding_counts,
                "backgroundColor": COLOR_PALETTE[color_index],
                "borderWidth": 1,
            }
        )

        color_index = (color_index + 1) % len(COLOR_PALETTE)

    t["findings"] = {}
    for finding in findings:
        conf = audit_config[finding.issue_id]
        group = t["findings"].get(conf["group"], {})

        # Get the severity struct
        for severity in SEVERITIES:
            if severity["name"] == conf["severity"]:
                break

        issue = group.get(
            finding.issue_id,
            {
                "title": conf["title"],
                "description": conf.get("description", ""),
                "severity": conf["severity"],
                "severity_color": severity["color"],
                "is_global": conf.get("is_global", False),
                "accounts": {},
            },
        )

        account_hits = issue["accounts"].get(
            finding.region.account.local_id,
            {"account_name": finding.region.account.name, "regions": {}},
        )

        region_hits = account_hits["regions"].get(finding.region.name, {"hits": []})

        region_hits["hits"].append(
            {
                "resource": finding.resource_id,
                "details": json.dumps(finding.resource_details, indent=4),
            }
        )

        account_hits["regions"][finding.region.name] = region_hits
        issue["accounts"][finding.region.account.local_id] = account_hits

        group[finding.issue_id] = issue
        t["findings"][conf["group"]] = group

    # Generate report from template
    with open(REPORT_OUTPUT_FILE, "w") as f:
        f.write(template.render(t=t))

    print("Report written to {}".format(REPORT_OUTPUT_FILE))
예제 #29
0
def build_data_structure(account_data, config, outputfilter):
    cytoscape_json = []

    account = Account(None, account_data)
    print("Building data for account {} ({})".format(account.name, account.local_id))

    cytoscape_json.append(account.cytoscape_data())
    for region_json in get_regions(account, outputfilter):
        node_count_per_region = 0
        region = Region(account, region_json)

        for vpc_json in get_vpcs(region, outputfilter):
            vpc = Vpc(region, vpc_json)

            for az_json in get_azs(vpc):
                # Availibility zones are not a per VPC construct, but VPC's can span AZ's,
                # so I make VPC a higher level construct
                az = Az(vpc, az_json)

                for subnet_json in get_subnets(az):
                    # If we ignore AZz, then tie the subnets up the VPC as the parent
                    if outputfilter["azs"]:
                        parent = az
                    else:
                        parent = vpc

                    subnet = Subnet(parent, subnet_json)

                    # Get EC2's
                    for ec2_json in get_ec2s(subnet):
                        ec2 = Ec2(subnet, ec2_json, outputfilter["collapse_by_tag"])
                        subnet.addChild(ec2)

                    # Get RDS's
                    for rds_json in get_rds_instances(subnet):
                        rds = Rds(subnet, rds_json)
                        if not outputfilter["read_replicas"] and rds.node_type == "rds_rr":
                            continue
                        subnet.addChild(rds)

                    # Get ELB's
                    for elb_json in get_elbs(subnet):
                        elb = Elb(subnet, elb_json)
                        subnet.addChild(elb)


                    # If there are leaves, then add this subnet to the final graph
                    if len(subnet.leaves) > 0:
                        node_count_per_region += len(subnet.leaves)
                        for leaf in subnet.leaves:
                            cytoscape_json.append(leaf.cytoscape_data())
                        cytoscape_json.append(subnet.cytoscape_data())
                        az.addChild(subnet)

                if az.has_leaves:
                    if outputfilter["azs"]:
                        cytoscape_json.append(az.cytoscape_data())
                    vpc.addChild(az)

            if vpc.has_leaves:
                cytoscape_json.append(vpc.cytoscape_data())
                region.addChild(vpc)

        if region.has_leaves:
            cytoscape_json.append(region.cytoscape_data())
            account.addChild(region)

        print("- {} nodes built in region {}".format(node_count_per_region, region.local_id))

    # Get VPC peerings
    for region in account.children:
        for vpc_peering in get_vpc_peerings(region):
            # For each peering, find the accepter and the requester
            accepter_id = vpc_peering["AccepterVpcInfo"]["VpcId"]
            requester_id = vpc_peering["RequesterVpcInfo"]["VpcId"]
            accepter = None
            requester = None
            for vpc in region.children:
                if accepter_id == vpc.local_id:
                    accepter = vpc
                if requester_id == vpc.local_id:
                    requester = vpc
            # If both have been found, add each as peers to one another
            if accepter and requester:
                accepter.addPeer(requester)
                requester.addPeer(accepter)

    # Get external cidr nodes
    cidrs = {}
    for cidr in get_external_cidrs(account, config):
        cidrs[cidr.arn] = cidr

    # Find connections between nodes
    # Only looking at Security Groups currently, which are a VPC level construct
    connections = {}
    for region in account.children:
        for vpc in region.children:
            for c, reasons in get_connections(cidrs, vpc, outputfilter).items():
                r = connections.get(c, [])
                r.extend(reasons)
                connections[c] = r

    # Add external cidr nodes
    used_cidrs = 0
    for _, cidr in cidrs.items():
        if cidr.is_used:
            used_cidrs += 1
            cytoscape_json.append(cidr.cytoscape_data())
    print("- {} external CIDRs built".format(used_cidrs))

    total_number_of_nodes = len(cytoscape_json)

    # Add the mapping to our graph
    for c, reasons in connections.items():
        if c.source == c.target:
            # Ensure we don't add connections with the same nodes on either side
            continue
        c._json = reasons
        cytoscape_json.append(c.cytoscape_data())
    print("- {} connections built".format(len(connections)))

    # Check if we have a lot of data, and if so, show a warning
    # Numbers chosen here are arbitrary
    MAX_NODES_FOR_WARNING = 200
    MAX_EDGES_FOR_WARNING = 500
    if total_number_of_nodes > MAX_NODES_FOR_WARNING or len(connections) > MAX_EDGES_FOR_WARNING:
        print("WARNING: There are {} total nodes and {} total edges.".format(total_number_of_nodes, len(connections)))
        print("  This will be difficult to display and may be too complex to make sense of.")
        print("  Consider reducing the number of items in the diagram by viewing a single")
        print("   region, ignoring internal edges, or other filtering.")

    return cytoscape_json
예제 #30
0
def build_data_structure(account_data, config, outputfilter):
    cytoscape_json = []

    if outputfilter.get('mute', False):
        global MUTE
        MUTE = True

    account = Account(None, account_data)
    log("Building data for account {} ({})".format(account.name, account.local_id))

    cytoscape_json.append(account.cytoscape_data())
    
    # Iterate through each region and add all the VPCs, AZs, and Subnets
    for region_json in get_regions(account, outputfilter):
        nodes = {}
        region = Region(account, region_json)

        for vpc_json in get_vpcs(region, outputfilter):
            vpc = Vpc(region, vpc_json)

            for az_json in get_azs(vpc):
                # Availibility zones are not a per VPC construct, but VPC's can span AZ's,
                # so I make VPC a higher level construct
                az = Az(vpc, az_json)

                for subnet_json in get_subnets(az):
                    # If we ignore AZz, then tie the subnets up the VPC as the parent
                    if outputfilter["azs"]:
                        parent = az
                    else:
                        parent = vpc

                    subnet = Subnet(parent, subnet_json)
                    az.addChild(subnet)
                vpc.addChild(az)
            region.addChild(vpc)
        account.addChild(region)

        #
        # In each region, iterate through all the resource types
        #

        # EC2 nodes
        for ec2_json in get_ec2s(region):
            node = Ec2(region, ec2_json, outputfilter["collapse_by_tag"], outputfilter["collapse_asgs"])
            nodes[node.arn] = node
        
        # RDS nodes
        for rds_json in get_rds_instances(region):
            node = Rds(region, rds_json)
            if not outputfilter["read_replicas"] and node.node_type == "rds_rr":
                continue
            nodes[node.arn] = node

        # ELB nodes
        for elb_json in get_elbs(region):
            node = Elb(region, elb_json)
            nodes[node.arn] = node
        
        for elb_json in get_elbv2s(region):
            node = Elbv2(region, elb_json)
            nodes[node.arn] = node

        # PrivateLink and VPC Endpoints
        for vpc_endpoint_json in get_vpc_endpoints(region):
            node = VpcEndpoint(region, vpc_endpoint_json)
            nodes[node.arn] = node

        # ECS tasks
        for ecs_json in get_ecs_tasks(region):
            node = Ecs(region, ecs_json)
            nodes[node.arn] = node
        
        # Lambda functions
        for lambda_json in get_lambda_functions(region):
            node = Lambda(region, lambda_json)
            nodes[node.arn] = node

        # Redshift clusters
        for node_json in get_redshift(region):
            node = Redshift(region, node_json)
            nodes[node.arn] = node

        # ElasticSearch clusters
        for node_json in get_elasticsearch(region):
            node = ElasticSearch(region, node_json)
            nodes[node.arn] = node

        # Filter out nodes based on tags
        if len(outputfilter.get("tags", [])) > 0:
            for node_id in list(nodes):
                has_match = False
                node = nodes[node_id]
                # For each node, look to see if its tags match one of the tag sets
                # Ex. --tags Env=Prod --tags Team=Dev,Name=Bastion
                for tag_set in outputfilter.get("tags", []):
                    conditions = [c.split("=") for c in tag_set.split(",")]
                    condition_matches = 0
                    # For a tag set, see if all conditions match, ex. [["Team","Dev"],["Name","Bastion"]]
                    for pair in conditions:
                        # Given ["Team","Dev"], see if it matches one of the tags in the node
                        for tag in node.tags:
                            if tag.get('Key','') == pair[0] and tag.get('Value','') == pair[1]:
                                condition_matches += 1
                    # We have a match if all of the conditions matched
                    if condition_matches == len(conditions):
                        has_match = True
                
                # If there were no matches, remove the node
                if not has_match:
                    del nodes[node_id]

        # Add the nodes to their respective subnets
        for node_arn in list(nodes):
            node = nodes[node_arn]
            add_node_to_subnets(region, node, nodes)

        # From the root of the tree (the account), add in the children if there are leaves
        # If not, mark the item for removal
        if region.has_leaves:
            cytoscape_json.append(region.cytoscape_data())

            region_children_to_remove = set()
            for vpc in region.children:
                if vpc.has_leaves:
                    cytoscape_json.append(vpc.cytoscape_data())
                
                    vpc_children_to_remove = set()
                    for vpc_child in vpc.children:
                        if vpc_child.has_leaves:
                            if outputfilter["azs"]:
                                cytoscape_json.append(vpc_child.cytoscape_data())
                            elif vpc_child.node_type != 'az':
                                # Add VPC children that are not AZs, such as Gateway endpoints
                                cytoscape_json.append(vpc_child.cytoscape_data())
                        
                            az_children_to_remove = set()
                            for subnet in vpc_child.children:
                                if subnet.has_leaves:
                                    cytoscape_json.append(subnet.cytoscape_data())

                                    for leaf in subnet.leaves:
                                        cytoscape_json.append(leaf.cytoscape_data(subnet.arn))
                                else:
                                    az_children_to_remove.add(subnet)
                            for subnet in az_children_to_remove:
                                vpc_child.removeChild(subnet)
                        else:
                            vpc_children_to_remove.add(vpc_child)
                    for az in vpc_children_to_remove:
                        vpc.removeChild(az)
                else:
                    region_children_to_remove.add(vpc)
            for vpc in region_children_to_remove:
                region.removeChild(vpc)

        log("- {} nodes built in region {}".format(len(nodes), region.local_id))

    # Get VPC peerings
    for region in account.children:
        for vpc_peering in get_vpc_peerings(region):
            # For each peering, find the accepter and the requester
            accepter_id = vpc_peering["AccepterVpcInfo"]["VpcId"]
            requester_id = vpc_peering["RequesterVpcInfo"]["VpcId"]
            accepter = None
            requester = None
            for vpc in region.children:
                if accepter_id == vpc.local_id:
                    accepter = vpc
                if requester_id == vpc.local_id:
                    requester = vpc
            # If both have been found, add each as peers to one another
            if accepter and requester:
                accepter.addPeer(requester)
                requester.addPeer(accepter)

    # Get external cidr nodes
    cidrs = {}
    for cidr in get_external_cidrs(account, config):
        cidrs[cidr.arn] = cidr

    # Find connections between nodes
    # Only looking at Security Groups currently, which are a VPC level construct
    connections = {}
    for region in account.children:
        for vpc in region.children:
            for c, reasons in get_connections(cidrs, vpc, outputfilter).items():
                r = connections.get(c, [])
                r.extend(reasons)
                connections[c] = r

    #
    # Collapse CIDRs
    #

    # Get a list of the current CIDRs
    current_cidrs = []
    for cidr_string in cidrs:
        current_cidrs.append(cidr_string)

    # Iterate through them
    for cidr_string in current_cidrs:
        # Find CIDRs in the config that our CIDR falls inside
        # It may fall inside multiple ranges
        matching_known_cidrs = {}
        for named_cidr in config["cidrs"]:
            if IPNetwork(cidr_string) in IPNetwork(named_cidr):
                # Match found
                matching_known_cidrs[named_cidr] = IPNetwork(named_cidr).size

        if len(matching_known_cidrs) > 0:
            # A match was found. Find the smallest matching range.
            sorted_matches = sorted(matching_known_cidrs.items(), key=operator.itemgetter(1))
            # Get first item to get (CIDR,size); and first item of that to get just the CIDR
            smallest_matched_cidr_string = sorted_matches[0][0]
            smallest_matched_cidr_name = config["cidrs"][smallest_matched_cidr_string]['name']

            # Check if we have a CIDR node that doesn't match the smallest one possible.
            if cidrs[cidr_string].name != config["cidrs"][smallest_matched_cidr_string]['name']:
                # See if we need to create the larger known range
                if cidrs.get(smallest_matched_cidr_string, "") == "":
                    cidrs[smallest_matched_cidr_string] = Cidr(smallest_matched_cidr_string, smallest_matched_cidr_name)

                # The existing CIDR node needs to be removed and rebuilt as the larger known range
                del cidrs[cidr_string]

                # Get the larger known range
                new_source = cidrs[smallest_matched_cidr_string]
                new_source.is_used = True

                # Find all the connections to the old node
                connections_to_remove = []
                for c in connections:
                    if c.source.node_type == 'ip' and c.source.arn == cidr_string:
                        connections_to_remove.append(c)
                
                # Create new connections to the new node
                for c in connections_to_remove:
                    r = connections[c]
                    del connections[c]
                    connections[Connection(new_source, c._target)] = r

    # Add external cidr nodes
    used_cidrs = 0
    for _, cidr in cidrs.items():
        if cidr.is_used:
            used_cidrs += 1
            cytoscape_json.append(cidr.cytoscape_data())
    log("- {} external CIDRs built".format(used_cidrs))

    total_number_of_nodes = len(cytoscape_json)

    # Add the mapping to our graph
    for c, reasons in connections.items():
        if c.source == c.target:
            # Ensure we don't add connections with the same nodes on either side
            continue
        c._json = reasons
        cytoscape_json.append(c.cytoscape_data())
    log("- {} connections built".format(len(connections)))

    # Check if we have a lot of data, and if so, show a warning
    # Numbers chosen here are arbitrary
    MAX_NODES_FOR_WARNING = 200
    MAX_EDGES_FOR_WARNING = 500
    if total_number_of_nodes > MAX_NODES_FOR_WARNING or len(connections) > MAX_EDGES_FOR_WARNING:
        log("WARNING: There are {} total nodes and {} total edges.".format(total_number_of_nodes, len(connections)))
        log("  This will be difficult to display and may be too complex to make sense of.")
        log("  Consider reducing the number of items in the diagram by viewing a single")
        log("   region, ignoring internal edges, or other filtering.")

    return cytoscape_json