Exemple #1
0
    def __init__(self, endpoints=None):

        if endpoints is None:
            self.endpoints = Endpoints()
        else:
            self.endpoints = endpoints

        self.connections = []
class GremlinUtils:
    @classmethod
    def init_statics(cls, globals):

        statics.load_statics(globals)

        del globals['range']
        del globals['map']
        del globals['min']
        del globals['sum']
        del globals['property']
        del globals['max']

    def __init__(self, endpoints=None):

        if endpoints is None:
            self.endpoints = Endpoints()
        else:
            self.endpoints = endpoints

    def remote_connection(self, show_endpoint=True):
        neptune_gremlin_endpoint = self.endpoints.gremlin_endpoint()
        if show_endpoint:
            print('gremlin: {}'.format(neptune_gremlin_endpoint))
        retry_count = 0
        while True:
            try:
                return DriverRemoteConnection(neptune_gremlin_endpoint, 'g')
            except HTTPError as e:
                exc_info = sys.exc_info()
                if retry_count < 3:
                    retry_count += 1
                    print('Connection timeout. Retrying...')
                else:
                    raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

    def traversal_source(self, show_endpoint=True, connection=None):
        if connection is None:
            connection = self.remote_connection(show_endpoint)
        return traversal().withRemote(connection)

    def client(self, pool_size=None, max_workers=None):
        return Client(self.endpoints.gremlin_endpoint(),
                      'g',
                      pool_size=pool_size,
                      max_workers=max_workers)

    def sessioned_client(self,
                         session_id=None,
                         pool_size=None,
                         max_workers=None):
        return SessionedClient(
            self.endpoints.gremlin_endpoint(),
            'g',
            uuid.uuid4().hex if session_id is None else session_id,
            pool_size=pool_size,
            max_workers=max_workers)
Exemple #3
0
class BulkLoad:
    
    def __init__(self, source, format='csv', role=None, region=None, endpoints=None):
        
        self.source = source
        self.format = format
        
        if role is None:
            assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN' in os.environ), 'role is missing.'
            self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN']
        else:
            self.role = role
            
        if region is None:
            assert ('AWS_REGION' in os.environ), 'region is missing.'
            self.region = os.environ['AWS_REGION']
        else:
            self.region = region
        
        if endpoints is None:
            self.endpoints = Endpoints()
        else:
            self.endpoints = endpoints
            
    def __load_from(self, source, format, role, region):
        return { 
              'source' : source, 
              'format' : format,  
              'iamRoleArn' : role, 
              'region' : region, 
              'failOnError' : 'FALSE'
            }
    
    def __load(self, loader_url, data):    
        jsondataasbytes = json.dumps(data).encode('utf8')
        req = urllib.request.Request(loader_url, data=jsondataasbytes, headers={'Content-Type': 'application/json'})
        response = urllib.request.urlopen(req)
        jsonresponse = json.loads(response.read().decode('utf8'))
        return jsonresponse['payload']['loadId']
    
    def load_async(self):
        localised_source = self.source.replace('${AWS_REGION}', self.region)
        loader_url = self.endpoints.loader_endpoint()
        json_payload = self.__load_from(localised_source, self.format, self.role, self.region)
        print('''curl -X POST \\
    -H 'Content-Type: application/json' \\
    {} -d \'{}\''''.format(loader_url, json.dumps(json_payload, indent=4)))
        load_id = self.__load(loader_url, json_payload)
        return BulkLoadStatus(self.endpoints.load_status_endpoint(load_id))
    
    def load(self, interval=2):
        status = self.load_async()
        print('status_uri: {}'.format(status.uri()))
        status.wait(interval)
    def neptune_endpoints(self, connection_name):
        """Gets Neptune endpoint information from the AWS Glue Data Catalog.
        
        You may need to install a Glue VPC Endpoint in your VPC for this method to work.
        
        You can either create a Glue Connection type of 'JDBC' or 'NETWORK'. 
        
        When you use Glue Connection Type of 'JDBC' store the Amazon Neptune endpoint in 'JDBC_CONNECTION_URL' field, e.g. 'jdbc:wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin'. 
        
        When you use Glue Connection Type of 'NETWORK' store the Amazon Neptune endpoint in 'Description' field, e.g. 'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin'.
        
        When you invoke the method it returns Neptune endpoint, e.g. 'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' 
        
        Example:
        >>> gremlin_endpoint = GlueNeptuneConnectionInfo(glueContext).neptune_endpoint('neptune')
        """
        glue = boto3.client('glue', region_name=self.region)
        connection = glue.get_connection(Name=connection_name)['Connection']

        if connection['ConnectionType'] == "JDBC":
            neptune_uri = connection['ConnectionProperties']['JDBC_CONNECTION_URL'][5:]

        if connection['ConnectionType'] == "NETWORK":
            neptune_uri = connection['Description']

        parse_result = requests.utils.urlparse(neptune_uri)
        netloc_parts = parse_result.netloc.split(':')
        host = netloc_parts[0]
        port = netloc_parts[1]
        
        return Endpoints(neptune_endpoint=host, neptune_port=port, region_name=self.region, role_arn=self.role_arn)
    def neptune_endpoints(self, connection_name):
        """Gets Neptune endpoint information from the Glue Data Catalog.
        
        You may need to install a Glue VPC Endpoint in your VPC for this method to work.
        
        You can store Neptune endpoint information as JDBC connections in the Glue Data Catalog.
        JDBC connection strings must begin 'jdbc:'. To store a Neptune endpoint, use the following format:
        
        'jdbc:<protocol>://<dns_name>:<port>/<endpoint>'
        
        For example, if you store:
        
        'jdbc:wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin'
        
        – this method will return:
        
        'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' 
        
        Example:
        >>> gremlin_endpoint = GlueNeptuneConnectionInfo(glueContext).neptune_endpoint('neptune')
        """
        glue = boto3.client('glue', region_name=self.region)

        connection = glue.get_connection(Name=connection_name)
        neptune_uri = connection['Connection']['ConnectionProperties'][
            'JDBC_CONNECTION_URL'][5:]
        parse_result = urlparse(neptune_uri)
        netloc_parts = parse_result.netloc.split(':')
        host = netloc_parts[0]
        port = netloc_parts[1]

        return Endpoints(neptune_endpoint=host,
                         neptune_port=port,
                         region_name=self.region,
                         role_arn=self.role_arn)
