Exemplo n.º 1
0
def test_proto_6():
    """ Tests the make_tensor_proto and make_nd_array function """
    from diplomacy_research.utils.tensorflow import tf
    tensor_1 = make_tensor_proto(0, dtype=np.float32, shape=[1, 0])
    tensor_2 = tf.make_tensor_proto(0, dtype=tf.float32, shape=[1, 0])
    array_1 = tf.make_ndarray(tensor_1)
    array_2 = make_ndarray(tensor_2)
    assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2)
    assert array_1.tostring() == array_2.tostring()
    assert array_1.dtype == array_2.dtype
Exemplo n.º 2
0
def test_proto_7():
    """ Tests the make_tensor_proto and make_nd_array function """
    from diplomacy_research.utils.tensorflow import tf
    random_tensor = np.random.rand(15, 25)
    tensor_1 = make_tensor_proto(random_tensor, dtype=np.float32)
    tensor_2 = tf.make_tensor_proto(random_tensor, dtype=tf.float32)
    array_1 = tf.make_ndarray(tensor_1)
    array_2 = make_ndarray(tensor_2)
    assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2)
    assert array_1.tostring() == array_2.tostring()
    assert array_1.dtype == array_2.dtype
Exemplo n.º 3
0
def test_proto_2():
    """ Tests the make_tensor_proto and make_nd_array function """
    from diplomacy_research.utils.tensorflow import tf
    tensor_1 = make_tensor_proto([bytes('', 'utf-8')],
                                 dtype=np.object,
                                 shape=[1])
    tensor_2 = tf.make_tensor_proto([bytes('', 'utf-8')],
                                    dtype=tf.string,
                                    shape=[1])
    array_1 = tf.make_ndarray(tensor_1)
    array_2 = make_ndarray(tensor_2)
    assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2)
    assert array_1.tostring() == array_2.tostring()
    assert array_1.dtype == array_2.dtype
Exemplo n.º 4
0
    def get_results(self, queue_name, item, retry_on_failure=True, **kwargs):
        """ Computes the outputs of a name using item as inout
            :param queue_name: The name of the queue where to put the item (or model_name/queue_name)
            :param item: A dictionary with the fields required for that queue
            :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
            :return: A tornado.concurrent.Future that will be set with the results when they become available
        """
        import grpc

        if not self.has_queue(queue_name):
            LOGGER.warning('The method "%s" could not be found.', queue_name)
            return None
        if not isinstance(item, dict):
            LOGGER.warning(
                'The item object passed to get_results must be a dictionary.')
            return None

        # Trying to infer the model_name from the queue_name
        model_name = self.model_name
        if '/' in queue_name:
            model_name, queue_name = queue_name.split('/')

        # Preparing the item
        item['request_id'] = bytes('', 'utf-8')
        item = self.prepare_item(item)

        # Building the request
        request = PredictRequest()
        request.model_spec.name = model_name  # pylint: disable=no-member
        request.model_spec.signature_name = queue_name  # pylint: disable=no-member

        # Setting the keys in items
        # Adding a leading batch dimension, so that TF Serving can batch items properly
        # i.e. (dim_1, dim_2) --> (1, dim_1, dim_2)
        for key in item:
            batched_item_key = item[key][None, ...] if isinstance(
                item[key], np.ndarray) else [item[key]]
            request.inputs[key].CopyFrom(
                make_tensor_proto(
                    batched_item_key,  # pylint: disable=no-member
                    dtype=self.dataset_builder.output_types[key]))

        # Setting the placeholders defined in the signature
        placeholders = self.signature[queue_name].get('placeholders', {})
        for ph_name in placeholders:
            ph_value, ph_dtype = placeholders[ph_name]
            request.inputs[ph_name].CopyFrom(
                make_tensor_proto(ph_value, dtype=ph_dtype))  # pylint: disable=no-member

        # Adding generic default values (all zeros)
        for default_key, default_val in self.default_features.items():
            if default_key not in item:
                request.inputs[default_key].CopyFrom(default_val)  # pylint: disable=no-member

        # Sending the request and processing the response
        # Trying for a maximum of self.nb_retries
        for attempt_ix in range(self.nb_retries):
            try:
                grpc_response = yield wrap_grpc_call(
                    self.predict_stub.Predict.future(request,
                                                     timeout=self.timeout))
                return [
                    make_ndarray(grpc_response.outputs[key])[0, ...]
                    for key in sorted(grpc_response.outputs)
                ]
            except grpc.RpcError as grpc_error:
                if not retry_on_failure:
                    raise grpc_error
                yield gen.sleep(30)
            if (attempt_ix + 1) % 10 == 0:
                LOGGER.warning('Unable to get results. Attempt %d/%d',
                               attempt_ix + 1, self.nb_retries)

        # Raising fatal exception
        raise RuntimeError(
            'Unable to get results after %d attempts. Aborting' %
            self.nb_retries)