Exemplo n.º 1
0
def split_inputs(ctx, features, labels):
    """Splits the dense and sparse tensors inside the features and labels."""
    enqueue_datas = collections.OrderedDict()
    if ctx.embedding_config:
        tpu_embedding_ = ctx.embedding_config.tpu_embedding
        feature_to_weight_key_name_dict = (
            ctx.embedding_config.feature_to_weight_key_name_dict)
        for feature_key in tpu_embedding_.feature_to_config_dict:
            sparse_feature = _get_sparse_feature_from_feature(
                feature_key, features)
            weight_key_name = feature_to_weight_key_name_dict[feature_key]
            if isinstance(sparse_feature, sparse_tensor.SparseTensor):
                weights = _get_weights_from_features(weight_key_name, features)
                enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor(
                    sparse_feature, weights)
            else:
                if weight_key_name is not None:
                    raise ValueError(
                        'Found weights {} for weighted_categorical_column, which is not'
                        'compatible with sparse feature {} enqueued as dense tensor.'
                        .format(weight_key_name, feature_key))
                enqueue_data = tpu_embedding.EnqueueData(sparse_feature)
            enqueue_datas[feature_key] = enqueue_data

    return features, labels, enqueue_datas
def split_inputs(ctx, features, labels, num_cores_per_batch=1):
  """Splits the dense and sparse tensors inside the features and labels."""
  enqueue_datas = collections.OrderedDict()

  if ctx.embedding_config:
    tpu_embedding_ = ctx.embedding_config.tpu_embedding
    for feature_key in tpu_embedding_.feature_to_config_dict:
      sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
      max_sequence_length = tpu_embedding_.feature_to_config_dict[
          feature_key].max_sequence_length
      combiner = tpu_embedding_._table_to_config_dict[
          tpu_embedding_._feature_to_config_dict[feature_key].table_id].combiner
      if max_sequence_length > 0:
        length_feature_name = (
            tpu_fc.get_sequence_length_feature_key_name_from_feature_key_name(
                feature_key))
        length_feature = tf.math.minimum(
            fc_utils.sequence_length_from_sparse_tensor(sparse_feature),
            max_sequence_length)
        length_feature.set_shape(ctx.batch_size_for_input_fn)
        features[length_feature_name] = length_feature
      weight_key = tpu_embedding_.feature_to_config_dict[feature_key].weight_key
      sparse_feature_split = _split_tensor(
          sparse_feature, num_cores_per_batch)
      if combiner is None and not isinstance(sparse_feature,
                                             tf.sparse.SparseTensor):
        # A dense tensor with no combiner was provided so we assume that each
        # of the embedding_indices belongs to a different sample (setting
        # sample_indices to None).
        if weight_key is not None:
          raise ValueError(
              'Found weights {} for weighted_categorical_column, which is not'
              'compatible with sparse feature {} enqueued as dense tensor.'
              .format(weight_key, feature_key))
        enqueue_data = []
        for i in range(num_cores_per_batch):
          enqueue_data.append(tpu_embedding.EnqueueData(
              sparse_feature_split[i]))
      else:
        weights = None
        if isinstance(sparse_feature, tf.sparse.SparseTensor):
          weights = _get_weights_from_features(weight_key, features)
          weights_split = _split_tensor(weights, num_cores_per_batch)
        enqueue_data = []
        for i in range(num_cores_per_batch):
          split_weights = weights_split[i] if weights else None
          enqueue_data.append(
              tpu_embedding.EnqueueData.from_sparse_tensor(
                  _maybe_dense_to_sparse(sparse_feature_split[i]),
                  weights=split_weights))
      enqueue_datas[feature_key] = enqueue_data

  # Transpose the enqueue_datas dict into a list of dicts
  enqueue_datas_list = []
  for i in range(num_cores_per_batch):
    enqueue_data = {}
    for key, value in enqueue_datas.items():
      enqueue_data[key] = value[i]
    enqueue_datas_list.append(enqueue_data)
  return features, labels, enqueue_datas_list
