Beispiel #1
0
 def test_invalid_diff_versions(self):
     graph_set_1 = GraphSet(
         name="graph-1", version="1", start_time=10, end_time=20, resources=[], errors=[],
     )
     graph_set_2 = GraphSet(
         name="graph-1", version="2", start_time=15, end_time=25, resources=[], errors=[],
     )
     with self.assertRaises(UnmergableGraphSetsException):
         GraphSet.from_graph_sets([graph_set_1, graph_set_2])
Beispiel #2
0
def scan_scan_unit(scan_unit: ScanUnit) -> Tuple[str, Dict[str, Any]]:
    logger = Logger()
    with logger.bind(
            account_id=scan_unit.account_id,
            region=scan_unit.region_name,
            service=scan_unit.service,
            resource_classes=sorted([
                resource_spec_class.__name__
                for resource_spec_class in scan_unit.resource_spec_classes
            ]),
    ):
        start_t = time.time()
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceStart)
        session = boto3.Session(
            aws_access_key_id=scan_unit.access_key,
            aws_secret_access_key=scan_unit.secret_key,
            aws_session_token=scan_unit.token,
            region_name=scan_unit.region_name,
        )
        scan_accessor = AWSAccessor(session=session,
                                    account_id=scan_unit.account_id,
                                    region_name=scan_unit.region_name)
        graph_spec = GraphSpec(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            resource_spec_classes=scan_unit.resource_spec_classes,
            scan_accessor=scan_accessor,
        )
        start_time = int(time.time())
        resources: List[Resource] = []
        errors = []
        try:
            resources = graph_spec.scan()
        except Exception as ex:
            error_str = str(ex)
            trace_back = traceback.format_exc()
            logger.error(event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back)
            error = f"{str(ex)}\n{trace_back}"
            errors.append(error)
        end_time = int(time.time())
        graph_set = GraphSet(
            name=scan_unit.graph_name,
            version=scan_unit.graph_version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=scan_accessor.api_call_stats,
        )
        end_t = time.time()
        elapsed_sec = end_t - start_t
        logger.info(event=AWSLogEvents.ScanAWSAccountServiceEnd,
                    elapsed_sec=elapsed_sec)
        return (scan_unit.account_id, graph_set.to_dict())
Beispiel #3
0
    def scan(self) -> GraphSet:
        """Perform a scan on all of the resource classes in this GraphSpec and return
        a GraphSet containing the scanned data.

        Returns:
            GraphSet representing results of scanning this GraphSpec's resource_spec_classes.
        """

        resources: List[Resource] = []
        errors: List[str] = []
        stats = MultilevelCounter()
        start_time = int(time.time())
        logger = Logger()
        for resource_spec_class in self.resource_spec_classes:
            with logger.bind(resource_type=str(resource_spec_class.type_name)):
                logger.debug(event=LogEvent.ScanResourceTypeStart)
                resource_scan_result = resource_spec_class.scan(
                    scan_accessor=self.scan_accessor)
                resources += resource_scan_result.resources
                errors += resource_scan_result.errors
                stats.merge(resource_scan_result.stats)
                logger.debug(event=LogEvent.ScanResourceTypeEnd)
        end_time = int(time.time())
        return GraphSet(
            name=self.name,
            version=self.version,
            start_time=start_time,
            end_time=end_time,
            resources=resources,
            errors=errors,
            stats=stats,
        )
Beispiel #4
0
    def write_graph_set(self,
                        name: str,
                        graph_set: GraphSet,
                        compression: Optional[str] = None) -> str:
        """Write a graph artifact

        Args:
            name: name
            graph_set: GraphSet object to write

        Returns:
            path to written artifact
        """
        logger = Logger()
        os.makedirs(self.output_dir, exist_ok=True)
        if compression is None:
            artifact_path = os.path.join(self.output_dir, f"{name}.rdf")
        elif compression == GZIP:
            artifact_path = os.path.join(self.output_dir, f"{name}.rdf.gz")
        else:
            raise ValueError(f"Unknown compression arg {compression}")
        graph = graph_set.to_rdf()
        with logger.bind(artifact_path=artifact_path):
            logger.info(event=LogEvent.WriteToFSStart)
            with open(artifact_path, "wb") as fp:
                if compression is None:
                    graph.serialize(fp)
                elif compression == GZIP:
                    with gzip.GzipFile(fileobj=fp, mode="wb") as gz:
                        graph.serialize(gz)
                else:
                    raise ValueError(f"Unknown compression arg {compression}")
            logger.info(event=LogEvent.WriteToFSEnd)
        return artifact_path
Beispiel #5
0
class TestGraphSetWithValidDataMerging(TestCase):
    def setUp(self):
        resource_a1 = Resource(resource_id="123",
                               type_name="test:a",
                               links=[SimpleLink(pred="has-foo", obj="goo")])
        resource_a2 = Resource(resource_id="123",
                               type_name="test:a",
                               links=[SimpleLink(pred="has-goo", obj="foo")])
        resource_b1 = Resource(
            resource_id="abc",
            type_name="test:b",
            links=[ResourceLinkLink(pred="has-a", obj="123")])
        resource_b2 = Resource(resource_id="def",
                               type_name="test:b",
                               links=[SimpleLink(pred="name", obj="sue")])
        resources = [resource_a1, resource_a2, resource_b1, resource_b2]
        self.graph_set = GraphSet(
            name="test-name",
            version="1",
            start_time=1234,
            end_time=4567,
            resources=resources,
            errors=["test err 1", "test err 2"],
            stats=MultilevelCounter(),
        )

    def test_rdf_a_type(self):
        graph = self.graph_set.to_rdf()

        a_results = graph.query(
            "select ?p ?o where {?s a <test-name:test:a> ; ?p ?o} order by ?p ?o"
        )
        expected_a_result_tuples = [
            ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
             "test-name:test:a"),
            ("test-name:has-foo", "goo"),
            ("test-name:has-goo", "foo"),
            ("test-name:id", "123"),
        ]
        a_result_tuples = []
        for a_result in a_results:
            self.assertEqual(2, len(a_result))
            a_result_tuples.append((str(a_result[0]), str(a_result[1])))
        self.assertEqual(expected_a_result_tuples, a_result_tuples)

    def test_validate(self):
        self.graph_set.validate()
Beispiel #6
0
 def test_orphaned_ref(self):
     resource_a1 = Resource(resource_id="123",
                            type_name="test:a",
                            links=[SimpleLink(pred="has-foo", obj="goo")])
     resource_b1 = Resource(
         resource_id="abc",
         type_name="test:b",
         links=[ResourceLinkLink(pred="has-a", obj="456")])
     resources = [resource_a1, resource_b1]
     graph_set = GraphSet(
         name="test-name",
         version="1",
         start_time=1234,
         end_time=4567,
         resources=resources,
         errors=["test err 1", "test err 2"],
         stats=MultilevelCounter(),
     )
     with self.assertRaises(GraphSetOrphanedReferencesException):
         graph_set.validate()
Beispiel #7
0
 def test_invalid_diff_versions(self):
     graph_set_1 = GraphSet(
         name="graph-1",
         version="1",
         start_time=10,
         end_time=20,
         resources=[],
         errors=[],
         stats=MultilevelCounter(),
     )
     graph_set_2 = GraphSet(
         name="graph-1",
         version="2",
         start_time=15,
         end_time=25,
         resources=[],
         errors=[],
         stats=MultilevelCounter(),
     )
     with self.assertRaises(UnmergableGraphSetsException):
         graph_set_1.merge(graph_set_2)
Beispiel #8
0
 def setUp(self):
     resource_a1 = Resource(resource_id="123",
                            type_name="test:a",
                            links=[SimpleLink(pred="has-foo", obj="goo")])
     resource_a2 = Resource(resource_id="456", type_name="test:a")
     resource_b1 = Resource(
         resource_id="abc",
         type_name="test:b",
         links=[ResourceLinkLink(pred="has-a", obj="123")])
     resource_b2 = Resource(resource_id="def",
                            type_name="test:b",
                            links=[SimpleLink(pred="name", obj="sue")])
     resources = [resource_a1, resource_a2, resource_b1, resource_b2]
     self.graph_set = GraphSet(
         name="test-name",
         version="1",
         start_time=1234,
         end_time=4567,
         resources=resources,
         errors=["test err 1", "test err 2"],
         stats=MultilevelCounter(),
     )
