async def submit(self, message, bindings=None): """ **coroutine** Submit a script and bindings to the Gremlin Server. :param message: Can be an instance of `RequestMessage<gremlin_python.driver.request.RequestMessage>` or `Bytecode<gremlin_python.process.traversal.Bytecode>` or a `str` representing a raw Gremlin script :param dict bindings: Optional bindings used with raw Grelmin :returns: :py:class:`ResultSet<aiogremlin.driver.resultset.ResultSet>` object """ if isinstance(message, traversal.Bytecode): message = request.RequestMessage( processor='traversal', op='bytecode', args={'gremlin': message, 'aliases': self._aliases}) elif isinstance(message, str): message = request.RequestMessage( processor='', op='eval', args={'gremlin': message, 'aliases': self._aliases}) if bindings: message.args.update({'bindings': bindings}) conn = await self.cluster.get_connection(hostname=self._hostname) resp = await conn.write(message) asyncio.create_task(conn.release_task(resp)) return resp
def submitAsync(self, message, bindings=None): if isinstance(message, traversal.Bytecode): message = request.RequestMessage(processor='traversal', op='bytecode', args={ 'gremlin': message, 'aliases': { 'g': self._traversal_source } }) elif isinstance(message, str): message = request.RequestMessage(processor='', op='eval', args={ 'gremlin': message, 'aliases': { 'g': self._traversal_source } }) if bindings: message.args.update({'bindings': bindings}) conn = self._pool.get(True) return conn.write(message)
def _kerberos_received(self, message): # Inspired by: https://github.com/thobbs/pure-sasl/blob/0.6.2/puresasl/mechanisms.py # https://github.com/thobbs/pure-sasl/blob/0.6.2/LICENSE try: import kerberos except ImportError: raise ImportError('Please install gremlinpython[kerberos].') # First pass: get service granting ticket and return it to gremlin-server if not self._kerberos_context: try: _, kerberos_context = kerberos.authGSSClientInit( self._kerberized_service, gssflags=kerberos.GSS_C_MUTUAL_FLAG) kerberos.authGSSClientStep(kerberos_context, '') auth = kerberos.authGSSClientResponse(kerberos_context) self._kerberos_context = kerberos_context except kerberos.KrbError as e: raise ConfigurationError( 'Kerberos authentication requires a valid service name in DriverRemoteConnection, ' 'as well as a valid tgt (export KRB5CCNAME) or keytab (export KRB5_KTNAME): ' + str(e)) return request.RequestMessage('', 'authentication', {'sasl': auth}) # Second pass: completion of authentication sasl_response = message['status']['attributes']['sasl'] if not self._username: result_code = kerberos.authGSSClientStep(self._kerberos_context, sasl_response) if result_code == kerberos.AUTH_GSS_COMPLETE: self._username = kerberos.authGSSClientUserName(self._kerberos_context) return request.RequestMessage('', 'authentication', {'sasl': ''}) # Third pass: sasl quality of protection (qop) handshake # Gremlin-server Krb5Authenticator only supports qop=QOP_AUTH; use ssl for confidentiality. # Handshake content format: # byte 0: the selected qop. 1==auth, 2==auth-int, 4==auth-conf # byte 1-3: the max length for any buffer sent back and forth on this connection. (big endian) # the rest of the buffer: the authorization user name in UTF-8 - not null terminated. kerberos.authGSSClientUnwrap(self._kerberos_context, sasl_response) data = kerberos.authGSSClientResponse(self._kerberos_context) plaintext_data = base64.b64decode(data) assert len(plaintext_data) == 4, "Unexpected response from gremlin server sasl handshake" word, = struct.unpack('!I', plaintext_data) qop_bits = word >> 24 assert self.QOP_AUTH_BIT & qop_bits, "Unexpected sasl qop level received from gremlin server" name_length = len(self._username) fmt = '!I' + str(name_length) + 's' word = self.QOP_AUTH_BIT << 24 | self.MAX_CONTENT_LENGTH out = struct.pack(fmt, word, self._username.encode("utf-8"),) encoded = base64.b64encode(out).decode('ascii') kerberos.authGSSClientWrap(self._kerberos_context, encoded) auth = kerberos.authGSSClientResponse(self._kerberos_context) return request.RequestMessage('', 'authentication', {'sasl': auth})
def submit_async(self, message, bindings=None, request_options=None): log.debug("message '%s'", str(message)) args = {'gremlin': message, 'aliases': {'g': self._traversal_source}} processor = '' op = 'eval' if isinstance(message, traversal.Bytecode): op = 'bytecode' processor = 'traversal' if isinstance(message, str) and bindings: args['bindings'] = bindings if self._session_enabled: args['session'] = str(self._session) processor = 'session' if isinstance(message, traversal.Bytecode) or isinstance(message, str): log.debug("processor='%s', op='%s', args='%s'", str(processor), str(op), str(args)) message = request.RequestMessage(processor=processor, op=op, args=args) conn = self._pool.get(True) if request_options: message.args.update(request_options) return conn.write(message)
async def _get(self, key): message = request.RequestMessage( 'traversal', 'gather', {'sideEffect': self._side_effect, 'sideEffectKey': key, 'aliases': self._client.aliases}) result_set = await self._client.submit(message) return await self._aggregate_results(result_set)
async def data_received(self, data, results_dict): data = data.decode('utf-8') message = self._message_serializer.deserialize_message(json.loads(data)) request_id = message['requestId'] status_code = message['status']['code'] data = message['result']['data'] msg = message['status']['message'] if request_id in results_dict: result_set = results_dict[request_id] aggregate_to = message['result']['meta'].get('aggregateTo', 'list') result_set.aggregate_to = aggregate_to if status_code == 407: auth = b''.join([b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8')]) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': base64.b64encode(auth).decode()}) await self.write(request_id, request_message) elif status_code == 204: result_set.queue_result(None) else: if data: for result in data: result = self._message_serializer.deserialize_message(result) message = Message(status_code, result, msg) result_set.queue_result(message) else: message = Message(status_code, data, msg) result_set.queue_result(message) if status_code != 206: result_set.queue_result(None)
def data_received(self, message, results_dict): message = self._message_serializer.deserialize_message( json.loads(message.decode('utf-8'))) request_id = message['requestId'] result_set = results_dict[request_id] status_code = message['status']['code'] aggregate_to = message['result']['meta'].get('aggregateTo', 'list') data = message['result']['data'] result_set.aggregate_to = aggregate_to if status_code == 407: auth = b''.join([ b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8') ]) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': base64.b64encode(auth).decode()}) self.write(request_id, request_message) data = self._transport.read() # Allow recursive call for auth self.data_received(data, results_dict) elif status_code == 204: result_set.stream.put_nowait([]) del results_dict[request_id] return status_code elif status_code in [200, 206]: result_set.stream.put_nowait(data) if status_code == 200: result_set.status_attributes = message['status']['attributes'] del results_dict[request_id] return status_code else: del results_dict[request_id] raise GremlinServerError(message["status"])
def data_received(self, message, results_dict): # if Gremlin Server cuts off then we get a None for the message if message is None: logging.error("Received empty message from server.") raise GremlinServerError({ 'code': 500, 'message': 'Server disconnected - please try to reconnect', 'attributes': {} }) message = self._message_serializer.deserialize_message(message) request_id = message['requestId'] result_set = results_dict[ request_id] if request_id in results_dict else ResultSet( None, None) status_code = message['status']['code'] aggregate_to = message['result']['meta'].get('aggregateTo', 'list') data = message['result']['data'] result_set.aggregate_to = aggregate_to if status_code == 407: if self._username and self._password: auth_bytes = b''.join([ b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8') ]) auth = base64.b64encode(auth_bytes) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': auth.decode()}) elif self._kerberized_service: request_message = self._kerberos_received(message) else: error_message = 'Gremlin server requires authentication credentials in DriverRemoteConnection. ' \ 'For basic authentication provide username and password. ' \ 'For kerberos authentication provide the kerberized_service parameter.' logging.error(error_message) raise ConfigurationError(error_message) self.write(request_id, request_message) data = self._transport.read() # Allow for auth handshake with multiple steps return self.data_received(data, results_dict) elif status_code == 204: result_set.stream.put_nowait([]) del results_dict[request_id] return status_code elif status_code in [200, 206]: result_set.stream.put_nowait(data) if status_code == 200: result_set.status_attributes = message['status']['attributes'] del results_dict[request_id] return status_code else: # This message is going to be huge and kind of hard to read, but in the event of an error, # it can provide invaluable info, so space it out appropriately. logging.error( "\r\nReceived error message '%s'\r\n\r\nWith results dictionary '%s'", str(message), str(results_dict)) del results_dict[request_id] raise GremlinServerError(message['status'])
async def test_stream_done(connection): async with connection: message = request.RequestMessage(processor='', op='eval', args={'gremlin': '1 + 1'}) stream = await connection.write(message) async for msg in stream: pass assert stream.done
async def test_server_error(connection): async with connection: message = request.RequestMessage(processor='', op='eval', args={'gremlin': 'g. V jla;sdf'}) with pytest.raises(exception.GremlinServerError): stream = await connection.write(message) async for msg in stream: pass
async def data_received(self, data, results_dict=None): if data is None: raise protocol.GremlinServerError({ 'code': 500, 'message': 'Server disconnected', 'attributes': {} }) if results_dict is None: results_dict = {} self.__log.debug(f"{data=}") message = self._message_serializer.deserialize_message(data) self.__log.debug(f"{message=}") request_id = message['requestId'] status_code = message['status']['code'] result_data = message['result']['data'] msg = message['status']['message'] if request_id in results_dict: self.__log.debug(f"{request_id=} is in {results_dict=}") result_set = results_dict[request_id] else: result_set = ResultSet(None, None) aggregate_to = message['result']['meta'].get('aggregateTo', 'list') result_set.aggregate_to = aggregate_to if status_code == 407: auth = b''.join([ b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8') ]) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': base64.b64encode(auth).decode()}) await self.write(request_id, request_message) elif status_code == 204: self.__log.debug(f"{status_code=} Queuing None to ResultSet") result_set.queue_result(None) else: self.__log.debug(f"{status_code=}") if result_data: self.__log.debug(f"{result_data=}") for result in result_data: self.__log.debug(f"{result=}") message = Message(status_code, result, msg) self.__log.debug(f"Queuing {message=}") result_set.queue_result(message) #message = Message(status_code, result_data, msg) #result_set.queue_result(message) else: message = Message(status_code, data, msg) result_set.queue_result(message) if status_code != 206: result_set.queue_result(None)
async def test_connection_response_timeout(connection): async with connection: message = request.RequestMessage(processor='', op='eval', args={'gremlin': '1 + 1'}) connection._response_timeout = 0.0000001 with pytest.raises(exception.ResponseTimeoutError): stream = await connection.write(message) async for msg in stream: pass
async def close(self): """Release side effects""" if not self._closed: message = request.RequestMessage( 'traversal', 'close', {'sideEffect': self._side_effect, 'aliases': {'g': self._client.aliases}}) result_set = await self._client.submit(message) self._closed = True return await result_set.one()
async def test_204_empty_stream(connection, aliases): resp = False async with connection: message = request.RequestMessage( processor='', op='eval', args={'gremlin': 'g.V().has("unlikely", "even less likely")'}) stream = await connection.write(message) async for msg in stream: resp = True assert not resp
async def test_resp_queue_removed_from_conn(connection): async with connection: message = request.RequestMessage(processor='', op='eval', args={'gremlin': '1 + 1'}) stream = await connection.submit(message) async for msg in stream: pass await asyncio.sleep(0) assert stream not in list(connection._result_sets.values())
async def test_submit(connection): async with connection: message = request.RequestMessage(processor='', op='eval', args={'gremlin': '1 + 1'}) stream = await connection.write(message) results = [] async for msg in stream: results.append(msg) assert len(results) == 1 assert results[0] == 2
def close(self): message = request.RequestMessage(processor='session', op='close', args={ 'session': self._session_id, 'manageTransaction': False, 'force': False }) conn = self._pool.get(True) conn.write(message).result() super(SessionedClient, self).close()
async def keys(self): """Get side effect keys associated with Traversal""" if not self._closed: message = request.RequestMessage( 'traversal', 'keys', {'sideEffect': self._side_effect, 'aliases': self._client.aliases}) result_set = await self._client.submit(message) results = await result_set.all() self._keys = set(results) return self._keys
def keys(self): if not self._closed: message = request.RequestMessage( 'traversal', 'keys', { 'sideEffect': self._side_effect, 'aliases': { 'g': self._client.traversal_source } }) self._keys = set(self._client.submit(message).all().result()) return self._keys
def close(self): if not self._closed: message = request.RequestMessage( 'traversal', 'close', { 'sideEffect': self._side_effect, 'aliases': { 'g': self._client._traversal_source } }) results = self._client.submit(message).all().result() self._closed = True return results
def submit(gremlin, user_ns, aliases, conn): """ Submit a script to the Gremlin Server using the IPython namespace using the IPython namespace to pass bindings using Magics configuration and a connection registered with :py:class:`ConnectionRegistry<gremlin.registry.ConnectionRegistry>` """ bindings = _sanitize_namespace(user_ns) message = request.RequestMessage( processor='', op='eval', args={'gremlin': gremlin, 'aliases': aliases, 'bindings': bindings}) return asyncio.run_coroutine_threadsafe(_submit(conn, message, aliases), registry.LOOP).result()
async def test_alias_serialization(event_loop): alias = { 'g': 'g' } message = '1 + 1' cluster = await driver.Cluster.open(event_loop, aliases=alias) client = await cluster.connect() # This is the code client/conn uses on submit message = request.RequestMessage( processor='', op='eval', args={'gremlin': message, 'aliases': client._aliases}) request_id = str(uuid.uuid4()) message = serializer.GraphSONMessageSerializer().serialize_message( request_id, message) message = message.decode('utf-8')[34:] aliases = json.loads(message)['args']['aliases'] assert aliases == alias await cluster.close()
def get(self, key): if not self._side_effects.get(key): if not self._closed: message = request.RequestMessage( 'traversal', 'gather', { 'sideEffect': self._side_effect, 'sideEffectKey': key, 'aliases': { 'g': self._client.traversal_source } }) results = self._aggregate_results(self._client.submit(message)) self._side_effects[key] = results self._keys.add(key) else: return None return self._side_effects[key]
def data_received(self, message, results_dict): # if Gremlin Server cuts off then we get a None for the message if message is None: raise GremlinServerError({'code': 500, 'message': 'Server disconnected - please try to reconnect', 'attributes': {}}) message = self._message_serializer.deserialize_message(message) request_id = message['requestId'] result_set = results_dict[request_id] if request_id in results_dict else ResultSet(None, None) status_code = message['status']['code'] aggregate_to = message['result']['meta'].get('aggregateTo', 'list') data = message['result']['data'] result_set.aggregate_to = aggregate_to if status_code == 407: if self._username and self._password: auth_bytes = b''.join([b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8')]) auth = base64.b64encode(auth_bytes) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': auth.decode()}) elif self._kerberized_service: request_message = self._kerberos_received(message) else: raise ConfigurationError( 'Gremlin server requires authentication credentials in DriverRemoteConnection.' 'For basic authentication provide username and password. ' 'For kerberos authentication provide the kerberized_service parameter.') self.write(request_id, request_message) data = self._transport.read() # Allow for auth handshake with multiple steps return self.data_received(data, results_dict) elif status_code == 204: result_set.stream.put_nowait([]) del results_dict[request_id] return status_code elif status_code in [200, 206]: result_set.stream.put_nowait(data) if status_code == 200: result_set.status_attributes = message['status']['attributes'] del results_dict[request_id] return status_code else: del results_dict[request_id] raise GremlinServerError(message['status'])
def data_received(self, message, results_dict): # if Gremlin Server cuts off then we get a None for the message if message is None: raise GremlinServerError({ 'code': 500, 'message': 'Server disconnected - please try to reconnect', 'attributes': {} }) message = self._message_serializer.deserialize_message(message) request_id = message['requestId'] result_set = results_dict[ request_id] if request_id in results_dict else ResultSet( None, None) status_code = message['status']['code'] aggregate_to = message['result']['meta'].get('aggregateTo', 'list') data = message['result']['data'] result_set.aggregate_to = aggregate_to if status_code == 407: auth = b''.join([ b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8') ]) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': base64.b64encode(auth).decode()}) self.write(request_id, request_message) data = self._transport.read() # Allow recursive call for auth return self.data_received(data, results_dict) elif status_code == 204: result_set.stream.put_nowait([]) del results_dict[request_id] return status_code elif status_code in [200, 206]: result_set.stream.put_nowait(data) if status_code == 200: result_set.status_attributes = message['status']['attributes'] del results_dict[request_id] return status_code else: del results_dict[request_id] raise GremlinServerError(message["status"])
def data_received(self, data, results_dict): data = json.loads(data.decode('utf-8')) request_id = data['requestId'] result_set = results_dict[request_id] status_code = data['status']['code'] aggregate_to = data['result']['meta'].get('aggregateTo', 'list') result_set.aggregate_to = aggregate_to if status_code == 407: auth = b''.join([ b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8') ]) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': base64.b64encode(auth).decode()}) self.write(request_id, request_message) data = self._transport.read() self.data_received(data, results_dict) elif status_code == 204: result_set.stream.put_nowait([]) del results_dict[request_id] elif status_code in [200, 206]: results = [] # this is a bit of a hack for now. basically the protocol.py picks the json apart and doesn't # account for types too well right now. if self._message_serializer.version == b"application/vnd.gremlin-v2.0+json": for msg in data["result"]["data"]: results.append( self._message_serializer.deserialize_message(msg)) else: results = self._message_serializer.deserialize_message( data["result"]["data"]["@value"]) result_set.stream.put_nowait(results) if status_code == 206: data = self._transport.read() self.data_received(data, results_dict) else: del results_dict[request_id] else: del results_dict[request_id] raise GremlinServerError("{0}: {1}".format( status_code, data["status"]["message"]))
def submitAsync(self, message, bindings=None): if isinstance(message, str): message = request.RequestMessage(processor='session', op='eval', args={ 'gremlin': message, 'aliases': { 'g': self._traversal_source }, 'session': self._session_id, 'manageTransaction': False }) if bindings: message.args.update({'bindings': bindings}) else: raise Exception('Unsupported message type: {}'.format( type(message))) conn = self._pool.get(True) return conn.write(message)
def data_received(self, data, results_dict): data = json.loads(data.decode('utf-8')) request_id = data['requestId'] result_set = results_dict[request_id] status_code = data['status']['code'] aggregate_to = data['result']['meta'].get('aggregateTo', 'list') result_set.aggregate_to = aggregate_to if status_code == 407: auth = b''.join([ b'\x00', self._username.encode('utf-8'), b'\x00', self._password.encode('utf-8') ]) request_message = request.RequestMessage( 'traversal', 'authentication', {'sasl': base64.b64encode(auth).decode()}) self.write(request_id, request_message) data = self._transport.read() self.data_received(data, results_dict) elif status_code == 204: result_set.stream.put_nowait([]) del results_dict[request_id] elif status_code in [200, 206]: results = [] for msg in data["result"]["data"]: results.append( self._message_serializer.deserialize_message(msg)) result_set.stream.put_nowait(results) if status_code == 206: data = self._transport.read() self.data_received(data, results_dict) else: del results_dict[request_id] else: del results_dict[request_id] raise GremlinServerError("{0}: {1}".format( status_code, data["status"]["message"]))
def _close_session(self): message = request.RequestMessage(processor='session', op='close', args={'session': self._session}) conn = self._pool.get(True) return conn.write(message).result()