Exemple #6
0
 def remoteConnection(self,
                      neptune_endpoint=None,
                      neptune_port=None,
                      show_endpoint=True):
     connection = GremlinUtils(Endpoints(
         neptune_endpoint, neptune_port)).remote_connection(show_endpoint)
     self.connections.append(connection)
     return connection
Exemple #7
0
 def __init__(self, neptune_endpoint, elasticache_endpoint):
     GremlinUtils.init_statics(globals())
     gremlin_utils = GremlinUtils(
         Endpoints(neptune_endpoint=neptune_endpoint))
     self.vertext_metrics = VertexMetrics(elasticache_endpoint)
     self.neptune_connection = gremlin_utils.remote_connection()
     self.g = gremlin_utils.traversal_source(
         connection=self.neptune_connection)
Exemple #8
0
    def __init__(
            self,
            source,
            format='csv',
            role=None,
            mode='AUTO',
            region=None,
            fail_on_error=False,
            parallelism='OVERSUBSCRIBE',
            base_uri='http://aws.amazon.com/neptune/default',
            named_graph_uri='http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph',
            update_single_cardinality_properties=False,
            endpoints=None):

        self.source = source
        self.format = format

        if role is None:
            assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN'
                    in os.environ), 'role is missing.'
            self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN']
        else:
            self.role = role

        self.mode = mode

        if region is None:
            assert ('AWS_REGION' in os.environ), 'region is missing.'
            self.region = os.environ['AWS_REGION']
        else:
            self.region = region

        if endpoints is None:
            self.endpoints = Endpoints()
        else:
            self.endpoints = endpoints

        self.fail_on_error = 'TRUE' if fail_on_error else 'FALSE'
        self.parallelism = parallelism
        self.base_uri = base_uri
        self.named_graph_uri = named_graph_uri
        self.update_single_cardinality_properties = 'TRUE' if update_single_cardinality_properties else 'FALSE'
    def handle_records(self, stream_log):

        params = json.loads(os.environ['AdditionalParams'])

        neptune_endpoint = params['neptune_cluster_endpoint']
        neptune_port = params['neptune_port']

        GremlinUtils.init_statics(globals())

        endpoints = Endpoints(neptune_endpoint=neptune_endpoint,
                              neptune_port=neptune_port)
        gremlin_utils = GremlinUtils(endpoints)

        conn = gremlin_utils.remote_connection()
        g = gremlin_utils.traversal_source(connection=conn)

        records = stream_log[RECORDS_STR]

        last_op_num = None
        last_commit_num = None
        count = 0

        try:
            for record in records:

                # Process record
                op = record[OPERATION_STR]
                data = record[DATA_STR]
                type = data['type']
                id = data['id']

                if op == ADD_OPERATION:
                    if type == 'vl':
                        logger.info(g.V(id).valueMap(True).toList())
                    if type == 'e':
                        logger.info(g.E(id).valueMap(True).toList())

                # Update local checkpoint info
                last_op_num = record[EVENT_ID_STR][OP_NUM_STR]
                last_commit_num = record[EVENT_ID_STR][COMMIT_NUM_STR]
                count += 1

        except Exception as e:
            logger.error('Error occurred - {}'.format(str(e)))
            raise e
        finally:
            try:
                conn.close()
                yield HandlerResponse(last_op_num, last_commit_num, count)
            except Exception as e:
                logger.error('Error occurred - {}'.format(str(e)))
                raise e
            finally:
                conn.close()
Exemple #10
0
 def graphTraversal(self,
                    neptune_endpoint=None,
                    neptune_port=None,
                    show_endpoint=True,
                    connection=None):
     if connection is None:
         connection = self.remoteConnection(neptune_endpoint, neptune_port,
                                            show_endpoint)
     self.connections.append(connection)
     return GremlinUtils(Endpoints(neptune_endpoint,
                                   neptune_port)).traversal_source(
                                       show_endpoint, connection)
Exemple #11
0
 def __init__(self, source, format='csv', role=None, region=None, endpoints=None):
     
     self.source = source
     self.format = format
     
     if role is None:
         assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN' in os.environ), 'role is missing.'
         self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN']
     else:
         self.role = role
         
     if region is None:
         assert ('AWS_REGION' in os.environ), 'region is missing.'
         self.region = os.environ['AWS_REGION']
     else:
         self.region = region
     
     if endpoints is None:
         self.endpoints = Endpoints()
     else:
         self.endpoints = endpoints
Exemple #12
0
 def bulkLoadAsync(self,
                   source,
                   format='csv',
                   role=None,
                   region=None,
                   neptune_endpoint=None,
                   neptune_port=None):
     bulkload = BulkLoad(source,
                         format,
                         role,
                         region=region,
                         endpoints=Endpoints(neptune_endpoint,
                                             neptune_port))
     return bulkload.load_async()
Exemple #13
0
def get_neptune_graph_traversal_source_factory(
        *, neptune_url: Union[str, Mapping[str, Any]],
        session: boto3.session.Session) -> Callable[[], GraphTraversalSource]:

    endpoints: Endpoints
    override_uri: Optional[str]
    if isinstance(neptune_url, str):
        uri = urlsplit(neptune_url)
        assert uri.scheme in ('wss', 'ws') and uri.path == '/gremlin' and not uri.query and not uri.fragment, \
            f'expected Neptune URL not {neptune_url}'
        endpoints = Endpoints(neptune_endpoint=uri.hostname,
                              neptune_port=uri.port,
                              region_name=session.region_name,
                              credentials=session.get_credentials())
        override_uri = None
    elif isinstance(neptune_url, Mapping):
        endpoints = Endpoints(neptune_endpoint=neptune_url['neptune_endpoint'],
                              neptune_port=neptune_url['neptune_port'],
                              region_name=session.region_name,
                              credentials=session.get_credentials())
        override_uri = neptune_url['uri']
        assert override_uri is None or isinstance(override_uri, str)
    else:
        raise AssertionError(f'what is NEPTUNE_URL? {neptune_url}')

    def create_graph_traversal_source(**kwargs: Any) -> GraphTraversalSource:
        assert all(e not in kwargs for e in ('url', 'traversal_source')), \
            f'do not pass traversal_source or url in {kwargs}'
        prepared_request = override_prepared_request_parameters(
            endpoints.gremlin_endpoint().prepare_request(),
            override_uri=override_uri)
        kwargs['traversal_source'] = 'g'
        remote_connection = DriverRemoteConnection(url=prepared_request,
                                                   **kwargs)
        return traversal().withRemote(remote_connection)

    return create_graph_traversal_source
Exemple #14
0
 def bulkLoad(self,
              source,
              format='csv',
              role=None,
              region=None,
              neptune_endpoint=None,
              neptune_port=None,
              interval=2):
     bulkload = BulkLoad(source,
                         format,
                         role,
                         region=region,
                         endpoints=Endpoints(neptune_endpoint,
                                             neptune_port))
     bulkload.load(interval)
Exemple #15
0
 def neptune_endpoints(self, connection_name):
     """Gets Neptune endpoint information from the Glue Data Catalog.
     
     You may need to install a Glue VPC Endpoint in your VPC for this method to work.
     
     You can store Neptune endpoint information as JDBC connections in the Glue Data Catalog.
     JDBC connection strings must begin 'jdbc:'. To store a Neptune endpoint, use the following format:
     
     'jdbc:<protocol>://<dns_name>:<port>/<endpoint>'
     
     For example, if you store:
     
     'jdbc:wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin'
     
     – this method will return:
     
     'wss://my-neptune-cluster.us-east-1.neptune.amazonaws.com:8182/gremlin' 
     
     Example:
     >>> gremlin_endpoint = GlueNeptuneConnectionInfo(glueContext).neptune_endpoint('neptune')
     """
     glue = boto3.client('glue', region_name=self.region)
     
     connection = glue.get_connection(Name=connection_name)
     neptune_uri = connection['Connection']['ConnectionProperties']['JDBC_CONNECTION_URL'][5:]
     parse_result = urlparse(neptune_uri)
     netloc_parts = parse_result.netloc.split(':')
     host = netloc_parts[0]
     port = netloc_parts[1]
     
     sts = boto3.client('sts', region_name=self.region)
     
     role = sts.assume_role(
         RoleArn=self.role_arn,
         RoleSessionName=uuid.uuid4().hex,
         DurationSeconds=3600
     )
     
     credentials = Credentials(
         access_key=role['Credentials']['AccessKeyId'], 
         secret_key=role['Credentials']['SecretAccessKey'], 
         token=role['Credentials']['SessionToken'])
     
     return Endpoints(neptune_endpoint=host, neptune_port=port, region_name=self.region, credentials=credentials)
     
modelS3Url = urlparse(environ['MODEL_PACKAGE'], allow_fragments=False)
originModelArtifact = f's3:/{modelS3Url.path}'
graphDataUrl = urlparse(environ['GRAPH_DATA_PATH'], allow_fragments=False)
graphDataPath = f's3:/{graphDataUrl.path}/graph/'
targetDataPath = f"{args.data_prefix}/{environ['JOB_NAME']}"
tempFolder = args.temp_folder

dataArgs = (originModelArtifact, graphDataPath, targetDataPath, tempFolder)

prepareDataCmd=Path(os.path.abspath(__file__)).parent.joinpath('prepare-data.sh')
logger.info(f"| {prepareDataCmd} {' '.join(dataArgs)}")
subprocess.check_call([prepareDataCmd] + list(dataArgs))
logger.info(f'Prepared graph data for bulk load...')

endpoints = Endpoints(neptune_endpoint=args.neptune_endpoint, neptune_port=args.neptune_port, region_name=args.region)

logger.info(f'Created Neptune endpoint ${endpoints.gremlin_endpoint()}.')

bulkload = BulkLoad(
        source=targetDataPath,
        endpoints=endpoints,
        role=args.neptune_iam_role_arn,
        region=args.region,
        update_single_cardinality_properties=True,
        fail_on_error=True)
        
load_status = bulkload.load_async()
logger.info(f'Bulk load request from {targetDataPath} is submmitted.')

status, json = load_status.status(details=True, errors=True)
Exemple #17
0
sc.setLogLevel("INFO")
glueContext = GlueContext(sc)
logger = glueContext.get_logger()

logger.info(f'Before resolving options...')

args = getResolvedOptions(sys.argv, [
    'database', 'transaction_table', 'identity_table', 'id_cols', 'cat_cols',
    'output_prefix', 'region', 'neptune_endpoint', 'neptune_port'
])

logger.info(f'Resolved options are: {args}')

GremlinUtils.init_statics(globals())
endpoints = Endpoints(neptune_endpoint=args['neptune_endpoint'],
                      neptune_port=args['neptune_port'],
                      region_name=args['region'])
logger.info(
    f'Initializing gremlin client to Neptune ${endpoints.gremlin_endpoint()}.')
gremlin_client = GlueGremlinClient(endpoints)

TRANSACTION_ID = 'TransactionID'

transactions = glueContext.create_dynamic_frame.from_catalog(
    database=args['database'], table_name=args['transaction_table'])
identities = glueContext.create_dynamic_frame.from_catalog(
    database=args['database'], table_name=args['identity_table'])

s3 = boto3.resource('s3', region_name=args['region'])

train_data_ratio = 0.8
Exemple #18
0
class GremlinUtils:
    @classmethod
    def init_statics(cls, globals):

        statics.load_statics(globals)

        del globals['range']
        del globals['map']
        del globals['min']
        del globals['sum']
        del globals['property']
        del globals['max']

    def __init__(self, endpoints=None):

        if endpoints is None:
            self.endpoints = Endpoints()
        else:
            self.endpoints = endpoints

        self.connections = []

    def close(self):
        for connection in self.connections:
            connection.close()

    def remote_connection(self,
                          show_endpoint=False,
                          protocol_factory=None,
                          transport_factory=lambda: TornadoTransportProxy(),
                          pool_size=None,
                          max_workers=None,
                          message_serializer=None,
                          graphson_reader=None,
                          graphson_writer=None):

        gremlin_endpoint = self.endpoints.gremlin_endpoint()
        if show_endpoint:
            print('gremlin: {}'.format(gremlin_endpoint))

        retry_count = 0

        while True:
            try:
                request_parameters = gremlin_endpoint.prepare_request()
                signed_ws_request = httpclient.HTTPRequest(
                    request_parameters.uri, headers=request_parameters.headers)
                connection = DriverRemoteConnection(
                    signed_ws_request,
                    'g',
                    protocol_factory=protocol_factory,
                    transport_factory=transport_factory,
                    pool_size=pool_size,
                    max_workers=max_workers,
                    message_serializer=message_serializer,
                    graphson_reader=graphson_reader,
                    graphson_writer=graphson_writer)
                self.connections.append(connection)
                return connection
            except HTTPError as e:
                exc_info = sys.exc_info()
                if retry_count < 3:
                    retry_count += 1
                    print('Connection timeout. Retrying...')
                else:
                    raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

    def traversal_source(self, show_endpoint=True, connection=None):
        if connection is None:
            connection = self.remote_connection(show_endpoint)
        return traversal().withRemote(connection)

    def client(self, pool_size=None, max_workers=None):
        gremlin_endpoint = self.endpoints.gremlin_endpoint()
        request_parameters = gremlin_endpoint.prepare_request()
        signed_ws_request = httpclient.HTTPRequest(
            request_parameters.uri, headers=request_parameters.headers)
        return Client(signed_ws_request,
                      'g',
                      pool_size=pool_size,
                      max_workers=max_workers)

    def sessioned_client(self,
                         session_id=None,
                         pool_size=None,
                         max_workers=None):
        gremlin_endpoint = self.endpoints.gremlin_endpoint()
        request_parameters = gremlin_endpoint.prepare_request()
        signed_ws_request = httpclient.HTTPRequest(
            request_parameters.uri, headers=request_parameters.headers)
        return SessionedClient(
            signed_ws_request,
            'g',
            uuid.uuid4().hex if session_id is None else session_id,
            pool_size=pool_size,
            max_workers=max_workers)
Exemple #19
0
    def __init__(self, *, host: str, port: Optional[int] = None, user: str = None,
                 password: Optional[Union[str, boto3.session.Session]] = None,
                 driver_remote_connection_options: Mapping[str, Any] = {},
                 client_kwargs: Dict = dict(),
                 **kwargs: dict) -> None:

        driver_remote_connection_options = dict(driver_remote_connection_options)
        # port should be part of that url
        if port is not None:
            raise NotImplementedError(f'port is not allowed! port={port}')

        # for IAM auth, we need the triplet or a Session which is more general
        if isinstance(password, boto3.session.Session):
            session = password
            self.aws_auth = AssumeRoleAWS4Auth(session.get_credentials(), session.region_name, 'neptune-db')
        else:
            raise NotImplementedError(f'to use authentication, pass a boto3.session.Session!)')

        if isinstance(host, str):
            # usually a wss URI
            url = urlsplit(host)
            assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \
                f'url is not a Neptune ws url?: {host}'

            self.endpoints = Endpoints(
                neptune_endpoint=url.hostname, neptune_port=url.port,
                region_name=session.region_name, credentials=session.get_credentials())
            self.override_uri = None
        elif isinstance(host, Mapping):
            # ...but development is a little complicated
            assert all(k in host for k in ('neptune_endpoint', 'neptune_port', 'uri')), \
                f'pass a dict with neptune_endpoint, neptune_port, and uri not: {host}'

            self.endpoints = Endpoints(
                neptune_endpoint=host['neptune_endpoint'], neptune_port=int(host['neptune_port']),
                region_name=session.region_name, credentials=session.get_credentials())
            uri = urlsplit(host['uri'])
            assert uri.scheme in ('wss', 'ws') and uri.path == '/gremlin' and not uri.query and not uri.fragment, \
                f'''url is not a Neptune ws url?: {host['uri']}'''
            self.override_uri = uri
        else:
            raise NotImplementedError(f'to use authentication, pass a Mapping with aws_access_key_id, '
                                      f'aws_secret_access_key, service_region!')

        # always g for Neptune
        driver_remote_connection_options.update(traversal_source='g')

        try:
            s3_bucket_name = client_kwargs['neptune_bulk_loader_s3_bucket_name']  # noqa: E731
        except Exception:
            raise NotImplementedError(f'Cannot find s3 bucket name!')

        # Instantiate bulk loader and graph traversal factory
        bulk_loader_config: Dict[str, Any] = dict(NEPTUNE_SESSION=password, NEPTUNE_URL=host,
                                                  NEPTUNE_BULK_LOADER_S3_BUCKET_NAME=s3_bucket_name)
        self.neptune_bulk_loader_api = NeptuneBulkLoaderApi.create_from_config(bulk_loader_config)
        self.neptune_graph_traversal_source_factory = get_neptune_graph_traversal_source_factory(session=password,
                                                                                                 neptune_url=host)

        AbstractGremlinProxy.__init__(self, key_property_name='key',
                                      driver_remote_connection_options=driver_remote_connection_options)
