def enqueue_ops_fn(idx): """Enqueue ops function for one host..""" with tf.device(device): sharded_inputs = [] start_idx = 0 if host_id in range(0, self.hparams.num_infeed_workers * 2, 2): core_id = tf.constant(host_id * self.hparams.num_shards_per_host, shape=[1], dtype=tf.int32) if self.hparams.use_synthetic_data: features = output else: def true_fn(): return iterator.get_next() def false_fn(): return { k: tf.zeros_like( self.feature_structure["features"][k]) for k in self.feature_structure["features"] } features = tf.cond( tf.equal(idx % self.hparams.num_infeed_workers, host_id // 2), true_fn, false_fn) sharded_inputs.append( data_nest.flatten({ "features": features, "core_id": core_id })) start_idx = 1 for i in range(start_idx, self.hparams.num_shards_per_host): sharded_inputs.append( data_nest.flatten({ "features": { k: tf.zeros_like( self.feature_structure["features"][k]) for k in self.feature_structure["features"] }, "core_id": tf.constant( host_id * self.hparams.num_shards_per_host + i, shape=[1], dtype=tf.int32) })) infeed = tpu_feed.InfeedQueue( number_of_tuple_elements=len(sharded_inputs[0])) self.infeed_queue.append(infeed) def tpu_ordinal_fn(shard_index_in_host): return shard_index_in_host % self.hparams.num_shards_per_host return infeed.generate_enqueue_ops( sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
def CreatePerHostEnqueueOp(self, task_id): """Create TPU infeed equeue tuple ops specific to a host. Note it's important this is a function to give the caller flexibility as to which context it is instantiated in. This is critical for tf.while_loop driven infeed for instance. Args: task_id: Which infeed host it is. Returns: List of enqueue ops for this host. """ p = self.params # Only support this in MLPerf. assert p.use_per_host_infeed host_device = '/task:{}/device:CPU:0'.format(task_id) self._batch = self.GetPreprocessedInputBatch(task_id=task_id) if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) for k, x in self._batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = self._batch.Transform(lambda x: x.shape).Flatten() dtypes = self._batch.Transform(lambda x: x.dtype).Flatten() q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert self.shards is not None q.set_number_of_shards(self.shards) self._tpu_queues.append(q) def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal(replica=replica) else: return shard_index_in_host tf.logging.info('CreatePerHostEnqueueOp: %d', task_id) input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) return input_ops
def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.hparams.num_shards_per_host): with tf.control_dependencies(control_deps): features = iterator.get_next() self.eval_feature_structure["features"] = features flattened_inputs = data_nest.flatten( self.eval_feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) self.eval_infeed_queue.append(infeed) def tpu_ordinal_fn(shard_index_in_host): return shard_index_in_host % self.hparams.num_shards_per_host return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
def enqueue_ops_fn(idx): """Generate the infeed enqueue ops graph.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(FLAGS.replicas_per_host): with tf.control_dependencies(control_deps): self.feature_structure[ is_training] = iterator.get_next() self.maybe_capture_embedding_inputs( self.feature_structure[is_training], is_training) flattened_inputs = tf.nest.flatten( self.feature_structure[is_training]) control_deps.extend(flattened_inputs) if input_partition_dims: padded_inputs = [] for inp in flattened_inputs: if inp.shape.ndims < len(input_partition_dims): padded_inputs.append(inp) continue paddings = [] for i, j in enumerate(input_partition_dims): r = inp.shape.as_list()[i] % j if r > 0: paddings.append([0, j - r]) else: paddings.append([0, 0]) for i in range(inp.shape.ndims - len(input_partition_dims)): paddings.append([0, 0]) padded_inputs.append(tf.pad(inp, paddings)) per_host_sharded_inputs.append(padded_inputs) else: per_host_sharded_inputs.append(flattened_inputs) if input_partition_dims: flattened_input_dims = [] for i in per_host_sharded_inputs[0]: if i.shape.ndims == len(input_partition_dims): flattened_input_dims.append( input_partition_dims) elif i.shape.ndims > len(input_partition_dims): flattened_input_dims.append( input_partition_dims + [1] * (i.shape.ndims - len(input_partition_dims)) ) else: flattened_input_dims.append([1] * i.shape.ndims) # pylint: disable=protected-access self.infeed_op[ is_training] = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) with tf.control_dependencies( self.infeed_op[is_training]. generate_enqueue_ops(per_host_sharded_inputs)): return idx + 1 else: self.infeed_op[is_training] = tpu_feed.InfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0])) per_host_enqueue_ops = ( self.infeed_op[is_training].generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=_tpu_ordinal_fn)) self.maybe_add_embedding_enqueue_ops_int( is_training, per_host_enqueue_ops) with tf.control_dependencies(per_host_enqueue_ops): return idx + 1