Beispiel #9
0
def graph_set_from_s3(s3_client: BaseClient, json_bucket: str,
                      json_key: str) -> GraphSet:
    """Load a GraphSet from json located in an s3 object."""
    logger = Logger()
    logger.info(event=LogEvent.ReadFromS3Start)
    with io.BytesIO() as json_bytes_buf:
        s3_client.download_fileobj(json_bucket, json_key, json_bytes_buf)
        json_bytes_buf.flush()
        json_bytes_buf.seek(0)
        graph_set_bytes = json_bytes_buf.read()
        logger.info(event=LogEvent.ReadFromS3End)
    graph_set_str = graph_set_bytes.decode("utf-8")
    graph_set_dict = json.loads(graph_set_str)
    return GraphSet.from_dict(graph_set_dict)
Beispiel #10
0
def main(argv: Optional[List[str]] = None) -> int:
    if argv is None:
        argv = sys.argv[1:]
    parser = argparse.ArgumentParser()
    parser.add_argument("input_json_filepaths", type=Path, nargs="+")
    args_ns = parser.parse_args(argv)

    input_json_filepaths = args_ns.input_json_filepaths
    if len(input_json_filepaths) > 1:
        raise NotImplementedError("Only one input supported at this time")

    # create a dict of scan ids to GraphSets. This contains all of the data in the provided input.
    scan_ids_graph_sets: Dict[int, GraphSet] = {
        scan_id: GraphSet.from_json_file(filepath)
        for scan_id, filepath in enumerate(input_json_filepaths)
    }

    # discover tables which need to be created by iterating over resources and finding the maximum
    # set of predicates used for each type
    table_defns = build_table_defns(scan_ids_graph_sets.values())

    # build data
    table_names_datas = build_data(scan_ids_graph_sets.values(), table_defns)

    table_names_tables: Dict[str, tableauhyperapi.TableDefinition] = {}
    with tableauhyperapi.HyperProcess(
            telemetry=tableauhyperapi.Telemetry.
            DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper:
        with tableauhyperapi.Connection(
                endpoint=hyper.endpoint,
                database="altimeter.hyper",
                create_mode=tableauhyperapi.CreateMode.CREATE_AND_REPLACE,
        ) as connection:
            # create tables
            for table_name, columns in table_defns.items():
                table = tableauhyperapi.TableDefinition(
                    table_name=table_name,
                    columns=[column.to_hyper() for column in columns])
                connection.catalog.create_table(table)
                table_names_tables[table_name] = table

            for table_name, datas in table_names_datas.items():
                with tableauhyperapi.Inserter(
                        connection,
                        table_names_tables[table_name]) as inserter:
                    inserter.add_rows(datas)
                    inserter.execute()

    return 0
Beispiel #11
0
 def test_unknown_type_name(self):
     resources = [
         Resource(resource_id="xyz", type_name="test:a"),
         Resource(resource_id="xyz", type_name="test:c"),
     ]
     with self.assertRaises(ResourceSpecClassNotFoundException):
         GraphSet(
             name="test-name",
             version="1",
             start_time=1234,
             end_time=4567,
             resources=resources,
             errors=[],
             stats=MultilevelCounter(),
         )
Beispiel #12
0
 def test_invalid_resources_dupes_same_class_conflicting_types_no_allow_clobber(
         self):
     resources = [
         Resource(resource_id="123", type_name="test:a"),
         Resource(resource_id="123", type_name="test:b"),
     ]
     with self.assertRaises(UnmergableDuplicateResourceIdsFoundException):
         GraphSet(
             name="test-name",
             version="1",
             start_time=1234,
             end_time=4567,
             resources=resources,
             errors=[],
             stats=MultilevelCounter(),
         )
Beispiel #13
0
 def test_from_dict(self):
     input_dict = {
         "name": "test-name",
         "version": "1",
         "start_time": 1234,
         "end_time": 4567,
         "resources": {
             "123": {
                 "type": "test:a",
                 "links": [{
                     "pred": "has-foo",
                     "obj": "goo",
                     "type": "simple"
                 }],
             },
             "456": {
                 "type": "test:a"
             },
             "abc": {
                 "type":
                 "test:b",
                 "links": [{
                     "pred": "has-a",
                     "obj": "123",
                     "type": "resource_link"
                 }],
             },
             "def": {
                 "type": "test:b",
                 "links": [{
                     "pred": "name",
                     "obj": "sue",
                     "type": "simple"
                 }],
             },
         },
         "errors": ["test err 1", "test err 2"],
         "stats": {
             "count": 0
         },
     }
     graph_set = GraphSet.from_dict(input_dict)
     self.assertEqual(graph_set.to_dict(), input_dict)
Beispiel #14
0
 def test_orphaned_ref(self):
     resource_a1 = Resource(
         resource_id="123",
         type="test:a",
         link_collection=LinkCollection(simple_links=[SimpleLink(pred="has-foo", obj="goo")]),
     )
     resource_b1 = Resource(
         resource_id="abc",
         type="test:b",
         link_collection=LinkCollection(resource_links=[ResourceLink(pred="has-a", obj="456")]),
     )
     resources = [resource_a1, resource_b1]
     graph_set = GraphSet(
         name="test-name",
         version="1",
         start_time=1234,
         end_time=4567,
         resources=resources,
         errors=["test err 1", "test err 2"],
     )
     with self.assertRaises(GraphSetOrphanedReferencesException):
         ValidatedGraphSet.from_graph_set(graph_set)
Beispiel #15
0
def graph_set_from_filepath(filepath: str) -> GraphSet:
    with open(filepath, "r") as fp:
        graph_set_dict = json.load(fp)
    return GraphSet.from_dict(data=graph_set_dict)
Beispiel #16
0
    def test_valid_merge(self):
        resource_a1 = Resource(resource_id="123",
                               type_name="test:a",
                               links=[SimpleLink(pred="has-foo", obj="goo")])
        resource_a2 = Resource(resource_id="456", type_name="test:a")
        resource_b1 = Resource(
            resource_id="abc",
            type_name="test:b",
            links=[ResourceLinkLink(pred="has-a", obj="123")])
        resource_b2 = Resource(resource_id="def",
                               type_name="test:b",
                               links=[SimpleLink(pred="name", obj="sue")])
        graph_set_1 = GraphSet(
            name="graph-1",
            version="1",
            start_time=10,
            end_time=20,
            resources=[resource_a1, resource_a2],
            errors=["errora1", "errora2"],
            stats=MultilevelCounter(),
        )
        graph_set_2 = GraphSet(
            name="graph-1",
            version="1",
            start_time=15,
            end_time=25,
            resources=[resource_b1, resource_b2],
            errors=["errorb1", "errorb2"],
            stats=MultilevelCounter(),
        )
        graph_set_1.merge(graph_set_2)

        self.assertEqual(graph_set_1.name, "graph-1")
        self.assertEqual(graph_set_1.version, "1")
        self.assertEqual(graph_set_1.start_time, 10)
        self.assertEqual(graph_set_1.end_time, 25)
        self.assertCountEqual(graph_set_1.errors,
                              ["errora1", "errora2", "errorb1", "errorb2"])
        expected_resource_dicts = [
            {
                "type": "test:a",
                "links": [{
                    "pred": "has-foo",
                    "obj": "goo",
                    "type": "simple"
                }]
            },
            {
                "type": "test:a"
            },
            {
                "type": "test:b",
                "links": [{
                    "pred": "has-a",
                    "obj": "123",
                    "type": "resource_link"
                }]
            },
            {
                "type": "test:b",
                "links": [{
                    "pred": "name",
                    "obj": "sue",
                    "type": "simple"
                }]
            },
        ]
        resource_dicts = [
            resource.to_dict() for resource in graph_set_1.resources
        ]
        self.assertCountEqual(expected_resource_dicts, resource_dicts)
Beispiel #17
0
class TestGraphSetWithValidDataNoMerging(TestCase):
    def setUp(self):
        resource_a1 = Resource(resource_id="123",
                               type_name="test:a",
                               links=[SimpleLink(pred="has-foo", obj="goo")])
        resource_a2 = Resource(resource_id="456", type_name="test:a")
        resource_b1 = Resource(
            resource_id="abc",
            type_name="test:b",
            links=[ResourceLinkLink(pred="has-a", obj="123")])
        resource_b2 = Resource(resource_id="def",
                               type_name="test:b",
                               links=[SimpleLink(pred="name", obj="sue")])
        resources = [resource_a1, resource_a2, resource_b1, resource_b2]
        self.graph_set = GraphSet(
            name="test-name",
            version="1",
            start_time=1234,
            end_time=4567,
            resources=resources,
            errors=["test err 1", "test err 2"],
            stats=MultilevelCounter(),
        )

    def test_rdf_a_type(self):
        graph = self.graph_set.to_rdf()

        a_results = graph.query(
            "select ?p ?o where {?s a <test-name:test:a> ; ?p ?o} order by ?p ?o"
        )
        expected_a_result_tuples = [
            ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
             "test-name:test:a"),
            ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
             "test-name:test:a"),
            ("test-name:has-foo", "goo"),
            ("test-name:id", "123"),
            ("test-name:id", "456"),
        ]
        a_result_tuples = []
        for a_result in a_results:
            self.assertEqual(2, len(a_result))
            a_result_tuples.append((str(a_result[0]), str(a_result[1])))
        self.assertEqual(expected_a_result_tuples, a_result_tuples)

    def test_rdf_b_type(self):
        graph = self.graph_set.to_rdf()
        graph.serialize("/tmp/test.rdf")
        linked_a_node_results = graph.query(
            "select ?s where {?s a <test-name:test:a>; <test-name:id> '123' }")
        self.assertEqual(len(linked_a_node_results), 1)
        for linked_a_node_result in linked_a_node_results:
            linked_a_node = str(linked_a_node_result[0])
        b_results = graph.query(
            "select ?p ?o where {?s a <test-name:test:b> ; ?p ?o} order by ?p ?o"
        )
        expected_b_result_tuples = [
            ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
             "test-name:test:b"),
            ("http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
             "test-name:test:b"),
            ("test-name:has-a", str(linked_a_node)),
            ("test-name:id", "abc"),
            ("test-name:id", "def"),
            ("test-name:name", "sue"),
        ]
        b_result_tuples = []
        for b_result in b_results:
            self.assertEqual(2, len(b_result))
            b_result_tuples.append((str(b_result[0]), str(b_result[1])))
        self.assertEqual(expected_b_result_tuples, b_result_tuples)

    def test_rdf_error_graphing(self):
        graph = self.graph_set.to_rdf()

        err_results = graph.query(
            "select ?o where { ?s <test-name:error> ?o } order by ?o")
        err_strs = []
        expected_err_strs = ["test err 1", "test err 2"]
        for err_result in err_results:
            self.assertEqual(1, len(err_result))
            err_strs.append(str(err_result[0]))
        self.assertEqual(err_strs, expected_err_strs)

    def test_to_dict(self):
        expected_dict = {
            "name": "test-name",
            "version": "1",
            "start_time": 1234,
            "end_time": 4567,
            "resources": {
                "123": {
                    "type": "test:a",
                    "links": [{
                        "pred": "has-foo",
                        "obj": "goo",
                        "type": "simple"
                    }],
                },
                "456": {
                    "type": "test:a"
                },
                "abc": {
                    "type":
                    "test:b",
                    "links": [{
                        "pred": "has-a",
                        "obj": "123",
                        "type": "resource_link"
                    }],
                },
                "def": {
                    "type": "test:b",
                    "links": [{
                        "pred": "name",
                        "obj": "sue",
                        "type": "simple"
                    }],
                },
            },
            "errors": ["test err 1", "test err 2"],
            "stats": {
                "count": 0
            },
        }
        self.assertDictEqual(expected_dict, self.graph_set.to_dict())

    def test_from_dict(self):
        input_dict = {
            "name": "test-name",
            "version": "1",
            "start_time": 1234,
            "end_time": 4567,
            "resources": {
                "123": {
                    "type": "test:a",
                    "links": [{
                        "pred": "has-foo",
                        "obj": "goo",
                        "type": "simple"
                    }],
                },
                "456": {
                    "type": "test:a"
                },
                "abc": {
                    "type":
                    "test:b",
                    "links": [{
                        "pred": "has-a",
                        "obj": "123",
                        "type": "resource_link"
                    }],
                },
                "def": {
                    "type": "test:b",
                    "links": [{
                        "pred": "name",
                        "obj": "sue",
                        "type": "simple"
                    }],
                },
            },
            "errors": ["test err 1", "test err 2"],
            "stats": {
                "count": 0
            },
        }
        graph_set = GraphSet.from_dict(input_dict)
        self.assertEqual(graph_set.to_dict(), input_dict)

    def test_validate(self):
        self.graph_set.validate()
Beispiel #18
0
    def write_graph_set(self,
                        name: str,
                        graph_set: GraphSet,
                        compression: Optional[str] = None) -> str:
        """Write a graph artifact

        Args:
            name: name
            graph_set: GraphSet to write

        Returns:
            path to written artifact
        """
        logger = Logger()
        if compression is None:
            key = f"{name}.rdf"
        elif compression == GZIP:
            key = f"{name}.rdf.gz"
        else:
            raise ValueError(f"Unknown compression arg {compression}")
        output_key = "/".join((self.key_prefix, key))
        graph = graph_set.to_rdf()
        with logger.bind(bucket=self.bucket,
                         key_prefix=self.key_prefix,
                         key=key):
            logger.info(event=LogEvent.WriteToS3Start)
            with io.BytesIO() as rdf_bytes_buf:
                if compression is None:
                    graph.serialize(rdf_bytes_buf)
                elif compression == GZIP:
                    with gzip.GzipFile(fileobj=rdf_bytes_buf, mode="wb") as gz:
                        graph.serialize(gz)
                else:
                    raise ValueError(f"Unknown compression arg {compression}")
                rdf_bytes_buf.flush()
                rdf_bytes_buf.seek(0)
                session = boto3.Session()
                s3_client = session.client("s3")
                s3_client.upload_fileobj(rdf_bytes_buf, self.bucket,
                                         output_key)
            s3_client.put_object_tagging(
                Bucket=self.bucket,
                Key=output_key,
                Tagging={
                    "TagSet": [
                        {
                            "Key": "name",
                            "Value": graph_set.name
                        },
                        {
                            "Key": "version",
                            "Value": graph_set.version
                        },
                        {
                            "Key": "start_time",
                            "Value": str(graph_set.start_time)
                        },
                        {
                            "Key": "end_time",
                            "Value": str(graph_set.end_time)
                        },
                    ]
                },
            )
            logger.info(event=LogEvent.WriteToS3End)
        return f"s3://{self.bucket}/{output_key}"
Beispiel #19
0
def scan(
    muxer: AWSScanMuxer,
    account_ids: List[str],
    regions: List[str],
    accessor: Accessor,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
    scan_sub_accounts: bool = False,
) -> ScanManifest:
    if scan_sub_accounts:
        account_ids = get_sub_account_ids(account_ids, accessor)
    account_scan_plans = build_account_scan_plans(accessor=accessor,
                                                  account_ids=account_ids,
                                                  regions=regions)
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    account_scan_manifests = muxer.scan(account_scan_plans=account_scan_plans)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[Dict[str, str]] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: List[Dict[str, str]] = []
    stats = MultilevelCounter()

    for account_scan_manifest in account_scan_manifests:
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.artifacts:
            artifacts += account_scan_manifest.artifacts
            if account_scan_manifest.errors:
                errors[account_id] = account_scan_manifest.errors
                unscanned_accounts.append(account_id)
            else:
                scanned_accounts.append(account_id)
        else:
            unscanned_accounts.append(account_id)
        account_stats = MultilevelCounter.from_dict(
            account_scan_manifest.api_call_stats)
        stats.merge(account_stats)
    graph_set = None
    for artifact_path in artifacts:
        artifact_dict = artifact_reader.read_artifact(artifact_path)
        artifact_graph_set = GraphSet.from_dict(artifact_dict)
        if graph_set is None:
            graph_set = artifact_graph_set
        else:
            graph_set.merge(artifact_graph_set)
    master_artifact_path = None
    if graph_set:
        master_artifact_path = artifact_writer.write_artifact(
            name="master", data=graph_set.to_dict())
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    if graph_set:
        start_time = graph_set.start_time
        end_time = graph_set.end_time
    else:
        start_time, end_time = None, None
    return ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=unscanned_accounts,
        api_call_stats=stats.to_dict(),
        start_time=start_time,
        end_time=end_time,
    )
Beispiel #20
0
def run_scan(
    muxer: AWSScanMuxer,
    config: Config,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
) -> Tuple[ScanManifest, GraphSet]:
    if config.scan.scan_sub_accounts:
        account_ids = get_sub_account_ids(config.scan.accounts,
                                          config.access.accessor)
    else:
        account_ids = config.scan.accounts
    account_scan_plan = AccountScanPlan(account_ids=account_ids,
                                        regions=config.scan.regions,
                                        accessor=config.access.accessor)
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[str] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: List[str] = []
    stats = MultilevelCounter()
    graph_set = None

    for account_scan_manifest in muxer.scan(
            account_scan_plan=account_scan_plan):
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.artifacts:
            for account_scan_artifact in account_scan_manifest.artifacts:
                artifacts.append(account_scan_artifact)
                artifact_graph_set_dict = artifact_reader.read_json(
                    account_scan_artifact)
                artifact_graph_set = GraphSet.from_dict(
                    artifact_graph_set_dict)
                if graph_set is None:
                    graph_set = artifact_graph_set
                else:
                    graph_set.merge(artifact_graph_set)
            if account_scan_manifest.errors:
                errors[account_id] = account_scan_manifest.errors
                unscanned_accounts.append(account_id)
            else:
                scanned_accounts.append(account_id)
        else:
            unscanned_accounts.append(account_id)
        account_stats = MultilevelCounter.from_dict(
            account_scan_manifest.api_call_stats)
        stats.merge(account_stats)
    if graph_set is None:
        raise Exception("BUG: No graph_set generated.")
    master_artifact_path = artifact_writer.write_json(name="master",
                                                      data=graph_set.to_dict())
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    start_time = graph_set.start_time
    end_time = graph_set.end_time
    scan_manifest = ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=unscanned_accounts,
        api_call_stats=stats.to_dict(),
        start_time=start_time,
        end_time=end_time,
    )
    artifact_writer.write_json("manifest", data=scan_manifest.to_dict())
    return scan_manifest, graph_set
def run_scan(
    muxer: AWSScanMuxer,
    config: AWSConfig,
    aws_resource_region_mapping_repo: AWSResourceRegionMappingRepository,
    artifact_writer: ArtifactWriter,
    artifact_reader: ArtifactReader,
) -> Tuple[ScanManifest, ValidatedGraphSet]:
    if config.scan.accounts:
        scan_account_ids = config.scan.accounts
    else:
        sts_client = boto3.client("sts")
        scan_account_id = sts_client.get_caller_identity()["Account"]
        scan_account_ids = (scan_account_id,)
    if config.scan.scan_sub_accounts:
        account_ids = get_sub_account_ids(scan_account_ids, config.accessor)
    else:
        account_ids = scan_account_ids
    scan_plan = ScanPlan(
        account_ids=account_ids,
        regions=config.scan.regions,
        aws_resource_region_mapping_repo=aws_resource_region_mapping_repo,
        accessor=config.accessor,
    )
    logger = Logger()
    logger.info(event=AWSLogEvents.ScanAWSAccountsStart)
    # now combine account_scan_results and org_details to build a ScanManifest
    scanned_accounts: List[str] = []
    artifacts: List[str] = []
    errors: Dict[str, List[str]] = {}
    unscanned_accounts: Set[str] = set()
    graph_sets: List[GraphSet] = []

    for account_scan_manifest in muxer.scan(scan_plan=scan_plan):
        account_id = account_scan_manifest.account_id
        if account_scan_manifest.errors:
            errors[account_id] = account_scan_manifest.errors
            unscanned_accounts.add(account_id)
        if account_scan_manifest.artifacts:
            for account_scan_artifact in account_scan_manifest.artifacts:
                artifacts.append(account_scan_artifact)
                artifact_graph_set_dict = artifact_reader.read_json(account_scan_artifact)
                graph_sets.append(GraphSet.parse_obj(artifact_graph_set_dict))
            scanned_accounts.append(account_id)
        else:
            unscanned_accounts.add(account_id)
    if not graph_sets:
        raise Exception("BUG: No graph_sets generated.")
    validated_graph_set = ValidatedGraphSet.from_graph_set(GraphSet.from_graph_sets(graph_sets))
    master_artifact_path: Optional[str] = None
    if config.write_master_json:
        master_artifact_path = artifact_writer.write_json(name="master", data=validated_graph_set)
    logger.info(event=AWSLogEvents.ScanAWSAccountsEnd)
    start_time = validated_graph_set.start_time
    end_time = validated_graph_set.end_time
    scan_manifest = ScanManifest(
        scanned_accounts=scanned_accounts,
        master_artifact=master_artifact_path,
        artifacts=artifacts,
        errors=errors,
        unscanned_accounts=list(unscanned_accounts),
        start_time=start_time,
        end_time=end_time,
    )
    artifact_writer.write_json("manifest", data=scan_manifest)
    return scan_manifest, validated_graph_set
Beispiel #22
0
 def scan(self) -> AccountScanResult:
     logger = Logger()
     now = int(time.time())
     prescan_errors: List[str] = []
     futures: List[Future] = []
     account_id = self.account_scan_plan.account_id
     with logger.bind(account_id=account_id):
         with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
             logger.info(event=AWSLogEvents.ScanAWSAccountStart)
             try:
                 session = self.account_scan_plan.accessor.get_session(
                     account_id=account_id)
                 # sanity check
                 sts_client = session.client("sts")
                 sts_account_id = sts_client.get_caller_identity(
                 )["Account"]
                 if sts_account_id != account_id:
                     raise ValueError(
                         f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                     )
                 if self.account_scan_plan.regions:
                     account_scan_regions = tuple(
                         self.account_scan_plan.regions)
                 else:
                     account_scan_regions = get_all_enabled_regions(
                         session=session)
                 # build a dict of regions -> services -> List[AWSResourceSpec]
                 regions_services_resource_spec_classes: DefaultDict[
                     str, DefaultDict[
                         str, List[Type[AWSResourceSpec]]]] = defaultdict(
                             lambda: defaultdict(list))
                 for resource_spec_class in self.resource_spec_classes:
                     resource_regions = self.account_scan_plan.aws_resource_region_mapping_repo.get_regions(
                         resource_spec_class=resource_spec_class,
                         region_whitelist=account_scan_regions,
                     )
                     for region in resource_regions:
                         regions_services_resource_spec_classes[region][
                             resource_spec_class.service_name].append(
                                 resource_spec_class)
                 # Build and submit ScanUnits
                 shuffed_regions_services_resource_spec_classes = random.sample(
                     regions_services_resource_spec_classes.items(),
                     len(regions_services_resource_spec_classes),
                 )
                 for (
                         region,
                         services_resource_spec_classes,
                 ) in shuffed_regions_services_resource_spec_classes:
                     region_session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id, region_name=region)
                     region_creds = region_session.get_credentials()
                     shuffled_services_resource_spec_classes = random.sample(
                         services_resource_spec_classes.items(),
                         len(services_resource_spec_classes),
                     )
                     for (
                             service,
                             svc_resource_spec_classes,
                     ) in shuffled_services_resource_spec_classes:
                         parallel_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if svc_resource_spec_class.parallel_scan
                         ]
                         serial_svc_resource_spec_classes = [
                             svc_resource_spec_class
                             for svc_resource_spec_class in
                             svc_resource_spec_classes
                             if not svc_resource_spec_class.parallel_scan
                         ]
                         for (parallel_svc_resource_spec_class
                              ) in parallel_svc_resource_spec_classes:
                             parallel_future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=(
                                     parallel_svc_resource_spec_class, ),
                             )
                             futures.append(parallel_future)
                         serial_future = schedule_scan(
                             executor=executor,
                             graph_name=self.graph_name,
                             graph_version=self.graph_version,
                             account_id=account_id,
                             region_name=region,
                             service=service,
                             access_key=region_creds.access_key,
                             secret_key=region_creds.secret_key,
                             token=region_creds.token,
                             resource_spec_classes=tuple(
                                 serial_svc_resource_spec_classes),
                         )
                         futures.append(serial_future)
             except Exception as ex:
                 error_str = str(ex)
                 trace_back = traceback.format_exc()
                 logger.error(
                     event=AWSLogEvents.ScanAWSAccountError,
                     error=error_str,
                     trace_back=trace_back,
                 )
                 prescan_errors.append(f"{error_str}\n{trace_back}")
         graph_sets: List[GraphSet] = []
         for future in as_completed(futures):
             graph_set = future.result()
             graph_sets.append(graph_set)
         # if there was a prescan error graph it and return the result
         if prescan_errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=prescan_errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=prescan_errors,
             )
             output_artifact = self.artifact_writer.write_json(
                 name=account_id,
                 data=account_graph_set,
             )
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             return AccountScanResult(
                 account_id=account_id,
                 artifacts=[output_artifact],
                 errors=prescan_errors,
             )
         # if there are any errors whatsoever we generate an empty graph with errors only
         errors = []
         for graph_set in graph_sets:
             errors += graph_set.errors
         if errors:
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
             )
         else:
             account_graph_set = GraphSet.from_graph_sets(graph_sets)
         output_artifact = self.artifact_writer.write_json(
             name=account_id,
             data=account_graph_set,
         )
         logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
         return AccountScanResult(
             account_id=account_id,
             artifacts=[output_artifact],
             errors=errors,
         )
Beispiel #23
0
    def test(self):
        with tempfile.TemporaryDirectory() as temp_dir:
            resource_region_name = "us-east-1"
            # get moto"s enabled regions
            ec2_client = boto3.client("ec2", region_name=resource_region_name)
            all_regions = ec2_client.describe_regions(
                Filters=[{
                    "Name": "opt-in-status",
                    "Values": ["opt-in-not-required", "opted-in"]
                }])["Regions"]
            account_id = get_account_id()
            all_region_names = tuple(region["RegionName"]
                                     for region in all_regions)
            enabled_region_names = tuple(
                region["RegionName"] for region in all_regions
                if region["OptInStatus"] != "not-opted-in")
            delete_vpcs(all_region_names)
            # add a diverse set of resources which are supported by moto
            ## dynamodb
            # TODO moto is not returning TableId in list/describe
            #            dynamodb_table_1_arn = create_dynamodb_table(
            #                name="test_table_1",
            #                attr_name="test_hash_key_attr_1",
            #                attr_type="S",
            #                key_type="HASH",
            #                region_name=region_name,
            #            )
            ## s3
            bucket_1_name = "test_bucket"
            bucket_1_arn, bucket_1_creation_date = create_bucket(
                name=bucket_1_name,
                account_id=account_id,
                region_name=resource_region_name)
            ## ec2
            vpc_1_cidr = "10.0.0.0/16"
            vpc_1_id = create_vpc(cidr_block=vpc_1_cidr,
                                  region_name=resource_region_name)
            vpc_1_arn = VPCResourceSpec.generate_arn(
                resource_id=vpc_1_id,
                account_id=account_id,
                region=resource_region_name)
            subnet_1_cidr = "10.0.0.0/24"
            subnet_1_cidr_network = ipaddress.IPv4Network(subnet_1_cidr,
                                                          strict=False)
            subnet_1_first_ip, subnet_1_last_ip = (
                int(subnet_1_cidr_network[0]),
                int(subnet_1_cidr_network[-1]),
            )
            subnet_1_id = create_subnet(cidr_block=subnet_1_cidr,
                                        vpc_id=vpc_1_id,
                                        region_name=resource_region_name)
            subnet_1_arn = SubnetResourceSpec.generate_arn(
                resource_id=subnet_1_id,
                account_id=account_id,
                region=resource_region_name)
            fixed_bucket_1_arn = f"arn:aws:s3:::{bucket_1_name}"
            flow_log_1_id, flow_log_1_creation_time = create_flow_log(
                vpc_id=vpc_1_id,
                dest_bucket_arn=fixed_bucket_1_arn,
                region_name=resource_region_name,
            )
            flow_log_1_arn = FlowLogResourceSpec.generate_arn(
                resource_id=flow_log_1_id,
                account_id=account_id,
                region=resource_region_name)
            ebs_volume_1_size = 128
            ebs_volume_1_az = f"{resource_region_name}a"
            ebs_volume_1_arn, ebs_volume_1_create_time = create_volume(
                size=ebs_volume_1_size,
                az=ebs_volume_1_az,
                region_name=resource_region_name)
            ## iam
            policy_1_name = "test_policy_1"
            policy_1_arn, policy_1_id = create_iam_policy(
                name=policy_1_name,
                policy_doc={
                    "Version":
                    "2012-10-17",
                    "Statement": [
                        {
                            "Effect": "Allow",
                            "Action": "logs:CreateLogGroup",
                            "Resource": "*"
                        },
                    ],
                },
            )
            role_1_name = "test_role_1"
            role_1_assume_role_policy_doc = {
                "Version":
                "2012-10-17",
                "Statement": [{
                    "Action": "sts:AssumeRole",
                    "Effect": "Allow",
                    "Principal": {
                        "Service": "lambda.amazonaws.com"
                    },
                    "Sid": "",
                }],
            }
            role_1_description = "Test Role 1"
            role_1_max_session_duration = 3600
            role_1_arn = create_iam_role(
                name=role_1_name,
                assume_role_policy_doc=role_1_assume_role_policy_doc,
                description=role_1_description,
                max_session_duration=role_1_max_session_duration,
            )
            ## lambda
            lambda_function_1_name = "test_lambda_function_1"
            lambda_function_1_runtime = "python3.7"
            lambda_function_1_handler = "lambda_function.lambda_handler"
            lambda_function_1_description = "Test Lambda Function 1"
            lambda_function_1_timeout = 30
            lambda_function_1_memory_size = 256
            lambda_function_1_arn = create_lambda_function(
                name=lambda_function_1_name,
                runtime=lambda_function_1_runtime,
                role_name=role_1_arn,
                handler=lambda_function_1_handler,
                description=lambda_function_1_description,
                timeout=lambda_function_1_timeout,
                memory_size=lambda_function_1_memory_size,
                publish=False,
                region_name=resource_region_name,
            )
            # scan
            test_scan_id = "test_scan_id"
            aws_config = AWSConfig(
                artifact_path=temp_dir,
                pruner_max_age_min=4320,
                graph_name="alti",
                concurrency=ConcurrencyConfig(max_account_scan_threads=1,
                                              max_svc_scan_threads=1,
                                              max_account_scan_tries=2),
                scan=ScanConfig(
                    accounts=(),
                    regions=(),
                    scan_sub_accounts=False,
                    preferred_account_scan_regions=(
                        "us-west-1",
                        "us-west-2",
                        "us-east-1",
                        "us-east-2",
                    ),
                ),
                accessor=Accessor(
                    credentials_cache=AWSCredentialsCache(cache={}),
                    multi_hop_accessors=[],
                    cache_creds=True,
                ),
                write_master_json=True,
            )
            resource_spec_classes = (
                # DynamoDbTableResourceSpec, TODO moto
                EBSVolumeResourceSpec,
                FlowLogResourceSpec,
                IAMPolicyResourceSpec,
                IAMRoleResourceSpec,
                LambdaFunctionResourceSpec,
                S3BucketResourceSpec,
                SubnetResourceSpec,
                VPCResourceSpec,
            )
            muxer = LocalAWSScanMuxer(
                scan_id=test_scan_id,
                config=aws_config,
                resource_spec_classes=resource_spec_classes,
            )
            with unittest.mock.patch(
                    "altimeter.aws.scan.account_scanner.get_all_enabled_regions"
            ) as mock_get_all_enabled_regions:
                mock_get_all_enabled_regions.return_value = enabled_region_names
                aws2n_result = aws2n(
                    scan_id=test_scan_id,
                    config=aws_config,
                    muxer=muxer,
                    load_neptune=False,
                )
                graph_set = GraphSet.from_json_file(
                    Path(aws2n_result.json_path))
                self.assertEqual(len(graph_set.errors), 0)
                self.assertEqual(graph_set.name, "alti")
                self.assertEqual(graph_set.version, "2")
                # now check each resource type
                self.maxDiff = None
                ## Accounts
                expected_account_resources = [
                    Resource(
                        resource_id=f"arn:aws::::account/{account_id}",
                        type="aws:account",
                        link_collection=LinkCollection(
                            simple_links=(SimpleLink(pred="account_id",
                                                     obj=account_id), ), ),
                    )
                ]
                account_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:account"
                ]
                self.assertCountEqual(account_resources,
                                      expected_account_resources)
                ## Regions
                expected_region_resources = [
                    Resource(
                        resource_id=
                        f"arn:aws:::{account_id}:region/{region['RegionName']}",
                        type="aws:region",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="name",
                                           obj=region["RegionName"]),
                                SimpleLink(pred="opt_in_status",
                                           obj=region["OptInStatus"]),
                            ),
                            resource_links=(ResourceLink(
                                pred="account",
                                obj=f"arn:aws::::account/{account_id}"), ),
                        ),
                    ) for region in all_regions
                ]
                region_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:region"
                ]
                self.assertCountEqual(region_resources,
                                      expected_region_resources)
                ## IAM Policies
                expected_iam_policy_resources = [
                    Resource(
                        resource_id=policy_1_arn,
                        type="aws:iam:policy",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="name", obj=policy_1_name),
                                SimpleLink(pred="policy_id", obj=policy_1_id),
                                SimpleLink(pred="default_version_id",
                                           obj="v1"),
                                SimpleLink(
                                    pred="default_version_policy_document_text",
                                    obj=
                                    '{"Statement": [{"Action": "logs:CreateLogGroup", "Effect": "Allow", "Resource": "*"}], "Version": "2012-10-17"}',
                                ),
                            ),
                            resource_links=(ResourceLink(
                                pred="account",
                                obj=f"arn:aws::::account/{account_id}"), ),
                        ),
                    )
                ]
                iam_policy_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:iam:policy"
                ]
                self.assertCountEqual(iam_policy_resources,
                                      expected_iam_policy_resources)
                ## IAM Roles
                expected_iam_role_resources = [
                    Resource(
                        resource_id=role_1_arn,
                        type="aws:iam:role",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="name", obj=role_1_name),
                                SimpleLink(pred="max_session_duration",
                                           obj=role_1_max_session_duration),
                                SimpleLink(pred="description",
                                           obj=role_1_description),
                                SimpleLink(
                                    pred="assume_role_policy_document_text",
                                    obj=policy_doc_dict_to_sorted_str(
                                        role_1_assume_role_policy_doc),
                                ),
                            ),
                            multi_links=(MultiLink(
                                pred="assume_role_policy_document",
                                obj=LinkCollection(
                                    simple_links=(SimpleLink(
                                        pred="version", obj="2012-10-17"), ),
                                    multi_links=(MultiLink(
                                        pred="statement",
                                        obj=LinkCollection(
                                            simple_links=(
                                                SimpleLink(pred="effect",
                                                           obj="Allow"),
                                                SimpleLink(
                                                    pred="action",
                                                    obj="sts:AssumeRole"),
                                            ),
                                            multi_links=(MultiLink(
                                                pred="principal",
                                                obj=LinkCollection(
                                                    simple_links=(SimpleLink(
                                                        pred="service",
                                                        obj=
                                                        "lambda.amazonaws.com",
                                                    ), )),
                                            ), ),
                                        ),
                                    ), ),
                                ),
                            ), ),
                            resource_links=(ResourceLink(
                                pred="account",
                                obj="arn:aws::::account/123456789012"), ),
                        ),
                    )
                ]
                iam_role_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:iam:role"
                ]
                self.assertCountEqual(iam_role_resources,
                                      expected_iam_role_resources)

                ## Lambda functions
                expected_lambda_function_resources = [
                    Resource(
                        resource_id=lambda_function_1_arn,
                        type="aws:lambda:function",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="function_name",
                                           obj=lambda_function_1_name),
                                SimpleLink(pred="runtime",
                                           obj=lambda_function_1_runtime),
                            ),
                            resource_links=(
                                ResourceLink(
                                    pred="account",
                                    obj=f"arn:aws::::account/{account_id}"),
                                ResourceLink(
                                    pred="region",
                                    obj=
                                    f"arn:aws:::{account_id}:region/{resource_region_name}",
                                ),
                            ),
                            transient_resource_links=(ResourceLink(
                                pred="role",
                                obj="arn:aws:iam::123456789012:role/test_role_1"
                            ), ),
                        ),
                    ),
                ]
                lambda_function_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:lambda:function"
                ]
                self.assertCountEqual(lambda_function_resources,
                                      expected_lambda_function_resources)
                ## EC2 VPCs
                expected_ec2_vpc_resources = [
                    Resource(
                        resource_id=vpc_1_arn,
                        type="aws:ec2:vpc",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="is_default", obj=True),
                                SimpleLink(pred="cidr_block", obj=vpc_1_cidr),
                                SimpleLink(pred="state", obj="available"),
                            ),
                            resource_links=(
                                ResourceLink(
                                    pred="account",
                                    obj=f"arn:aws::::account/{account_id}"),
                                ResourceLink(
                                    pred="region",
                                    obj=
                                    f"arn:aws:::{account_id}:region/{resource_region_name}",
                                ),
                            ),
                        ),
                    )
                ]
                ec2_vpc_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:ec2:vpc"
                ]
                self.assertCountEqual(ec2_vpc_resources,
                                      expected_ec2_vpc_resources)
                ## EC2 VPC Flow Logs
                expected_ec2_vpc_flow_log_resources = [
                    Resource(
                        resource_id=flow_log_1_arn,
                        type="aws:ec2:flow-log",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(
                                    pred="creation_time",
                                    obj=flow_log_1_creation_time.replace(
                                        tzinfo=datetime.timezone.utc).
                                    isoformat(),
                                ),
                                SimpleLink(pred="deliver_logs_status",
                                           obj="SUCCESS"),
                                SimpleLink(pred="flow_log_status",
                                           obj="ACTIVE"),
                                SimpleLink(pred="traffic_type", obj="ALL"),
                                SimpleLink(pred="log_destination_type",
                                           obj="s3"),
                                SimpleLink(pred="log_destination",
                                           obj=fixed_bucket_1_arn),
                                SimpleLink(
                                    pred="log_format",
                                    obj=
                                    "${version} ${account-id} ${interface-id} ${srcaddr} ${dstaddr} ${srcport} ${dstport} ${protocol} ${packets} ${bytes} ${start} ${end} ${action} ${log-status}",
                                ),
                            ),
                            resource_links=(
                                ResourceLink(
                                    pred="account",
                                    obj=f"arn:aws::::account/{account_id}"),
                                ResourceLink(
                                    pred="region",
                                    obj=
                                    f"arn:aws:::{account_id}:region/{resource_region_name}",
                                ),
                            ),
                            transient_resource_links=(TransientResourceLink(
                                pred="vpc",
                                obj=vpc_1_arn,
                            ), ),
                        ),
                    )
                ]
                ec2_vpc_flow_log_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:ec2:flow-log"
                ]
                self.assertCountEqual(ec2_vpc_flow_log_resources,
                                      expected_ec2_vpc_flow_log_resources)
                ## EC2 Subnets
                expected_ec2_subnet_resources = [
                    Resource(
                        resource_id=subnet_1_arn,
                        type="aws:ec2:subnet",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="cidr_block",
                                           obj=subnet_1_cidr),
                                SimpleLink(pred="first_ip",
                                           obj=subnet_1_first_ip),
                                SimpleLink(pred="last_ip",
                                           obj=subnet_1_last_ip),
                                SimpleLink(pred="state", obj="available"),
                            ),
                            resource_links=(
                                ResourceLink(pred="vpc", obj=vpc_1_arn),
                                ResourceLink(
                                    pred="account",
                                    obj=f"arn:aws::::account/{account_id}"),
                                ResourceLink(
                                    pred="region",
                                    obj=
                                    f"arn:aws:::{account_id}:region/{resource_region_name}",
                                ),
                            ),
                        ),
                    )
                ]
                ec2_subnet_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:ec2:subnet"
                ]
                self.assertCountEqual(ec2_subnet_resources,
                                      expected_ec2_subnet_resources)
                ## EC2 EBS Volumes
                expected_ec2_ebs_volume_resources = [
                    Resource(
                        resource_id=ebs_volume_1_arn,
                        type="aws:ec2:volume",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="availability_zone",
                                           obj=ebs_volume_1_az),
                                SimpleLink(
                                    pred="create_time",
                                    obj=ebs_volume_1_create_time.replace(
                                        tzinfo=datetime.timezone.utc).
                                    isoformat(),
                                ),
                                SimpleLink(pred="size", obj=ebs_volume_1_size),
                                SimpleLink(pred="state", obj="available"),
                                SimpleLink(pred="volume_type", obj="gp2"),
                                SimpleLink(pred="encrypted", obj=False),
                            ),
                            resource_links=(
                                ResourceLink(
                                    pred="account",
                                    obj=f"arn:aws::::account/{account_id}"),
                                ResourceLink(
                                    pred="region",
                                    obj=
                                    f"arn:aws:::{account_id}:region/{resource_region_name}",
                                ),
                            ),
                        ),
                    )
                ]
                ec2_ebs_volume_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:ec2:volume"
                ]
                self.assertCountEqual(ec2_ebs_volume_resources,
                                      expected_ec2_ebs_volume_resources)
                ## S3 Buckets
                expected_s3_bucket_resources = [
                    Resource(
                        resource_id=bucket_1_arn,
                        type="aws:s3:bucket",
                        link_collection=LinkCollection(
                            simple_links=(
                                SimpleLink(pred="name", obj=bucket_1_name),
                                SimpleLink(
                                    pred="creation_date",
                                    obj=bucket_1_creation_date.replace(
                                        tzinfo=datetime.timezone.utc).
                                    isoformat(),
                                ),
                            ),
                            resource_links=(
                                ResourceLink(
                                    pred="account",
                                    obj=f"arn:aws::::account/{account_id}"),
                                ResourceLink(
                                    pred="region",
                                    obj=
                                    f"arn:aws:::{account_id}:region/{resource_region_name}",
                                ),
                            ),
                        ),
                    )
                ]
                s3_bucket_resources = [
                    resource for resource in graph_set.resources
                    if resource.type == "aws:s3:bucket"
                ]
                self.assertCountEqual(s3_bucket_resources,
                                      expected_s3_bucket_resources)

                expected_num_graph_set_resources = (
                    0 + len(expected_account_resources) +
                    len(expected_region_resources) +
                    len(expected_iam_policy_resources) +
                    len(expected_iam_role_resources) +
                    len(expected_lambda_function_resources) +
                    len(expected_ec2_ebs_volume_resources) +
                    len(expected_ec2_subnet_resources) +
                    len(expected_ec2_vpc_resources) +
                    len(expected_ec2_vpc_flow_log_resources) +
                    len(expected_s3_bucket_resources))
                self.assertEqual(len(graph_set.resources),
                                 expected_num_graph_set_resources)
