def get_enqueue_ops_fn(): """Generate the enqueue ops graph function.""" with tf.device(utils.device_for_host(self._get_host(0))): dataset = input_fn(params) iterator = dataset.make_initializable_iterator() self.dataset_initializer.append(iterator.initializer) def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(FLAGS.num_shards_per_host): with tf.control_dependencies(control_deps): features = iterator.get_next() self.feature_structure["features"] = features flattened_inputs = data_nest.flatten(self.feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) infeed = tpu.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn) return enqueue_ops_fn
def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) with tf.device(utils.device_for_host(self._get_host(host_id))): for i in range(FLAGS.num_shards_per_host): outfeed = tpu.outfeed_dequeue_tuple(dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i) if len(outfeed) == 2: if outfeed[0].shape.ndims == 3: detections, is_pad = outfeed else: is_pad, detections = outfeed num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum( tf.cast(is_pad, tf.int32)) dequeue_ops.append( tf.slice(detections, [0, 0, 0], [num_non_pad, -1, -1])) else: dequeue_ops.append(outfeed) dequeue_ops = tf.concat(dequeue_ops, axis=0) return dequeue_ops
def get_eval_enqueue_ops_fn(host_id): """Generate the eval enqueue ops graph function.""" params["dataset_num_shards"] = self.num_hosts params["dataset_index"] = host_id with tf.device(utils.device_for_host(self._get_host(host_id))): dataset = input_fn(params) iterator = dataset.make_initializable_iterator() self.eval_dataset_initializer.append(iterator.initializer) def eval_enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.replicas_per_worker): with tf.control_dependencies(control_deps): features = iterator.get_next() if self.use_spatial_partition: self.eval_input_dims_flattener.validate_and_flatten_input_dims( features, None) self.eval_feature_structure["features"] = features flattened_inputs = data_nest.flatten( self.eval_feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if self.use_spatial_partition: flattened_input_dims = (self.eval_input_dims_flattener. flattened_input_dims) # pylint: disable=protected-access infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) self.eval_infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs) infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.eval_infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn) return eval_enqueue_ops_fn
def create_dequeue_ops(): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: dequeue_ops.append([]) tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) for i in range(FLAGS.num_shards): with tf.device(utils.device_for_host(self._get_host(0))): outfeed_tensors = tpu.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) for j in range(len(outfeed_tensors)): dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0) return dequeue_ops
def create_dequeue_ops(host_id): """Create outfeed dequeue ops.""" dequeue_ops = [] tensor_dtypes = [] tensor_shapes = [] for v in self.outfeed_tensors: tensor_dtypes.append(v.dtype) tensor_shapes.append(v.shape) with tf.device(utils.device_for_host(self._get_host(host_id))): for i in range(self.replicas_per_worker): if self.use_spatial_partition: replica_id = self.device_assignment.lookup_replicas( host_id, 0)[i] ordinal = self.device_assignment.tpu_ordinal( replica=replica_id, logical_core=0) else: ordinal = i outfeed = tpu_ops.outfeed_dequeue_tuple( dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=ordinal) if len(outfeed) == 2: # 2 outfeed tensors # is_pad: [batch] # detections: [batch, 200, 7] if outfeed[0].shape.ndims == 3: detections, is_pad = outfeed else: is_pad, detections = outfeed num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum( tf.cast(is_pad, tf.int32)) dequeue_ops.append( tf.slice(detections, [0, 0, 0], [num_non_pad, -1, -1])) else: # no padding, only detections are in the outfeed dequeue_ops.append(outfeed) dequeue_ops = tf.concat(dequeue_ops, axis=0) return dequeue_ops
def get_enqueue_ops_fn(host_id): """Generate the enqueue ops graph function.""" params["dataset_num_shards"] = self.num_hosts params["dataset_index"] = host_id with tf.device(utils.device_for_host(self._get_host(host_id))): dataset = input_fn(params) iterator = dataset.make_initializable_iterator() self.dataset_initializer.append(iterator.initializer) def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.replicas_per_worker): with tf.control_dependencies(control_deps): features, labels = iterator.get_next() if self.use_spatial_partition: num_elements = [] for i, d in enumerate(ssd_constants.FEATURE_SIZES): num_elements.append( d * d * ssd_constants.NUM_DEFAULTS[i]) gt_boxes = tf.split(labels[ssd_constants.BOXES], num_elements, 1) gt_classes = tf.split( labels[ssd_constants.CLASSES], num_elements, 1) def transpose_gt_box(gt_box, i): return tf.transpose( tf.reshape(gt_box, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i], 4 ]), [0, 2, 3, 1, 4]) def transpose_gt_class(gt_class, i): return tf.transpose( tf.reshape(gt_class, [ -1, ssd_constants.NUM_DEFAULTS[i], ssd_constants.FEATURE_SIZES[i], ssd_constants.FEATURE_SIZES[i] ]), [0, 2, 3, 1]) labels[ssd_constants.BOXES] = { i: transpose_gt_box(gt_boxes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } labels[ssd_constants.CLASSES] = { i: transpose_gt_class(gt_classes[i], i) for i in range(len(ssd_constants.NUM_DEFAULTS)) } self.feature_structure["features"] = features self.feature_structure["labels"] = labels flattened_inputs = data_nest.flatten( self.feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) if self.use_spatial_partition: flattened_input_dims = [] for i in per_host_sharded_inputs[0]: if i.shape.ndims >= len( self.input_partition_dims[0]): if i.shape.as_list() == self.feature_structure[ "features"].shape.as_list(): flattened_input_dims.append( self.input_partition_dims[0]) else: flattened_input_dims.append( FLAGS.input_partition_dims + [1] * (i.shape.ndims - len(self.input_partition_dims[0]))) else: flattened_input_dims.append([1] * i.shape.ndims) # pylint: disable=protected-access infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs) infeed = tpu.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=utils.tpu_ordinal_fn) return enqueue_ops_fn