def test_invalid_diff_versions(self): graph_set_1 = GraphSet( name="graph-1", version="1", start_time=10, end_time=20, resources=[], errors=[], ) graph_set_2 = GraphSet( name="graph-1", version="2", start_time=15, end_time=25, resources=[], errors=[], ) with self.assertRaises(UnmergableGraphSetsException): GraphSet.from_graph_sets([graph_set_1, graph_set_2])
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]: logger = Logger() with logger.bind( account_id=scan_unit.account_id, region=scan_unit.region_name, service=scan_unit.service, resource_classes=sorted([ resource_spec_class.__name__ for resource_spec_class in scan_unit.resource_spec_classes ]), ): start_t = time.time() logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart) session = boto3.Session( aws_access_key_id=scan_unit.access_key, aws_secret_access_key=scan_unit.secret_key, aws_session_token=scan_unit.token, region_name=scan_unit.region_name, ) scan_accessor = AWSAccessor(session=session, account_id=scan_unit.account_id, region_name=scan_unit.region_name) graph_spec = GraphSpec( name=scan_unit.graph_name, version=scan_unit.graph_version, resource_spec_classes=scan_unit.resource_spec_classes, scan_accessor=scan_accessor, ) start_time = int(time.time()) resources: List[Resource] = [] errors = [] try: resources = graph_spec.scan() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) error = f"{str(ex)}\n{trace_back}" errors.append(error) end_time = int(time.time()) graph_set = GraphSet( name=scan_unit.graph_name, version=scan_unit.graph_version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=scan_accessor.api_call_stats, ) end_t = time.time() elapsed_sec = end_t - start_t logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd, elapsed_sec=elapsed_sec) return (scan_unit.account_id, graph_set.to_dict())
def scan(self) -> GraphSet: """Perform a scan on all of the resource classes in this GraphSpec and return a GraphSet containing the scanned data. Returns: GraphSet representing results of scanning this GraphSpec's resource_spec_classes. """ resources: List[Resource] = [] errors: List[str] = [] stats = MultilevelCounter() start_time = int(time.time()) logger = Logger() for resource_spec_class in self.resource_spec_classes: with logger.bind(resource_type=str(resource_spec_class.type_name)): logger.debug(event=LogEvent.ScanResourceTypeStart) resource_scan_result = resource_spec_class.scan( scan_accessor=self.scan_accessor) resources += resource_scan_result.resources errors += resource_scan_result.errors stats.merge(resource_scan_result.stats) logger.debug(event=LogEvent.ScanResourceTypeEnd) end_time = int(time.time()) return GraphSet( name=self.name, version=self.version, start_time=start_time, end_time=end_time, resources=resources, errors=errors, stats=stats, )
def write_graph_set(self, name: str, graph_set: GraphSet, compression: Optional[str] = None) -> str: """Write a graph artifact Args: name: name graph_set: GraphSet object to write Returns: path to written artifact """ logger = Logger() os.makedirs(self.output_dir, exist_ok=True) if compression is None: artifact_path = os.path.join(self.output_dir, f"{name}.rdf") elif compression == GZIP: artifact_path = os.path.join(self.output_dir, f"{name}.rdf.gz") else: raise ValueError(f"Unknown compression arg {compression}") graph = graph_set.to_rdf() with logger.bind(artifact_path=artifact_path): logger.info(event=LogEvent.WriteToFSStart) with open(artifact_path, "wb") as fp: if compression is None: graph.serialize(fp) elif compression == GZIP: with gzip.GzipFile(fileobj=fp, mode="wb") as gz: graph.serialize(gz) else: raise ValueError(f"Unknown compression arg {compression}") logger.info(event=LogEvent.WriteToFSEnd) return artifact_path
class TestGraphSetWithValidDataMerging(TestCase): def setUp(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_a2 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-goo", obj="foo")]) resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="123")]) resource_b2 = Resource(resource_id="def", type_name="test:b", links=[SimpleLink(pred="name", obj="sue")]) resources = [resource_a1, resource_a2, resource_b1, resource_b2] self.graph_set = GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=["test err 1", "test err 2"], stats=MultilevelCounter(), ) def test_rdf_a_type(self): graph = self.graph_set.to_rdf() a_results = graph.query( "select ?p ?o where {?s a <test-name:test:a> ; ?p ?o} order by ?p ?o" ) expected_a_result_tuples = [ ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", "test-name:test:a"), ("test-name:has-foo", "goo"), ("test-name:has-goo", "foo"), ("test-name:id", "123"), ] a_result_tuples = [] for a_result in a_results: self.assertEqual(2, len(a_result)) a_result_tuples.append((str(a_result[0]), str(a_result[1]))) self.assertEqual(expected_a_result_tuples, a_result_tuples) def test_validate(self): self.graph_set.validate()
def test_orphaned_ref(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="456")]) resources = [resource_a1, resource_b1] graph_set = GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=["test err 1", "test err 2"], stats=MultilevelCounter(), ) with self.assertRaises(GraphSetOrphanedReferencesException): graph_set.validate()
def test_invalid_diff_versions(self): graph_set_1 = GraphSet( name="graph-1", version="1", start_time=10, end_time=20, resources=[], errors=[], stats=MultilevelCounter(), ) graph_set_2 = GraphSet( name="graph-1", version="2", start_time=15, end_time=25, resources=[], errors=[], stats=MultilevelCounter(), ) with self.assertRaises(UnmergableGraphSetsException): graph_set_1.merge(graph_set_2)
def setUp(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_a2 = Resource(resource_id="456", type_name="test:a") resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="123")]) resource_b2 = Resource(resource_id="def", type_name="test:b", links=[SimpleLink(pred="name", obj="sue")]) resources = [resource_a1, resource_a2, resource_b1, resource_b2] self.graph_set = GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=["test err 1", "test err 2"], stats=MultilevelCounter(), )
def graph_set_from_s3(s3_client: BaseClient, json_bucket: str, json_key: str) -> GraphSet: """Load a GraphSet from json located in an s3 object.""" logger = Logger() logger.info(event=LogEvent.ReadFromS3Start) with io.BytesIO() as json_bytes_buf: s3_client.download_fileobj(json_bucket, json_key, json_bytes_buf) json_bytes_buf.flush() json_bytes_buf.seek(0) graph_set_bytes = json_bytes_buf.read() logger.info(event=LogEvent.ReadFromS3End) graph_set_str = graph_set_bytes.decode("utf-8") graph_set_dict = json.loads(graph_set_str) return GraphSet.from_dict(graph_set_dict)
def main(argv: Optional[List[str]] = None) -> int: if argv is None: argv = sys.argv[1:] parser = argparse.ArgumentParser() parser.add_argument("input_json_filepaths", type=Path, nargs="+") args_ns = parser.parse_args(argv) input_json_filepaths = args_ns.input_json_filepaths if len(input_json_filepaths) > 1: raise NotImplementedError("Only one input supported at this time") # create a dict of scan ids to GraphSets. This contains all of the data in the provided input. scan_ids_graph_sets: Dict[int, GraphSet] = { scan_id: GraphSet.from_json_file(filepath) for scan_id, filepath in enumerate(input_json_filepaths) } # discover tables which need to be created by iterating over resources and finding the maximum # set of predicates used for each type table_defns = build_table_defns(scan_ids_graph_sets.values()) # build data table_names_datas = build_data(scan_ids_graph_sets.values(), table_defns) table_names_tables: Dict[str, tableauhyperapi.TableDefinition] = {} with tableauhyperapi.HyperProcess( telemetry=tableauhyperapi.Telemetry. DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper: with tableauhyperapi.Connection( endpoint=hyper.endpoint, database="altimeter.hyper", create_mode=tableauhyperapi.CreateMode.CREATE_AND_REPLACE, ) as connection: # create tables for table_name, columns in table_defns.items(): table = tableauhyperapi.TableDefinition( table_name=table_name, columns=[column.to_hyper() for column in columns]) connection.catalog.create_table(table) table_names_tables[table_name] = table for table_name, datas in table_names_datas.items(): with tableauhyperapi.Inserter( connection, table_names_tables[table_name]) as inserter: inserter.add_rows(datas) inserter.execute() return 0
def test_unknown_type_name(self): resources = [ Resource(resource_id="xyz", type_name="test:a"), Resource(resource_id="xyz", type_name="test:c"), ] with self.assertRaises(ResourceSpecClassNotFoundException): GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=[], stats=MultilevelCounter(), )
def test_invalid_resources_dupes_same_class_conflicting_types_no_allow_clobber( self): resources = [ Resource(resource_id="123", type_name="test:a"), Resource(resource_id="123", type_name="test:b"), ] with self.assertRaises(UnmergableDuplicateResourceIdsFoundException): GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=[], stats=MultilevelCounter(), )
def test_from_dict(self): input_dict = { "name": "test-name", "version": "1", "start_time": 1234, "end_time": 4567, "resources": { "123": { "type": "test:a", "links": [{ "pred": "has-foo", "obj": "goo", "type": "simple" }], }, "456": { "type": "test:a" }, "abc": { "type": "test:b", "links": [{ "pred": "has-a", "obj": "123", "type": "resource_link" }], }, "def": { "type": "test:b", "links": [{ "pred": "name", "obj": "sue", "type": "simple" }], }, }, "errors": ["test err 1", "test err 2"], "stats": { "count": 0 }, } graph_set = GraphSet.from_dict(input_dict) self.assertEqual(graph_set.to_dict(), input_dict)
def test_orphaned_ref(self): resource_a1 = Resource( resource_id="123", type="test:a", link_collection=LinkCollection(simple_links=[SimpleLink(pred="has-foo", obj="goo")]), ) resource_b1 = Resource( resource_id="abc", type="test:b", link_collection=LinkCollection(resource_links=[ResourceLink(pred="has-a", obj="456")]), ) resources = [resource_a1, resource_b1] graph_set = GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=["test err 1", "test err 2"], ) with self.assertRaises(GraphSetOrphanedReferencesException): ValidatedGraphSet.from_graph_set(graph_set)
def graph_set_from_filepath(filepath: str) -> GraphSet: with open(filepath, "r") as fp: graph_set_dict = json.load(fp) return GraphSet.from_dict(data=graph_set_dict)
def test_valid_merge(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_a2 = Resource(resource_id="456", type_name="test:a") resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="123")]) resource_b2 = Resource(resource_id="def", type_name="test:b", links=[SimpleLink(pred="name", obj="sue")]) graph_set_1 = GraphSet( name="graph-1", version="1", start_time=10, end_time=20, resources=[resource_a1, resource_a2], errors=["errora1", "errora2"], stats=MultilevelCounter(), ) graph_set_2 = GraphSet( name="graph-1", version="1", start_time=15, end_time=25, resources=[resource_b1, resource_b2], errors=["errorb1", "errorb2"], stats=MultilevelCounter(), ) graph_set_1.merge(graph_set_2) self.assertEqual(graph_set_1.name, "graph-1") self.assertEqual(graph_set_1.version, "1") self.assertEqual(graph_set_1.start_time, 10) self.assertEqual(graph_set_1.end_time, 25) self.assertCountEqual(graph_set_1.errors, ["errora1", "errora2", "errorb1", "errorb2"]) expected_resource_dicts = [ { "type": "test:a", "links": [{ "pred": "has-foo", "obj": "goo", "type": "simple" }] }, { "type": "test:a" }, { "type": "test:b", "links": [{ "pred": "has-a", "obj": "123", "type": "resource_link" }] }, { "type": "test:b", "links": [{ "pred": "name", "obj": "sue", "type": "simple" }] }, ] resource_dicts = [ resource.to_dict() for resource in graph_set_1.resources ] self.assertCountEqual(expected_resource_dicts, resource_dicts)
class TestGraphSetWithValidDataNoMerging(TestCase): def setUp(self): resource_a1 = Resource(resource_id="123", type_name="test:a", links=[SimpleLink(pred="has-foo", obj="goo")]) resource_a2 = Resource(resource_id="456", type_name="test:a") resource_b1 = Resource( resource_id="abc", type_name="test:b", links=[ResourceLinkLink(pred="has-a", obj="123")]) resource_b2 = Resource(resource_id="def", type_name="test:b", links=[SimpleLink(pred="name", obj="sue")]) resources = [resource_a1, resource_a2, resource_b1, resource_b2] self.graph_set = GraphSet( name="test-name", version="1", start_time=1234, end_time=4567, resources=resources, errors=["test err 1", "test err 2"], stats=MultilevelCounter(), ) def test_rdf_a_type(self): graph = self.graph_set.to_rdf() a_results = graph.query( "select ?p ?o where {?s a <test-name:test:a> ; ?p ?o} order by ?p ?o" ) expected_a_result_tuples = [ ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", "test-name:test:a"), ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", "test-name:test:a"), ("test-name:has-foo", "goo"), ("test-name:id", "123"), ("test-name:id", "456"), ] a_result_tuples = [] for a_result in a_results: self.assertEqual(2, len(a_result)) a_result_tuples.append((str(a_result[0]), str(a_result[1]))) self.assertEqual(expected_a_result_tuples, a_result_tuples) def test_rdf_b_type(self): graph = self.graph_set.to_rdf() graph.serialize("/tmp/test.rdf") linked_a_node_results = graph.query( "select ?s where {?s a <test-name:test:a>; <test-name:id> '123' }") self.assertEqual(len(linked_a_node_results), 1) for linked_a_node_result in linked_a_node_results: linked_a_node = str(linked_a_node_result[0]) b_results = graph.query( "select ?p ?o where {?s a <test-name:test:b> ; ?p ?o} order by ?p ?o" ) expected_b_result_tuples = [ ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", "test-name:test:b"), ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type", "test-name:test:b"), ("test-name:has-a", str(linked_a_node)), ("test-name:id", "abc"), ("test-name:id", "def"), ("test-name:name", "sue"), ] b_result_tuples = [] for b_result in b_results: self.assertEqual(2, len(b_result)) b_result_tuples.append((str(b_result[0]), str(b_result[1]))) self.assertEqual(expected_b_result_tuples, b_result_tuples) def test_rdf_error_graphing(self): graph = self.graph_set.to_rdf() err_results = graph.query( "select ?o where { ?s <test-name:error> ?o } order by ?o") err_strs = [] expected_err_strs = ["test err 1", "test err 2"] for err_result in err_results: self.assertEqual(1, len(err_result)) err_strs.append(str(err_result[0])) self.assertEqual(err_strs, expected_err_strs) def test_to_dict(self): expected_dict = { "name": "test-name", "version": "1", "start_time": 1234, "end_time": 4567, "resources": { "123": { "type": "test:a", "links": [{ "pred": "has-foo", "obj": "goo", "type": "simple" }], }, "456": { "type": "test:a" }, "abc": { "type": "test:b", "links": [{ "pred": "has-a", "obj": "123", "type": "resource_link" }], }, "def": { "type": "test:b", "links": [{ "pred": "name", "obj": "sue", "type": "simple" }], }, }, "errors": ["test err 1", "test err 2"], "stats": { "count": 0 }, } self.assertDictEqual(expected_dict, self.graph_set.to_dict()) def test_from_dict(self): input_dict = { "name": "test-name", "version": "1", "start_time": 1234, "end_time": 4567, "resources": { "123": { "type": "test:a", "links": [{ "pred": "has-foo", "obj": "goo", "type": "simple" }], }, "456": { "type": "test:a" }, "abc": { "type": "test:b", "links": [{ "pred": "has-a", "obj": "123", "type": "resource_link" }], }, "def": { "type": "test:b", "links": [{ "pred": "name", "obj": "sue", "type": "simple" }], }, }, "errors": ["test err 1", "test err 2"], "stats": { "count": 0 }, } graph_set = GraphSet.from_dict(input_dict) self.assertEqual(graph_set.to_dict(), input_dict) def test_validate(self): self.graph_set.validate()
def write_graph_set(self, name: str, graph_set: GraphSet, compression: Optional[str] = None) -> str: """Write a graph artifact Args: name: name graph_set: GraphSet to write Returns: path to written artifact """ logger = Logger() if compression is None: key = f"{name}.rdf" elif compression == GZIP: key = f"{name}.rdf.gz" else: raise ValueError(f"Unknown compression arg {compression}") output_key = "/".join((self.key_prefix, key)) graph = graph_set.to_rdf() with logger.bind(bucket=self.bucket, key_prefix=self.key_prefix, key=key): logger.info(event=LogEvent.WriteToS3Start) with io.BytesIO() as rdf_bytes_buf: if compression is None: graph.serialize(rdf_bytes_buf) elif compression == GZIP: with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz: graph.serialize(gz) else: raise ValueError(f"Unknown compression arg {compression}") rdf_bytes_buf.flush() rdf_bytes_buf.seek(0) session = boto3.Session() s3_client = session.client("s3") s3_client.upload_fileobj(rdf_bytes_buf, self.bucket, output_key) s3_client.put_object_tagging( Bucket=self.bucket, Key=output_key, Tagging={ "TagSet": [ { "Key": "name", "Value": graph_set.name }, { "Key": "version", "Value": graph_set.version }, { "Key": "start_time", "Value": str(graph_set.start_time) }, { "Key": "end_time", "Value": str(graph_set.end_time) }, ] }, ) logger.info(event=LogEvent.WriteToS3End) return f"s3://{self.bucket}/{output_key}"
def scan( muxer: AWSScanMuxer, account_ids: List[str], regions: List[str], accessor: Accessor, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, scan_sub_accounts: bool = False, ) -> ScanManifest: if scan_sub_accounts: account_ids = get_sub_account_ids(account_ids, accessor) account_scan_plans = build_account_scan_plans(accessor=accessor, account_ids=account_ids, regions=regions) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) account_scan_manifests = muxer.scan(account_scan_plans=account_scan_plans) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[Dict[str, str]] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: List[Dict[str, str]] = [] stats = MultilevelCounter() for account_scan_manifest in account_scan_manifests: account_id = account_scan_manifest.account_id if account_scan_manifest.artifacts: artifacts += account_scan_manifest.artifacts if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.append(account_id) else: scanned_accounts.append(account_id) else: unscanned_accounts.append(account_id) account_stats = MultilevelCounter.from_dict( account_scan_manifest.api_call_stats) stats.merge(account_stats) graph_set = None for artifact_path in artifacts: artifact_dict = artifact_reader.read_artifact(artifact_path) artifact_graph_set = GraphSet.from_dict(artifact_dict) if graph_set is None: graph_set = artifact_graph_set else: graph_set.merge(artifact_graph_set) master_artifact_path = None if graph_set: master_artifact_path = artifact_writer.write_artifact( name="master", data=graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) if graph_set: start_time = graph_set.start_time end_time = graph_set.end_time else: start_time, end_time = None, None return ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=unscanned_accounts, api_call_stats=stats.to_dict(), start_time=start_time, end_time=end_time, )
def run_scan( muxer: AWSScanMuxer, config: Config, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, ) -> Tuple[ScanManifest, GraphSet]: if config.scan.scan_sub_accounts: account_ids = get_sub_account_ids(config.scan.accounts, config.access.accessor) else: account_ids = config.scan.accounts account_scan_plan = AccountScanPlan(account_ids=account_ids, regions=config.scan.regions, accessor=config.access.accessor) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[str] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: List[str] = [] stats = MultilevelCounter() graph_set = None for account_scan_manifest in muxer.scan( account_scan_plan=account_scan_plan): account_id = account_scan_manifest.account_id if account_scan_manifest.artifacts: for account_scan_artifact in account_scan_manifest.artifacts: artifacts.append(account_scan_artifact) artifact_graph_set_dict = artifact_reader.read_json( account_scan_artifact) artifact_graph_set = GraphSet.from_dict( artifact_graph_set_dict) if graph_set is None: graph_set = artifact_graph_set else: graph_set.merge(artifact_graph_set) if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.append(account_id) else: scanned_accounts.append(account_id) else: unscanned_accounts.append(account_id) account_stats = MultilevelCounter.from_dict( account_scan_manifest.api_call_stats) stats.merge(account_stats) if graph_set is None: raise Exception("BUG: No graph_set generated.") master_artifact_path = artifact_writer.write_json(name="master", data=graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) start_time = graph_set.start_time end_time = graph_set.end_time scan_manifest = ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=unscanned_accounts, api_call_stats=stats.to_dict(), start_time=start_time, end_time=end_time, ) artifact_writer.write_json("manifest", data=scan_manifest.to_dict()) return scan_manifest, graph_set
def run_scan( muxer: AWSScanMuxer, config: AWSConfig, aws_resource_region_mapping_repo: AWSResourceRegionMappingRepository, artifact_writer: ArtifactWriter, artifact_reader: ArtifactReader, ) -> Tuple[ScanManifest, ValidatedGraphSet]: if config.scan.accounts: scan_account_ids = config.scan.accounts else: sts_client = boto3.client("sts") scan_account_id = sts_client.get_caller_identity()["Account"] scan_account_ids = (scan_account_id,) if config.scan.scan_sub_accounts: account_ids = get_sub_account_ids(scan_account_ids, config.accessor) else: account_ids = scan_account_ids scan_plan = ScanPlan( account_ids=account_ids, regions=config.scan.regions, aws_resource_region_mapping_repo=aws_resource_region_mapping_repo, accessor=config.accessor, ) logger = Logger() logger.info(event=AWSLogEvents.ScanAWSAccountsStart) # now combine account_scan_results and org_details to build a ScanManifest scanned_accounts: List[str] = [] artifacts: List[str] = [] errors: Dict[str, List[str]] = {} unscanned_accounts: Set[str] = set() graph_sets: List[GraphSet] = [] for account_scan_manifest in muxer.scan(scan_plan=scan_plan): account_id = account_scan_manifest.account_id if account_scan_manifest.errors: errors[account_id] = account_scan_manifest.errors unscanned_accounts.add(account_id) if account_scan_manifest.artifacts: for account_scan_artifact in account_scan_manifest.artifacts: artifacts.append(account_scan_artifact) artifact_graph_set_dict = artifact_reader.read_json(account_scan_artifact) graph_sets.append(GraphSet.parse_obj(artifact_graph_set_dict)) scanned_accounts.append(account_id) else: unscanned_accounts.add(account_id) if not graph_sets: raise Exception("BUG: No graph_sets generated.") validated_graph_set = ValidatedGraphSet.from_graph_set(GraphSet.from_graph_sets(graph_sets)) master_artifact_path: Optional[str] = None if config.write_master_json: master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set) logger.info(event=AWSLogEvents.ScanAWSAccountsEnd) start_time = validated_graph_set.start_time end_time = validated_graph_set.end_time scan_manifest = ScanManifest( scanned_accounts=scanned_accounts, master_artifact=master_artifact_path, artifacts=artifacts, errors=errors, unscanned_accounts=list(unscanned_accounts), start_time=start_time, end_time=end_time, ) artifact_writer.write_json("manifest", data=scan_manifest) return scan_manifest, validated_graph_set
def scan(self) -> AccountScanResult: logger = Logger() now = int(time.time()) prescan_errors: List[str] = [] futures: List[Future] = [] account_id = self.account_scan_plan.account_id with logger.bind(account_id=account_id): with ThreadPoolExecutor(max_workers=self.max_threads) as executor: logger.info(event=AWSLogEvents.ScanAWSAccountStart) try: session = self.account_scan_plan.accessor.get_session( account_id=account_id) # sanity check sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity( )["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {account_id}" ) if self.account_scan_plan.regions: account_scan_regions = tuple( self.account_scan_plan.regions) else: account_scan_regions = get_all_enabled_regions( session=session) # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[ str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) for resource_spec_class in self.resource_spec_classes: resource_regions = self.account_scan_plan.aws_resource_region_mapping_repo.get_regions( resource_spec_class=resource_spec_class, region_whitelist=account_scan_regions, ) for region in resource_regions: regions_services_resource_spec_classes[region][ resource_spec_class.service_name].append( resource_spec_class) # Build and submit ScanUnits shuffed_regions_services_resource_spec_classes = random.sample( regions_services_resource_spec_classes.items(), len(regions_services_resource_spec_classes), ) for ( region, services_resource_spec_classes, ) in shuffed_regions_services_resource_spec_classes: region_session = self.account_scan_plan.accessor.get_session( account_id=account_id, region_name=region) region_creds = region_session.get_credentials() shuffled_services_resource_spec_classes = random.sample( services_resource_spec_classes.items(), len(services_resource_spec_classes), ) for ( service, svc_resource_spec_classes, ) in shuffled_services_resource_spec_classes: parallel_svc_resource_spec_classes = [ svc_resource_spec_class for svc_resource_spec_class in svc_resource_spec_classes if svc_resource_spec_class.parallel_scan ] serial_svc_resource_spec_classes = [ svc_resource_spec_class for svc_resource_spec_class in svc_resource_spec_classes if not svc_resource_spec_class.parallel_scan ] for (parallel_svc_resource_spec_class ) in parallel_svc_resource_spec_classes: parallel_future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=( parallel_svc_resource_spec_class, ), ) futures.append(parallel_future) serial_future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( serial_svc_resource_spec_classes), ) futures.append(serial_future) except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error( event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back, ) prescan_errors.append(f"{error_str}\n{trace_back}") graph_sets: List[GraphSet] = [] for future in as_completed(futures): graph_set = future.result() graph_sets.append(graph_set) # if there was a prescan error graph it and return the result if prescan_errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=prescan_errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=prescan_errors, ) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set, ) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) return AccountScanResult( account_id=account_id, artifacts=[output_artifact], errors=prescan_errors, ) # if there are any errors whatsoever we generate an empty graph with errors only errors = [] for graph_set in graph_sets: errors += graph_set.errors if errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, ) else: account_graph_set = GraphSet.from_graph_sets(graph_sets) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set, ) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) return AccountScanResult( account_id=account_id, artifacts=[output_artifact], errors=errors, )
def test(self): with tempfile.TemporaryDirectory() as temp_dir: resource_region_name = "us-east-1" # get moto"s enabled regions ec2_client = boto3.client("ec2", region_name=resource_region_name) all_regions = ec2_client.describe_regions( Filters=[{ "Name": "opt-in-status", "Values": ["opt-in-not-required", "opted-in"] }])["Regions"] account_id = get_account_id() all_region_names = tuple(region["RegionName"] for region in all_regions) enabled_region_names = tuple( region["RegionName"] for region in all_regions if region["OptInStatus"] != "not-opted-in") delete_vpcs(all_region_names) # add a diverse set of resources which are supported by moto ## dynamodb # TODO moto is not returning TableId in list/describe # dynamodb_table_1_arn = create_dynamodb_table( # name="test_table_1", # attr_name="test_hash_key_attr_1", # attr_type="S", # key_type="HASH", # region_name=region_name, # ) ## s3 bucket_1_name = "test_bucket" bucket_1_arn, bucket_1_creation_date = create_bucket( name=bucket_1_name, account_id=account_id, region_name=resource_region_name) ## ec2 vpc_1_cidr = "10.0.0.0/16" vpc_1_id = create_vpc(cidr_block=vpc_1_cidr, region_name=resource_region_name) vpc_1_arn = VPCResourceSpec.generate_arn( resource_id=vpc_1_id, account_id=account_id, region=resource_region_name) subnet_1_cidr = "10.0.0.0/24" subnet_1_cidr_network = ipaddress.IPv4Network(subnet_1_cidr, strict=False) subnet_1_first_ip, subnet_1_last_ip = ( int(subnet_1_cidr_network[0]), int(subnet_1_cidr_network[-1]), ) subnet_1_id = create_subnet(cidr_block=subnet_1_cidr, vpc_id=vpc_1_id, region_name=resource_region_name) subnet_1_arn = SubnetResourceSpec.generate_arn( resource_id=subnet_1_id, account_id=account_id, region=resource_region_name) fixed_bucket_1_arn = f"arn:aws:s3:::{bucket_1_name}" flow_log_1_id, flow_log_1_creation_time = create_flow_log( vpc_id=vpc_1_id, dest_bucket_arn=fixed_bucket_1_arn, region_name=resource_region_name, ) flow_log_1_arn = FlowLogResourceSpec.generate_arn( resource_id=flow_log_1_id, account_id=account_id, region=resource_region_name) ebs_volume_1_size = 128 ebs_volume_1_az = f"{resource_region_name}a" ebs_volume_1_arn, ebs_volume_1_create_time = create_volume( size=ebs_volume_1_size, az=ebs_volume_1_az, region_name=resource_region_name) ## iam policy_1_name = "test_policy_1" policy_1_arn, policy_1_id = create_iam_policy( name=policy_1_name, policy_doc={ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": "logs:CreateLogGroup", "Resource": "*" }, ], }, ) role_1_name = "test_role_1" role_1_assume_role_policy_doc = { "Version": "2012-10-17", "Statement": [{ "Action": "sts:AssumeRole", "Effect": "Allow", "Principal": { "Service": "lambda.amazonaws.com" }, "Sid": "", }], } role_1_description = "Test Role 1" role_1_max_session_duration = 3600 role_1_arn = create_iam_role( name=role_1_name, assume_role_policy_doc=role_1_assume_role_policy_doc, description=role_1_description, max_session_duration=role_1_max_session_duration, ) ## lambda lambda_function_1_name = "test_lambda_function_1" lambda_function_1_runtime = "python3.7" lambda_function_1_handler = "lambda_function.lambda_handler" lambda_function_1_description = "Test Lambda Function 1" lambda_function_1_timeout = 30 lambda_function_1_memory_size = 256 lambda_function_1_arn = create_lambda_function( name=lambda_function_1_name, runtime=lambda_function_1_runtime, role_name=role_1_arn, handler=lambda_function_1_handler, description=lambda_function_1_description, timeout=lambda_function_1_timeout, memory_size=lambda_function_1_memory_size, publish=False, region_name=resource_region_name, ) # scan test_scan_id = "test_scan_id" aws_config = AWSConfig( artifact_path=temp_dir, pruner_max_age_min=4320, graph_name="alti", concurrency=ConcurrencyConfig(max_account_scan_threads=1, max_svc_scan_threads=1, max_account_scan_tries=2), scan=ScanConfig( accounts=(), regions=(), scan_sub_accounts=False, preferred_account_scan_regions=( "us-west-1", "us-west-2", "us-east-1", "us-east-2", ), ), accessor=Accessor( credentials_cache=AWSCredentialsCache(cache={}), multi_hop_accessors=[], cache_creds=True, ), write_master_json=True, ) resource_spec_classes = ( # DynamoDbTableResourceSpec, TODO moto EBSVolumeResourceSpec, FlowLogResourceSpec, IAMPolicyResourceSpec, IAMRoleResourceSpec, LambdaFunctionResourceSpec, S3BucketResourceSpec, SubnetResourceSpec, VPCResourceSpec, ) muxer = LocalAWSScanMuxer( scan_id=test_scan_id, config=aws_config, resource_spec_classes=resource_spec_classes, ) with unittest.mock.patch( "altimeter.aws.scan.account_scanner.get_all_enabled_regions" ) as mock_get_all_enabled_regions: mock_get_all_enabled_regions.return_value = enabled_region_names aws2n_result = aws2n( scan_id=test_scan_id, config=aws_config, muxer=muxer, load_neptune=False, ) graph_set = GraphSet.from_json_file( Path(aws2n_result.json_path)) self.assertEqual(len(graph_set.errors), 0) self.assertEqual(graph_set.name, "alti") self.assertEqual(graph_set.version, "2") # now check each resource type self.maxDiff = None ## Accounts expected_account_resources = [ Resource( resource_id=f"arn:aws::::account/{account_id}", type="aws:account", link_collection=LinkCollection( simple_links=(SimpleLink(pred="account_id", obj=account_id), ), ), ) ] account_resources = [ resource for resource in graph_set.resources if resource.type == "aws:account" ] self.assertCountEqual(account_resources, expected_account_resources) ## Regions expected_region_resources = [ Resource( resource_id= f"arn:aws:::{account_id}:region/{region['RegionName']}", type="aws:region", link_collection=LinkCollection( simple_links=( SimpleLink(pred="name", obj=region["RegionName"]), SimpleLink(pred="opt_in_status", obj=region["OptInStatus"]), ), resource_links=(ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ), ), ) for region in all_regions ] region_resources = [ resource for resource in graph_set.resources if resource.type == "aws:region" ] self.assertCountEqual(region_resources, expected_region_resources) ## IAM Policies expected_iam_policy_resources = [ Resource( resource_id=policy_1_arn, type="aws:iam:policy", link_collection=LinkCollection( simple_links=( SimpleLink(pred="name", obj=policy_1_name), SimpleLink(pred="policy_id", obj=policy_1_id), SimpleLink(pred="default_version_id", obj="v1"), SimpleLink( pred="default_version_policy_document_text", obj= '{"Statement": [{"Action": "logs:CreateLogGroup", "Effect": "Allow", "Resource": "*"}], "Version": "2012-10-17"}', ), ), resource_links=(ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ), ), ) ] iam_policy_resources = [ resource for resource in graph_set.resources if resource.type == "aws:iam:policy" ] self.assertCountEqual(iam_policy_resources, expected_iam_policy_resources) ## IAM Roles expected_iam_role_resources = [ Resource( resource_id=role_1_arn, type="aws:iam:role", link_collection=LinkCollection( simple_links=( SimpleLink(pred="name", obj=role_1_name), SimpleLink(pred="max_session_duration", obj=role_1_max_session_duration), SimpleLink(pred="description", obj=role_1_description), SimpleLink( pred="assume_role_policy_document_text", obj=policy_doc_dict_to_sorted_str( role_1_assume_role_policy_doc), ), ), multi_links=(MultiLink( pred="assume_role_policy_document", obj=LinkCollection( simple_links=(SimpleLink( pred="version", obj="2012-10-17"), ), multi_links=(MultiLink( pred="statement", obj=LinkCollection( simple_links=( SimpleLink(pred="effect", obj="Allow"), SimpleLink( pred="action", obj="sts:AssumeRole"), ), multi_links=(MultiLink( pred="principal", obj=LinkCollection( simple_links=(SimpleLink( pred="service", obj= "lambda.amazonaws.com", ), )), ), ), ), ), ), ), ), ), resource_links=(ResourceLink( pred="account", obj="arn:aws::::account/123456789012"), ), ), ) ] iam_role_resources = [ resource for resource in graph_set.resources if resource.type == "aws:iam:role" ] self.assertCountEqual(iam_role_resources, expected_iam_role_resources) ## Lambda functions expected_lambda_function_resources = [ Resource( resource_id=lambda_function_1_arn, type="aws:lambda:function", link_collection=LinkCollection( simple_links=( SimpleLink(pred="function_name", obj=lambda_function_1_name), SimpleLink(pred="runtime", obj=lambda_function_1_runtime), ), resource_links=( ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ResourceLink( pred="region", obj= f"arn:aws:::{account_id}:region/{resource_region_name}", ), ), transient_resource_links=(ResourceLink( pred="role", obj="arn:aws:iam::123456789012:role/test_role_1" ), ), ), ), ] lambda_function_resources = [ resource for resource in graph_set.resources if resource.type == "aws:lambda:function" ] self.assertCountEqual(lambda_function_resources, expected_lambda_function_resources) ## EC2 VPCs expected_ec2_vpc_resources = [ Resource( resource_id=vpc_1_arn, type="aws:ec2:vpc", link_collection=LinkCollection( simple_links=( SimpleLink(pred="is_default", obj=True), SimpleLink(pred="cidr_block", obj=vpc_1_cidr), SimpleLink(pred="state", obj="available"), ), resource_links=( ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ResourceLink( pred="region", obj= f"arn:aws:::{account_id}:region/{resource_region_name}", ), ), ), ) ] ec2_vpc_resources = [ resource for resource in graph_set.resources if resource.type == "aws:ec2:vpc" ] self.assertCountEqual(ec2_vpc_resources, expected_ec2_vpc_resources) ## EC2 VPC Flow Logs expected_ec2_vpc_flow_log_resources = [ Resource( resource_id=flow_log_1_arn, type="aws:ec2:flow-log", link_collection=LinkCollection( simple_links=( SimpleLink( pred="creation_time", obj=flow_log_1_creation_time.replace( tzinfo=datetime.timezone.utc). isoformat(), ), SimpleLink(pred="deliver_logs_status", obj="SUCCESS"), SimpleLink(pred="flow_log_status", obj="ACTIVE"), SimpleLink(pred="traffic_type", obj="ALL"), SimpleLink(pred="log_destination_type", obj="s3"), SimpleLink(pred="log_destination", obj=fixed_bucket_1_arn), SimpleLink( pred="log_format", obj= "${version} ${account-id} ${interface-id} ${srcaddr} ${dstaddr} ${srcport} ${dstport} ${protocol} ${packets} ${bytes} ${start} ${end} ${action} ${log-status}", ), ), resource_links=( ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ResourceLink( pred="region", obj= f"arn:aws:::{account_id}:region/{resource_region_name}", ), ), transient_resource_links=(TransientResourceLink( pred="vpc", obj=vpc_1_arn, ), ), ), ) ] ec2_vpc_flow_log_resources = [ resource for resource in graph_set.resources if resource.type == "aws:ec2:flow-log" ] self.assertCountEqual(ec2_vpc_flow_log_resources, expected_ec2_vpc_flow_log_resources) ## EC2 Subnets expected_ec2_subnet_resources = [ Resource( resource_id=subnet_1_arn, type="aws:ec2:subnet", link_collection=LinkCollection( simple_links=( SimpleLink(pred="cidr_block", obj=subnet_1_cidr), SimpleLink(pred="first_ip", obj=subnet_1_first_ip), SimpleLink(pred="last_ip", obj=subnet_1_last_ip), SimpleLink(pred="state", obj="available"), ), resource_links=( ResourceLink(pred="vpc", obj=vpc_1_arn), ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ResourceLink( pred="region", obj= f"arn:aws:::{account_id}:region/{resource_region_name}", ), ), ), ) ] ec2_subnet_resources = [ resource for resource in graph_set.resources if resource.type == "aws:ec2:subnet" ] self.assertCountEqual(ec2_subnet_resources, expected_ec2_subnet_resources) ## EC2 EBS Volumes expected_ec2_ebs_volume_resources = [ Resource( resource_id=ebs_volume_1_arn, type="aws:ec2:volume", link_collection=LinkCollection( simple_links=( SimpleLink(pred="availability_zone", obj=ebs_volume_1_az), SimpleLink( pred="create_time", obj=ebs_volume_1_create_time.replace( tzinfo=datetime.timezone.utc). isoformat(), ), SimpleLink(pred="size", obj=ebs_volume_1_size), SimpleLink(pred="state", obj="available"), SimpleLink(pred="volume_type", obj="gp2"), SimpleLink(pred="encrypted", obj=False), ), resource_links=( ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ResourceLink( pred="region", obj= f"arn:aws:::{account_id}:region/{resource_region_name}", ), ), ), ) ] ec2_ebs_volume_resources = [ resource for resource in graph_set.resources if resource.type == "aws:ec2:volume" ] self.assertCountEqual(ec2_ebs_volume_resources, expected_ec2_ebs_volume_resources) ## S3 Buckets expected_s3_bucket_resources = [ Resource( resource_id=bucket_1_arn, type="aws:s3:bucket", link_collection=LinkCollection( simple_links=( SimpleLink(pred="name", obj=bucket_1_name), SimpleLink( pred="creation_date", obj=bucket_1_creation_date.replace( tzinfo=datetime.timezone.utc). isoformat(), ), ), resource_links=( ResourceLink( pred="account", obj=f"arn:aws::::account/{account_id}"), ResourceLink( pred="region", obj= f"arn:aws:::{account_id}:region/{resource_region_name}", ), ), ), ) ] s3_bucket_resources = [ resource for resource in graph_set.resources if resource.type == "aws:s3:bucket" ] self.assertCountEqual(s3_bucket_resources, expected_s3_bucket_resources) expected_num_graph_set_resources = ( 0 + len(expected_account_resources) + len(expected_region_resources) + len(expected_iam_policy_resources) + len(expected_iam_role_resources) + len(expected_lambda_function_resources) + len(expected_ec2_ebs_volume_resources) + len(expected_ec2_subnet_resources) + len(expected_ec2_vpc_resources) + len(expected_ec2_vpc_flow_log_resources) + len(expected_s3_bucket_resources)) self.assertEqual(len(graph_set.resources), expected_num_graph_set_resources)
def scan(self) -> Dict[str, Any]: """Scan an account and return a dict containing keys: * account_id: str * output_artifact: str * api_call_stats: Dict[str, Any] * errors: List[str] If errors is non-empty the results are incomplete for this account. output_artifact is a pointer to the actual scan data - either on the local fs or in s3. To scan an account we create a set of GraphSpecs, one for each region. Any ACCOUNT level granularity resources are only scanned in a single region (e.g. IAM Users) Returns: Dict of scan result, see above for details. """ logger = Logger() with logger.bind(account_id=self.account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) output_artifact = None stats = MultilevelCounter() errors: List[str] = [] now = int(time.time()) try: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=stats, ) # sanity check session = self.get_session() sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity()["Account"] if sts_account_id != self.account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {self.account_id}" ) if self.regions: scan_regions = tuple(self.regions) else: scan_regions = get_all_enabled_regions(session=session) # build graph specs. # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() resource_class_scan_granularity = resource_spec_class.scan_granularity if resource_class_scan_granularity == ScanGranularity.ACCOUNT: regions_services_resource_spec_classes[scan_regions[ 0]][client_name].append(resource_spec_class) elif resource_class_scan_granularity == ScanGranularity.REGION: for region in scan_regions: regions_services_resource_spec_classes[region][ client_name].append(resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_class_scan_granularity} not implemented" ) with ThreadPoolExecutor( max_workers=self.max_svc_threads) as executor: futures = [] for ( region, services_resource_spec_classes, ) in regions_services_resource_spec_classes.items(): for ( service, resource_spec_classes, ) in services_resource_spec_classes.items(): region_session = self.get_session(region=region) region_creds = region_session.get_credentials() scan_future = schedule_scan_services( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=self.account_id, region=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( resource_spec_classes), ) futures.append(scan_future) for future in as_completed(futures): graph_set_dict = future.result() graph_set = GraphSet.from_dict(graph_set_dict) errors += graph_set.errors account_graph_set.merge(graph_set) account_graph_set.validate() except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error(event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back) errors.append(" : ".join((error_str, trace_back))) unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=self.account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=stats, ) account_graph_set.validate() output_artifact = self.artifact_writer.write_artifact( name=self.account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() return { "account_id": self.account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }
def test_valid_merge(self): resource_a1 = Resource( resource_id="123", type="test:a", link_collection=LinkCollection(simple_links=[SimpleLink(pred="has-foo", obj="goo")]), ) resource_a2 = Resource(resource_id="456", type="test:a", link_collection=LinkCollection()) resource_b1 = Resource( resource_id="abc", type="test:b", link_collection=LinkCollection(simple_links=[SimpleLink(pred="has-a", obj="123")]), ) resource_b2 = Resource( resource_id="def", type="test:b", link_collection=LinkCollection(simple_links=[SimpleLink(pred="name", obj="sue")]), ) graph_set_1 = GraphSet( name="graph-1", version="1", start_time=10, end_time=20, resources=[resource_a1, resource_a2], errors=["errora1", "errora2"], ) graph_set_2 = GraphSet( name="graph-1", version="1", start_time=15, end_time=25, resources=[resource_b1, resource_b2], errors=["errorb1", "errorb2"], ) merged_graph_set = ValidatedGraphSet.from_graph_sets([graph_set_1, graph_set_2]) self.assertEqual(merged_graph_set.name, "graph-1") self.assertEqual(merged_graph_set.version, "1") self.assertEqual(merged_graph_set.start_time, 10) self.assertEqual(merged_graph_set.end_time, 25) self.assertCountEqual(merged_graph_set.errors, ["errora1", "errora2", "errorb1", "errorb2"]) expected_resources = ( Resource( resource_id="123", type="test:a", link_collection=LinkCollection( simple_links=(SimpleLink(pred="has-foo", obj="goo"),), ), ), Resource(resource_id="456", type="test:a", link_collection=LinkCollection(),), Resource( resource_id="abc", type="test:b", link_collection=LinkCollection( simple_links=(SimpleLink(pred="has-a", obj="123"),), ), ), Resource( resource_id="def", type="test:b", link_collection=LinkCollection(simple_links=(SimpleLink(pred="name", obj="sue"),),), ), ) expected_errors = ["errora1", "errora2", "errorb1", "errorb2"] self.assertCountEqual(merged_graph_set.resources, expected_resources) self.assertCountEqual(merged_graph_set.errors, expected_errors)
def scan(self) -> List[Dict[str, Any]]: logger = Logger() scan_result_dicts = [] now = int(time.time()) prescan_account_ids_errors: DefaultDict[str, List[str]] = defaultdict(list) futures = [] with ThreadPoolExecutor(max_workers=self.max_threads) as executor: shuffled_account_ids = random.sample( self.account_scan_plan.account_ids, k=len(self.account_scan_plan.account_ids)) for account_id in shuffled_account_ids: with logger.bind(account_id=account_id): logger.info(event=AWSLogEvents.ScanAWSAccountStart) try: session = self.account_scan_plan.accessor.get_session( account_id=account_id) # sanity check sts_client = session.client("sts") sts_account_id = sts_client.get_caller_identity( )["Account"] if sts_account_id != account_id: raise ValueError( f"BUG: sts detected account_id {sts_account_id} != {account_id}" ) if self.account_scan_plan.regions: scan_regions = tuple( self.account_scan_plan.regions) else: scan_regions = get_all_enabled_regions( session=session) account_gran_scan_region = random.choice( self.preferred_account_scan_regions) # build a dict of regions -> services -> List[AWSResourceSpec] regions_services_resource_spec_classes: DefaultDict[ str, DefaultDict[ str, List[Type[AWSResourceSpec]]]] = defaultdict( lambda: defaultdict(list)) resource_spec_class: Type[AWSResourceSpec] for resource_spec_class in self.resource_spec_classes: client_name = resource_spec_class.get_client_name() if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT: if resource_spec_class.region_whitelist: account_resource_scan_region = resource_spec_class.region_whitelist[ 0] else: account_resource_scan_region = account_gran_scan_region regions_services_resource_spec_classes[ account_resource_scan_region][ client_name].append( resource_spec_class) elif resource_spec_class.scan_granularity == ScanGranularity.REGION: if resource_spec_class.region_whitelist: resource_scan_regions = tuple( region for region in scan_regions if region in resource_spec_class.region_whitelist) if not resource_scan_regions: resource_scan_regions = resource_spec_class.region_whitelist else: resource_scan_regions = scan_regions for region in resource_scan_regions: regions_services_resource_spec_classes[ region][client_name].append( resource_spec_class) else: raise NotImplementedError( f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented" ) # Build and submit ScanUnits shuffed_regions_services_resource_spec_classes = random.sample( regions_services_resource_spec_classes.items(), len(regions_services_resource_spec_classes), ) for ( region, services_resource_spec_classes, ) in shuffed_regions_services_resource_spec_classes: region_session = self.account_scan_plan.accessor.get_session( account_id=account_id, region_name=region) region_creds = region_session.get_credentials() shuffled_services_resource_spec_classes = random.sample( services_resource_spec_classes.items(), len(services_resource_spec_classes), ) for ( service, svc_resource_spec_classes, ) in shuffled_services_resource_spec_classes: future = schedule_scan( executor=executor, graph_name=self.graph_name, graph_version=self.graph_version, account_id=account_id, region_name=region, service=service, access_key=region_creds.access_key, secret_key=region_creds.secret_key, token=region_creds.token, resource_spec_classes=tuple( svc_resource_spec_classes), ) futures.append(future) except Exception as ex: error_str = str(ex) trace_back = traceback.format_exc() logger.error( event=AWSLogEvents.ScanAWSAccountError, error=error_str, trace_back=trace_back, ) prescan_account_ids_errors[account_id].append( f"{error_str}\n{trace_back}") account_ids_graph_set_dicts: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for future in as_completed(futures): account_id, graph_set_dict = future.result() account_ids_graph_set_dicts[account_id].append(graph_set_dict) # first make sure no account id appears both in account_ids_graph_set_dicts # and prescan_account_ids_errors - this should never happen doubled_accounts = set( account_ids_graph_set_dicts.keys()).intersection( set(prescan_account_ids_errors.keys())) if doubled_accounts: raise Exception(( f"BUG: Account(s) {doubled_accounts} in both " "account_ids_graph_set_dicts and prescan_account_ids_errors.")) # graph prescan error accounts for account_id, errors in prescan_account_ids_errors.items(): with logger.bind(account_id=account_id): unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter(), ) account_graph_set.validate() output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) # graph rest for account_id, graph_set_dicts in account_ids_graph_set_dicts.items(): with logger.bind(account_id=account_id): # if there are any errors whatsoever we generate an empty graph with # errors only errors = [] for graph_set_dict in graph_set_dicts: errors += graph_set_dict["errors"] if errors: unscanned_account_resource = UnscannedAccountResourceSpec.create_resource( account_id=account_id, errors=errors) account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[unscanned_account_resource], errors=errors, stats=MultilevelCounter( ), # ENHANCHMENT: could technically get partial stats. ) account_graph_set.validate() else: account_graph_set = GraphSet( name=self.graph_name, version=self.graph_version, start_time=now, end_time=now, resources=[], errors=[], stats=MultilevelCounter(), ) for graph_set_dict in graph_set_dicts: graph_set = GraphSet.from_dict(graph_set_dict) account_graph_set.merge(graph_set) output_artifact = self.artifact_writer.write_json( name=account_id, data=account_graph_set.to_dict()) logger.info(event=AWSLogEvents.ScanAWSAccountEnd) api_call_stats = account_graph_set.stats.to_dict() scan_result_dicts.append({ "account_id": account_id, "output_artifact": output_artifact, "errors": errors, "api_call_stats": api_call_stats, }) return scan_result_dicts