def test_proto_6(): """ Tests the make_tensor_proto and make_nd_array function """ from diplomacy_research.utils.tensorflow import tf tensor_1 = make_tensor_proto(0, dtype=np.float32, shape=[1, 0]) tensor_2 = tf.make_tensor_proto(0, dtype=tf.float32, shape=[1, 0]) array_1 = tf.make_ndarray(tensor_1) array_2 = make_ndarray(tensor_2) assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2) assert array_1.tostring() == array_2.tostring() assert array_1.dtype == array_2.dtype
def test_proto_7(): """ Tests the make_tensor_proto and make_nd_array function """ from diplomacy_research.utils.tensorflow import tf random_tensor = np.random.rand(15, 25) tensor_1 = make_tensor_proto(random_tensor, dtype=np.float32) tensor_2 = tf.make_tensor_proto(random_tensor, dtype=tf.float32) array_1 = tf.make_ndarray(tensor_1) array_2 = make_ndarray(tensor_2) assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2) assert array_1.tostring() == array_2.tostring() assert array_1.dtype == array_2.dtype
def test_proto_2(): """ Tests the make_tensor_proto and make_nd_array function """ from diplomacy_research.utils.tensorflow import tf tensor_1 = make_tensor_proto([bytes('', 'utf-8')], dtype=np.object, shape=[1]) tensor_2 = tf.make_tensor_proto([bytes('', 'utf-8')], dtype=tf.string, shape=[1]) array_1 = tf.make_ndarray(tensor_1) array_2 = make_ndarray(tensor_2) assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2) assert array_1.tostring() == array_2.tostring() assert array_1.dtype == array_2.dtype
def build(self): """ Builds the channel and the stub """ import grpc from diplomacy_research.proto.tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub assert 'request_id' in self.proto_fields, 'You need to have a "request_id" field.' assert is_port_opened( self.port, self.hostname), 'Unable to connect to %s:%d.' % (self.hostname, self.port) # Creating insecure channel with corresponding stubs self.channel = grpc.insecure_channel('%s:%d' % (self.hostname, self.port)) self.predict_stub = PredictionServiceStub(self.channel) # Padding output shapes with None output_types = self.dataset_builder.output_types output_shapes = self.dataset_builder.output_shapes output_shapes = { key: [None] + list(shape) for key, shape in output_shapes.items() } # Building a list of generic default values from the output types and output shapes for feature_name, feature_shape in output_shapes.items(): if output_types[feature_name] == np.object: self.default_features[feature_name] = make_tensor_proto( bytes('', 'utf-8'), dtype=np.object, shape=[1]) elif isinstance(self.proto_fields[feature_name], VarProtoField): self.default_features[feature_name] = make_tensor_proto( [], dtype=output_types[feature_name], shape=[1, 0]) else: self.default_features[feature_name] = make_tensor_proto( 0, dtype=output_types[feature_name], shape=[1] + feature_shape[1:])
def get_results(self, queue_name, item, retry_on_failure=True, **kwargs): """ Computes the outputs of a name using item as inout :param queue_name: The name of the queue where to put the item (or model_name/queue_name) :param item: A dictionary with the fields required for that queue :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered. :return: A tornado.concurrent.Future that will be set with the results when they become available """ import grpc if not self.has_queue(queue_name): LOGGER.warning('The method "%s" could not be found.', queue_name) return None if not isinstance(item, dict): LOGGER.warning( 'The item object passed to get_results must be a dictionary.') return None # Trying to infer the model_name from the queue_name model_name = self.model_name if '/' in queue_name: model_name, queue_name = queue_name.split('/') # Preparing the item item['request_id'] = bytes('', 'utf-8') item = self.prepare_item(item) # Building the request request = PredictRequest() request.model_spec.name = model_name # pylint: disable=no-member request.model_spec.signature_name = queue_name # pylint: disable=no-member # Setting the keys in items # Adding a leading batch dimension, so that TF Serving can batch items properly # i.e. (dim_1, dim_2) --> (1, dim_1, dim_2) for key in item: batched_item_key = item[key][None, ...] if isinstance( item[key], np.ndarray) else [item[key]] request.inputs[key].CopyFrom( make_tensor_proto( batched_item_key, # pylint: disable=no-member dtype=self.dataset_builder.output_types[key])) # Setting the placeholders defined in the signature placeholders = self.signature[queue_name].get('placeholders', {}) for ph_name in placeholders: ph_value, ph_dtype = placeholders[ph_name] request.inputs[ph_name].CopyFrom( make_tensor_proto(ph_value, dtype=ph_dtype)) # pylint: disable=no-member # Adding generic default values (all zeros) for default_key, default_val in self.default_features.items(): if default_key not in item: request.inputs[default_key].CopyFrom(default_val) # pylint: disable=no-member # Sending the request and processing the response # Trying for a maximum of self.nb_retries for attempt_ix in range(self.nb_retries): try: grpc_response = yield wrap_grpc_call( self.predict_stub.Predict.future(request, timeout=self.timeout)) return [ make_ndarray(grpc_response.outputs[key])[0, ...] for key in sorted(grpc_response.outputs) ] except grpc.RpcError as grpc_error: if not retry_on_failure: raise grpc_error yield gen.sleep(30) if (attempt_ix + 1) % 10 == 0: LOGGER.warning('Unable to get results. Attempt %d/%d', attempt_ix + 1, self.nb_retries) # Raising fatal exception raise RuntimeError( 'Unable to get results after %d attempts. Aborting' % self.nb_retries)