Exemple #20
0
class NeptuneGremlinProxy(AbstractGremlinProxy):
    """
    A proxy to a Neptune using the Gremlin protocol.

    See also https://docs.aws.amazon.com/neptune/latest/userguide/access-graph-gremlin-differences.html
    See also https://docs.aws.amazon.com/neptune/latest/userguide/access-graph-gremlin-sessions.html

    TODO: HTTP proxy support.  This does *NOT* support HTTP proxies as-is. Why? The default transport factory in
    gremlin_python is tornado.websocket, which is hardcoded to use simple_httpclient (look at
    WebSocketClientConnection).  But, even if that could be made to use curl_httpclient, curl_httpclient requires pycurl
    which requires libcurl and other native libraries which is a pain to install.
    """

    def __init__(self, *, host: str, port: Optional[int] = None, user: str = None,
                 password: Optional[Union[str, boto3.session.Session]] = None,
                 driver_remote_connection_options: Mapping[str, Any] = {},
                 client_kwargs: Dict = dict(),
                 **kwargs: dict) -> None:

        driver_remote_connection_options = dict(driver_remote_connection_options)
        # port should be part of that url
        if port is not None:
            raise NotImplementedError(f'port is not allowed! port={port}')

        # for IAM auth, we need the triplet or a Session which is more general
        if isinstance(password, boto3.session.Session):
            session = password
            self.aws_auth = AssumeRoleAWS4Auth(session.get_credentials(), session.region_name, 'neptune-db')
        else:
            raise NotImplementedError(f'to use authentication, pass a boto3.session.Session!)')

        if isinstance(host, str):
            # usually a wss URI
            url = urlsplit(host)
            assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \
                f'url is not a Neptune ws url?: {host}'

            self.endpoints = Endpoints(
                neptune_endpoint=url.hostname, neptune_port=url.port,
                region_name=session.region_name, credentials=session.get_credentials())
            self.override_uri = None
        elif isinstance(host, Mapping):
            # ...but development is a little complicated
            assert all(k in host for k in ('neptune_endpoint', 'neptune_port', 'uri')), \
                f'pass a dict with neptune_endpoint, neptune_port, and uri not: {host}'

            self.endpoints = Endpoints(
                neptune_endpoint=host['neptune_endpoint'], neptune_port=int(host['neptune_port']),
                region_name=session.region_name, credentials=session.get_credentials())
            uri = urlsplit(host['uri'])
            assert uri.scheme in ('wss', 'ws') and uri.path == '/gremlin' and not uri.query and not uri.fragment, \
                f'''url is not a Neptune ws url?: {host['uri']}'''
            self.override_uri = uri
        else:
            raise NotImplementedError(f'to use authentication, pass a Mapping with aws_access_key_id, '
                                      f'aws_secret_access_key, service_region!')

        # always g for Neptune
        driver_remote_connection_options.update(traversal_source='g')

        try:
            s3_bucket_name = client_kwargs['neptune_bulk_loader_s3_bucket_name']  # noqa: E731
        except Exception:
            raise NotImplementedError(f'Cannot find s3 bucket name!')

        # Instantiate bulk loader and graph traversal factory
        bulk_loader_config: Dict[str, Any] = dict(NEPTUNE_SESSION=password, NEPTUNE_URL=host,
                                                  NEPTUNE_BULK_LOADER_S3_BUCKET_NAME=s3_bucket_name)
        self.neptune_bulk_loader_api = NeptuneBulkLoaderApi.create_from_config(bulk_loader_config)
        self.neptune_graph_traversal_source_factory = get_neptune_graph_traversal_source_factory(session=password,
                                                                                                 neptune_url=host)

        AbstractGremlinProxy.__init__(self, key_property_name='key',
                                      driver_remote_connection_options=driver_remote_connection_options)

    @classmethod
    @overrides
    def script_translator(cls) -> Type[ScriptTranslatorTargetNeptune]:
        return ScriptTranslatorTargetNeptune

    def override_prepared_request_parameters(
            self, request_parameters: RequestParameters, *, method: Optional[str] = None,
            data: Optional[str] = None) -> httpclient.HTTPRequest:
        http_request_param: Dict[str, Any] = dict(url=request_parameters.uri, headers=request_parameters.headers)
        if method is not None:
            http_request_param['method'] = method
        if data is not None:
            http_request_param['body'] = data
        if self.override_uri:
            # we override the URI slightly (because the instance thinks it's a different host than we're connecting to)
            uri = urlsplit(request_parameters.uri)
            http_request_param['headers'] = dict(request_parameters.headers)
            http_request_param['headers']['Host'] = uri.netloc
            http_request_param['ssl_options'] = OverrideServerHostnameSSLContext(server_hostname=uri.hostname)
            http_request_param['url'] = urlunsplit(
                (uri.scheme, self.override_uri.netloc, uri.path, uri.query, uri.fragment))
        return httpclient.HTTPRequest(**http_request_param)

    @overrides
    def possibly_signed_ws_client_request_or_url(self) -> Union[httpclient.HTTPRequest, str]:
        return self.override_prepared_request_parameters(self.endpoints.gremlin_endpoint().prepare_request())

    @classmethod
    @overrides
    def _is_retryable_exception(cls, *, method_name: str, exception: Exception) -> bool:
        # any method
        return _is_neptune_retryable_exception(exception) or isinstance(exception, ConnectionError)

    def is_healthy(self) -> None:
        signed_request = self.override_prepared_request_parameters(self.endpoints.status_endpoint().prepare_request())
        http_client = httpclient.HTTPClient()
        # this will throw if the instance is really borked or we can't connect or we're not allowed (see
        # https://docs.aws.amazon.com/neptune/latest/userguide/access-graph-status.html )
        response = http_client.fetch(signed_request)
        status = json.loads(response.body, encoding='utf-8')

        if status.get('status') == 'healthy' and status.get('role') == 'writer':
            LOGGER.debug(f'status is healthy: {status}')
        else:
            # we'll log in healthcheck
            raise RuntimeError(f'status is unhealthy: {status}')

    def _non_standard_endpoint(self, scheme: str, path: str) -> Endpoint:
        return self.endpoints._Endpoints__endpoint(
            scheme, self.endpoints.neptune_endpoint, self.endpoints.neptune_port, path)

    def _gremlin_status(self, query_id: Optional[str] = None, include_waiting: bool = False) -> str:
        """
        https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-api-status.html
        """
        endpoint = self._non_standard_endpoint('https', 'gremlin/status')

        query_parameters = {}
        if query_id is not None:
            query_parameters['queryId'] = query_id
        if include_waiting:
            query_parameters['includeWaiting'] = 'true'

        signed_request = self.override_prepared_request_parameters(
            endpoint.prepare_request(querystring=query_parameters))
        http_client = httpclient.HTTPClient()
        response = http_client.fetch(signed_request)
        return json.loads(response.body, encoding='utf-8')

    def _sparql_status(self, query_id: Optional[str] = None) -> str:
        """
        https://docs.aws.amazon.com/neptune/latest/userguide/sparql-api-status.html
        """
        endpoint = self._non_standard_endpoint('https', 'sparql/status')

        query_parameters = {}
        if query_id is not None:
            query_parameters['queryId'] = query_id

        signed_request = self.override_prepared_request_parameters(
            endpoint.prepare_request(querystring=query_parameters))
        http_client = httpclient.HTTPClient()
        response = http_client.fetch(signed_request)
        return json.loads(response.body, encoding='utf-8')

    def _explain(self, gremlin_query: str) -> str:
        """
        return the Neptune specific explaination of the query
        see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-explain-api.html
        see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-explain-background.html
        """
        # why not use endpoints? Despite the fact that it accepts a method and payload, it doesn't *actually* generate
        # sufficient headers so we'll use requests for these since we can
        url = urlsplit(self.endpoints.gremlin_endpoint().prepare_request().uri)
        assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \
            f'url is not a Neptune ws url?: {url}'
        _explain_url = urlunsplit(
            ('https' if url.scheme == 'wss' else 'http', url.netloc, url.path + '/explain', '', ''))
        host = to_aws4_request_compatible_host(_explain_url)
        if self.override_uri:
            _explain_url = urlunsplit(
                ('https' if url.scheme == 'wss' else 'http', self.override_uri.netloc, url.path + '/explain', '', ''))
        s = requests.Session()
        s.mount('https://', HostHeaderSSLAdapter())
        response = s.post(_explain_url, auth=self.aws_auth,
                          data=json.dumps(dict(gremlin=gremlin_query)).encode('utf-8'),
                          # include Host header
                          headers=dict(Host=host))
        return response.content.decode('utf-8')

    def _profile(self, gremlin_query: str) -> str:
        """
        return the Neptune specific explaination of the RUNNING query.  Now it can't return the result set, so the
        utility is limited to cases where you can re-run this, or are running as a one off from console, or as a last
        resort
        see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-profile-api.htlm
        see https://docs.aws.amazon.com/neptune/latest/userguide/gremlin-explain-background.html
        """
        # why not use endpoints? Despite the fact that it accepts a method and payload, it doesn't *actually* generate
        # sufficient headers so we'll use requests for these since we can
        url = urlsplit(self.endpoints.gremlin_endpoint().prepare_request().uri)
        assert url.scheme in ('wss', 'ws') and url.path == '/gremlin' and not url.query and not url.fragment, \
            f'url is not a Neptune ws url?: {url}'
        _profile_url = urlunsplit(
            ('https' if url.scheme == 'wss' else 'http', url.netloc, url.path + '/profile', '', ''))
        host = to_aws4_request_compatible_host(_profile_url)
        if self.override_uri:
            _profile_url = urlunsplit(
                ('https' if url.scheme == 'wss' else 'http', self.override_uri.netloc, url.path + '/profile', '', ''))
        s = requests.Session()
        s.mount('https://', HostHeaderSSLAdapter())
        response = s.post(_profile_url, auth=self.aws_auth,
                          data=json.dumps(dict(gremlin=gremlin_query)).encode('utf-8'),
                          # include Host header
                          headers=dict(Host=host))
        return response.content.decode('utf-8')

    @overrides
    def drop(self) -> None:
        test_shard = get_shard()
        g = self.g.V()
        if test_shard:
            g = g.has(WellKnownProperties.TestShard.value.name, test_shard)
        g = g.drop()
        LOGGER.warning('DROPPING ALL NODES')
        self.query_executor()(query=g, get=FromResultSet.iterate)
        # we seem to mess this up easily
        leftover = self.query_executor()(query=self.g.V().hasId(TextP.startingWith(test_shard)).id(),
                                         get=FromResultSet.toList)
        self.query_executor()(query=self.g.V().hasId(TextP.startingWith(test_shard)).drop(),
                              get=FromResultSet.iterate)
        assert not leftover, f'we have some leftover: {leftover}'
        LOGGER.warning('COMPLETED DROP OF ALL NODES')
