Exemple #1
0
    def _ReadHandshakeResponse(self, decoder):
        """Reads and processes the handshake response message.

    Args:
      decoder: Decoder to read messages from.
    Returns:
      call-response exists (boolean) ???
    Raises:
      schema.AvroException on ???
    """
        handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder)
        logger.info('Processing handshake response: %s', handshake_response)
        match = handshake_response['match']
        if match == 'BOTH':
            # Both client and server protocol hashes match:
            self._send_protocol = False
            return True

        elif match == 'CLIENT':
            # Client's side hash mismatch:
            self._remote_protocol = \
                protocol.Parse(handshake_response['serverProtocol'])
            self._remote_hash = handshake_response['serverHash']
            self._send_protocol = False
            return True

        elif match == 'NONE':
            # Neither client nor server match:
            self._remote_protocol = \
                protocol.Parse(handshake_response['serverProtocol'])
            self._remote_hash = handshake_response['serverHash']
            self._send_protocol = True
            return False
        else:
            raise schema.AvroException('handshake_response.match=%r' % match)
Exemple #2
0
  def testEquivalenceAfterRoundTrip(self):
    """
    1. Given a string, parse it to get Avro protocol "original".
    2. Serialize "original" to a string and parse that string
         to generate Avro protocol "round trip".
    3. Ensure "original" and "round trip" protocols are equivalent.
    """
    num_correct = 0
    for example in VALID_EXAMPLES:
      original_protocol = protocol.Parse(example.protocol_string)
      round_trip_protocol = protocol.Parse(str(original_protocol))

      if original_protocol == round_trip_protocol:
        num_correct += 1
        logging.debug(
            'Successful round-trip for protocol:\n%s',
            example.protocol_string)
      else:
        self.fail(
            'Round-trip failure for protocol:\n%s\nOriginal protocol:\n%s'
            % (example.protocol_string, str(original_protocol)))

    self.assertEqual(
        num_correct,
        len(VALID_EXAMPLES),
        'Round trip success on %d out of %d protocols.'
        % (num_correct, len(VALID_EXAMPLES)))
Exemple #3
0
  def testParse(self):
    correct = 0
    for iexample, example in enumerate(EXAMPLES):
      logging.debug(
          'Parsing protocol #%d:\n%s',
          iexample, example.protocol_string)
      try:
        parsed = protocol.Parse(example.protocol_string)
        if example.valid:
          correct += 1
        else:
          self.fail(
              'Invalid protocol was parsed:\n%s' % example.protocol_string)
      except Exception as exn:
        if example.valid:
          self.fail(
              'Valid protocol failed to parse: %s\n%s'
              % (example.protocol_string, traceback.format_exc()))
        else:
          if logging.getLogger().getEffectiveLevel() <= 5:
            logging.debug('Expected error:\n%s', traceback.format_exc())
          else:
            logging.debug('Expected error: %r', exn)
          correct += 1

    self.assertEqual(
      correct,
      len(EXAMPLES),
      'Parse behavior correct on %d out of %d protocols.'
      % (correct, len(EXAMPLES)))
    def testValidCastToStringAfterParse(self):
        """
        Test that the string generated by an Avro Protocol object is,
        in fact, a valid Avro protocol.
        """
        num_correct = 0
        for example in VALID_EXAMPLES:
            proto = protocol.Parse(example.protocol_string)
            try:
                protocol.Parse(str(proto))
                logging.debug('Successfully reparsed protocol:\n%s',
                              example.protocol_string)
                num_correct += 1
            except:
                logging.debug('Failed to reparse protocol:\n%s',
                              example.protocol_string)

        fail_msg = ('Cast to string success on %d out of %d protocols' %
                    (num_correct, len(VALID_EXAMPLES)))
        self.assertEqual(num_correct, len(VALID_EXAMPLES), fail_msg)
Exemple #5
0
    def _ProcessHandshake(self, decoder, encoder):
        """Processes an RPC handshake.

    Args:
      decoder: Where to read from.
      encoder: Where to write to.
    Returns:
      The requested Protocol.
    """
        handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder)
        logger.info('Processing handshake request: %s', handshake_request)

        # determine the remote protocol
        client_hash = handshake_request.get('clientHash')
        client_protocol = handshake_request.get('clientProtocol')
        remote_protocol = self.get_protocol_cache(client_hash)
        if remote_protocol is None and client_protocol is not None:
            remote_protocol = protocol.Parse(client_protocol)
            self.set_protocol_cache(client_hash, remote_protocol)

        # evaluate remote's guess of the local protocol
        server_hash = handshake_request.get('serverHash')

        handshake_response = {}
        if self._local_hash == server_hash:
            if remote_protocol is None:
                handshake_response['match'] = 'NONE'
            else:
                handshake_response['match'] = 'BOTH'
        else:
            if remote_protocol is None:
                handshake_response['match'] = 'NONE'
            else:
                handshake_response['match'] = 'CLIENT'

        if handshake_response['match'] != 'BOTH':
            handshake_response['serverProtocol'] = str(self.local_protocol)
            handshake_response['serverHash'] = self._local_hash

        logger.info('Handshake response: %s', handshake_response)
        HANDSHAKE_RESPONDER_WRITER.write(handshake_response, encoder)
        return remote_protocol