Exemplo n.º 3
0
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
Exemplo n.º 4
0
def split_inputs(ctx, features, labels, num_cores_per_batch=1):
    """Splits the dense and sparse tensors inside the features and labels."""
    enqueue_datas = collections.OrderedDict()

    if ctx.embedding_config:
        tpu_embedding_ = ctx.embedding_config.tpu_embedding
        for feature_key in tpu_embedding_.feature_to_config_dict:
            sparse_feature = _get_sparse_feature_from_feature(
                feature_key, features)
            max_sequence_length = tpu_embedding_.feature_to_config_dict[
                feature_key].max_sequence_length
            if max_sequence_length > 0:
                length_feature_name = (
                    tpu_fc.
                    get_sequence_length_feature_key_name_from_feature_key_name(
                        feature_key))
                length_feature = math_ops.minimum(
                    fc_utils.sequence_length_from_sparse_tensor(
                        sparse_feature), max_sequence_length)
                length_feature.set_shape(ctx.batch_size_for_input_fn)
                features[length_feature_name] = length_feature
            weight_key = tpu_embedding_.feature_to_config_dict[
                feature_key].weight_key
            sparse_feature_split = _split_tensor(sparse_feature,
                                                 num_cores_per_batch)
            if isinstance(sparse_feature, sparse_tensor.SparseTensor):
                weights = _get_weights_from_features(weight_key, features)
                weights_split = _split_tensor(weights, num_cores_per_batch)
                enqueue_data = []
                for i in range(num_cores_per_batch):
                    enqueue_data.append(
                        tpu_embedding.EnqueueData.from_sparse_tensor(
                            sparse_feature_split[i], weights_split[i]))
            else:
                if weight_key is not None:
                    raise ValueError(
                        'Found weights {} for weighted_categorical_column, which is not'
                        'compatible with sparse feature {} enqueued as dense tensor.'
                        .format(weight_key, feature_key))
                enqueue_data = []
                for i in range(num_cores_per_batch):
                    enqueue_data.append(
                        tpu_embedding.EnqueueData(sparse_feature_split[i]))
            enqueue_datas[feature_key] = enqueue_data

    # Transpose the enqueue_datas dict into a list of dicts
    enqueue_datas_list = []
    for i in range(num_cores_per_batch):
        enqueue_data = {}
        for key, value in enqueue_datas.items():
            enqueue_data[key] = value[i]
        enqueue_datas_list.append(enqueue_data)
    return features, labels, enqueue_datas_list
def split_inputs(ctx, features, labels):
    """Splits the dense and sparse tensors inside the features and labels."""
    enqueue_datas = collections.OrderedDict()
    if ctx.embedding_config:
        tpu_embedding_ = ctx.embedding_config.tpu_embedding
        feature_to_weight_key_name_dict = (
            ctx.embedding_config.feature_to_weight_key_name_dict)
        for feature_key in tpu_embedding_.feature_to_config_dict:
            sparse_feature = _get_sparse_feature_from_feature(
                feature_key, features)
            max_sequence_length = tpu_embedding_.feature_to_config_dict[
                feature_key].max_sequence_length
            if max_sequence_length > 0:
                length_feature_name = (
                    tpu_fc.
                    get_sequence_length_feature_key_name_from_feature_key_name(
                        feature_key))
                length_feature = math_ops.minimum(
                    fc_utils.sequence_length_from_sparse_tensor(
                        sparse_feature), max_sequence_length)
                length_feature.set_shape(ctx.batch_size_for_input_fn)
                features[length_feature_name] = length_feature
            weight_key_name = feature_to_weight_key_name_dict[feature_key]
            if isinstance(sparse_feature, sparse_tensor.SparseTensor):
                weights = _get_weights_from_features(weight_key_name, features)
                enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor(
                    sparse_feature, weights)
            else:
                if weight_key_name is not None:
                    raise ValueError(
                        'Found weights {} for weighted_categorical_column, which is not'
                        'compatible with sparse feature {} enqueued as dense tensor.'
                        .format(weight_key_name, feature_key))
                enqueue_data = tpu_embedding.EnqueueData(sparse_feature)
            enqueue_datas[feature_key] = enqueue_data

    return features, labels, enqueue_datas
Exemplo n.º 6
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Exemplo n.º 7
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            first_batch = None
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_embedding_input_keys = (
                tpu_embedding.feature_to_config_dict.keys()
                if tpu_embedding is not None else [])

            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    tpu_embedding_features = []
                    for tpu_embedding_input_key in tpu_embedding_input_keys:
                        tpu_embedding_feature = batch.pop(
                            tpu_embedding_input_key)
                        tpu_embedding_features.append(
                            (tpu_embedding_input_key, tpu_embedding_feature))

                    if first_batch is None:
                        first_batch = batch
                    flat_batch = batch.FlattenItems()

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {}
                        ] * tpu_embedding.num_cores_per_host
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features:
                            tpu_embedding_feature_splitted = tf.split(
                                tpu_embedding_feature, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_embedding_feature_splitted):
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    tf.squeeze(split, axis=[1]))
                                enqueue_dict_per_core[core][
                                    tpu_embedding_input_key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    shapes, types = [], []
                    for k, x in flat_batch:
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                        shapes.append(x.shape)
                        types.append(x.dtype)
                    q = tf.contrib.tpu.InfeedQueue(tuple_types=types,
                                                   tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def _tpu_ordinal_function(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=_tpu_ordinal_function)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        with tf.device(tf.compat.v1.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return first_batch.Pack(tensors)
Exemplo n.º 8
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)