def test_proto_6(): """ Tests the make_tensor_proto and make_nd_array function """ from diplomacy_research.utils.tensorflow import tf tensor_1 = make_tensor_proto(0, dtype=np.float32, shape=[1, 0]) tensor_2 = tf.make_tensor_proto(0, dtype=tf.float32, shape=[1, 0]) array_1 = tf.make_ndarray(tensor_1) array_2 = make_ndarray(tensor_2) assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2) assert array_1.tostring() == array_2.tostring() assert array_1.dtype == array_2.dtype
def test_proto_7(): """ Tests the make_tensor_proto and make_nd_array function """ from diplomacy_research.utils.tensorflow import tf random_tensor = np.random.rand(15, 25) tensor_1 = make_tensor_proto(random_tensor, dtype=np.float32) tensor_2 = tf.make_tensor_proto(random_tensor, dtype=tf.float32) array_1 = tf.make_ndarray(tensor_1) array_2 = make_ndarray(tensor_2) assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2) assert array_1.tostring() == array_2.tostring() assert array_1.dtype == array_2.dtype
def test_proto_2(): """ Tests the make_tensor_proto and make_nd_array function """ from diplomacy_research.utils.tensorflow import tf tensor_1 = make_tensor_proto([bytes('', 'utf-8')], dtype=np.object, shape=[1]) tensor_2 = tf.make_tensor_proto([bytes('', 'utf-8')], dtype=tf.string, shape=[1]) array_1 = tf.make_ndarray(tensor_1) array_2 = make_ndarray(tensor_2) assert proto_to_bytes(tensor_1) == proto_to_bytes(tensor_2) assert array_1.tostring() == array_2.tostring() assert array_1.dtype == array_2.dtype
def get_results(self, queue_name, item, retry_on_failure=True, **kwargs): """ Computes the outputs of a name using item as inout :param queue_name: The name of the queue where to put the item (or model_name/queue_name) :param item: A dictionary with the fields required for that queue :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered. :return: A tornado.concurrent.Future that will be set with the results when they become available """ import grpc if not self.has_queue(queue_name): LOGGER.warning('The method "%s" could not be found.', queue_name) return None if not isinstance(item, dict): LOGGER.warning( 'The item object passed to get_results must be a dictionary.') return None # Trying to infer the model_name from the queue_name model_name = self.model_name if '/' in queue_name: model_name, queue_name = queue_name.split('/') # Preparing the item item['request_id'] = bytes('', 'utf-8') item = self.prepare_item(item) # Building the request request = PredictRequest() request.model_spec.name = model_name # pylint: disable=no-member request.model_spec.signature_name = queue_name # pylint: disable=no-member # Setting the keys in items # Adding a leading batch dimension, so that TF Serving can batch items properly # i.e. (dim_1, dim_2) --> (1, dim_1, dim_2) for key in item: batched_item_key = item[key][None, ...] if isinstance( item[key], np.ndarray) else [item[key]] request.inputs[key].CopyFrom( make_tensor_proto( batched_item_key, # pylint: disable=no-member dtype=self.dataset_builder.output_types[key])) # Setting the placeholders defined in the signature placeholders = self.signature[queue_name].get('placeholders', {}) for ph_name in placeholders: ph_value, ph_dtype = placeholders[ph_name] request.inputs[ph_name].CopyFrom( make_tensor_proto(ph_value, dtype=ph_dtype)) # pylint: disable=no-member # Adding generic default values (all zeros) for default_key, default_val in self.default_features.items(): if default_key not in item: request.inputs[default_key].CopyFrom(default_val) # pylint: disable=no-member # Sending the request and processing the response # Trying for a maximum of self.nb_retries for attempt_ix in range(self.nb_retries): try: grpc_response = yield wrap_grpc_call( self.predict_stub.Predict.future(request, timeout=self.timeout)) return [ make_ndarray(grpc_response.outputs[key])[0, ...] for key in sorted(grpc_response.outputs) ] except grpc.RpcError as grpc_error: if not retry_on_failure: raise grpc_error yield gen.sleep(30) if (attempt_ix + 1) % 10 == 0: LOGGER.warning('Unable to get results. Attempt %d/%d', attempt_ix + 1, self.nb_retries) # Raising fatal exception raise RuntimeError( 'Unable to get results after %d attempts. Aborting' % self.nb_retries)