Exemple #6
0
 def testInnerNamespaceNotRendered(self):
   proto = protocol.Parse(HELLO_WORLD.protocol_string)
   self.assertEqual('com.acme.Greeting', proto.types[0].fullname)
   self.assertEqual('Greeting', proto.types[0].name)
   # but there shouldn't be 'namespace' rendered to json on the inner type
   self.assertFalse('namespace' in proto.to_json()['types'][0])
Exemple #7
0
 def testInnerNamespaceSet(self):
   proto = protocol.Parse(HELLO_WORLD.protocol_string)
   self.assertEqual(proto.namespace, 'com.acme')
   greeting_type = proto.type_map['com.acme.Greeting']
   self.assertEqual(greeting_type.namespace, 'com.acme')
 def __init__(self, proto, msg, datum):
     proto_json = file(proto, 'r').read()
     ipc.Responder.__init__(self, protocol.Parse(proto_json))
     self.msg = msg
     self.datum = datum
def send_message(uri, proto, msg, datum):
    url_obj = urllib.parse.urlparse(uri)
    client = ipc.HTTPTransceiver(url_obj.hostname, url_obj.port)
    proto_json = file(proto, 'r').read()
    requestor = ipc.Requestor(protocol.Parse(proto_json), client)
    print(requestor.request(msg, datum))
Exemple #10
0
def generate_protocol(protocol_json,
                      use_logical_types=False,
                      custom_imports=None,
                      avro_json_converter=None):
    """
    Generate content of the file which will contain concrete classes for RecordSchemas and requests contained
    in the avro protocol
    :param str protocol_json: JSON containing avro protocol
    :param bool use_logical_types: Use logical types extensions if true
    :param list[str] custom_imports: Add additional import modules
    :param str avro_json_converter: AvroJsonConverter type to use for default values
    :return:
    """

    if avro_json_converter is None:
        avro_json_converter = 'avrojson.AvroJsonConverter'

    if '(' not in avro_json_converter:
        avro_json_converter += '(use_logical_types=%s, schema_types=__SCHEMA_TYPES)' % use_logical_types

    custom_imports = custom_imports or []

    if not hasattr(protocol, 'parse'):
        # Older versions of avro used a capital P in Parse.
        proto = protocol.Parse(protocol_json)
    else:
        proto = protocol.parse(protocol_json)

    schemas = []
    messages = []
    schema_names = set()
    request_names = set()

    known_types = set()
    for schema_idx, record_schema in enumerate(proto.types):
        if isinstance(record_schema, (schema.RecordSchema, schema.EnumSchema)):
            schemas.append((schema_idx, record_schema))
            known_types.add(clean_fullname(record_schema.fullname))

    for message in (six.itervalues(proto.messages)
                    if six.PY2 else proto.messages):
        messages.append(
            (message, message.request, message.response
             if isinstance(message.response,
                           (schema.EnumSchema, schema.RecordSchema))
             and clean_fullname(message.response.fullname) not in known_types
             else None))
        if isinstance(message.response,
                      (schema.EnumSchema, schema.RecordSchema)):
            known_types.add(clean_fullname(message.response.fullname))

    namespaces = {}
    for schema_idx, record_schema in schemas:
        ns, name = ns_.split_fullname(clean_fullname(record_schema.fullname))
        if ns not in namespaces:
            namespaces[ns] = {'requests': [], 'records': [], 'responses': []}
        namespaces[ns]['records'].append((schema_idx, record_schema))

    for message, request, response in messages:
        fullname = ns_.make_fullname(proto.namespace,
                                     clean_fullname(message.name))
        ns, name = ns_.split_fullname(fullname)
        if ns not in namespaces:
            namespaces[ns] = {'requests': [], 'records': [], 'responses': []}
        namespaces[ns]['requests'].append(message)
        if response:
            namespaces[ns]['responses'].append(message)

    main_out = StringIO()
    writer = TabbedWriter(main_out)

    write_preamble(writer, use_logical_types, custom_imports)
    write_protocol_preamble(writer, use_logical_types, custom_imports)
    write_get_schema(writer)
    write_populate_schemas(writer)

    writer.write('\n\n\nclass SchemaClasses(object):')
    with writer.indent():
        writer.write('\n\n')

        current_namespace = tuple()
        all_ns = sorted(namespaces.keys())

        for ns in all_ns:
            if not (namespaces[ns]['responses'] or namespaces[ns]['records']):
                continue

            namespace = ns.split('.')
            if namespace != current_namespace:
                start_namespace(current_namespace, namespace, writer)

            for idx, record in namespaces[ns]['records']:
                schema_names.add(clean_fullname(record.fullname))
                if isinstance(record, schema.RecordSchema):
                    write_schema_record(record, writer, use_logical_types)
                elif isinstance(record, schema.EnumSchema):
                    write_enum(record, writer)

            for message in namespaces[ns]['responses']:
                schema_names.add(clean_fullname(message.response.fullname))
                if isinstance(message.response, schema.RecordSchema):
                    write_schema_record(message.response, writer,
                                        use_logical_types)
                elif isinstance(message.response, schema.EnumSchema):
                    write_enum(message.response, writer)

        writer.write('\n\npass')

    writer.set_tab(0)
    writer.write('\n\n\nclass RequestClasses(object):')
    with writer.indent() as indent:
        writer.write('\n\n')

        current_namespace = tuple()
        all_ns = sorted(namespaces.keys())

        for ns in all_ns:
            if not (namespaces[ns]['requests'] or namespaces[ns]['responses']):
                continue

            namespace = ns.split('.')
            if namespace != current_namespace:
                start_namespace(current_namespace, namespace, writer)

            for message in namespaces[ns]['requests']:
                request_names.add(
                    ns_.make_fullname(proto.namespace,
                                      clean_fullname(message.name)))
                write_protocol_request(message, proto.namespace, writer,
                                       use_logical_types)

        writer.write('\n\npass')

    writer.untab()
    writer.set_tab(0)
    writer.write('\n__SCHEMA_TYPES = {\n')
    writer.tab()

    all_ns = sorted(namespaces.keys())
    for ns in all_ns:
        for idx, record in (namespaces[ns]['records'] or []):
            writer.write("'%s': SchemaClasses.%sClass,\n" % (clean_fullname(
                record.fullname), clean_fullname(record.fullname)))

        for message in (namespaces[ns]['responses'] or []):
            writer.write("'%s': SchemaClasses.%sClass,\n" %
                         (clean_fullname(message.response.fullname),
                          clean_fullname(message.response.fullname)))

        for message in (namespaces[ns]['requests'] or []):
            name = ns_.make_fullname(proto.namespace,
                                     clean_fullname(message.name))
            writer.write("'%s': RequestClasses.%sRequestClass, \n" %
                         (name, name))

    writer.untab()
    writer.write('\n}\n')

    writer.write('_json_converter = %s\n\n' % avro_json_converter)
    value = main_out.getvalue()
    main_out.close()
    return value, schema_names, request_names
