示例#1
0
def do_load(args):
    print("Do load")
    with open(args.filename, 'rb') as fd:
        batches = batch_pb2.BatchList()
        batches.ParseFromString(fd.read())

    stream = Stream(args.url)
    futures = []
    start = time.time()

    for batch_list in _split_batch_list(batches):
        future = stream.send(message_type=Message.CLIENT_BATCH_SUBMIT_REQUEST,
                             content=batch_list.SerializeToString())
        futures.append(future)

    for future in futures:
        result = future.result()
        assert (result.message_type == Message.CLIENT_BATCH_SUBMIT_RESPONSE)

    stop = time.time()
    print("batches: {} batch/sec: {}".format(
        str(len(batches.batches)),
        len(batches.batches) / (stop - start)))

    stream.close()
示例#2
0
def do_load(args):
    with open(args.filename, mode='rb') as fd:
        batches = batch_pb2.BatchList()
        batches.ParseFromString(fd.read())

    stream = Stream(args.url)
    futures = []
    start = time.time()

    for batch_list in _split_batch_list(batches):
        future = stream.send(
            message_type='system/load',
            content=batch_list.SerializeToString())
        futures.append(future)

    for future in futures:
        result = future.result()
        assert result.message_type == 'system/load-response'

    stop = time.time()
    print("batches: {} batch/sec: {}".format(
        str(len(batches.batches)),
        len(batches.batches) / (stop - start)))

    stream.close()
示例#3
0
def do_load(args):
    with open(args.filename, mode='rb') as fd:
        batches = batch_pb2.BatchList()
        batches.ParseFromString(fd.read())

    stream = Stream(args.url)
    futures = []
    start = time.time()

    for batch_list in _split_batch_list(batches):
        future = stream.send(message_type=Message.CLIENT_BATCH_SUBMIT_REQUEST,
                             content=batch_list.SerializeToString())
        futures.append(future)

    for future in futures:
        result = future.result()
        try:
            assert result.message_type == Message.CLIENT_BATCH_SUBMIT_RESPONSE
        except ValidatorConnectionError as vce:
            LOGGER.warning("the future resolved to %s", vce)

    stop = time.time()
    print("batches: {} batch/sec: {}".format(
        str(len(batches.batches)),
        len(batches.batches) / (stop - start)))

    stream.close()
示例#4
0
class TransactionProcessor(object):
    def __init__(self, url):
        self._stream = Stream(url)
        self._handlers = []

    def add_handler(self, handler):
        self._handlers.append(handler)

    def start(self):
        self._stream.connect()

        futures = []
        for handler in self._handlers:
            for version in handler.family_versions:
                for encoding in handler.encodings:
                    future = self._stream.send(
                        message_type='tp/register',
                        content=TransactionProcessorRegisterRequest(
                            family=handler.family_name,
                            version=version,
                            encoding=encoding,
                            namespaces=handler.namespaces).SerializeToString())
                    futures.append(future)

        for future in futures:
            LOGGER.info("future result: %s", repr(future.result))

        while True:
            msg = self._stream.receive()
            LOGGER.info("received %s", msg.message_type)

            request = TransactionProcessRequest()
            request.ParseFromString(msg.content)
            state = State(self._stream, request.context_id)

            try:
                self._handlers[0].apply(request, state)
                self._stream.send_back(message_type=MessageType.TP_RESPONSE,
                                       correlation_id=msg.correlation_id,
                                       content=TransactionProcessResponse(
                                           status=TransactionProcessResponse.OK
                                       ).SerializeToString())
            except InvalidTransaction as it:
                LOGGER.warning("Invalid Transaction %s", it)
                self._stream.send_back(
                    message_type=MessageType.TP_RESPONSE,
                    correlation_id=msg.correlation_id,
                    content=TransactionProcessResponse(
                        status=TransactionProcessResponse.INVALID_TRANSACTION).
                    SerializeToString())
            except InternalError as ie:
                LOGGER.warning("State Error! %s", ie)
                self._stream.send_back(message_type=MessageType.TP_RESPONSE,
                                       correlation_id=msg.correlation_id,
                                       content=TransactionProcessResponse(
                                           status=TransactionProcessResponse.
                                           INTERNAL_ERROR).SerializeToString())
示例#5
0
def do_load(args):
    print "Do load"
    with open(args.filename) as fd:
        batches = batch_pb2.BatchList()
        batches.ParseFromString(fd.read())

    stream = Stream(args.url)
    stream.connect()

    futures = []
    start = time.time()

    for batch_list in _split_batch_list(batches):
        future = stream.send(
            message_type='system/load',
            content=batch_list.SerializeToString())
        futures.append(future)

    for future in futures:
        result = future.result()
        assert(result.message_type == 'system/load-response')

    stop = time.time()
    print "batches: {} batch/sec: {}".format(
        str(len(batches.batches)),
        len(batches.batches) / (stop - start))

    stream.close()
