def test_var_args_and_defaults(self): """Tests that arg checker works for a function with varargs and defaults.""" def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg return x + y + z + len(q) self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) self.assertEqual(None, xla.check_function_argument_count(func, 5, None)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 1, None)) queue = tpu_feed.InfeedQueue(1) self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 4, queue)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 0, queue))
def testModification(self): """Tests modification of the queue post-construction.""" i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) i.set_tuple_types([dtypes.float32, dtypes.int32]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) i.set_tuple_types([dtypes.float32, dtypes.float32]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32]) with self.assertRaises(ValueError): i.set_tuple_types([dtypes.float32]) i.set_tuple_shapes([[1], [2, 3]]) self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) i.set_tuple_shapes([[1, 2], [3, 4]]) self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]]) with self.assertRaises(ValueError): i.set_tuple_shapes([[1, 2]]) i.set_number_of_shards(2) self.assertEqual(i.number_of_shards, 2) i.set_number_of_shards(3) self.assertEqual(i.number_of_shards, 3) t1 = constant_op.constant(1, dtypes.int32, shape=[6]) t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18]) i.set_configuration_from_input_tensors([t1, t2]) self.assertEqual(i.tuple_shapes, [[6], [3, 18]]) self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32]) i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) self.assertEqual(i.number_of_shards, 2) self.assertEqual(i.tuple_shapes, [[6, 18], [12]]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) i.set_shard_dimensions([1, 0]) i.set_number_of_shards(3) with self.assertRaises(ValueError): i.set_number_of_shards(4)
def testUsingInfeedQueueWithRegularizer(self): """Test that Layer regularizers can reference data created in loops.""" with ops.Graph().as_default(): def make_regularizer(scale): def regularizer(inputs): return scale * math_ops.reduce_sum(math_ops.square(inputs)) return regularizer def training_step(inputs, scale): outputs = convolutional.conv2d( inputs, filters=16, kernel_size=(3, 3), data_format="channels_first", kernel_regularizer=make_regularizer(scale)) loss = math_ops.reduce_mean(math_ops.square(outputs)) return loss.op inputs = array_ops.zeros(shape=(128, 32, 32, 16)) scale = array_ops.ones(shape=()) infeed = tpu_feed.InfeedQueue( tuple_types=[dtypes.float32, dtypes.float32], tuple_shapes=[inputs.shape, scale.shape]) def loop(): return training_loop.repeat(5, training_step, infeed_queue=infeed) # This should not throw an error. tpu.rewrite(loop)
def testFreezing(self): """Tests freezing the queue.""" i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) t1 = constant_op.constant(1, dtypes.int32, shape=[2]) t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4]) i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) self.assertEqual(i.number_of_shards, 2) self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) self.assertEqual(i.shard_dimensions, [0, 0]) i.freeze() i.set_number_of_shards(2) i.set_tuple_shapes([[4, 4], [4]]) i.set_tuple_types([dtypes.float32, dtypes.int32]) i.set_shard_dimensions([0, 0]) with self.assertRaises(ValueError): i.set_number_of_shards(1) with self.assertRaises(ValueError): i.set_tuple_shapes([[8, 8], [8]]) with self.assertRaises(ValueError): i.set_tuple_types([dtypes.int32, dtypes.float32]) with self.assertRaises(ValueError): i.set_shard_dimensions([1, 0]) self.assertEqual(i.number_of_shards, 2) self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) self.assertEqual(i.shard_dimensions, [0, 0])
def enqueue_ops_fn(idx): """Enqueue ops function for one host..""" with tf.device(device): sharded_inputs = [] start_idx = 0 if host_id in range(0, self.hparams.num_infeed_workers * 2, 2): core_id = tf.constant( host_id * self.hparams.num_shards_per_host, shape=[1], dtype=tf.int32) if self.hparams.use_synthetic_data: features = output else: def true_fn(): return iterator.get_next() def false_fn(): return { k: tf.zeros_like(self.feature_structure["features"][k]) for k in self.feature_structure["features"] } features = tf.cond( tf.equal(idx % self.hparams.num_infeed_workers, host_id // 2), true_fn, false_fn) sharded_inputs.append( data_nest.flatten({ "features": features, "core_id": core_id })) start_idx = 1 for i in range(start_idx, self.hparams.num_shards_per_host): sharded_inputs.append( data_nest.flatten({ "features": { k: tf.zeros_like(self.feature_structure["features"][k]) for k in self.feature_structure["features"] }, "core_id": tf.constant( host_id * self.hparams.num_shards_per_host + i, shape=[1], dtype=tf.int32) })) infeed = tpu_feed.InfeedQueue( number_of_tuple_elements=len(sharded_inputs[0])) self.infeed_queue.append(infeed) def tpu_ordinal_fn(shard_index_in_host): return shard_index_in_host % self.hparams.num_shards_per_host return infeed.generate_enqueue_ops( sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
def test_simple(self): """Tests that arg checker works for functions with no varargs or defaults. """ def func(x, y, z): return x + y + z self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) self.assertEqual('exactly 3 arguments', xla.check_function_argument_count(func, 2, None)) queue = tpu_feed.InfeedQueue(2) self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) self.assertEqual('exactly 3 arguments', xla.check_function_argument_count(func, 2, queue))
def testConstructor(self): """Tests that the constructor can be called with different arguments.""" i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) self.assertEqual(i.number_of_tuple_elements, 2) self.assertEqual(i.tuple_types, None) self.assertEqual(i.tuple_shapes, None) self.assertEqual(i.number_of_shards, None) i = tpu_feed.InfeedQueue( tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32]) self.assertEqual(i.number_of_tuple_elements, 3) self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32, dtypes.int32]) self.assertEqual(i.tuple_shapes, None) self.assertEqual(i.number_of_shards, None) i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]]) self.assertEqual(i.number_of_tuple_elements, 2) self.assertEqual(i.tuple_types, None) self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) self.assertEqual(i.number_of_shards, None) i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7]) self.assertEqual(i.number_of_tuple_elements, 3) self.assertEqual(i.tuple_types, None) self.assertEqual(i.tuple_shapes, None) self.assertEqual([p.shard_dimension for p in i.sharding_policies], [1, 0, 7]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue() with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_types=[dtypes.float32]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1]) with self.assertRaises(ValueError): i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1])
def _enqueue_laidout_tensors(self, all_laidout_tensors): """Generate enqueue ops to enqueue all_laidout_tensors.""" def _tpu_ordinal_function_impl(pnum): return self._p_dev.ordered_ordinals[pnum] def _placement_function_impl(pnum): return self._p_dev.ordered_hosts[pnum] laidout_tensors0 = all_laidout_tensors[0] infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(laidout_tensors0), tuple_types=[x.dtype for x in laidout_tensors0], tuple_shapes=[x.shape for x in laidout_tensors0]) enqueue_ops = infeed_queue.generate_enqueue_ops( all_laidout_tensors, tpu_ordinal_function=_tpu_ordinal_function_impl, placement_function=_placement_function_impl) return infeed_queue, enqueue_ops
def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(self.hparams.num_shards_per_host): with tf.control_dependencies(control_deps): features = iterator.get_next() self.eval_feature_structure["features"] = features flattened_inputs = data_nest.flatten(self.eval_feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) infeed = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) self.eval_infeed_queue.append(infeed) def tpu_ordinal_fn(shard_index_in_host): return shard_index_in_host % self.hparams.num_shards_per_host return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
def test_var_args(self): """Tests that arg checker works for a function with varargs.""" def func(x, y, *z): return x + y + len(z) self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 1, None)) queue = tpu_feed.InfeedQueue(1) self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) self.assertEqual('at least 2 arguments', xla.check_function_argument_count(func, 0, queue))
def _config_infeed(self, num_partitions, device_assignment, batch_size, key_size=2, return_tgt_mask=False, use_partitioned_infeed_queue=False): """Config the infeed ops and args.""" zero_batch = get_zero_batch(batch_size=batch_size, max_len=self._prefix_max_len, key_size=key_size, return_tgt_mask=return_tgt_mask) host_device = device_assignment.host_device(replica=0, job=self._tpu) host_id = int(host_device.split('/task:')[1].split('/device:')[0]) input_partition_dims = [[num_partitions] + [1] * (len(x.shape) - 1) for x in zero_batch] if use_partitioned_infeed_queue: infeed = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(zero_batch), host_id=host_id, input_partition_dims=input_partition_dims, device_assignment=device_assignment) else: infeed = tpu_feed.InfeedQueue( number_of_tuple_elements=len(zero_batch)) self.infeed_args = [] for x in zero_batch: p = tf.placeholder(tf.as_dtype(x.dtype), shape=x.shape) self.infeed_args += [p] if use_partitioned_infeed_queue: self.infeed_op = infeed.generate_enqueue_ops([self.infeed_args]) else: self.infeed_op = infeed.split_inputs_and_generate_enqueue_ops( self.infeed_args, device_assignment=device_assignment) return infeed
def enqueue_ops_fn(): """Enqueue ops function for one host.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(FLAGS.tpu_num_shards_per_host): if "eval" in task: with tf.control_dependencies(control_deps): features = iterator.get_next() feature_structure["features"] = features else: with tf.control_dependencies(control_deps): features, labels = iterator.get_next() feature_structure["features"] = features feature_structure["labels"] = labels flattened_inputs = data_nest.flatten(feature_structure) control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len( per_host_sharded_inputs[0])) infeed_queue.append(infeed) return infeed.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=low_level_utils.tpu_ordinal_fn)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def CreateTpuEnqueueOps(self): """Create the host-side enqueue ops. This should be called in an outer non-TPU context. """ assert not self._tpu_queues, ( 'CreateTpuEnqueueOps should only be called ' 'once.') self._tpu_queues = [] p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTpuEnqueueOps num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 shards = (cluster.total_worker_devices // num_infeed_hosts) // cluster.num_devices_per_split tf.logging.info('shards {}'.format(shards)) input_ops_list = [] tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): self._batch = self.GetPreprocessedInputBatch() if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) for k, x in self._batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = self._batch.Transform(lambda x: x.shape).Flatten() dtypes = self._batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:')[0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions] + [1] * (len(s) - 1) for s in shapes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) self._tpu_queues.append(q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([self._batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) grouped_infeed_op = tf.group(*input_ops_list) self._tpu_infeed_op = [] for _ in range(p.tpu_infeed_parallelism): self._tpu_infeed_op.append(grouped_infeed_op)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if 'bucket_keys' in batch: # Hack: bucket_keys are not needed on TPU. del batch['bucket_keys'] tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def enqueue_ops_fn(idx): """Generate the infeed enqueue ops graph.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(FLAGS.replicas_per_host): with tf.control_dependencies(control_deps): self.feature_structure[ is_training] = iterator.get_next() self.maybe_capture_embedding_inputs( self.feature_structure[is_training], is_training) flattened_inputs = tf.nest.flatten( self.feature_structure[is_training]) control_deps.extend(flattened_inputs) if input_partition_dims: padded_inputs = [] for inp in flattened_inputs: if inp.shape.ndims < len(input_partition_dims): padded_inputs.append(inp) continue paddings = [] for i, j in enumerate(input_partition_dims): r = inp.shape.as_list()[i] % j if r > 0: paddings.append([0, j - r]) else: paddings.append([0, 0]) for i in range(inp.shape.ndims - len(input_partition_dims)): paddings.append([0, 0]) padded_inputs.append(tf.pad(inp, paddings)) per_host_sharded_inputs.append(padded_inputs) else: per_host_sharded_inputs.append(flattened_inputs) if input_partition_dims: flattened_input_dims = [] for i in per_host_sharded_inputs[0]: if i.shape.ndims == len(input_partition_dims): flattened_input_dims.append( input_partition_dims) elif i.shape.ndims > len(input_partition_dims): flattened_input_dims.append( input_partition_dims + [1] * (i.shape.ndims - len(input_partition_dims)) ) else: flattened_input_dims.append([1] * i.shape.ndims) # pylint: disable=protected-access self.infeed_op[ is_training] = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) with tf.control_dependencies( self.infeed_op[is_training]. generate_enqueue_ops(per_host_sharded_inputs)): return idx + 1 else: self.infeed_op[is_training] = tpu_feed.InfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0])) per_host_enqueue_ops = ( self.infeed_op[is_training].generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=_tpu_ordinal_fn)) self.maybe_add_embedding_enqueue_ops_int( is_training, per_host_enqueue_ops) with tf.control_dependencies(per_host_enqueue_ops): return idx + 1