Exemple #11
0
     }
 ],

 "messages": {
     "send": {
         "request": [{"name": "message", "type": "Message"}],
         "response": "string"
     },
     "replay": {
         "request": [],
         "response": "string"
     }
 }
}
"""
MAIL_PROTOCOL = protocol.Parse(MAIL_PROTOCOL_JSON)
SERVER_ADDRESS = ('localhost', 9090)

class MailResponder(ipc.Responder):
  def __init__(self):
    ipc.Responder.__init__(self, MAIL_PROTOCOL)

  def invoke(self, message, request):
    if message.name == 'send':
      request_content = request['message']
      response = "Sent message to %(to)s from %(from)s with body %(body)s" % \
                 request_content
      return response
    elif message.name == 'replay':
      return 'replay'
Exemple #12
0
    } ]
  } ],
  "messages" : {
    "ping" : {
      "request" : [ {
        "name" : "ping",
        "type" : "Ping"
      } ],
      "response" : "Pong"
    }
  }
}
"""


ECHO_PROTOCOL = protocol.Parse(ECHO_PROTOCOL_JSON)


class EchoResponder(ipc.Responder):
  def __init__(self):
    super(EchoResponder, self).__init__(
        local_protocol=ECHO_PROTOCOL,
    )

  def Invoke(self, message, request):
    logging.info('Message: %s', message)
    logging.info('Request: %s', request)
    ping = request['ping']
    return {'timestamp': NowMS(), 'ping': ping}

#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import http.client

import avro.ipc as ipc
import avro.protocol as protocol

PROTOCOL = protocol.Parse(open("../avro/mail.avpr").read())

server_addr = ('localhost', 9090)


class UsageError(Exception):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


if __name__ == '__main__':
    if len(sys.argv) != 4:
        raise UsageError("Usage: <to> <from> <body>")
def __get_protocol(file_name):
    proto = avro_protocol.Parse(
        __read_file(file_name)) if six.PY3 else avro_protocol.parse(
            __read_file(file_name))
    return proto
Exemple #15
0
 def test_Parse_is_deprecated(self, warn):
   """Capital-P Parse is deprecated."""
   protocol.Parse(HELLO_WORLD.protocol_string)
   self.assertEqual(warn.call_count, 1)