def eval_enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.train_params["replicas_per_worker"]): with tf.control_dependencies(control_deps): features = iterator.get_next() if self.use_spatial_partition: self.eval_input_dims_flattener.validate_and_flatten_input_dims( features, None) flattened_inputs = (self.eval_input_flattener. flatten_features_and_labels( features, None)) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if self.use_spatial_partition: flattened_input_dims = (self.eval_input_dims_flattener. flattened_input_dims) # pylint: disable=protected-access infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) self.eval_infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs) infeed = tf.contrib.tpu.InfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.eval_infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=functools.partial( runner_utils.tpu_ordinal_fn, replicas_per_worker=self. train_params["replicas_per_worker"]))
def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.replicas_per_worker): with tf.control_dependencies(control_deps): features, labels = iterator.get_next() if self.use_spatial_partition: num_elements = [] for i, d in enumerate(ssd_constants.FEATURE_SIZES): num_elements.append( d * d * ssd_constants.NUM_DEFAULTS[i]) gt_boxes = tf.split(labels[ssd_constants.BOXES], num_elements, 1) gt_classes = tf.split( labels[ssd_constants.CLASSES], num_elements, 1) def transpose_gt_box(gt_box, i): return tf.transpose( tf.reshape(gt_box, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i], 4 ]), [0, 2, 3, 1, 4]) def transpose_gt_class(gt_class, i): return tf.transpose( tf.reshape(gt_class, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i] ]), [0, 2, 3, 1]) labels[ssd_constants.BOXES] = { i: transpose_gt_box(gt_boxes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } labels[ssd_constants.CLASSES] = { i: transpose_gt_class(gt_classes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } self.feature_structure["features"] = features self.feature_structure["labels"] = labels flattened_inputs = data_nest.flatten( self.feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if self.use_spatial_partition: flattened_input_dims = [] for i in per_host_sharded_inputs[0]: if i.shape.ndims >= len( self.input_partition_dims[0]): if i.shape.as_list() == self.feature_structure[ "features"].shape.as_list(): flattened_input_dims.append( self.input_partition_dims[0]) else: flattened_input_dims.append( FLAGS.input_partition_dims + [1] * (i.shape.ndims - len(self.input_partition_dims[0]))) else: flattened_input_dims.append([1] * i.shape.ndims) # pylint: disable=protected-access infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs) infeed = tpu.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn)