コード例 #1
0
 def __write_edges(self, g: traversal, edges: List[Dict], scan_id: str) -> None:
     """
     Writes the edges to the labeled property graph
     :param g: The graph traversal source
     :param edges: A list of dictionaries for each edge
     :return: None
     """
     cnt = 0
     t = g
     for r in edges:
         to_id = f'{r["~to"]}_{scan_id}'
         from_id = f'{r["~from"]}_{scan_id}'
         t = (
             t.addE(r["~label"])
             .property(T.id, str(r["~id"]))
             .from_(
                 __.V(from_id)
                 .fold()
                 .coalesce(
                     __.unfold(),
                     __.addV(self.parse_arn(r["~from"])["resource"])
                     .property(T.id, from_id)
                     .property("scan_id", scan_id)
                     .property("arn", r["~from"]),
                 )
             )
             .to(
                 __.V(to_id)
                 .fold()
                 .coalesce(
                     __.unfold(),
                     __.addV(self.parse_arn(r["~to"])["resource"])
                     .property(T.id, to_id)
                     .property("scan_id", scan_id)
                     .property("arn", r["~to"]),
                 )
             )
         )
         cnt += 1
         if cnt % 100 == 0 or cnt == len(edges):
             try:
                 self.logger.info(
                     event=LogEvent.NeptunePeriodicWrite,
                     msg=f"Writing edges {cnt} of {len(edges)}",
                 )
                 t.next()
                 t = g
             except Exception as err:
                 self.logger.error(event=LogEvent.NeptuneLoadError, msg=str(err))
                 raise NeptuneLoadGraphException(
                     f"Error loading edge {r} " f"with {str(t.bytecode)}"
                 ) from err
コード例 #2
0
 def __write_vertices(self, g: traversal, vertices: List[Dict], scan_id: str) -> None:
     """
     Writes the vertices to the labeled property graph
     :param g: The graph traversal source
     :param vertices: A list of dictionaries for each vertex
     :return: None
     """
     cnt = 0
     t = g
     for r in vertices:
         vertex_id = f'{r["~id"]}_{scan_id}'
         t = (
             t.V(vertex_id)
             .fold()
             .coalesce(
                 __.unfold(),
                 __.addV(self.parse_arn(r["~label"])["resource"]).property(T.id, vertex_id),
             )
         )
         for k in r.keys():
             # Need to handle numbers that are bigger than a Long in Java, for now we stringify it
             if isinstance(r[k], int) and (
                 r[k] > 9223372036854775807 or r[k] < -9223372036854775807
             ):
                 r[k] = str(r[k])
             if k not in ["~id", "~label"]:
                 t = t.property(k, r[k])
         cnt += 1
         if cnt % 100 == 0 or cnt == len(vertices):
             try:
                 self.logger.info(
                     event=LogEvent.NeptunePeriodicWrite,
                     msg=f"Writing vertices {cnt} of {len(vertices)}",
                 )
                 t.next()
                 t = g
             except Exception as err:
                 print(str(err))
                 raise NeptuneLoadGraphException(
                     f"Error loading vertex {r} " f"with {str(t.bytecode)}"
                 ) from err
コード例 #3
0
    def load_graph(self, bucket: str, key: str, load_iam_role_arn: str) -> GraphMetadata:
        """Load a graph into Neptune.
        Args:
             bucket: s3 bucket of graph rdf
             key: s3 key of graph rdf
             load_iam_role_arn: arn of iam role used to load the graph

        Returns:
            GraphMetadata object describing loaded graph

        Raises:
            NeptuneLoadGraphException if errors occur during graph load
        """
        session = boto3.Session(region_name=self._neptune_endpoint.region)
        s3_client = session.client("s3")
        rdf_object_tagging = s3_client.get_object_tagging(Bucket=bucket, Key=key)
        tag_set = rdf_object_tagging["TagSet"]
        graph_name = get_required_tag_value(tag_set, "name")
        graph_version = get_required_tag_value(tag_set, "version")
        graph_start_time = int(get_required_tag_value(tag_set, "start_time"))
        graph_end_time = int(get_required_tag_value(tag_set, "end_time"))
        graph_metadata = GraphMetadata(
            uri=f"{GRAPH_BASE_URI}/{graph_name}/{graph_version}/{graph_end_time}",
            name=graph_name,
            version=graph_version,
            start_time=graph_start_time,
            end_time=graph_end_time,
        )
        logger = self.logger
        with logger.bind(
            rdf_bucket=bucket,
            rdf_key=key,
            graph_uri=graph_metadata.uri,
            neptune_endpoint=self._neptune_endpoint.get_endpoint_str(),
        ):
            session = boto3.Session(region_name=self._neptune_endpoint.region)
            credentials = session.get_credentials()
            auth = AWSRequestsAuth(
                aws_access_key=credentials.access_key,
                aws_secret_access_key=credentials.secret_key,
                aws_token=credentials.token,
                aws_host=self._neptune_endpoint.get_endpoint_str(),
                aws_region=self._neptune_endpoint.region,
                aws_service="neptune-db",
            )
            post_body = {
                "source": f"s3://{bucket}/{key}",
                "format": "rdfxml",
                "iamRoleArn": load_iam_role_arn,
                "region": self._neptune_endpoint.region,
                "failOnError": "TRUE",
                "parallelism": "MEDIUM",
                "parserConfiguration": {
                    "baseUri": GRAPH_BASE_URI,
                    "namedGraphUri": graph_metadata.uri,
                },
            }
            logger.info(event=LogEvent.NeptuneLoadStart, post_body=post_body)
            submit_resp = requests.post(
                self._neptune_endpoint.get_loader_endpoint(), json=post_body, auth=auth
            )
            if submit_resp.status_code != 200:
                raise NeptuneLoadGraphException(
                    f"Non 200 from Neptune: {submit_resp.status_code} : {submit_resp.text}"
                )
            submit_resp_json = submit_resp.json()
            load_id = submit_resp_json["payload"]["loadId"]
            with logger.bind(load_id=load_id):
                logger.info(event=LogEvent.NeptuneLoadPolling)
                while True:
                    time.sleep(10)
                    status_resp = requests.get(
                        f"{self._neptune_endpoint.get_loader_endpoint()}/{load_id}",
                        params={"details": "true", "errors": "true"},
                        auth=auth,
                    )
                    if status_resp.status_code != 200:
                        raise NeptuneLoadGraphException(
                            f"Non 200 from Neptune: {status_resp.status_code} : {status_resp.text}"
                        )
                    status_resp_json = status_resp.json()
                    status = status_resp_json["payload"]["overallStatus"]["status"]
                    logger.info(event=LogEvent.NeptuneLoadPolling, status=status)
                    if status == "LOAD_COMPLETED":
                        break
                    if status not in ("LOAD_NOT_STARTED", "LOAD_IN_PROGRESS"):
                        logger.error(event=LogEvent.NeptuneLoadError, status=status)
                        raise NeptuneLoadGraphException(f"Error loading graph: {status_resp_json}")
                logger.info(event=LogEvent.NeptuneLoadEnd)

                logger.info(event=LogEvent.MetadataGraphUpdateStart)
                self._register_graph(graph_metadata=graph_metadata)
                logger.info(event=LogEvent.MetadataGraphUpdateEnd)

                return graph_metadata