示例#6
0
class RouteHandler(object):
    def __init__(self, stream_url):
        self._stream = Stream(stream_url)

    @asyncio.coroutine
    def hello(self, request):
        text = "Hello World \n"
        return web.Response(text=text)

    @asyncio.coroutine
    def batches(self, request):
        """
        Takes protobuf binary from HTTP POST, and sends it to the validator
        """
        mime_type = 'application/octet-stream'
        type_msg = 'Expected an octet-stream encoded Protobuf binary'
        type_error = web.HTTPBadRequest(reason=type_msg)

        if request.headers['Content-Type'] != mime_type:
            return type_error

        payload = yield from request.read()
        validator_response = self._try_validator_request(
            Message.CLIENT_BATCH_SUBMIT_REQUEST,
            payload
        )
        response = client.ClientBatchSubmitResponse()
        response.ParseFromString(validator_response)
        return RouteHandler._try_client_response(request.headers, response)

    @asyncio.coroutine
    def state_current(self, request):
        # CLIENT_STATE_CURRENT_REQUEST
        return self._generic_get(
            web_request=request,
            msg_type=Message.CLIENT_STATE_CURRENT_REQUEST,
            msg_content=client.ClientStateCurrentRequest(),
            resp_proto=client.ClientStateCurrentResponse,
        )

    @asyncio.coroutine
    def state_list(self, request):
        # CLIENT_STATE_LIST_REQUEST
        root = RouteHandler._safe_get(request.match_info, 'merkle_root')
        # if no prefix is defined return all
        prefix = RouteHandler._safe_get(request.rel_url.query, 'prefix')
        client_request = client.ClientStateListRequest(merkle_root=root,
                                                       prefix=prefix)
        return self._generic_get(
            web_request=request,
            msg_type=Message.CLIENT_STATE_LIST_REQUEST,
            msg_content=client_request,
            resp_proto=client.ClientStateListResponse,
        )

    @asyncio.coroutine
    def state_get(self, request):
        # CLIENT_STATE_GET_REQUEST
        nonleaf_msg = 'Expected a specific leaf address, ' \
                      'but received a prefix instead'

        root = RouteHandler._safe_get(request.match_info, 'merkle_root')
        addr = RouteHandler._safe_get(request.match_info, 'address')
        client_request = client.ClientStateGetRequest(merkle_root=root,
                                                      address=addr)

        validator_response = self._try_validator_request(
            Message.CLIENT_STATE_GET_REQUEST,
            client_request
        )

        parsed_response = RouteHandler._old_response_parse(
            client.ClientStateGetResponse,
            validator_response
        )

        if parsed_response.status == client.ClientStateGetResponse.NONLEAF:
            raise web.HTTPBadRequest(reason=nonleaf_msg)

        return RouteHandler._try_client_response(
            request.headers,
            parsed_response
        )

    @asyncio.coroutine
    def block_list(self, request):
        """
        Fetch a list of blocks from the validator
        """
        response = self._query_validator(
            Message.CLIENT_BLOCK_LIST_REQUEST,
            client.ClientBlockListResponse,
            client.ClientBlockListRequest()
        )

        blocks = [RouteHandler._expand_block(b) for b in response['blocks']]
        return RouteHandler._wrap_response(data=blocks)

    @asyncio.coroutine
    def block_get(self, request):
        """
        Fetch a list of blocks from the validator
        """
        block_id = RouteHandler._safe_get(request.match_info, 'block_id')
        request = client.ClientBlockGetRequest(block_id=block_id)

        response = self._query_validator(
            Message.CLIENT_BLOCK_GET_REQUEST,
            client.ClientBlockGetResponse,
            request
        )

        block = RouteHandler._expand_block(response['block'])
        return RouteHandler._wrap_response(data=block)

    @staticmethod
    def _safe_get(obj, key, default=''):
        """
        aiohttp very helpfully parses param strings to replace '+' with ' '
        This is very bad when your block ids contain meaningful +'s
        """
        return obj.get(key, default).replace(' ', '+')

    def _query_validator(self, request_type, response_proto, content):
        """
        Sends a request to the validator and parses the response
        """
        response = self._try_validator_request(request_type, content)
        return RouteHandler._try_response_parse(response_proto, response)

    def _try_validator_request(self, message_type, content):
        """
        Sends a protobuf message to the validator
        Handles a possible timeout if validator is unresponsive
        """
        timeout = 5
        timeout_msg = 'Could not reach validator, validator timed out'

        if isinstance(content, BaseMessage):
            content = content.SerializeToString()

        future = self._stream.send(message_type=message_type, content=content)

        try:
            response = future.result(timeout=timeout)
        except FutureTimeoutError:
            raise web.HTTPServiceUnavailable(reason=timeout_msg)

        return response.content

    @staticmethod
    def _try_response_parse(proto, response):
        """
        Parses a protobuf response from the validator
        Raises common validator error statuses as HTTP errors
        """
        unknown_msg = 'An unknown error occured with your request'
        notfound_msg = 'There is no resource at that root, address or prefix'

        parsed = proto()
        parsed.ParseFromString(response)

        try:
            if parsed.status == proto.ERROR:
                raise web.HTTPInternalServerError(reason=unknown_msg)
            if parsed.status == proto.NORESOURCE:
                raise web.HTTPNotFound(reason=notfound_msg)
        except AttributeError:
            # Not every protobuf has every status, so pass AttributeErrors
            pass

        return MessageToDict(parsed, preserving_proto_field_name=True)

    @staticmethod
    def _wrap_response(data=None, head=None, link=None):
        """
        Creates a JSON response envelope and sends it back to the client
        """
        envelope = {}

        if data:
            envelope['data'] = data
        if head:
            envelope['head'] = head
        if link:
            envelope['link'] = link

        return web.Response(
            content_type='application/json',
            text=json.dumps(
                envelope,
                indent=2,
                separators=(',', ': '),
                sort_keys=True
            )
        )

    @staticmethod
    def _expand_block(block):
        RouteHandler._parse_header(BlockHeader, block)
        if 'batches' in block:
            block['batches'] = [RouteHandler._expand_batch(b)
                                for b in block['batches']]
        return block

    @staticmethod
    def _expand_batch(batch):
        RouteHandler._parse_header(BatchHeader, batch)
        if 'transactions' in batch:
            batch['transactions'] = [RouteHandler._expand_transaction(t)
                                     for t in batch['transactions']]
        return batch

    @staticmethod
    def _expand_transaction(transaction):
        return RouteHandler._parse_header(TransactionHeader, transaction)

    @staticmethod
    def _parse_header(header_proto, obj):
        """
        A helper method to parse a byte string encoded protobuf 'header'
        Args:
            header_proto: The protobuf class of the encoded header
            obj: The dict formatted object containing the 'header'
        """
        header = header_proto()
        header_bytes = base64.b64decode(obj['header'])
        header.ParseFromString(header_bytes)
        obj['header'] = MessageToDict(header, preserving_proto_field_name=True)
        return obj

    def _generic_get(self, web_request, msg_type, msg_content, resp_proto):
        """
        Used by pre-spec /state routes
        Should be removed when routes are updated to spec
        """
        response = self._try_validator_request(msg_type, msg_content)
        parsed = RouteHandler._old_response_parse(resp_proto, response)
        return RouteHandler._try_client_response(web_request.headers, parsed)

    @staticmethod
    def _old_response_parse(proto, response):
        """
        Used by pre-spec /state routes
        Should be removed when routes are updated to spec
        """
        unknown_msg = 'An unknown error occured with your request'
        notfound_msg = 'There is no resource at that root, address or prefix'

        parsed = proto()
        parsed.ParseFromString(response)

        try:
            if parsed.status == proto.ERROR:
                raise web.HTTPInternalServerError(reason=unknown_msg)
            if parsed.status == proto.NORESOURCE:
                raise web.HTTPNotFound(reason=notfound_msg)
        except AttributeError:
            # Not every protobuf has every status, so pass AttributeErrors
            pass

        return parsed

    @staticmethod
    def _try_client_response(headers, parsed):
        """
        Used by pre-spec /state and /batches routes
        Should be removed when routes are updated to spec
        """
        media_msg = 'The requested media type is unsupported'
        mime_type = None
        sub_type = None

        try:
            accept_types = headers['Accept']
            mime_type, sub_type, _, _ = parse_mimetype(accept_types)
        except KeyError:
            pass

        if mime_type == 'application' and sub_type == 'octet-stream':
            return web.Response(
                content_type='application/octet-stream',
                body=parsed.SerializeToString()
            )

        if ((mime_type in ['application', '*'] or mime_type is None)
                and (sub_type in ['json', '*'] or sub_type is None)):
            return web.Response(
                content_type='application/json',
                text=MessageToJson(parsed)
            )

        raise web.HTTPUnsupportedMediaType(reason=media_msg)
示例#7
0
 def __init__(self, url):
     self._stream = Stream(url)
     self._url = url
     self._handlers = []