Beispiel #24
0
    def scan(self) -> Dict[str, Any]:
        """Scan an account and return a dict containing keys:

            * account_id: str
            * output_artifact: str
            * api_call_stats: Dict[str, Any]
            * errors: List[str]

        If errors is non-empty the results are incomplete for this account.
        output_artifact is a pointer to the actual scan data - either on the local fs or in s3.

        To scan an account we create a set of GraphSpecs, one for each region.  Any ACCOUNT
        level granularity resources are only scanned in a single region (e.g. IAM Users)

        Returns:
            Dict of scan result, see above for details.
        """
        logger = Logger()
        with logger.bind(account_id=self.account_id):
            logger.info(event=AWSLogEvents.ScanAWSAccountStart)
            output_artifact = None
            stats = MultilevelCounter()
            errors: List[str] = []
            now = int(time.time())
            try:
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[],
                    errors=[],
                    stats=stats,
                )
                # sanity check
                session = self.get_session()
                sts_client = session.client("sts")
                sts_account_id = sts_client.get_caller_identity()["Account"]
                if sts_account_id != self.account_id:
                    raise ValueError(
                        f"BUG: sts detected account_id {sts_account_id} != {self.account_id}"
                    )
                if self.regions:
                    scan_regions = tuple(self.regions)
                else:
                    scan_regions = get_all_enabled_regions(session=session)
                # build graph specs.
                # build a dict of regions -> services -> List[AWSResourceSpec]
                regions_services_resource_spec_classes: DefaultDict[
                    str,
                    DefaultDict[str,
                                List[Type[AWSResourceSpec]]]] = defaultdict(
                                    lambda: defaultdict(list))
                resource_spec_class: Type[AWSResourceSpec]
                for resource_spec_class in self.resource_spec_classes:
                    client_name = resource_spec_class.get_client_name()
                    resource_class_scan_granularity = resource_spec_class.scan_granularity
                    if resource_class_scan_granularity == ScanGranularity.ACCOUNT:
                        regions_services_resource_spec_classes[scan_regions[
                            0]][client_name].append(resource_spec_class)
                    elif resource_class_scan_granularity == ScanGranularity.REGION:
                        for region in scan_regions:
                            regions_services_resource_spec_classes[region][
                                client_name].append(resource_spec_class)
                    else:
                        raise NotImplementedError(
                            f"ScanGranularity {resource_class_scan_granularity} not implemented"
                        )
                with ThreadPoolExecutor(
                        max_workers=self.max_svc_threads) as executor:
                    futures = []
                    for (
                            region,
                            services_resource_spec_classes,
                    ) in regions_services_resource_spec_classes.items():
                        for (
                                service,
                                resource_spec_classes,
                        ) in services_resource_spec_classes.items():
                            region_session = self.get_session(region=region)
                            region_creds = region_session.get_credentials()
                            scan_future = schedule_scan_services(
                                executor=executor,
                                graph_name=self.graph_name,
                                graph_version=self.graph_version,
                                account_id=self.account_id,
                                region=region,
                                service=service,
                                access_key=region_creds.access_key,
                                secret_key=region_creds.secret_key,
                                token=region_creds.token,
                                resource_spec_classes=tuple(
                                    resource_spec_classes),
                            )
                            futures.append(scan_future)
                    for future in as_completed(futures):
                        graph_set_dict = future.result()
                        graph_set = GraphSet.from_dict(graph_set_dict)
                        errors += graph_set.errors
                        account_graph_set.merge(graph_set)
                    account_graph_set.validate()
            except Exception as ex:
                error_str = str(ex)
                trace_back = traceback.format_exc()
                logger.error(event=AWSLogEvents.ScanAWSAccountError,
                             error=error_str,
                             trace_back=trace_back)
                errors.append(" : ".join((error_str, trace_back)))
                unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                    account_id=self.account_id, errors=errors)
                account_graph_set = GraphSet(
                    name=self.graph_name,
                    version=self.graph_version,
                    start_time=now,
                    end_time=now,
                    resources=[unscanned_account_resource],
                    errors=errors,
                    stats=stats,
                )
                account_graph_set.validate()
        output_artifact = self.artifact_writer.write_artifact(
            name=self.account_id, data=account_graph_set.to_dict())
        logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
        api_call_stats = account_graph_set.stats.to_dict()
        return {
            "account_id": self.account_id,
            "output_artifact": output_artifact,
            "errors": errors,
            "api_call_stats": api_call_stats,
        }
