def __init__(self, endpoints=None): if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints self.connections = []
class GremlinUtils: @classmethod def init_statics(cls, globals): statics.load_statics(globals) del globals['range'] del globals['map'] del globals['min'] del globals['sum'] del globals['property'] del globals['max'] def __init__(self, endpoints=None): if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints def remote_connection(self, show_endpoint=True): neptune_gremlin_endpoint = self.endpoints.gremlin_endpoint() if show_endpoint: print('gremlin: {}'.format(neptune_gremlin_endpoint)) retry_count = 0 while True: try: return DriverRemoteConnection(neptune_gremlin_endpoint, 'g') except HTTPError as e: exc_info = sys.exc_info() if retry_count < 3: retry_count += 1 print('Connection timeout. Retrying...') else: raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) def traversal_source(self, show_endpoint=True, connection=None): if connection is None: connection = self.remote_connection(show_endpoint) return traversal().withRemote(connection) def client(self, pool_size=None, max_workers=None): return Client(self.endpoints.gremlin_endpoint(), 'g', pool_size=pool_size, max_workers=max_workers) def sessioned_client(self, session_id=None, pool_size=None, max_workers=None): return SessionedClient( self.endpoints.gremlin_endpoint(), 'g', uuid.uuid4().hex if session_id is None else session_id, pool_size=pool_size, max_workers=max_workers)
class BulkLoad: def __init__(self, source, format='csv', role=None, region=None, endpoints=None): self.source = source self.format = format if role is None: assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN' in os.environ), 'role is missing.' self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN'] else: self.role = role if region is None: assert ('AWS_REGION' in os.environ), 'region is missing.' self.region = os.environ['AWS_REGION'] else: self.region = region if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints def __load_from(self, source, format, role, region): return { 'source' : source, 'format' : format, 'iamRoleArn' : role, 'region' : region, 'failOnError' : 'FALSE' } def __load(self, loader_url, data): jsondataasbytes = json.dumps(data).encode('utf8') req = urllib.request.Request(loader_url, data=jsondataasbytes, headers={'Content-Type': 'application/json'}) response = urllib.request.urlopen(req) jsonresponse = json.loads(response.read().decode('utf8')) return jsonresponse['payload']['loadId'] def load_async(self): localised_source = self.source.replace('${AWS_REGION}', self.region) loader_url = self.endpoints.loader_endpoint() json_payload = self.__load_from(localised_source, self.format, self.role, self.region) print('''curl -X POST \\ -H 'Content-Type: application/json' \\ {} -d \'{}\''''.format(loader_url, json.dumps(json_payload, indent=4))) load_id = self.__load(loader_url, json_payload) return BulkLoadStatus(self.endpoints.load_status_endpoint(load_id)) def load(self, interval=2): status = self.load_async() print('status_uri: {}'.format(status.uri())) status.wait(interval)
def neptune_endpoints(self, connection_name): """Gets Neptune endpoint information from the AWS Glue Data Catalog. You may need to install a Glue VPC Endpoint in your VPC for this method to work. You can either create a Glue Connection type of 'JDBC' or 'NETWORK'. When you use Glue Connection Type of 'JDBC' store the Amazon Neptune endpoint in 'JDBC_CONNECTION_URL' field, e.g. 'jdbc:wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin'. When you use Glue Connection Type of 'NETWORK' store the Amazon Neptune endpoint in 'Description' field, e.g. 'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin'. When you invoke the method it returns Neptune endpoint, e.g. 'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' Example: >>> gremlin_endpoint = GlueNeptuneConnectionInfo(glueContext).neptune_endpoint('neptune') """ glue = boto3.client('glue', region_name=self.region) connection = glue.get_connection(Name=connection_name)['Connection'] if connection['ConnectionType'] == "JDBC": neptune_uri = connection['ConnectionProperties']['JDBC_CONNECTION_URL'][5:] if connection['ConnectionType'] == "NETWORK": neptune_uri = connection['Description'] parse_result = requests.utils.urlparse(neptune_uri) netloc_parts = parse_result.netloc.split(':') host = netloc_parts[0] port = netloc_parts[1] return Endpoints(neptune_endpoint=host, neptune_port=port, region_name=self.region, role_arn=self.role_arn)
def neptune_endpoints(self, connection_name): """Gets Neptune endpoint information from the Glue Data Catalog. You may need to install a Glue VPC Endpoint in your VPC for this method to work. You can store Neptune endpoint information as JDBC connections in the Glue Data Catalog. JDBC connection strings must begin 'jdbc:'. To store a Neptune endpoint, use the following format: 'jdbc:<protocol>://<dns_name>:<port>/<endpoint>' For example, if you store: 'jdbc:wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' – this method will return: 'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' Example: >>> gremlin_endpoint = GlueNeptuneConnectionInfo(glueContext).neptune_endpoint('neptune') """ glue = boto3.client('glue', region_name=self.region) connection = glue.get_connection(Name=connection_name) neptune_uri = connection['Connection']['ConnectionProperties'][ 'JDBC_CONNECTION_URL'][5:] parse_result = urlparse(neptune_uri) netloc_parts = parse_result.netloc.split(':') host = netloc_parts[0] port = netloc_parts[1] return Endpoints(neptune_endpoint=host, neptune_port=port, region_name=self.region, role_arn=self.role_arn)
def remoteConnection(self, neptune_endpoint=None, neptune_port=None, show_endpoint=True): connection = GremlinUtils(Endpoints( neptune_endpoint, neptune_port)).remote_connection(show_endpoint) self.connections.append(connection) return connection
def __init__(self, neptune_endpoint, elasticache_endpoint): GremlinUtils.init_statics(globals()) gremlin_utils = GremlinUtils( Endpoints(neptune_endpoint=neptune_endpoint)) self.vertext_metrics = VertexMetrics(elasticache_endpoint) self.neptune_connection = gremlin_utils.remote_connection() self.g = gremlin_utils.traversal_source( connection=self.neptune_connection)
def __init__( self, source, format='csv', role=None, mode='AUTO', region=None, fail_on_error=False, parallelism='OVERSUBSCRIBE', base_uri='http://aws.amazon.com/neptune/default', named_graph_uri='http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph', update_single_cardinality_properties=False, endpoints=None): self.source = source self.format = format if role is None: assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN' in os.environ), 'role is missing.' self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN'] else: self.role = role self.mode = mode if region is None: assert ('AWS_REGION' in os.environ), 'region is missing.' self.region = os.environ['AWS_REGION'] else: self.region = region if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints self.fail_on_error = 'TRUE' if fail_on_error else 'FALSE' self.parallelism = parallelism self.base_uri = base_uri self.named_graph_uri = named_graph_uri self.update_single_cardinality_properties = 'TRUE' if update_single_cardinality_properties else 'FALSE'
def handle_records(self, stream_log): params = json.loads(os.environ['AdditionalParams']) neptune_endpoint = params['neptune_cluster_endpoint'] neptune_port = params['neptune_port'] GremlinUtils.init_statics(globals()) endpoints = Endpoints(neptune_endpoint=neptune_endpoint, neptune_port=neptune_port) gremlin_utils = GremlinUtils(endpoints) conn = gremlin_utils.remote_connection() g = gremlin_utils.traversal_source(connection=conn) records = stream_log[RECORDS_STR] last_op_num = None last_commit_num = None count = 0 try: for record in records: # Process record op = record[OPERATION_STR] data = record[DATA_STR] type = data['type'] id = data['id'] if op == ADD_OPERATION: if type == 'vl': logger.info(g.V(id).valueMap(True).toList()) if type == 'e': logger.info(g.E(id).valueMap(True).toList()) # Update local checkpoint info last_op_num = record[EVENT_ID_STR][OP_NUM_STR] last_commit_num = record[EVENT_ID_STR][COMMIT_NUM_STR] count += 1 except Exception as e: logger.error('Error occurred - {}'.format(str(e))) raise e finally: try: conn.close() yield HandlerResponse(last_op_num, last_commit_num, count) except Exception as e: logger.error('Error occurred - {}'.format(str(e))) raise e finally: conn.close()
def graphTraversal(self, neptune_endpoint=None, neptune_port=None, show_endpoint=True, connection=None): if connection is None: connection = self.remoteConnection(neptune_endpoint, neptune_port, show_endpoint) self.connections.append(connection) return GremlinUtils(Endpoints(neptune_endpoint, neptune_port)).traversal_source( show_endpoint, connection)
def __init__(self, source, format='csv', role=None, region=None, endpoints=None): self.source = source self.format = format if role is None: assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN' in os.environ), 'role is missing.' self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN'] else: self.role = role if region is None: assert ('AWS_REGION' in os.environ), 'region is missing.' self.region = os.environ['AWS_REGION'] else: self.region = region if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints
def bulkLoadAsync(self, source, format='csv', role=None, region=None, neptune_endpoint=None, neptune_port=None): bulkload = BulkLoad(source, format, role, region=region, endpoints=Endpoints(neptune_endpoint, neptune_port)) return bulkload.load_async()
def get_neptune_graph_traversal_source_factory( *, neptune_url: Union[str, Mapping[str, Any]], session: boto3.session.Session) -> Callable[[], GraphTraversalSource]: endpoints: Endpoints override_uri: Optional[str] if isinstance(neptune_url, str): uri = urlsplit(neptune_url) assert uri.scheme in ('wss', 'ws') and uri.path == '/gremlin' and not uri.query and not uri.fragment, \ f'expected Neptune URL not {neptune_url}' endpoints = Endpoints(neptune_endpoint=uri.hostname, neptune_port=uri.port, region_name=session.region_name, credentials=session.get_credentials()) override_uri = None elif isinstance(neptune_url, Mapping): endpoints = Endpoints(neptune_endpoint=neptune_url['neptune_endpoint'], neptune_port=neptune_url['neptune_port'], region_name=session.region_name, credentials=session.get_credentials()) override_uri = neptune_url['uri'] assert override_uri is None or isinstance(override_uri, str) else: raise AssertionError(f'what is NEPTUNE_URL? {neptune_url}') def create_graph_traversal_source(**kwargs: Any) -> GraphTraversalSource: assert all(e not in kwargs for e in ('url', 'traversal_source')), \ f'do not pass traversal_source or url in {kwargs}' prepared_request = override_prepared_request_parameters( endpoints.gremlin_endpoint().prepare_request(), override_uri=override_uri) kwargs['traversal_source'] = 'g' remote_connection = DriverRemoteConnection(url=prepared_request, **kwargs) return traversal().withRemote(remote_connection) return create_graph_traversal_source
def bulkLoad(self, source, format='csv', role=None, region=None, neptune_endpoint=None, neptune_port=None, interval=2): bulkload = BulkLoad(source, format, role, region=region, endpoints=Endpoints(neptune_endpoint, neptune_port)) bulkload.load(interval)
def neptune_endpoints(self, connection_name): """Gets Neptune endpoint information from the Glue Data Catalog. You may need to install a Glue VPC Endpoint in your VPC for this method to work. You can store Neptune endpoint information as JDBC connections in the Glue Data Catalog. JDBC connection strings must begin 'jdbc:'. To store a Neptune endpoint, use the following format: 'jdbc:<protocol>://<dns_name>:<port>/<endpoint>' For example, if you store: 'jdbc:wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' – this method will return: 'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' Example: >>> gremlin_endpoint = GlueNeptuneConnectionInfo(glueContext).neptune_endpoint('neptune') """ glue = boto3.client('glue', region_name=self.region) connection = glue.get_connection(Name=connection_name) neptune_uri = connection['Connection']['ConnectionProperties']['JDBC_CONNECTION_URL'][5:] parse_result = urlparse(neptune_uri) netloc_parts = parse_result.netloc.split(':') host = netloc_parts[0] port = netloc_parts[1] sts = boto3.client('sts', region_name=self.region) role = sts.assume_role( RoleArn=self.role_arn, RoleSessionName=uuid.uuid4().hex, DurationSeconds=3600 ) credentials = Credentials( access_key=role['Credentials']['AccessKeyId'], secret_key=role['Credentials']['SecretAccessKey'], token=role['Credentials']['SessionToken']) return Endpoints(neptune_endpoint=host, neptune_port=port, region_name=self.region, credentials=credentials)
modelS3Url = urlparse(environ['MODEL_PACKAGE'], allow_fragments=False) originModelArtifact = f's3:/{modelS3Url.path}' graphDataUrl = urlparse(environ['GRAPH_DATA_PATH'], allow_fragments=False) graphDataPath = f's3:/{graphDataUrl.path}/graph/' targetDataPath = f"{args.data_prefix}/{environ['JOB_NAME']}" tempFolder = args.temp_folder dataArgs = (originModelArtifact, graphDataPath, targetDataPath, tempFolder) prepareDataCmd=Path(os.path.abspath(__file__)).parent.joinpath('prepare-data.sh') logger.info(f"| {prepareDataCmd} {' '.join(dataArgs)}") subprocess.check_call([prepareDataCmd] + list(dataArgs)) logger.info(f'Prepared graph data for bulk load...') endpoints = Endpoints(neptune_endpoint=args.neptune_endpoint, neptune_port=args.neptune_port, region_name=args.region) logger.info(f'Created Neptune endpoint ${endpoints.gremlin_endpoint()}.') bulkload = BulkLoad( source=targetDataPath, endpoints=endpoints, role=args.neptune_iam_role_arn, region=args.region, update_single_cardinality_properties=True, fail_on_error=True) load_status = bulkload.load_async() logger.info(f'Bulk load request from {targetDataPath} is submmitted.') status, json = load_status.status(details=True, errors=True)
sc.setLogLevel("INFO") glueContext = GlueContext(sc) logger = glueContext.get_logger() logger.info(f'Before resolving options...') args = getResolvedOptions(sys.argv, [ 'database', 'transaction_table', 'identity_table', 'id_cols', 'cat_cols', 'output_prefix', 'region', 'neptune_endpoint', 'neptune_port' ]) logger.info(f'Resolved options are: {args}') GremlinUtils.init_statics(globals()) endpoints = Endpoints(neptune_endpoint=args['neptune_endpoint'], neptune_port=args['neptune_port'], region_name=args['region']) logger.info( f'Initializing gremlin client to Neptune ${endpoints.gremlin_endpoint()}.') gremlin_client = GlueGremlinClient(endpoints) TRANSACTION_ID = 'TransactionID' transactions = glueContext.create_dynamic_frame.from_catalog( database=args['database'], table_name=args['transaction_table']) identities = glueContext.create_dynamic_frame.from_catalog( database=args['database'], table_name=args['identity_table']) s3 = boto3.resource('s3', region_name=args['region']) train_data_ratio = 0.8
class GremlinUtils: @classmethod def init_statics(cls, globals): statics.load_statics(globals) del globals['range'] del globals['map'] del globals['min'] del globals['sum'] del globals['property'] del globals['max'] def __init__(self, endpoints=None): if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints self.connections = [] def close(self): for connection in self.connections: connection.close() def remote_connection(self, show_endpoint=False, protocol_factory=None, transport_factory=lambda: TornadoTransportProxy(), pool_size=None, max_workers=None, message_serializer=None, graphson_reader=None, graphson_writer=None): gremlin_endpoint = self.endpoints.gremlin_endpoint() if show_endpoint: print('gremlin: {}'.format(gremlin_endpoint)) retry_count = 0 while True: try: request_parameters = gremlin_endpoint.prepare_request() signed_ws_request = httpclient.HTTPRequest( request_parameters.uri, headers=request_parameters.headers) connection = DriverRemoteConnection( signed_ws_request, 'g', protocol_factory=protocol_factory, transport_factory=transport_factory, pool_size=pool_size, max_workers=max_workers, message_serializer=message_serializer, graphson_reader=graphson_reader, graphson_writer=graphson_writer) self.connections.append(connection) return connection except HTTPError as e: exc_info = sys.exc_info() if retry_count < 3: retry_count += 1 print('Connection timeout. Retrying...') else: raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) def traversal_source(self, show_endpoint=True, connection=None): if connection is None: connection = self.remote_connection(show_endpoint) return traversal().withRemote(connection) def client(self, pool_size=None, max_workers=None): gremlin_endpoint = self.endpoints.gremlin_endpoint() request_parameters = gremlin_endpoint.prepare_request() signed_ws_request = httpclient.HTTPRequest( request_parameters.uri, headers=request_parameters.headers) return Client(signed_ws_request, 'g', pool_size=pool_size, max_workers=max_workers) def sessioned_client(self, session_id=None, pool_size=None, max_workers=None): gremlin_endpoint = self.endpoints.gremlin_endpoint() request_parameters = gremlin_endpoint.prepare_request() signed_ws_request = httpclient.HTTPRequest( request_parameters.uri, headers=request_parameters.headers) return SessionedClient( signed_ws_request, 'g', uuid.uuid4().hex if session_id is None else session_id, pool_size=pool_size, max_workers=max_workers)
def __init__(self, *, host: str, port: Optional[int] = None, user: str = None, password: Optional[Union[str, boto3.session.Session]] = None, driver_remote_connection_options: Mapping[str, Any] = {}, client_kwargs: Dict = dict(), **kwargs: dict) -> None: driver_remote_connection_options = dict(driver_remote_connection_options) # port should be part of that url if port is not None: raise NotImplementedError(f'port is not allowed! port={port}') # for IAM auth, we need the triplet or a Session which is more general if isinstance(password, boto3.session.Session): session = password self.aws_auth = AssumeRoleAWS4Auth(session.get_credentials(), session.region_name, 'neptune-db') else: raise NotImplementedError(f'to use authentication, pass a boto3.session.Session!)') if isinstance(host, str): # usually a wss URI url = urlsplit(host) assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \ f'url is not a Neptune ws url?: {host}' self.endpoints = Endpoints( neptune_endpoint=url.hostname, neptune_port=url.port, region_name=session.region_name, credentials=session.get_credentials()) self.override_uri = None elif isinstance(host, Mapping): # ...but development is a little complicated assert all(k in host for k in ('neptune_endpoint', 'neptune_port', 'uri')), \ f'pass a dict with neptune_endpoint, neptune_port, and uri not: {host}' self.endpoints = Endpoints( neptune_endpoint=host['neptune_endpoint'], neptune_port=int(host['neptune_port']), region_name=session.region_name, credentials=session.get_credentials()) uri = urlsplit(host['uri']) assert uri.scheme in ('wss', 'ws') and uri.path == '/gremlin' and not uri.query and not uri.fragment, \ f'''url is not a Neptune ws url?: {host['uri']}''' self.override_uri = uri else: raise NotImplementedError(f'to use authentication, pass a Mapping with aws_access_key_id, ' f'aws_secret_access_key, service_region!') # always g for Neptune driver_remote_connection_options.update(traversal_source='g') try: s3_bucket_name = client_kwargs['neptune_bulk_loader_s3_bucket_name'] # noqa: E731 except Exception: raise NotImplementedError(f'Cannot find s3 bucket name!') # Instantiate bulk loader and graph traversal factory bulk_loader_config: Dict[str, Any] = dict(NEPTUNE_SESSION=password, NEPTUNE_URL=host, NEPTUNE_BULK_LOADER_S3_BUCKET_NAME=s3_bucket_name) self.neptune_bulk_loader_api = NeptuneBulkLoaderApi.create_from_config(bulk_loader_config) self.neptune_graph_traversal_source_factory = get_neptune_graph_traversal_source_factory(session=password, neptune_url=host) AbstractGremlinProxy.__init__(self, key_property_name='key', driver_remote_connection_options=driver_remote_connection_options)
class NeptuneGremlinProxy(AbstractGremlinProxy): """ A proxy to a Neptune using the Gremlin protocol. See also https://docs.aws.amazon.com/neptune/latest/userguide/access-graph-gremlin-differences.html See also https://docs.aws.amazon.com/neptune/latest/userguide/access-graph-gremlin-sessions.html TODO: HTTP proxy support. This does *NOT* support HTTP proxies as-is. Why? The default transport factory in gremlin_python is tornado.websocket, which is hardcoded to use simple_httpclient (look at WebSocketClientConnection). But, even if that could be made to use curl_httpclient, curl_httpclient requires pycurl which requires libcurl and other native libraries which is a pain to install. """ def __init__(self, *, host: str, port: Optional[int] = None, user: str = None, password: Optional[Union[str, boto3.session.Session]] = None, driver_remote_connection_options: Mapping[str, Any] = {}, client_kwargs: Dict = dict(), **kwargs: dict) -> None: driver_remote_connection_options = dict(driver_remote_connection_options) # port should be part of that url if port is not None: raise NotImplementedError(f'port is not allowed! port={port}') # for IAM auth, we need the triplet or a Session which is more general if isinstance(password, boto3.session.Session): session = password self.aws_auth = AssumeRoleAWS4Auth(session.get_credentials(), session.region_name, 'neptune-db') else: raise NotImplementedError(f'to use authentication, pass a boto3.session.Session!)') if isinstance(host, str): # usually a wss URI url = urlsplit(host) assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \ f'url is not a Neptune ws url?: {host}' self.endpoints = Endpoints( neptune_endpoint=url.hostname, neptune_port=url.port, region_name=session.region_name, credentials=session.get_credentials()) self.override_uri = None elif isinstance(host, Mapping): # ...but development is a little complicated assert all(k in host for k in ('neptune_endpoint', 'neptune_port', 'uri')), \ f'pass a dict with neptune_endpoint, neptune_port, and uri not: {host}' self.endpoints = Endpoints( neptune_endpoint=host['neptune_endpoint'], neptune_port=int(host['neptune_port']), region_name=session.region_name, credentials=session.get_credentials()) uri = urlsplit(host['uri']) assert uri.scheme in ('wss', 'ws') and uri.path == '/gremlin' and not uri.query and not uri.fragment, \ f'''url is not a Neptune ws url?: {host['uri']}''' self.override_uri = uri else: raise NotImplementedError(f'to use authentication, pass a Mapping with aws_access_key_id, ' f'aws_secret_access_key, service_region!') # always g for Neptune driver_remote_connection_options.update(traversal_source='g') try: s3_bucket_name = client_kwargs['neptune_bulk_loader_s3_bucket_name'] # noqa: E731 except Exception: raise NotImplementedError(f'Cannot find s3 bucket name!') # Instantiate bulk loader and graph traversal factory bulk_loader_config: Dict[str, Any] = dict(NEPTUNE_SESSION=password, NEPTUNE_URL=host, NEPTUNE_BULK_LOADER_S3_BUCKET_NAME=s3_bucket_name) self.neptune_bulk_loader_api = NeptuneBulkLoaderApi.create_from_config(bulk_loader_config) self.neptune_graph_traversal_source_factory = get_neptune_graph_traversal_source_factory(session=password, neptune_url=host) AbstractGremlinProxy.__init__(self, key_property_name='key', driver_remote_connection_options=driver_remote_connection_options) @classmethod @overrides def script_translator(cls) -> Type[ScriptTranslatorTargetNeptune]: return ScriptTranslatorTargetNeptune def override_prepared_request_parameters( self, request_parameters: RequestParameters, *, method: Optional[str] = None, data: Optional[str] = None) -> httpclient.HTTPRequest: http_request_param: Dict[str, Any] = dict(url=request_parameters.uri, headers=request_parameters.headers) if method is not None: http_request_param['method'] = method if data is not None: http_request_param['body'] = data if self.override_uri: # we override the URI slightly (because the instance thinks it's a different host than we're connecting to) uri = urlsplit(request_parameters.uri) http_request_param['headers'] = dict(request_parameters.headers) http_request_param['headers']['Host'] = uri.netloc http_request_param['ssl_options'] = OverrideServerHostnameSSLContext(server_hostname=uri.hostname) http_request_param['url'] = urlunsplit( (uri.scheme, self.override_uri.netloc, uri.path, uri.query, uri.fragment)) return httpclient.HTTPRequest(**http_request_param) @overrides def possibly_signed_ws_client_request_or_url(self) -> Union[httpclient.HTTPRequest, str]: return self.override_prepared_request_parameters(self.endpoints.gremlin_endpoint().prepare_request()) @classmethod @overrides def _is_retryable_exception(cls, *, method_name: str, exception: Exception) -> bool: # any method return _is_neptune_retryable_exception(exception) or isinstance(exception, ConnectionError) def is_healthy(self) -> None: signed_request = self.override_prepared_request_parameters(self.endpoints.status_endpoint().prepare_request()) http_client = httpclient.HTTPClient() # this will throw if the instance is really borked or we can't connect or we're not allowed (see # https://docs.aws.amazon.com/neptune/latest/userguide/access-graph-status.html ) response = http_client.fetch(signed_request) status = json.loads(response.body, encoding='utf-8') if status.get('status') == 'healthy' and status.get('role') == 'writer': LOGGER.debug(f'status is healthy: {status}') else: # we'll log in healthcheck raise RuntimeError(f'status is unhealthy: {status}') def _non_standard_endpoint(self, scheme: str, path: str) -> Endpoint: return self.endpoints._Endpoints__endpoint( scheme, self.endpoints.neptune_endpoint, self.endpoints.neptune_port, path) def _gremlin_status(self, query_id: Optional[str] = None, include_waiting: bool = False) -> str: """ https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-api-status.html """ endpoint = self._non_standard_endpoint('https', 'gremlin/status') query_parameters = {} if query_id is not None: query_parameters['queryId'] = query_id if include_waiting: query_parameters['includeWaiting'] = 'true' signed_request = self.override_prepared_request_parameters( endpoint.prepare_request(querystring=query_parameters)) http_client = httpclient.HTTPClient() response = http_client.fetch(signed_request) return json.loads(response.body, encoding='utf-8') def _sparql_status(self, query_id: Optional[str] = None) -> str: """ https://docs.aws.amazon.com/neptune/latest/userguide/sparql-api-status.html """ endpoint = self._non_standard_endpoint('https', 'sparql/status') query_parameters = {} if query_id is not None: query_parameters['queryId'] = query_id signed_request = self.override_prepared_request_parameters( endpoint.prepare_request(querystring=query_parameters)) http_client = httpclient.HTTPClient() response = http_client.fetch(signed_request) return json.loads(response.body, encoding='utf-8') def _explain(self, gremlin_query: str) -> str: """ return the Neptune specific explaination of the query see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-explain-api.html see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-explain-background.html """ # why not use endpoints? Despite the fact that it accepts a method and payload, it doesn't *actually* generate # sufficient headers so we'll use requests for these since we can url = urlsplit(self.endpoints.gremlin_endpoint().prepare_request().uri) assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \ f'url is not a Neptune ws url?: {url}' _explain_url = urlunsplit( ('https' if url.scheme == 'wss' else 'http', url.netloc, url.path + '/explain', '', '')) host = to_aws4_request_compatible_host(_explain_url) if self.override_uri: _explain_url = urlunsplit( ('https' if url.scheme == 'wss' else 'http', self.override_uri.netloc, url.path + '/explain', '', '')) s = requests.Session() s.mount('https://', HostHeaderSSLAdapter()) response = s.post(_explain_url, auth=self.aws_auth, data=json.dumps(dict(gremlin=gremlin_query)).encode('utf-8'), # include Host header headers=dict(Host=host)) return response.content.decode('utf-8') def _profile(self, gremlin_query: str) -> str: """ return the Neptune specific explaination of the RUNNING query. Now it can't return the result set, so the utility is limited to cases where you can re-run this, or are running as a one off from console, or as a last resort see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-profile-api.htlm see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-explain-background.html """ # why not use endpoints? Despite the fact that it accepts a method and payload, it doesn't *actually* generate # sufficient headers so we'll use requests for these since we can url = urlsplit(self.endpoints.gremlin_endpoint().prepare_request().uri) assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \ f'url is not a Neptune ws url?: {url}' _profile_url = urlunsplit( ('https' if url.scheme == 'wss' else 'http', url.netloc, url.path + '/profile', '', '')) host = to_aws4_request_compatible_host(_profile_url) if self.override_uri: _profile_url = urlunsplit( ('https' if url.scheme == 'wss' else 'http', self.override_uri.netloc, url.path + '/profile', '', '')) s = requests.Session() s.mount('https://', HostHeaderSSLAdapter()) response = s.post(_profile_url, auth=self.aws_auth, data=json.dumps(dict(gremlin=gremlin_query)).encode('utf-8'), # include Host header headers=dict(Host=host)) return response.content.decode('utf-8') @overrides def drop(self) -> None: test_shard = get_shard() g = self.g.V() if test_shard: g = g.has(WellKnownProperties.TestShard.value.name, test_shard) g = g.drop() LOGGER.warning('DROPPING ALL NODES') self.query_executor()(query=g, get=FromResultSet.iterate) # we seem to mess this up easily leftover = self.query_executor()(query=self.g.V().hasId(TextP.startingWith(test_shard)).id(), get=FromResultSet.toList) self.query_executor()(query=self.g.V().hasId(TextP.startingWith(test_shard)).drop(), get=FromResultSet.iterate) assert not leftover, f'we have some leftover: {leftover}' LOGGER.warning('COMPLETED DROP OF ALL NODES')
class BulkLoad: def __init__( self, source, format='csv', role=None, mode='AUTO', region=None, fail_on_error=False, parallelism='OVERSUBSCRIBE', base_uri='http://aws.amazon.com/neptune/default', named_graph_uri='http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph', update_single_cardinality_properties=False, endpoints=None): self.source = source self.format = format if role is None: assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN' in os.environ), 'role is missing.' self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN'] else: self.role = role self.mode = mode if region is None: assert ('AWS_REGION' in os.environ), 'region is missing.' self.region = os.environ['AWS_REGION'] else: self.region = region if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints self.fail_on_error = 'TRUE' if fail_on_error else 'FALSE' self.parallelism = parallelism self.base_uri = base_uri self.named_graph_uri = named_graph_uri self.update_single_cardinality_properties = 'TRUE' if update_single_cardinality_properties else 'FALSE' def __load_from(self, source): return { 'source': source, 'format': self.format, 'iamRoleArn': self.role, 'mode': self.mode, 'region': self.region, 'failOnError': self.fail_on_error, 'parallelism': self.parallelism, 'parserConfiguration': { 'baseUri': self.base_uri, 'namedGraphUri': self.named_graph_uri }, 'updateSingleCardinalityProperties': self.update_single_cardinality_properties } def __load(self, loader_endpoint, data): json_string = json.dumps(data) json_bytes = json_string.encode('utf8') request_parameters = loader_endpoint.prepare_request( 'POST', json_string) request_parameters.headers['Content-Type'] = 'application/json' req = urllib.request.Request(request_parameters.uri, data=json_bytes, headers=request_parameters.headers) try: response = urllib.request.urlopen(req) json_response = json.loads(response.read().decode('utf8')) return json_response['payload']['loadId'] except HTTPError as e: exc_info = sys.exc_info() if e.code == 500: raise Exception(json.loads(e.read().decode('utf8'))) from None else: raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) def load_async(self): localised_source = self.source.replace('${AWS_REGION}', self.region) loader_endpoint = self.endpoints.loader_endpoint() json_payload = self.__load_from(localised_source) print('''curl -X POST \\ -H 'Content-Type: application/json' \\ {} -d \'{}\''''.format(loader_endpoint, json.dumps(json_payload, indent=4))) load_id = self.__load(loader_endpoint, json_payload) return BulkLoadStatus(self.endpoints.load_status_endpoint(load_id)) def load(self, interval=2): status = self.load_async() print('status_uri: {}'.format(status.load_status_endpoint)) status.wait(interval)
def sparql_endpoint(self, neptune_endpoint=None, neptune_port=None): return Endpoints(neptune_endpoint, neptune_port).sparql_endpoint()
def __init__(self, endpoints=None): if endpoints is None: self.endpoints = Endpoints() else: self.endpoints = endpoints
CLUSTER_ENDPOINT = os.environ['CLUSTER_ENDPOINT'] CLUSTER_PORT = os.environ['CLUSTER_PORT'] CLUSTER_REGION = os.environ['CLUSTER_REGION'] ENDPOINT_NAME = os.environ['ENDPOINT_NAME'] MODEL_BTW = float(os.environ['MODEL_BTW']) QUEUE_URL = os.environ['QUEUE_URL'] transactions_id_cols = os.environ['TRANSACTION_ID_COLS'] transactions_cat_cols = os.environ['TRANSACTION_CAT_COLS'] dummied_col = os.environ['DUMMIED_COL'] sqs = boto3.client('sqs') runtime = boto3.client('runtime.sagemaker') endpoints = Endpoints(neptune_endpoint=CLUSTER_ENDPOINT, neptune_port=CLUSTER_PORT, region_name=CLUSTER_REGION) def load_data_from_event(input_event, transactions_id_cols, transactions_cat_cols, dummied_col): """Load and transform event data into correct format for next step subgraph loading and model inference input. input event keys should come from related dataset.] Example: >>> load_data_from_event(event = {"transaction_data":[{"TransactionID":"3163166", "V1":1, ...]}, 'card1,card2,,...', 'M2_T,M3_F,M3_T,...') """ TRANSACTION_ID = 'TransactionID' transactions_id_cols = transactions_id_cols.split(',') transactions_cat_cols = transactions_cat_cols.split(',')