def __init__(self, tensor_dict, batch_size, batch_queue_capacity, num_batch_queue_threads, prefetch_queue_capacity): """Constructs a batch queue holding tensor_dict. Args: tensor_dict: dictionary of tensors to batch. batch_size: batch size. batch_queue_capacity: max capacity of the queue from which the tensors are batched. num_batch_queue_threads: number of threads to use for batching. prefetch_queue_capacity: max capacity of the queue used to prefetch assembled batches. """ # Remember static shapes to set shapes of batched tensors. static_shapes = collections.OrderedDict( {key: tensor.get_shape() for key, tensor in tensor_dict.items()}) # Remember runtime shapes to unpad tensors after batching. runtime_shapes = collections.OrderedDict( {(key + rt_shape_str): tf.shape(tensor) for key, tensor in tensor_dict.items()}) all_tensors = tensor_dict all_tensors.update(runtime_shapes) batched_tensors = tf.train.batch( all_tensors, capacity=batch_queue_capacity, batch_size=batch_size, dynamic_pad=True, num_threads=num_batch_queue_threads) self._queue = prefetcher.prefetch(batched_tensors, prefetch_queue_capacity) self._static_shapes = static_shapes self._batch_size = batch_size
def test_prefetch_tensors_with_partially_defined_shapes(self): with self.test_session() as sess: batch_size = 10 image_size = 32 num_batches = 5 examples = tf.Variable(tf.constant(0, dtype=tf.int64)) counter = examples.count_up_to(num_batches) image = tf.random_normal([ batch_size, tf.Variable(image_size), tf.Variable(image_size), 3 ], dtype=tf.float32, name='image') image.set_shape([batch_size, None, None, 3]) label = tf.random_uniform([batch_size, tf.Variable(1)], 0, 10, dtype=tf.int32, name='label') label.set_shape([batch_size, None]) prefetch_queue = prefetcher.prefetch(tensor_dict={ 'counter': counter, 'image': image, 'label': label }, capacity=100) tensor_dict = prefetch_queue.dequeue() self.assertAllEqual(tensor_dict['image'].get_shape().as_list(), [batch_size, None, None, 3]) self.assertAllEqual(tensor_dict['label'].get_shape().as_list(), [batch_size, None]) tf.initialize_all_variables().run() with slim.queues.QueueRunners(sess): for _ in range(num_batches): results = sess.run(tensor_dict) self.assertEquals(results['image'].shape, (batch_size, image_size, image_size, 3)) self.assertEquals(results['label'].shape, (batch_size, 1)) with self.assertRaises(tf.errors.OutOfRangeError): sess.run(tensor_dict)
def _extract_predictions_and_losses(model, create_input_dict_fn, ignore_groundtruth=False): """Constructs tensorflow detection graph and returns output tensors. Args: model: model to perform predictions with. create_input_dict_fn: function to create input tensor dictionaries. ignore_groundtruth: whether groundtruth should be ignored. Returns: prediction_groundtruth_dict: A dictionary with postprocessed tensors (keyed by standard_fields.DetectionResultsFields) and optional groundtruth tensors (keyed by standard_fields.InputDataFields). losses_dict: A dictionary containing detection losses. This is empty when ignore_groundtruth is true. """ input_dict = create_input_dict_fn() prefetch_queue = prefetcher.prefetch(input_dict, capacity=500) input_dict = prefetch_queue.dequeue() original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0) preprocessed_image, true_image_shapes = model.preprocess( tf.to_float(original_image)) prediction_dict = model.predict(preprocessed_image, true_image_shapes) detections = model.postprocess(prediction_dict, true_image_shapes) groundtruth = None losses_dict = {} if not ignore_groundtruth: groundtruth = { fields.InputDataFields.groundtruth_boxes: input_dict[fields.InputDataFields.groundtruth_boxes], fields.InputDataFields.groundtruth_classes: input_dict[fields.InputDataFields.groundtruth_classes], fields.InputDataFields.groundtruth_area: input_dict[fields.InputDataFields.groundtruth_area], fields.InputDataFields.groundtruth_is_crowd: input_dict[fields.InputDataFields.groundtruth_is_crowd], fields.InputDataFields.groundtruth_difficult: input_dict[fields.InputDataFields.groundtruth_difficult] } if fields.InputDataFields.groundtruth_group_of in input_dict: groundtruth[fields.InputDataFields.groundtruth_group_of] = ( input_dict[fields.InputDataFields.groundtruth_group_of]) groundtruth_masks_list = None if fields.DetectionResultFields.detection_masks in detections: groundtruth[fields.InputDataFields.groundtruth_instance_masks] = ( input_dict[fields.InputDataFields.groundtruth_instance_masks]) groundtruth_masks_list = [ input_dict[fields.InputDataFields.groundtruth_instance_masks]] groundtruth_keypoints_list = None if fields.DetectionResultFields.detection_keypoints in detections: groundtruth[fields.InputDataFields.groundtruth_keypoints] = ( input_dict[fields.InputDataFields.groundtruth_keypoints]) groundtruth_keypoints_list = [ input_dict[fields.InputDataFields.groundtruth_keypoints]] label_id_offset = 1 model.provide_groundtruth( [input_dict[fields.InputDataFields.groundtruth_boxes]], [tf.one_hot(input_dict[fields.InputDataFields.groundtruth_classes] - label_id_offset, depth=model.num_classes)], groundtruth_masks_list, groundtruth_keypoints_list) losses_dict.update(model.loss(prediction_dict, true_image_shapes)) result_dict = eval_util.result_dict_for_single_example( original_image, input_dict[fields.InputDataFields.source_id], detections, groundtruth, class_agnostic=( fields.DetectionResultFields.detection_classes not in detections), scale_to_absolute=True) return result_dict, losses_dict