Beispiel #25
0
    def test_valid_merge(self):
        resource_a1 = Resource(
            resource_id="123",
            type="test:a",
            link_collection=LinkCollection(simple_links=[SimpleLink(pred="has-foo", obj="goo")]),
        )
        resource_a2 = Resource(resource_id="456", type="test:a", link_collection=LinkCollection())
        resource_b1 = Resource(
            resource_id="abc",
            type="test:b",
            link_collection=LinkCollection(simple_links=[SimpleLink(pred="has-a", obj="123")]),
        )
        resource_b2 = Resource(
            resource_id="def",
            type="test:b",
            link_collection=LinkCollection(simple_links=[SimpleLink(pred="name", obj="sue")]),
        )
        graph_set_1 = GraphSet(
            name="graph-1",
            version="1",
            start_time=10,
            end_time=20,
            resources=[resource_a1, resource_a2],
            errors=["errora1", "errora2"],
        )
        graph_set_2 = GraphSet(
            name="graph-1",
            version="1",
            start_time=15,
            end_time=25,
            resources=[resource_b1, resource_b2],
            errors=["errorb1", "errorb2"],
        )
        merged_graph_set = ValidatedGraphSet.from_graph_sets([graph_set_1, graph_set_2])

        self.assertEqual(merged_graph_set.name, "graph-1")
        self.assertEqual(merged_graph_set.version, "1")
        self.assertEqual(merged_graph_set.start_time, 10)
        self.assertEqual(merged_graph_set.end_time, 25)
        self.assertCountEqual(merged_graph_set.errors, ["errora1", "errora2", "errorb1", "errorb2"])

        expected_resources = (
            Resource(
                resource_id="123",
                type="test:a",
                link_collection=LinkCollection(
                    simple_links=(SimpleLink(pred="has-foo", obj="goo"),),
                ),
            ),
            Resource(resource_id="456", type="test:a", link_collection=LinkCollection(),),
            Resource(
                resource_id="abc",
                type="test:b",
                link_collection=LinkCollection(
                    simple_links=(SimpleLink(pred="has-a", obj="123"),),
                ),
            ),
            Resource(
                resource_id="def",
                type="test:b",
                link_collection=LinkCollection(simple_links=(SimpleLink(pred="name", obj="sue"),),),
            ),
        )
        expected_errors = ["errora1", "errora2", "errorb1", "errorb2"]

        self.assertCountEqual(merged_graph_set.resources, expected_resources)
        self.assertCountEqual(merged_graph_set.errors, expected_errors)
Beispiel #26
0
 def scan(self) -> List[Dict[str, Any]]:
     logger = Logger()
     scan_result_dicts = []
     now = int(time.time())
     prescan_account_ids_errors: DefaultDict[str,
                                             List[str]] = defaultdict(list)
     futures = []
     with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
         shuffled_account_ids = random.sample(
             self.account_scan_plan.account_ids,
             k=len(self.account_scan_plan.account_ids))
         for account_id in shuffled_account_ids:
             with logger.bind(account_id=account_id):
                 logger.info(event=AWSLogEvents.ScanAWSAccountStart)
                 try:
                     session = self.account_scan_plan.accessor.get_session(
                         account_id=account_id)
                     # sanity check
                     sts_client = session.client("sts")
                     sts_account_id = sts_client.get_caller_identity(
                     )["Account"]
                     if sts_account_id != account_id:
                         raise ValueError(
                             f"BUG: sts detected account_id {sts_account_id} != {account_id}"
                         )
                     if self.account_scan_plan.regions:
                         scan_regions = tuple(
                             self.account_scan_plan.regions)
                     else:
                         scan_regions = get_all_enabled_regions(
                             session=session)
                     account_gran_scan_region = random.choice(
                         self.preferred_account_scan_regions)
                     # build a dict of regions -> services -> List[AWSResourceSpec]
                     regions_services_resource_spec_classes: DefaultDict[
                         str, DefaultDict[
                             str,
                             List[Type[AWSResourceSpec]]]] = defaultdict(
                                 lambda: defaultdict(list))
                     resource_spec_class: Type[AWSResourceSpec]
                     for resource_spec_class in self.resource_spec_classes:
                         client_name = resource_spec_class.get_client_name()
                         if resource_spec_class.scan_granularity == ScanGranularity.ACCOUNT:
                             if resource_spec_class.region_whitelist:
                                 account_resource_scan_region = resource_spec_class.region_whitelist[
                                     0]
                             else:
                                 account_resource_scan_region = account_gran_scan_region
                             regions_services_resource_spec_classes[
                                 account_resource_scan_region][
                                     client_name].append(
                                         resource_spec_class)
                         elif resource_spec_class.scan_granularity == ScanGranularity.REGION:
                             if resource_spec_class.region_whitelist:
                                 resource_scan_regions = tuple(
                                     region for region in scan_regions
                                     if region in
                                     resource_spec_class.region_whitelist)
                                 if not resource_scan_regions:
                                     resource_scan_regions = resource_spec_class.region_whitelist
                             else:
                                 resource_scan_regions = scan_regions
                             for region in resource_scan_regions:
                                 regions_services_resource_spec_classes[
                                     region][client_name].append(
                                         resource_spec_class)
                         else:
                             raise NotImplementedError(
                                 f"ScanGranularity {resource_spec_class.scan_granularity} unimplemented"
                             )
                     # Build and submit ScanUnits
                     shuffed_regions_services_resource_spec_classes = random.sample(
                         regions_services_resource_spec_classes.items(),
                         len(regions_services_resource_spec_classes),
                     )
                     for (
                             region,
                             services_resource_spec_classes,
                     ) in shuffed_regions_services_resource_spec_classes:
                         region_session = self.account_scan_plan.accessor.get_session(
                             account_id=account_id, region_name=region)
                         region_creds = region_session.get_credentials()
                         shuffled_services_resource_spec_classes = random.sample(
                             services_resource_spec_classes.items(),
                             len(services_resource_spec_classes),
                         )
                         for (
                                 service,
                                 svc_resource_spec_classes,
                         ) in shuffled_services_resource_spec_classes:
                             future = schedule_scan(
                                 executor=executor,
                                 graph_name=self.graph_name,
                                 graph_version=self.graph_version,
                                 account_id=account_id,
                                 region_name=region,
                                 service=service,
                                 access_key=region_creds.access_key,
                                 secret_key=region_creds.secret_key,
                                 token=region_creds.token,
                                 resource_spec_classes=tuple(
                                     svc_resource_spec_classes),
                             )
                             futures.append(future)
                 except Exception as ex:
                     error_str = str(ex)
                     trace_back = traceback.format_exc()
                     logger.error(
                         event=AWSLogEvents.ScanAWSAccountError,
                         error=error_str,
                         trace_back=trace_back,
                     )
                     prescan_account_ids_errors[account_id].append(
                         f"{error_str}\n{trace_back}")
     account_ids_graph_set_dicts: Dict[str,
                                       List[Dict[str,
                                                 Any]]] = defaultdict(list)
     for future in as_completed(futures):
         account_id, graph_set_dict = future.result()
         account_ids_graph_set_dicts[account_id].append(graph_set_dict)
     # first make sure no account id appears both in account_ids_graph_set_dicts
     # and prescan_account_ids_errors - this should never happen
     doubled_accounts = set(
         account_ids_graph_set_dicts.keys()).intersection(
             set(prescan_account_ids_errors.keys()))
     if doubled_accounts:
         raise Exception((
             f"BUG: Account(s) {doubled_accounts} in both "
             "account_ids_graph_set_dicts and prescan_account_ids_errors."))
     # graph prescan error accounts
     for account_id, errors in prescan_account_ids_errors.items():
         with logger.bind(account_id=account_id):
             unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                 account_id=account_id, errors=errors)
             account_graph_set = GraphSet(
                 name=self.graph_name,
                 version=self.graph_version,
                 start_time=now,
                 end_time=now,
                 resources=[unscanned_account_resource],
                 errors=errors,
                 stats=MultilevelCounter(),
             )
             account_graph_set.validate()
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     # graph rest
     for account_id, graph_set_dicts in account_ids_graph_set_dicts.items():
         with logger.bind(account_id=account_id):
             # if there are any errors whatsoever we generate an empty graph with
             # errors only
             errors = []
             for graph_set_dict in graph_set_dicts:
                 errors += graph_set_dict["errors"]
             if errors:
                 unscanned_account_resource = UnscannedAccountResourceSpec.create_resource(
                     account_id=account_id, errors=errors)
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[unscanned_account_resource],
                     errors=errors,
                     stats=MultilevelCounter(
                     ),  # ENHANCHMENT: could technically get partial stats.
                 )
                 account_graph_set.validate()
             else:
                 account_graph_set = GraphSet(
                     name=self.graph_name,
                     version=self.graph_version,
                     start_time=now,
                     end_time=now,
                     resources=[],
                     errors=[],
                     stats=MultilevelCounter(),
                 )
                 for graph_set_dict in graph_set_dicts:
                     graph_set = GraphSet.from_dict(graph_set_dict)
                     account_graph_set.merge(graph_set)
             output_artifact = self.artifact_writer.write_json(
                 name=account_id, data=account_graph_set.to_dict())
             logger.info(event=AWSLogEvents.ScanAWSAccountEnd)
             api_call_stats = account_graph_set.stats.to_dict()
             scan_result_dicts.append({
                 "account_id": account_id,
                 "output_artifact": output_artifact,
                 "errors": errors,
                 "api_call_stats": api_call_stats,
             })
     return scan_result_dicts