def split_inputs(ctx, features, labels): """Splits the dense and sparse tensors inside the features and labels.""" enqueue_datas = collections.OrderedDict() if ctx.embedding_config: tpu_embedding_ = ctx.embedding_config.tpu_embedding feature_to_weight_key_name_dict = ( ctx.embedding_config.feature_to_weight_key_name_dict) for feature_key in tpu_embedding_.feature_to_config_dict: sparse_feature = _get_sparse_feature_from_feature( feature_key, features) weight_key_name = feature_to_weight_key_name_dict[feature_key] if isinstance(sparse_feature, sparse_tensor.SparseTensor): weights = _get_weights_from_features(weight_key_name, features) enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor( sparse_feature, weights) else: if weight_key_name is not None: raise ValueError( 'Found weights {} for weighted_categorical_column, which is not' 'compatible with sparse feature {} enqueued as dense tensor.' .format(weight_key_name, feature_key)) enqueue_data = tpu_embedding.EnqueueData(sparse_feature) enqueue_datas[feature_key] = enqueue_data return features, labels, enqueue_datas
def split_inputs(ctx, features, labels, num_cores_per_batch=1): """Splits the dense and sparse tensors inside the features and labels.""" enqueue_datas = collections.OrderedDict() if ctx.embedding_config: tpu_embedding_ = ctx.embedding_config.tpu_embedding for feature_key in tpu_embedding_.feature_to_config_dict: sparse_feature = _get_sparse_feature_from_feature(feature_key, features) max_sequence_length = tpu_embedding_.feature_to_config_dict[ feature_key].max_sequence_length combiner = tpu_embedding_._table_to_config_dict[ tpu_embedding_._feature_to_config_dict[feature_key].table_id].combiner if max_sequence_length > 0: length_feature_name = ( tpu_fc.get_sequence_length_feature_key_name_from_feature_key_name( feature_key)) length_feature = tf.math.minimum( fc_utils.sequence_length_from_sparse_tensor(sparse_feature), max_sequence_length) length_feature.set_shape(ctx.batch_size_for_input_fn) features[length_feature_name] = length_feature weight_key = tpu_embedding_.feature_to_config_dict[feature_key].weight_key sparse_feature_split = _split_tensor( sparse_feature, num_cores_per_batch) if combiner is None and not isinstance(sparse_feature, tf.sparse.SparseTensor): # A dense tensor with no combiner was provided so we assume that each # of the embedding_indices belongs to a different sample (setting # sample_indices to None). if weight_key is not None: raise ValueError( 'Found weights {} for weighted_categorical_column, which is not' 'compatible with sparse feature {} enqueued as dense tensor.' .format(weight_key, feature_key)) enqueue_data = [] for i in range(num_cores_per_batch): enqueue_data.append(tpu_embedding.EnqueueData( sparse_feature_split[i])) else: weights = None if isinstance(sparse_feature, tf.sparse.SparseTensor): weights = _get_weights_from_features(weight_key, features) weights_split = _split_tensor(weights, num_cores_per_batch) enqueue_data = [] for i in range(num_cores_per_batch): split_weights = weights_split[i] if weights else None enqueue_data.append( tpu_embedding.EnqueueData.from_sparse_tensor( _maybe_dense_to_sparse(sparse_feature_split[i]), weights=split_weights)) enqueue_datas[feature_key] = enqueue_data # Transpose the enqueue_datas dict into a list of dicts enqueue_datas_list = [] for i in range(num_cores_per_batch): enqueue_data = {} for key, value in enqueue_datas.items(): enqueue_data[key] = value[i] enqueue_datas_list.append(enqueue_data) return features, labels, enqueue_datas_list
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def split_inputs(ctx, features, labels, num_cores_per_batch=1): """Splits the dense and sparse tensors inside the features and labels.""" enqueue_datas = collections.OrderedDict() if ctx.embedding_config: tpu_embedding_ = ctx.embedding_config.tpu_embedding for feature_key in tpu_embedding_.feature_to_config_dict: sparse_feature = _get_sparse_feature_from_feature( feature_key, features) max_sequence_length = tpu_embedding_.feature_to_config_dict[ feature_key].max_sequence_length if max_sequence_length > 0: length_feature_name = ( tpu_fc. get_sequence_length_feature_key_name_from_feature_key_name( feature_key)) length_feature = math_ops.minimum( fc_utils.sequence_length_from_sparse_tensor( sparse_feature), max_sequence_length) length_feature.set_shape(ctx.batch_size_for_input_fn) features[length_feature_name] = length_feature weight_key = tpu_embedding_.feature_to_config_dict[ feature_key].weight_key sparse_feature_split = _split_tensor(sparse_feature, num_cores_per_batch) if isinstance(sparse_feature, sparse_tensor.SparseTensor): weights = _get_weights_from_features(weight_key, features) weights_split = _split_tensor(weights, num_cores_per_batch) enqueue_data = [] for i in range(num_cores_per_batch): enqueue_data.append( tpu_embedding.EnqueueData.from_sparse_tensor( sparse_feature_split[i], weights_split[i])) else: if weight_key is not None: raise ValueError( 'Found weights {} for weighted_categorical_column, which is not' 'compatible with sparse feature {} enqueued as dense tensor.' .format(weight_key, feature_key)) enqueue_data = [] for i in range(num_cores_per_batch): enqueue_data.append( tpu_embedding.EnqueueData(sparse_feature_split[i])) enqueue_datas[feature_key] = enqueue_data # Transpose the enqueue_datas dict into a list of dicts enqueue_datas_list = [] for i in range(num_cores_per_batch): enqueue_data = {} for key, value in enqueue_datas.items(): enqueue_data[key] = value[i] enqueue_datas_list.append(enqueue_data) return features, labels, enqueue_datas_list
def split_inputs(ctx, features, labels): """Splits the dense and sparse tensors inside the features and labels.""" enqueue_datas = collections.OrderedDict() if ctx.embedding_config: tpu_embedding_ = ctx.embedding_config.tpu_embedding feature_to_weight_key_name_dict = ( ctx.embedding_config.feature_to_weight_key_name_dict) for feature_key in tpu_embedding_.feature_to_config_dict: sparse_feature = _get_sparse_feature_from_feature( feature_key, features) max_sequence_length = tpu_embedding_.feature_to_config_dict[ feature_key].max_sequence_length if max_sequence_length > 0: length_feature_name = ( tpu_fc. get_sequence_length_feature_key_name_from_feature_key_name( feature_key)) length_feature = math_ops.minimum( fc_utils.sequence_length_from_sparse_tensor( sparse_feature), max_sequence_length) length_feature.set_shape(ctx.batch_size_for_input_fn) features[length_feature_name] = length_feature weight_key_name = feature_to_weight_key_name_dict[feature_key] if isinstance(sparse_feature, sparse_tensor.SparseTensor): weights = _get_weights_from_features(weight_key_name, features) enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor( sparse_feature, weights) else: if weight_key_name is not None: raise ValueError( 'Found weights {} for weighted_categorical_column, which is not' 'compatible with sparse feature {} enqueued as dense tensor.' .format(weight_key_name, feature_key)) enqueue_data = tpu_embedding.EnqueueData(sparse_feature) enqueue_datas[feature_key] = enqueue_data return features, labels, enqueue_datas
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_embedding_input_keys = ( tpu_embedding.feature_to_config_dict.keys() if tpu_embedding is not None else []) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() tpu_embedding_features = [] for tpu_embedding_input_key in tpu_embedding_input_keys: tpu_embedding_feature = batch.pop( tpu_embedding_input_key) tpu_embedding_features.append( (tpu_embedding_input_key, tpu_embedding_feature)) if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() if tpu_embedding is not None: enqueue_dict_per_core = [ {} ] * tpu_embedding.num_cores_per_host num_cores_per_host = tpu_embedding.num_cores_per_host for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features: tpu_embedding_feature_splitted = tf.split( tpu_embedding_feature, num_cores_per_host) for core, split in enumerate( tpu_embedding_feature_splitted): enqueue_data = tpu_embedding_lib.EnqueueData( tf.squeeze(split, axis=[1])) enqueue_dict_per_core[core][ tpu_embedding_input_key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.compat.v1.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if 'bucket_keys' in batch: # Hack: bucket_keys are not needed on TPU. del batch['bucket_keys'] tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)