示例#8
0
class TransactionProcessor(object):
    def __init__(self, url):
        self._stream = Stream(url)
        self._url = url
        self._handlers = []

    def add_handler(self, handler):
        """Add a transaction family handler
        :param handler:
        """
        self._handlers.append(handler)

    def _find_handler(self, header):
        """Find a handler for a particular (family_name,
        family_versions, payload_encoding)
        :param header transaction_pb2.TransactionHeader:
        :return: handler
        """
        return list(filter(lambda h: header.family_name == h.family_name and
                           header.family_version in h.family_versions and
                           header.payload_encoding in h.encodings,
                           self._handlers))[0]

    def _register_requests(self):
        """Returns all of the TpRegisterRequests for handlers

        :return (list): list of TpRegisterRequests
        """
        return itertools.chain.from_iterable(  # flattens the nested list
            [
                [TpRegisterRequest(
                    family=n,
                    version=v,
                    encoding=e,
                    namespaces=h.namespaces)
                 for n, v, e in itertools.product(
                    [h.family_name],
                     h.family_versions,
                     h.encodings)] for h in self._handlers])

    def _unregister_request(self):
        """Returns a single TP_UnregisterRequest that requests
        that the validator stop sending transactions for previously
        registered handlers.

        :return (processor_pb2.TpUnregisterRequest):
        """
        return TpUnregisterRequest()

    def _process(self, msg):
        request = TpProcessRequest()
        request.ParseFromString(msg.content)
        state = State(self._stream, request.context_id)
        header = TransactionHeader()
        header.ParseFromString(request.header)
        try:
            if not self._stream.is_ready():
                raise ValidatorConnectionError()
            self._find_handler(header).apply(request, state)
            self._stream.send_back(
                message_type=Message.TP_PROCESS_RESPONSE,
                correlation_id=msg.correlation_id,
                content=TpProcessResponse(
                    status=TpProcessResponse.OK
                ).SerializeToString())
        except InvalidTransaction as it:
            LOGGER.warning("Invalid Transaction %s", it)
            try:
                self._stream.send_back(
                    message_type=Message.TP_PROCESS_RESPONSE,
                    correlation_id=msg.correlation_id,
                    content=TpProcessResponse(
                        status=TpProcessResponse.INVALID_TRANSACTION
                    ).SerializeToString())
            except ValidatorConnectionError as vce:
                # TP_PROCESS_REQUEST has made it through the
                # handler.apply and an INVALID_TRANSACTION would have been
                # sent back but the validator has disconnected and so it
                # doesn't care about the response.
                LOGGER.warning("during invalid transaction response: %s", vce)
        except InternalError as ie:
            LOGGER.warning("internal error: %s", ie)
            try:
                self._stream.send_back(
                    message_type=Message.TP_PROCESS_RESPONSE,
                    correlation_id=msg.correlation_id,
                    content=TpProcessResponse(
                        status=TpProcessResponse.INTERNAL_ERROR
                    ).SerializeToString())
            except ValidatorConnectionError as vce:
                # Same as the prior except block, but an internal error has
                # happened, but because of the disconnect the validator
                # probably doesn't care about the response.
                LOGGER.warning("during internal error response: %s", vce)
        except ValidatorConnectionError as vce:
            # Somewhere within handler.apply a future resolved with an
            # error status that the validator has disconnected. There is
            # nothing left to do but reconnect.
            LOGGER.warning("during handler.apply a future was resolved "
                           "with error status: %s", vce)

    def _process_future(self, future, timeout=None, sigint=False):
        try:
            msg = future.result(timeout)
        except CancelledError:
            # This error is raised when Task.cancel is called on
            # disconnect from the validator in stream.py, for
            # this future.
            return
        if msg is RECONNECT_EVENT:
            if sigint is False:
                LOGGER.info("reregistering with validator")
                self._stream.wait_for_ready()
                self._register()
        else:
            LOGGER.debug(
                'received message of type: %s',
                Message.MessageType.Name(msg.message_type))
            self._process(msg)

    def _register(self):
        futures = []
        for message in self._register_requests():
            self._stream.wait_for_ready()
            future = self._stream.send(
                message_type=Message.TP_REGISTER_REQUEST,
                content=message.SerializeToString())
            futures.append(future)

        for future in futures:
            resp = TpRegisterResponse()
            try:
                resp.ParseFromString(future.result().content)
                LOGGER.info("register attempt: %s",
                            TpRegisterResponse.Status.Name(resp.status))
            except ValidatorConnectionError as vce:
                LOGGER.info("during waiting for response on registration: %s",
                            vce)

    def _unregister(self):
        message = self._unregister_request()
        self._stream.wait_for_ready()
        future = self._stream.send(
            message_type=Message.TP_UNREGISTER_REQUEST,
            content=message.SerializeToString())
        response = TpUnregisterResponse()
        try:
            response.ParseFromString(future.result(1).content)
            LOGGER.info("unregister attempt: %s",
                        TpUnregisterResponse.Status.Name(response.status))
        except ValidatorConnectionError as vce:
            LOGGER.info("during waiting for response on unregistration: %s",
                        vce)

    def start(self):
        fut = None
        try:
            self._register()
            while True:
                # During long running processing this
                # is where the transaction processor will
                # spend most of its time
                fut = self._stream.receive()
                self._process_future(fut)
        except KeyboardInterrupt:
            try:
                # tell the validator to not send any more messages
                self._unregister()
                while True:
                    if fut is not None:
                        # process futures as long as the tp has them,
                        # if the TP_PROCESS_REQUEST doesn't come from
                        # zeromq->asyncio in 1 second raise a
                        # concurrent.futures.TimeOutError and be done.
                        self._process_future(fut, 1, sigint=True)
                        fut = self._stream.receive()
            except concurrent.futures.TimeoutError:
                # Where the tp will usually exit after
                # a KeyboardInterrupt. Caused by the 1 second
                # timeout in _process_future.
                pass
            except FutureTimeoutError:
                # If the validator is not able to respond to the
                # unregister request, exit.
                pass

    def stop(self):
        self._stream.close()
示例#9
0
class Routes(object):
    def __init__(self, stream_url):
        self._stream = Stream(stream_url)

    def _try_validator_request(self, message_type, content):
        """
        Sends a protobuf message to the validator
        Handles a possible timeout if validator is unresponsive
        """
        timeout = 5
        timeout_msg = 'Could not reach validator, validator timed out'

        if isinstance(content, BaseMessage):
            content = content.SerializeToString()

        future = self._stream.send(message_type=message_type, content=content)

        try:
            response = future.result(timeout=timeout)
        except FutureTimeoutError as e:
            print(str(e))
            raise web.HTTPGatewayTimeout(reason=timeout_msg)

        return response.content

    def _try_response_parse(self, proto, response):
        """
        Parses a protobuf response from the validator
        Raises common validator error statuses as HTTP errors
        """
        unknown_msg = 'An unknown error occured with your request'
        notfound_msg = 'There is no resource at that root, address or prefix'

        parsed = proto()
        parsed.ParseFromString(response)

        try:
            if parsed.status == proto.ERROR:
                raise web.HTTPInternalServerError(reason=unknown_msg)
            if parsed.status == proto.NORESOURCE:
                raise web.HTTPNotFound(reason=notfound_msg)
        except AttributeError:
            pass

        return parsed

    def _try_client_response(self, headers, parsed):
        """
        Sends a response back to the client based on Accept header
        Defaults to JSON
        """
        media_msg = 'The requested media type is unsupported'
        mime_type = None
        sub_type = None

        try:
            accept_types = headers['Accept']
            mime_type, sub_type, _, _ = parse_mimetype(accept_types)
        except KeyError:
            pass

        if mime_type == 'application' and sub_type == 'octet-stream':
            return web.Response(
                content_type='application/octet-stream',
                body=parsed.SerializeToString()
            )

        if ((mime_type in ['application', '*'] or mime_type is None)
                and (sub_type in ['json', '*'] or sub_type is None)):
            return web.Response(
                content_type='application/json',
                text=MessageToJson(parsed)
            )

        raise web.HTTPUnsupportedMediaType(reason=media_msg)

    def _generic_get(self, web_request, msg_type, msg_content, resp_proto):
        response = self._try_validator_request(msg_type, msg_content)
        parsed = self._try_response_parse(resp_proto, response)
        return self._try_client_response(web_request.headers, parsed)

    @asyncio.coroutine
    def hello(self, request):
        text = "Hello World \n"
        return web.Response(text=text)

    @asyncio.coroutine
    def batches(self, request):
        """
        Takes protobuf binary from HTTP POST, and sends it to the validator
        """
        mime_type = 'application/octet-stream'
        type_msg = 'Expected an octet-stream encoded Protobuf binary'
        type_error = web.HTTPBadRequest(reason=type_msg)

        if request.headers['Content-Type'] != mime_type:
            return type_error

        payload = yield from request.read()
        validator_response = self._try_validator_request(
            Message.CLIENT_BATCH_SUBMIT_REQUEST,
            payload
        )
        response = client.ClientStateCurrentResponse()
        response.ParseFromString(validator_response)
        return self._try_client_response(request.headers, response)

    @asyncio.coroutine
    def state_current(self, request):
        # CLIENT_STATE_CURRENT_REQUEST
        return self._generic_get(
            web_request=request,
            msg_type=Message.CLIENT_STATE_CURRENT_REQUEST,
            msg_content=client.ClientStateCurrentRequest(),
            resp_proto=client.ClientStateCurrentResponse,
        )

    @asyncio.coroutine
    def state_list(self, request):
        # CLIENT_STATE_LIST_REQUEST
        root = request.match_info.get("merkle_root", "")
        params = request.rel_url.query
        # if no prefix is defined return all
        prefix = params.get("prefix", "")
        client_request = client.ClientStateListRequest(merkle_root=root,
                                                       prefix=prefix)
        return self._generic_get(
            web_request=request,
            msg_type=Message.CLIENT_STATE_LIST_REQUEST,
            msg_content=client_request,
            resp_proto=client.ClientStateListResponse,
        )

    @asyncio.coroutine
    def state_get(self, request):
        # CLIENT_STATE_GET_REQUEST
        nonleaf_msg = 'Expected a specific leaf address, ' \
                      'but received a prefix instead'

        root = request.match_info.get("merkle_root", "")
        addr = request.match_info.get("address", "")
        client_request = client.ClientStateGetRequest(merkle_root=root,
                                                      address=addr)

        validator_response = self._try_validator_request(
            Message.CLIENT_STATE_GET_REQUEST,
            client_request
        )

        parsed_response = self._try_response_parse(
            client.ClientStateGetResponse,
            validator_response
        )

        if parsed_response.status == client.ClientStateGetResponse.NONLEAF:
            raise web.HTTPBadRequest(reason=nonleaf_msg)

        return self._try_client_response(
            request.headers,
            parsed_response
        )