Exemple #21
0
class BulkLoad:
    def __init__(
            self,
            source,
            format='csv',
            role=None,
            mode='AUTO',
            region=None,
            fail_on_error=False,
            parallelism='OVERSUBSCRIBE',
            base_uri='http://aws.amazon.com/neptune/default',
            named_graph_uri='http://aws.amazon.com/neptune/vocab/v01/DefaultNamedGraph',
            update_single_cardinality_properties=False,
            endpoints=None):

        self.source = source
        self.format = format

        if role is None:
            assert ('NEPTUNE_LOAD_FROM_S3_ROLE_ARN'
                    in os.environ), 'role is missing.'
            self.role = os.environ['NEPTUNE_LOAD_FROM_S3_ROLE_ARN']
        else:
            self.role = role

        self.mode = mode

        if region is None:
            assert ('AWS_REGION' in os.environ), 'region is missing.'
            self.region = os.environ['AWS_REGION']
        else:
            self.region = region

        if endpoints is None:
            self.endpoints = Endpoints()
        else:
            self.endpoints = endpoints

        self.fail_on_error = 'TRUE' if fail_on_error else 'FALSE'
        self.parallelism = parallelism
        self.base_uri = base_uri
        self.named_graph_uri = named_graph_uri
        self.update_single_cardinality_properties = 'TRUE' if update_single_cardinality_properties else 'FALSE'

    def __load_from(self, source):
        return {
            'source':
            source,
            'format':
            self.format,
            'iamRoleArn':
            self.role,
            'mode':
            self.mode,
            'region':
            self.region,
            'failOnError':
            self.fail_on_error,
            'parallelism':
            self.parallelism,
            'parserConfiguration': {
                'baseUri': self.base_uri,
                'namedGraphUri': self.named_graph_uri
            },
            'updateSingleCardinalityProperties':
            self.update_single_cardinality_properties
        }

    def __load(self, loader_endpoint, data):

        json_string = json.dumps(data)
        json_bytes = json_string.encode('utf8')
        request_parameters = loader_endpoint.prepare_request(
            'POST', json_string)
        request_parameters.headers['Content-Type'] = 'application/json'
        req = urllib.request.Request(request_parameters.uri,
                                     data=json_bytes,
                                     headers=request_parameters.headers)
        try:
            response = urllib.request.urlopen(req)
            json_response = json.loads(response.read().decode('utf8'))
            return json_response['payload']['loadId']
        except HTTPError as e:
            exc_info = sys.exc_info()
            if e.code == 500:
                raise Exception(json.loads(e.read().decode('utf8'))) from None
            else:
                raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

    def load_async(self):
        localised_source = self.source.replace('${AWS_REGION}', self.region)
        loader_endpoint = self.endpoints.loader_endpoint()
        json_payload = self.__load_from(localised_source)
        print('''curl -X POST \\
    -H 'Content-Type: application/json' \\
    {} -d \'{}\''''.format(loader_endpoint, json.dumps(json_payload,
                                                       indent=4)))
        load_id = self.__load(loader_endpoint, json_payload)
        return BulkLoadStatus(self.endpoints.load_status_endpoint(load_id))

    def load(self, interval=2):
        status = self.load_async()
        print('status_uri: {}'.format(status.load_status_endpoint))
        status.wait(interval)
