コード例 #1
0
def test_proto_6():
    """ Tests the make_tensor_proto and make_nd_array function """
    from diplomacy_research.utils.tensorflow import tf
    tensor_1 = make_tensor_proto(0, dtype=np.float32, shape=[1, 0])
    tensor_2 = tf.make_tensor_proto(0, dtype=tf.float32, shape=[1, 0])
    array_1 = tf.make_ndarray(tensor_1)
    array_2 = make_ndarray(tensor_2)
    assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2)
    assert array_1.tostring() == array_2.tostring()
    assert array_1.dtype == array_2.dtype
コード例 #2
0
def test_proto_7():
    """ Tests the make_tensor_proto and make_nd_array function """
    from diplomacy_research.utils.tensorflow import tf
    random_tensor = np.random.rand(15, 25)
    tensor_1 = make_tensor_proto(random_tensor, dtype=np.float32)
    tensor_2 = tf.make_tensor_proto(random_tensor, dtype=tf.float32)
    array_1 = tf.make_ndarray(tensor_1)
    array_2 = make_ndarray(tensor_2)
    assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2)
    assert array_1.tostring() == array_2.tostring()
    assert array_1.dtype == array_2.dtype
コード例 #3
0
def test_proto_2():
    """ Tests the make_tensor_proto and make_nd_array function """
    from diplomacy_research.utils.tensorflow import tf
    tensor_1 = make_tensor_proto([bytes('', 'utf-8')],
                                 dtype=np.object,
                                 shape=[1])
    tensor_2 = tf.make_tensor_proto([bytes('', 'utf-8')],
                                    dtype=tf.string,
                                    shape=[1])
    array_1 = tf.make_ndarray(tensor_1)
    array_2 = make_ndarray(tensor_2)
    assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2)
    assert array_1.tostring() == array_2.tostring()
    assert array_1.dtype == array_2.dtype
コード例 #4
0
ファイル: grpc_dataset.py プロジェクト: zhanpengfang/research
    def build(self):
        """ Builds the channel and the stub """
        import grpc
        from diplomacy_research.proto.tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub

        assert 'request_id' in self.proto_fields, 'You need to have a "request_id" field.'
        assert is_port_opened(
            self.port,
            self.hostname), 'Unable to connect to %s:%d.' % (self.hostname,
                                                             self.port)

        # Creating insecure channel with corresponding stubs
        self.channel = grpc.insecure_channel('%s:%d' %
                                             (self.hostname, self.port))
        self.predict_stub = PredictionServiceStub(self.channel)

        # Padding output shapes with None
        output_types = self.dataset_builder.output_types
        output_shapes = self.dataset_builder.output_shapes
        output_shapes = {
            key: [None] + list(shape)
            for key, shape in output_shapes.items()
        }

        # Building a list of generic default values from the output types and output shapes
        for feature_name, feature_shape in output_shapes.items():
            if output_types[feature_name] == np.object:
                self.default_features[feature_name] = make_tensor_proto(
                    bytes('', 'utf-8'), dtype=np.object, shape=[1])
            elif isinstance(self.proto_fields[feature_name], VarProtoField):
                self.default_features[feature_name] = make_tensor_proto(
                    [], dtype=output_types[feature_name], shape=[1, 0])
            else:
                self.default_features[feature_name] = make_tensor_proto(
                    0,
                    dtype=output_types[feature_name],
                    shape=[1] + feature_shape[1:])
コード例 #5
0
ファイル: grpc_dataset.py プロジェクト: zhanpengfang/research
    def get_results(self, queue_name, item, retry_on_failure=True, **kwargs):
        """ Computes the outputs of a name using item as inout
            :param queue_name: The name of the queue where to put the item (or model_name/queue_name)
            :param item: A dictionary with the fields required for that queue
            :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
            :return: A tornado.concurrent.Future that will be set with the results when they become available
        """
        import grpc

        if not self.has_queue(queue_name):
            LOGGER.warning('The method "%s" could not be found.', queue_name)
            return None
        if not isinstance(item, dict):
            LOGGER.warning(
                'The item object passed to get_results must be a dictionary.')
            return None

        # Trying to infer the model_name from the queue_name
        model_name = self.model_name
        if '/' in queue_name:
            model_name, queue_name = queue_name.split('/')

        # Preparing the item
        item['request_id'] = bytes('', 'utf-8')
        item = self.prepare_item(item)

        # Building the request
        request = PredictRequest()
        request.model_spec.name = model_name  # pylint: disable=no-member
        request.model_spec.signature_name = queue_name  # pylint: disable=no-member

        # Setting the keys in items
        # Adding a leading batch dimension, so that TF Serving can batch items properly
        # i.e. (dim_1, dim_2) --> (1, dim_1, dim_2)
        for key in item:
            batched_item_key = item[key][None, ...] if isinstance(
                item[key], np.ndarray) else [item[key]]
            request.inputs[key].CopyFrom(
                make_tensor_proto(
                    batched_item_key,  # pylint: disable=no-member
                    dtype=self.dataset_builder.output_types[key]))

        # Setting the placeholders defined in the signature
        placeholders = self.signature[queue_name].get('placeholders', {})
        for ph_name in placeholders:
            ph_value, ph_dtype = placeholders[ph_name]
            request.inputs[ph_name].CopyFrom(
                make_tensor_proto(ph_value, dtype=ph_dtype))  # pylint: disable=no-member

        # Adding generic default values (all zeros)
        for default_key, default_val in self.default_features.items():
            if default_key not in item:
                request.inputs[default_key].CopyFrom(default_val)  # pylint: disable=no-member

        # Sending the request and processing the response
        # Trying for a maximum of self.nb_retries
        for attempt_ix in range(self.nb_retries):
            try:
                grpc_response = yield wrap_grpc_call(
                    self.predict_stub.Predict.future(request,
                                                     timeout=self.timeout))
                return [
                    make_ndarray(grpc_response.outputs[key])[0, ...]
                    for key in sorted(grpc_response.outputs)
                ]
            except grpc.RpcError as grpc_error:
                if not retry_on_failure:
                    raise grpc_error
                yield gen.sleep(30)
            if (attempt_ix + 1) % 10 == 0:
                LOGGER.warning('Unable to get results. Attempt %d/%d',
                               attempt_ix + 1, self.nb_retries)

        # Raising fatal exception
        raise RuntimeError(
            'Unable to get results after %d attempts. Aborting' %
            self.nb_retries)