示例#10
0
 def __init__(self, loop, stream_url, timeout=DEFAULT_TIMEOUT):
     loop.set_default_executor(ThreadPoolExecutor())
     self._loop = loop
     self._stream = Stream(stream_url)
     self._timeout = timeout
示例#11
0
class RouteHandler(object):
    """Contains a number of aiohttp handlers for endpoints in the Rest Api.

    Each handler takes an aiohttp Request object, and uses the data in
    that request to send Protobuf message to a validator. The Protobuf response
    is then parsed, and finally an aiohttp Response object is sent back
    to the client with JSON formatted data and metadata.

    If something goes wrong, an aiohttp HTTP exception is raised or returned
    instead.

    Args:
        stream_url (str): The TCP url to communitcate with the validator
        timeout (int, optional): The time in seconds before the Api should
            cancel a request and report that the validator is unavailable.
    """
    def __init__(self, loop, stream_url, timeout=DEFAULT_TIMEOUT):
        loop.set_default_executor(ThreadPoolExecutor())
        self._loop = loop
        self._stream = Stream(stream_url)
        self._timeout = timeout

    async def submit_batches(self, request):
        """Accepts a binary encoded BatchList and submits it to the validator.

        Request:
            body: octet-stream BatchList of one or more Batches
            query:
                - wait: Request should not return until all batches committed

        Response:
            status:
                 - 200: Batches submitted, but wait timed out before committed
                 - 201: All batches submitted and committed
                 - 202: Batches submitted and pending (not told to wait)
            data: Status of uncommitted batches (if any, when told to wait)
            link: /batches or /batch_status link for submitted batches

        """
        # Parse request
        if request.headers['Content-Type'] != 'application/octet-stream':
            return errors.WrongBodyType()

        payload = await request.read()
        if not payload:
            return errors.EmptyProtobuf()

        try:
            batch_list = BatchList()
            batch_list.ParseFromString(payload)
        except DecodeError:
            return errors.BadProtobuf()

        # Query validator
        error_traps = [error_handlers.InvalidBatch()]
        validator_query = client_pb2.ClientBatchSubmitRequest(
            batches=batch_list.batches)
        self._set_wait(request, validator_query)

        response = await self._query_validator(
            Message.CLIENT_BATCH_SUBMIT_REQUEST,
            client_pb2.ClientBatchSubmitResponse, validator_query, error_traps)

        # Build response envelope
        data = response['batch_statuses'] or None
        link = '{}://{}/batch_status?id={}'.format(
            request.scheme, request.host,
            ','.join(b.header_signature for b in batch_list.batches))

        if data is None:
            status = 202
        elif any(s != 'COMMITTED' for _, s in data.items()):
            status = 200
        else:
            status = 201
            data = None
            link = link.replace('batch_status', 'batches')

        return self._wrap_response(data=data,
                                   metadata={'link': link},
                                   status=status)

    async def list_statuses(self, request):
        """Fetches the committed status of batches by either a POST or GET.

        Request:
            body: A JSON array of one or more id strings (if POST)
            query:
                - id: A comma separated list of up to 15 ids (if GET)
                - wait: Request should not return until all batches committed

        Response:
            data: A JSON object, with batch ids as keys, and statuses as values
            link: The /batch_status link queried (if GET)
        """
        error_traps = [error_handlers.StatusesNotReturned()]

        # Parse batch ids from POST body, or query paramaters
        if request.method == 'POST':
            if request.headers['Content-Type'] != 'application/json':
                return errors.BadStatusBody()

            ids = await request.json()

            if not isinstance(ids, list):
                return errors.BadStatusBody()
            if len(ids) == 0:
                return errors.MissingStatusId()
            if not isinstance(ids[0], str):
                return errors.BadStatusBody()

        else:
            try:
                ids = request.url.query['id'].split(',')
            except KeyError:
                return errors.MissingStatusId()

        # Query validator
        validator_query = client_pb2.ClientBatchStatusRequest(batch_ids=ids)
        self._set_wait(request, validator_query)

        response = await self._query_validator(
            Message.CLIENT_BATCH_STATUS_REQUEST,
            client_pb2.ClientBatchStatusResponse, validator_query, error_traps)

        # Send response
        if request.method != 'POST':
            metadata = self._get_metadata(request, response)
        else:
            metadata = None

        return self._wrap_response(data=response.get('batch_statuses'),
                                   metadata=metadata)

    async def list_state(self, request):
        """Fetches list of data leaves, optionally filtered by address prefix.

        Request:
            query:
                - head: The id of the block to use as the head of the chain
                - address: Return leaves whose addresses begin with this prefix

        Response:
            data: An array of leaf objects with address and data keys
            head: The head used for this query (most recent if unspecified)
            link: The link to this exact query, including head block
            paging: Paging info and nav, like total resources and a next link
        """
        paging_controls = self._get_paging_controls(request)
        validator_query = client_pb2.ClientStateListRequest(
            head_id=request.url.query.get('head', None),
            address=request.url.query.get('address', None),
            paging=self._make_paging_message(paging_controls))

        response = await self._query_validator(
            Message.CLIENT_STATE_LIST_REQUEST,
            client_pb2.ClientStateListResponse, validator_query)

        return self._wrap_paginated_response(request=request,
                                             response=response,
                                             controls=paging_controls,
                                             data=response.get('leaves', []))

    async def fetch_state(self, request):
        """Fetches data from a specific address in the validator's state tree.

        Request:
            query:
                - head: The id of the block to use as the head of the chain
                - address: The 70 character address of the data to be fetched

        Response:
            data: The base64 encoded binary data stored at that address
            head: The head used for this query (most recent if unspecified)
            link: The link to this exact query, including head block
        """
        error_traps = [
            error_handlers.MissingLeaf(),
            error_handlers.BadAddress()
        ]

        address = request.match_info.get('address', '')
        head = request.url.query.get('head', None)

        response = await self._query_validator(
            Message.CLIENT_STATE_GET_REQUEST,
            client_pb2.ClientStateGetResponse,
            client_pb2.ClientStateGetRequest(head_id=head,
                                             address=address), error_traps)

        return self._wrap_response(data=response['value'],
                                   metadata=self._get_metadata(
                                       request, response))

    async def list_blocks(self, request):
        """Fetches list of blocks from validator, optionally filtered by id.

        Request:
            query:
                - head: The id of the block to use as the head of the chain
                - id: Comma separated list of block ids to include in results

        Response:
            data: JSON array of fully expanded Block objects
            head: The head used for this query (most recent if unspecified)
            link: The link to this exact query, including head block
            paging: Paging info and nav, like total resources and a next link
        """
        paging_controls = self._get_paging_controls(request)
        validator_query = client_pb2.ClientBlockListRequest(
            head_id=request.url.query.get('head', None),
            block_ids=self._get_filter_ids(request),
            paging=self._make_paging_message(paging_controls))

        response = await self._query_validator(
            Message.CLIENT_BLOCK_LIST_REQUEST,
            client_pb2.ClientBlockListResponse, validator_query)

        return self._wrap_paginated_response(
            request=request,
            response=response,
            controls=paging_controls,
            data=[self._expand_block(b) for b in response['blocks']])

    async def fetch_block(self, request):
        """Fetches a specific block from the validator, specified by id.
        Request:
            path:
                - block_id: The 128-character id of the block to be fetched

        Response:
            data: A JSON object with the data from the fully expanded Block
            link: The link to this exact query
        """
        error_traps = [
            error_handlers.MissingBlock(),
            error_handlers.InvalidBlockId()
        ]

        block_id = request.match_info.get('block_id', '')

        response = await self._query_validator(
            Message.CLIENT_BLOCK_GET_REQUEST,
            client_pb2.ClientBlockGetResponse,
            client_pb2.ClientBlockGetRequest(block_id=block_id), error_traps)

        return self._wrap_response(data=self._expand_block(response['block']),
                                   metadata=self._get_metadata(
                                       request, response))

    async def list_batches(self, request):
        """Fetches list of batches from validator, optionally filtered by id.

        Request:
            query:
                - head: The id of the block to use as the head of the chain
                - id: Comma separated list of batch ids to include in results

        Response:
            data: JSON array of fully expanded Batch objects
            head: The head used for this query (most recent if unspecified)
            link: The link to this exact query, including head block
            paging: Paging info and nav, like total resources and a next link
        """
        paging_controls = self._get_paging_controls(request)
        validator_query = client_pb2.ClientBatchListRequest(
            head_id=request.url.query.get('head', None),
            batch_ids=self._get_filter_ids(request),
            paging=self._make_paging_message(paging_controls))

        response = await self._query_validator(
            Message.CLIENT_BATCH_LIST_REQUEST,
            client_pb2.ClientBatchListResponse, validator_query)

        return self._wrap_paginated_response(
            request=request,
            response=response,
            controls=paging_controls,
            data=[self._expand_batch(b) for b in response['batches']])

    async def fetch_batch(self, request):
        """Fetches a specific batch from the validator, specified by id.
        Request:
            path:
                - batch_id: The 128-character id of the block to be fetched

        Response:
            data: A JSON object with the data from the fully expanded Batch
            link: The link to this exact query
        """
        error_traps = [
            error_handlers.MissingBatch(),
            error_handlers.InvalidBatchId()
        ]

        batch_id = request.match_info.get('batch_id', '')

        response = await self._query_validator(
            Message.CLIENT_BATCH_GET_REQUEST,
            client_pb2.ClientBatchGetResponse,
            client_pb2.ClientBatchGetRequest(batch_id=batch_id), error_traps)

        return self._wrap_response(data=self._expand_batch(response['batch']),
                                   metadata=self._get_metadata(
                                       request, response))

    async def _query_validator(self,
                               request_type,
                               response_proto,
                               content,
                               traps=None):
        """Sends a request to the validator and parses the response.
        """
        response = await self._try_validator_request(request_type, content)
        return self._try_response_parse(response_proto, response, traps)

    async def _try_validator_request(self, message_type, content):
        """Serializes and sends a Protobuf message to the validator.
        Handles timeout errors as needed.
        """
        if isinstance(content, BaseMessage):
            content = content.SerializeToString()

        future = self._stream.send(message_type=message_type, content=content)

        try:
            response = await self._loop.run_in_executor(
                None, future.result, self._timeout)
        except FutureTimeoutError:
            raise errors.ValidatorUnavailable()

        try:
            return response.content
        # Caused by resolving a FutureError on validator disconnect
        except ValidatorConnectionError:
            raise errors.ValidatorUnavailable()

    @classmethod
    def _try_response_parse(cls, proto, response, traps=None):
        """Parses the Protobuf response from the validator.
        Uses "error traps" to send back any HTTP error triggered by a Protobuf
        status, both those common to many handlers, and specified individually.
        """
        parsed = proto()
        parsed.ParseFromString(response)
        traps = traps or []

        try:
            traps.append(error_handlers.Unknown(proto.INTERNAL_ERROR))
        except AttributeError:
            # Not every protobuf has every status enum, so pass AttributeErrors
            pass

        try:
            traps.append(error_handlers.NotReady(proto.NOT_READY))
        except AttributeError:
            pass

        try:
            traps.append(error_handlers.MissingHead(proto.NO_ROOT))
        except AttributeError:
            pass

        try:
            traps.append(error_handlers.InvalidPaging(proto.INVALID_PAGING))
        except AttributeError:
            pass

        for trap in traps:
            trap.check(parsed.status)

        return cls.message_to_dict(parsed)

    @staticmethod
    def _wrap_response(data=None, metadata=None, status=200):
        """Creates the JSON response envelope to be sent back to the client.
        """
        envelope = metadata or {}

        if data is not None:
            envelope['data'] = data

        return web.Response(status=status,
                            content_type='application/json',
                            text=json.dumps(envelope,
                                            indent=2,
                                            separators=(',', ': '),
                                            sort_keys=True))

    @classmethod
    def _wrap_paginated_response(cls, request, response, controls, data):
        """Builds the metadata for a pagingated response and wraps everying in
        a JSON encoded web.Response
        """
        head = response['head_id']
        link = cls._build_url(request, head)

        paging_response = response['paging']
        total = paging_response['total_resources']
        paging = {'total_count': total}

        # If there are no resources, there should be nothing else in paging
        if total == 0:
            return cls._wrap_response(data=data,
                                      metadata={
                                          'head': head,
                                          'link': link,
                                          'paging': paging
                                      })

        count = controls.get('count', len(data))
        start = paging_response['start_index']
        paging['start_index'] = start

        # Builds paging urls specific to this response
        def build_pg_url(min_pos=None, max_pos=None):
            return cls._build_url(request, head, count, min_pos, max_pos)

        # Build paging urls based on ids
        if 'start_id' in controls or 'end_id' in controls:
            if paging_response['next_id']:
                paging['next'] = build_pg_url(paging_response['next_id'])
            if paging_response['previous_id']:
                paging['previous'] = build_pg_url(
                    max_pos=paging_response['previous_id'])

        # Build paging urls based on indexes
        else:
            end_index = controls.get('end_index', None)
            if end_index is None and start + count < total:
                paging['next'] = build_pg_url(start + count)
            elif end_index is not None and end_index + 1 < total:
                paging['next'] = build_pg_url(end_index + 1)
            if start - count >= 0:
                paging['previous'] = build_pg_url(start - count)

        return cls._wrap_response(data=data,
                                  metadata={
                                      'head': head,
                                      'link': link,
                                      'paging': paging
                                  })

    @classmethod
    def _get_metadata(cls, request, response):
        """Parses out the head and link properties based on the HTTP Request
        from the client, and the Protobuf response from the validator.
        """
        head = response.get('head_id', None)
        metadata = {'link': cls._build_url(request, head)}

        if head is not None:
            metadata['head'] = head
        return metadata

    @classmethod
    def _build_url(cls,
                   request,
                   head=None,
                   count=None,
                   min_pos=None,
                   max_pos=None):
        """Builds a response URL to send back in response envelope.
        """
        query = request.url.query.copy()

        if head is not None:
            url = '{}://{}{}?head={}'.format(request.scheme, request.host,
                                             request.path, head)
            query.pop('head', None)
        else:
            return str(request.url)

        if min_pos is not None:
            url += '&{}={}'.format('min', min_pos)
        elif max_pos is not None:
            url += '&{}={}'.format('max', max_pos)
        else:
            queries = ['{}={}'.format(k, v) for k, v in query.items()]
            return url + '&' + '&'.join(queries) if queries else url

        url += '&{}={}'.format('count', count)
        query.pop('min', None)
        query.pop('max', None)
        query.pop('count', None)

        queries = ['{}={}'.format(k, v) for k, v in query.items()]
        return url + '&' + '&'.join(queries) if queries else url

    @classmethod
    def _expand_block(cls, block):
        """Deserializes a Block's header, and the header of its Batches.
        """
        cls._parse_header(BlockHeader, block)
        if 'batches' in block:
            block['batches'] = [cls._expand_batch(b) for b in block['batches']]
        return block

    @classmethod
    def _expand_batch(cls, batch):
        """Deserializes a Batch's header, and the header of its Transactions.
        """
        cls._parse_header(BatchHeader, batch)
        if 'transactions' in batch:
            batch['transactions'] = [
                cls._expand_transaction(t) for t in batch['transactions']
            ]
        return batch

    @classmethod
    def _expand_transaction(cls, transaction):
        """Deserializes a Transaction's header.
        """
        return cls._parse_header(TransactionHeader, transaction)

    @classmethod
    def _parse_header(cls, header_proto, obj):
        """Deserializes a base64 encoded Protobuf header.
        """
        header = header_proto()
        header_bytes = base64.b64decode(obj['header'])
        header.ParseFromString(header_bytes)
        obj['header'] = cls.message_to_dict(header)
        return obj

    @staticmethod
    def _get_paging_controls(request):
        """Parses min, max, and/or count queries into A paging controls dict.
        """
        min_pos = request.url.query.get('min', None)
        max_pos = request.url.query.get('max', None)
        count = request.url.query.get('count', None)
        controls = {}

        if count == '0':
            raise errors.BadCount()
        elif count is not None:
            try:
                controls['count'] = int(count)
            except ValueError:
                raise errors.BadCount()

        if min_pos is not None:
            try:
                controls['start_index'] = int(min_pos)
            except ValueError:
                controls['start_id'] = min_pos

        elif max_pos is not None:
            try:
                controls['end_index'] = int(max_pos)
            except ValueError:
                controls['end_id'] = max_pos

        return controls

    @staticmethod
    def _make_paging_message(controls):
        """Turns a raw paging controls dict into Protobuf PagingControls.
        """
        count = controls.get('count', None)
        end_index = controls.get('end_index', None)

        # an end_index must be changed to start_index, possibly modifying count
        if end_index is not None:
            if count is None:
                start_index = 0
                count = end_index
            elif count > end_index + 1:
                start_index = 0
                count = end_index + 1
            else:
                start_index = end_index + 1 - count
        else:
            start_index = controls.get('start_index', None)

        return client_pb2.PagingControls(start_id=controls.get(
            'start_id', None),
                                         end_id=controls.get('end_id', None),
                                         start_index=start_index,
                                         count=count)

    def _set_wait(self, request, validator_query):
        """Parses the `wait` query parameter, and sets the corresponding
        `wait_for_commit` and `timeout` properties in the validator query.
        """
        wait = request.url.query.get('wait', 'false')
        if wait.lower() != 'false':
            validator_query.wait_for_commit = True
            try:
                validator_query.timeout = int(wait)
            except ValueError:
                # By default, waits for 95% of REST API's configured timeout
                validator_query.timeout = int(self._timeout * 0.95)

    @staticmethod
    def _get_filter_ids(request):
        """Parses the `id` filter paramter from the url query.
        """
        filter_ids = request.url.query.get('id', None)
        return filter_ids and filter_ids.split(',')

    @staticmethod
    def message_to_dict(message):
        """Converts a Protobuf object to a python dict with desired settings.
        """
        return MessageToDict(message,
                             including_default_value_fields=True,
                             preserving_proto_field_name=True)
示例#12
0
 def __init__(self, stream_url, timeout=300):
     self._stream = Stream(stream_url)
     self._timeout = timeout
示例#13
0
 def __init__(self, stream_url):
     self._stream = Stream(stream_url)
示例#14
0
class RouteHandler(object):
    def __init__(self, stream_url, timeout=300):
        self._stream = Stream(stream_url)
        self._timeout = timeout

    @asyncio.coroutine
    def batches_post(self, request):
        """
        Takes protobuf binary from HTTP POST, and sends it to the validator
        """
        if request.headers['Content-Type'] != 'application/octet-stream':
            return errors.WrongBodyType()

        payload = yield from request.read()
        validator_response = self._try_validator_request(
            Message.CLIENT_BATCH_SUBMIT_REQUEST, payload)

        response = client.ClientBatchSubmitResponse()
        response.ParseFromString(validator_response)
        return RouteHandler._try_client_response(request.headers, response)

    @asyncio.coroutine
    def status_list(self, request):
        error_traps = [error_handlers.MissingStatus()]

        try:
            batch_ids = request.url.query['id'].split(',')
        except KeyError:
            return errors.MissingStatusId()

        response = self._query_validator(
            Message.CLIENT_BATCH_STATUS_REQUEST,
            client.ClientBatchStatusResponse,
            client.ClientBatchStatusRequest(batch_ids=batch_ids), error_traps)

        return RouteHandler._wrap_response(data=response.get('batch_statuses'),
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def state_list(self, request):
        """
        Fetch a list of data leaves from the validator's state merkle-tree
        """
        head = request.url.query.get('head', '')
        address = request.url.query.get('address', '')

        response = self._query_validator(
            Message.CLIENT_STATE_LIST_REQUEST, client.ClientStateListResponse,
            client.ClientStateListRequest(head_id=head, address=address))

        return RouteHandler._wrap_response(data=response.get('leaves', []),
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def state_get(self, request):
        """
        Fetch a specific data leaf from the validator's state merkle-tree
        """
        error_traps = [
            error_handlers.MissingLeaf(),
            error_handlers.BadAddress()
        ]

        address = request.match_info.get('address', '')
        head = request.url.query.get('head', '')

        response = self._query_validator(
            Message.CLIENT_STATE_GET_REQUEST, client.ClientStateGetResponse,
            client.ClientStateGetRequest(head_id=head, address=address),
            error_traps)

        return RouteHandler._wrap_response(data=response['value'],
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def block_list(self, request):
        """
        Fetch a list of blocks from the validator
        """
        head = request.url.query.get('head', '')

        response = self._query_validator(
            Message.CLIENT_BLOCK_LIST_REQUEST, client.ClientBlockListResponse,
            client.ClientBlockListRequest(head_id=head))

        blocks = [RouteHandler._expand_block(b) for b in response['blocks']]
        return RouteHandler._wrap_response(data=blocks,
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def block_get(self, request):
        """
        Fetch a list of blocks from the validator
        """
        error_traps = [error_handlers.MissingBlock()]
        block_id = request.match_info.get('block_id', '')

        response = self._query_validator(
            Message.CLIENT_BLOCK_GET_REQUEST, client.ClientBlockGetResponse,
            client.ClientBlockGetRequest(block_id=block_id), error_traps)

        return RouteHandler._wrap_response(
            data=RouteHandler._expand_block(response['block']),
            metadata=RouteHandler._get_metadata(request, response))

    def _query_validator(self, req_type, resp_proto, content, traps=None):
        """
        Sends a request to the validator and parses the response
        """
        response = self._try_validator_request(req_type, content)
        return RouteHandler._try_response_parse(resp_proto, response, traps)

    def _try_validator_request(self, message_type, content):
        """
        Sends a protobuf message to the validator
        Handles a possible timeout if validator is unresponsive
        """
        if isinstance(content, BaseMessage):
            content = content.SerializeToString()

        future = self._stream.send(message_type=message_type, content=content)

        try:
            response = future.result(timeout=self._timeout)
        except FutureTimeoutError:
            raise errors.ValidatorUnavailable()

        try:
            return response.content
            # the error is caused by resolving a FutureError
            # on validator disconnect.
        except ValidatorConnectionError:
            raise errors.ValidatorUnavailable()

    @staticmethod
    def _try_response_parse(proto, response, traps=None):
        """
        Parses a protobuf response from the validator
        Raises common validator error statuses as HTTP errors
        """
        parsed = proto()
        parsed.ParseFromString(response)
        traps = traps or []

        try:
            traps.append(error_handlers.Unknown(proto.INTERNAL_ERROR))
        except AttributeError:
            # Not every protobuf has every status enum, so pass AttributeErrors
            pass
        try:
            traps.append(error_handlers.NotReady(proto.NOT_READY))
        except AttributeError:
            pass
        try:
            traps.append(error_handlers.MissingHead(proto.NO_ROOT))
        except AttributeError:
            pass

        for trap in traps:
            trap.check(parsed.status)

        return MessageToDict(parsed, preserving_proto_field_name=True)

    @staticmethod
    def _wrap_response(data=None, metadata=None):
        """
        Creates a JSON response envelope and sends it back to the client
        """
        envelope = metadata or {}

        if data is not None:
            envelope['data'] = data

        return web.Response(content_type='application/json',
                            text=json.dumps(envelope,
                                            indent=2,
                                            separators=(',', ': '),
                                            sort_keys=True))

    @staticmethod
    def _get_metadata(request, response):
        head = response.get('head_id', None)
        if not head:
            return {'link': str(request.url)}

        link = '{}://{}{}?head={}'.format(request.scheme, request.host,
                                          request.path, head)

        headless = filter(lambda i: i[0] != 'head', request.url.query.items())
        queries = ['{}={}'.format(k, v) for k, v in headless]
        if len(queries) > 0:
            link += '&' + '&'.join(queries)

        return {'head': head, 'link': link}

    @staticmethod
    def _expand_block(block):
        RouteHandler._parse_header(BlockHeader, block)
        if 'batches' in block:
            block['batches'] = [
                RouteHandler._expand_batch(b) for b in block['batches']
            ]
        return block

    @staticmethod
    def _expand_batch(batch):
        RouteHandler._parse_header(BatchHeader, batch)
        if 'transactions' in batch:
            batch['transactions'] = [
                RouteHandler._expand_transaction(t)
                for t in batch['transactions']
            ]
        return batch

    @staticmethod
    def _expand_transaction(transaction):
        return RouteHandler._parse_header(TransactionHeader, transaction)

    @staticmethod
    def _parse_header(header_proto, obj):
        """
        A helper method to parse a byte string encoded protobuf 'header'
        Args:
            header_proto: The protobuf class of the encoded header
            obj: The dict formatted object containing the 'header'
        """
        header = header_proto()
        header_bytes = base64.b64decode(obj['header'])
        header.ParseFromString(header_bytes)
        obj['header'] = MessageToDict(header, preserving_proto_field_name=True)
        return obj

    @staticmethod
    def _try_client_response(headers, parsed):
        """
        Used by pre-spec /batches route,
        Should be removed when updated to spec
        """
        media_msg = 'The requested media type is unsupported'
        mime_type = None
        sub_type = None

        try:
            accept_types = headers['Accept']
            mime_type, sub_type, _, _ = parse_mimetype(accept_types)
        except KeyError:
            pass

        if mime_type == 'application' and sub_type == 'octet-stream':
            return web.Response(content_type='application/octet-stream',
                                body=parsed.SerializeToString())

        if ((mime_type in ['application', '*'] or mime_type is None)
                and (sub_type in ['json', '*'] or sub_type is None)):
            return web.Response(content_type='application/json',
                                text=MessageToJson(parsed))

        raise web.HTTPUnsupportedMediaType(reason=media_msg)
示例#15
0
 def __init__(self, url):
     self._stream = Stream(url)
     self._handlers = []
     self._stop = False
示例#16
0
class TransactionProcessor(object):
    def __init__(self, url):
        self._stream = Stream(url)
        self._handlers = []
        self._stop = False

    def add_handler(self, handler):
        self._handlers.append(handler)

    def start(self):
        futures = []
        for handler in self._handlers:
            for version in handler.family_versions:
                for encoding in handler.encodings:
                    future = self._stream.send(
                        message_type=Message.TP_REGISTER_REQUEST,
                        content=TpRegisterRequest(
                            family=handler.family_name,
                            version=version,
                            encoding=encoding,
                            namespaces=handler.namespaces).SerializeToString())
                    futures.append(future)

        for future in futures:
            LOGGER.debug("future result: %s", repr(future.result))

        while True:
            if self._stop:
                break
            msg = self._stream.receive()
            LOGGER.debug(
                'received message of type: %s',
                Message.MessageType.Name(msg.message_type))

            request = TpProcessRequest()
            request.ParseFromString(msg.content)
            state = State(self._stream, request.context_id)

            try:
                self._handlers[0].apply(request, state)
                self._stream.send_back(
                    message_type=Message.TP_PROCESS_RESPONSE,
                    correlation_id=msg.correlation_id,
                    content=TpProcessResponse(
                        status=TpProcessResponse.OK
                    ).SerializeToString())
            except InvalidTransaction as it:
                LOGGER.warning("Invalid Transaction %s", it)
                self._stream.send_back(
                    message_type=Message.TP_PROCESS_RESPONSE,
                    correlation_id=msg.correlation_id,
                    content=TpProcessResponse(
                        status=TpProcessResponse.INVALID_TRANSACTION
                    ).SerializeToString())
            except InternalError as ie:
                LOGGER.warning("State Error! %s", ie)
                self._stream.send_back(
                    message_type=Message.TP_PROCESS_RESPONSE,
                    correlation_id=msg.correlation_id,
                    content=TpProcessResponse(
                        status=TpProcessResponse.INTERNAL_ERROR
                    ).SerializeToString())

    def stop(self):
        self._stop = True
示例#17
0
 def __init__(self, url):
     self._stream = Stream(url)
     self._handlers = []
     self._stop = False
示例#18
0
 def on_validator_discovered(self, url):
     stream = Stream(url)
     self._streams.append(stream)
示例#19
0
class RouteHandler(object):
    def __init__(self, stream_url, timeout=300):
        self._stream = Stream(stream_url)
        self._timeout = timeout

    @asyncio.coroutine
    def batches_post(self, request):
        """
        Takes protobuf binary from HTTP POST, and sends it to the validator
        """
        # Parse request
        if request.headers['Content-Type'] != 'application/octet-stream':
            return errors.WrongBodyType()

        payload = yield from request.read()
        if not payload:
            return errors.EmptyProtobuf()

        try:
            batch_list = BatchList()
            batch_list.ParseFromString(payload)
        except DecodeError:
            return errors.BadProtobuf()

        # Query validator
        error_traps = [error_handlers.InvalidBatch()]
        validator_query = client.ClientBatchSubmitRequest(
            batches=batch_list.batches)
        self._set_wait(request, validator_query)

        response = self._query_validator(Message.CLIENT_BATCH_SUBMIT_REQUEST,
                                         client.ClientBatchSubmitResponse,
                                         validator_query, error_traps)

        # Build response
        data = response.get('batch_statuses', None)
        metadata = {
            'link':
            '{}://{}/batch_status?id={}'.format(
                request.scheme, request.host,
                ','.join(b.header_signature for b in batch_list.batches))
        }

        if data is None:
            status = 202
        elif any(s != 'COMMITTED' for _, s in data.items()):
            status = 200
        else:
            status = 201
            data = None
            # Replace with /batches link when implemented
            metadata = None

        return RouteHandler._wrap_response(data=data,
                                           metadata=metadata,
                                           status=status)

    @asyncio.coroutine
    def status_list(self, request):
        """
        Fetches the status of a set of batches submitted to the validator
        Will wait for batches to commit if the `wait` parameter is set
        """
        error_traps = [error_handlers.MissingStatus()]

        try:
            batch_ids = request.url.query['id'].split(',')
        except KeyError:
            return errors.MissingStatusId()

        validator_query = client.ClientBatchStatusRequest(batch_ids=batch_ids)
        self._set_wait(request, validator_query)

        response = self._query_validator(Message.CLIENT_BATCH_STATUS_REQUEST,
                                         client.ClientBatchStatusResponse,
                                         validator_query, error_traps)

        return RouteHandler._wrap_response(data=response.get('batch_statuses'),
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def state_list(self, request):
        """
        Fetch a list of data leaves from the validator's state merkle-tree
        """
        head = request.url.query.get('head', '')
        address = request.url.query.get('address', '')

        response = self._query_validator(
            Message.CLIENT_STATE_LIST_REQUEST, client.ClientStateListResponse,
            client.ClientStateListRequest(head_id=head, address=address))

        return RouteHandler._wrap_response(data=response.get('leaves', []),
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def state_get(self, request):
        """
        Fetch a specific data leaf from the validator's state merkle-tree
        """
        error_traps = [
            error_handlers.MissingLeaf(),
            error_handlers.BadAddress()
        ]

        address = request.match_info.get('address', '')
        head = request.url.query.get('head', '')

        response = self._query_validator(
            Message.CLIENT_STATE_GET_REQUEST, client.ClientStateGetResponse,
            client.ClientStateGetRequest(head_id=head, address=address),
            error_traps)

        return RouteHandler._wrap_response(data=response['value'],
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def block_list(self, request):
        """
        Fetch a list of blocks from the validator
        """
        head = request.url.query.get('head', '')

        response = self._query_validator(
            Message.CLIENT_BLOCK_LIST_REQUEST, client.ClientBlockListResponse,
            client.ClientBlockListRequest(head_id=head))

        blocks = [RouteHandler._expand_block(b) for b in response['blocks']]
        return RouteHandler._wrap_response(data=blocks,
                                           metadata=RouteHandler._get_metadata(
                                               request, response))

    @asyncio.coroutine
    def block_get(self, request):
        """
        Fetch a list of blocks from the validator
        """
        error_traps = [error_handlers.MissingBlock()]
        block_id = request.match_info.get('block_id', '')

        response = self._query_validator(
            Message.CLIENT_BLOCK_GET_REQUEST, client.ClientBlockGetResponse,
            client.ClientBlockGetRequest(block_id=block_id), error_traps)

        return RouteHandler._wrap_response(
            data=RouteHandler._expand_block(response['block']),
            metadata=RouteHandler._get_metadata(request, response))

    def _query_validator(self, req_type, resp_proto, content, traps=None):
        """
        Sends a request to the validator and parses the response
        """
        response = self._try_validator_request(req_type, content)
        return RouteHandler._try_response_parse(resp_proto, response, traps)

    def _try_validator_request(self, message_type, content):
        """
        Sends a protobuf message to the validator
        Handles a possible timeout if validator is unresponsive
        """
        if isinstance(content, BaseMessage):
            content = content.SerializeToString()

        future = self._stream.send(message_type=message_type, content=content)

        try:
            response = future.result(timeout=self._timeout)
        except FutureTimeoutError:
            raise errors.ValidatorUnavailable()

        try:
            return response.content
            # the error is caused by resolving a FutureError
            # on validator disconnect.
        except ValidatorConnectionError:
            raise errors.ValidatorUnavailable()

    @staticmethod
    def _try_response_parse(proto, response, traps=None):
        """
        Parses a protobuf response from the validator
        Raises common validator error statuses as HTTP errors
        """
        parsed = proto()
        parsed.ParseFromString(response)
        traps = traps or []

        try:
            traps.append(error_handlers.Unknown(proto.INTERNAL_ERROR))
        except AttributeError:
            # Not every protobuf has every status enum, so pass AttributeErrors
            pass
        try:
            traps.append(error_handlers.NotReady(proto.NOT_READY))
        except AttributeError:
            pass
        try:
            traps.append(error_handlers.MissingHead(proto.NO_ROOT))
        except AttributeError:
            pass

        for trap in traps:
            trap.check(parsed.status)

        return MessageToDict(
            parsed,
            including_default_value_fields=True,
            preserving_proto_field_name=True,
        )

    @staticmethod
    def _wrap_response(data=None, metadata=None, status=200):
        """
        Creates a JSON response envelope and sends it back to the client
        """
        envelope = metadata or {}

        if data is not None:
            envelope['data'] = data

        return web.Response(status=status,
                            content_type='application/json',
                            text=json.dumps(envelope,
                                            indent=2,
                                            separators=(',', ': '),
                                            sort_keys=True))

    @staticmethod
    def _get_metadata(request, response):
        head = response.get('head_id', None)
        if not head:
            return {'link': str(request.url)}

        link = '{}://{}{}?head={}'.format(request.scheme, request.host,
                                          request.path, head)

        headless = filter(lambda i: i[0] != 'head', request.url.query.items())
        queries = ['{}={}'.format(k, v) for k, v in headless]
        if len(queries) > 0:
            link += '&' + '&'.join(queries)

        return {'head': head, 'link': link}

    def _set_wait(self, request, validator_query):
        """
        Parses the `wait` query parameter, and sets the corresponding
        fields in a validator query
        """
        wait = request.url.query.get('wait', 'false')
        if wait.lower() != 'false':
            validator_query.wait_for_commit = True
            try:
                validator_query.timeout = int(wait)
            except ValueError:
                # By default, waits for 95% of REST API's configured timeout
                validator_query.timeout = int(self._timeout * 0.95)

    @staticmethod
    def _expand_block(block):
        RouteHandler._parse_header(BlockHeader, block)
        if 'batches' in block:
            block['batches'] = [
                RouteHandler._expand_batch(b) for b in block['batches']
            ]
        return block

    @staticmethod
    def _expand_batch(batch):
        RouteHandler._parse_header(BatchHeader, batch)
        if 'transactions' in batch:
            batch['transactions'] = [
                RouteHandler._expand_transaction(t)
                for t in batch['transactions']
            ]
        return batch

    @staticmethod
    def _expand_transaction(transaction):
        return RouteHandler._parse_header(TransactionHeader, transaction)

    @staticmethod
    def _parse_header(header_proto, obj):
        """
        A helper method to parse a byte string encoded protobuf 'header'
        Args:
            header_proto: The protobuf class of the encoded header
            obj: The dict formatted object containing the 'header'
        """
        header = header_proto()
        header_bytes = base64.b64decode(obj['header'])
        header.ParseFromString(header_bytes)
        obj['header'] = MessageToDict(header, preserving_proto_field_name=True)
        return obj