Exemple #22
0
 def sparql_endpoint(self, neptune_endpoint=None, neptune_port=None):
     return Endpoints(neptune_endpoint, neptune_port).sparql_endpoint()
    def __init__(self, endpoints=None):

        if endpoints is None:
            self.endpoints = Endpoints()
        else:
            self.endpoints = endpoints
Exemple #24
0
CLUSTER_ENDPOINT = os.environ['CLUSTER_ENDPOINT']
CLUSTER_PORT = os.environ['CLUSTER_PORT']
CLUSTER_REGION = os.environ['CLUSTER_REGION']
ENDPOINT_NAME = os.environ['ENDPOINT_NAME']
MODEL_BTW = float(os.environ['MODEL_BTW'])
QUEUE_URL = os.environ['QUEUE_URL']

transactions_id_cols = os.environ['TRANSACTION_ID_COLS']
transactions_cat_cols = os.environ['TRANSACTION_CAT_COLS']
dummied_col = os.environ['DUMMIED_COL']

sqs = boto3.client('sqs')
runtime = boto3.client('runtime.sagemaker')

endpoints = Endpoints(neptune_endpoint=CLUSTER_ENDPOINT,
                      neptune_port=CLUSTER_PORT,
                      region_name=CLUSTER_REGION)


def load_data_from_event(input_event, transactions_id_cols,
                         transactions_cat_cols, dummied_col):
    """Load and transform event data into correct format for next step subgraph loading and model inference input. 
        input event keys should come from related dataset.]
    
    Example:
    >>> load_data_from_event(event = {"transaction_data":[{"TransactionID":"3163166", "V1":1, ...]}, 'card1,card2,,...', 'M2_T,M3_F,M3_T,...')
    """
    TRANSACTION_ID = 'TransactionID'

    transactions_id_cols = transactions_id_cols.split(',')
    transactions_cat_cols = transactions_cat_cols.split(',')