def enqueue_ops_fn(idx):
                """Enqueue ops function for one host.."""
                with tf.device(device):
                    sharded_inputs = []
                    start_idx = 0
                    if host_id in range(0, self.hparams.num_infeed_workers * 2,
                                        2):
                        core_id = tf.constant(host_id *
                                              self.hparams.num_shards_per_host,
                                              shape=[1],
                                              dtype=tf.int32)
                        if self.hparams.use_synthetic_data:
                            features = output
                        else:

                            def true_fn():
                                return iterator.get_next()

                            def false_fn():
                                return {
                                    k: tf.zeros_like(
                                        self.feature_structure["features"][k])
                                    for k in self.feature_structure["features"]
                                }

                            features = tf.cond(
                                tf.equal(idx % self.hparams.num_infeed_workers,
                                         host_id // 2), true_fn, false_fn)
                        sharded_inputs.append(
                            data_nest.flatten({
                                "features": features,
                                "core_id": core_id
                            }))
                        start_idx = 1
                    for i in range(start_idx,
                                   self.hparams.num_shards_per_host):
                        sharded_inputs.append(
                            data_nest.flatten({
                                "features": {
                                    k: tf.zeros_like(
                                        self.feature_structure["features"][k])
                                    for k in self.feature_structure["features"]
                                },
                                "core_id":
                                tf.constant(
                                    host_id * self.hparams.num_shards_per_host
                                    + i,
                                    shape=[1],
                                    dtype=tf.int32)
                            }))
                infeed = tpu_feed.InfeedQueue(
                    number_of_tuple_elements=len(sharded_inputs[0]))
                self.infeed_queue.append(infeed)

                def tpu_ordinal_fn(shard_index_in_host):
                    return shard_index_in_host % self.hparams.num_shards_per_host

                return infeed.generate_enqueue_ops(
                    sharded_inputs, tpu_ordinal_function=tpu_ordinal_fn)
    def CreatePerHostEnqueueOp(self, task_id):
        """Create TPU infeed equeue tuple ops specific to a host.

    Note it's important this is a function to give the caller flexibility
    as to which context it is instantiated in.
    This is critical for tf.while_loop driven infeed for instance.

    Args:
      task_id: Which infeed host it is.

    Returns:
      List of enqueue ops for this host.
    """
        p = self.params
        # Only support this in MLPerf.
        assert p.use_per_host_infeed
        host_device = '/task:{}/device:CPU:0'.format(task_id)

        self._batch = self.GetPreprocessedInputBatch(task_id=task_id)
        if isinstance(self._batch, py_utils.NestedMap):
            # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
            # Note that when MultiTaskData is used, bucket_keys will be at the
            # second level of the dictionary.
            self._batch = self._batch.FilterKeyVal(
                lambda k, _: not k.endswith('bucket_keys'))
        tf.logging.info('host_device: %s, batch: %r', host_device, self._batch)

        for k, x in self._batch.FlattenItems():
            assert x.shape.is_fully_defined(), (
                'Shape must be fully defined: %s: %s' % (k, x))
            # TODO(cwhipkey): if it's a string (or other type not supported on
            # TPU), drop it from feeding and on the other end add in an op that
            # fails if used.
        shapes = self._batch.Transform(lambda x: x.shape).Flatten()
        dtypes = self._batch.Transform(lambda x: x.dtype).Flatten()

        q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes)
        assert self.shards is not None
        q.set_number_of_shards(self.shards)
        self._tpu_queues.append(q)

        def TPUOrdinalFunction(shard_index_in_host):
            device_assignment = py_utils.GetTpuDeviceAssignment()
            if device_assignment:
                # We put both enqueue/dequeue ops at core 0 in each replica.
                replica = device_assignment.lookup_replicas(
                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                return device_assignment.tpu_ordinal(replica=replica)
            else:
                return shard_index_in_host

        tf.logging.info('CreatePerHostEnqueueOp: %d', task_id)
        input_ops = q.split_inputs_and_generate_enqueue_ops(
            self._batch.Flatten(),
            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
            tpu_ordinal_function=TPUOrdinalFunction)
        return input_ops
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.hparams.num_shards_per_host):
                        with tf.control_dependencies(control_deps):
                            features = iterator.get_next()
                        self.eval_feature_structure["features"] = features
                        flattened_inputs = data_nest.flatten(
                            self.eval_feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.eval_infeed_queue.append(infeed)

                    def tpu_ordinal_fn(shard_index_in_host):
                        return shard_index_in_host % self.hparams.num_shards_per_host

                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=tpu_ordinal_fn)
Exemple #4
0
                def enqueue_ops_fn(idx):
                    """Generate the infeed enqueue ops graph."""

                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.replicas_per_host):
                        with tf.control_dependencies(control_deps):
                            self.feature_structure[
                                is_training] = iterator.get_next()
                        self.maybe_capture_embedding_inputs(
                            self.feature_structure[is_training], is_training)
                        flattened_inputs = tf.nest.flatten(
                            self.feature_structure[is_training])
                        control_deps.extend(flattened_inputs)
                        if input_partition_dims:
                            padded_inputs = []
                            for inp in flattened_inputs:
                                if inp.shape.ndims < len(input_partition_dims):
                                    padded_inputs.append(inp)
                                    continue
                                paddings = []
                                for i, j in enumerate(input_partition_dims):
                                    r = inp.shape.as_list()[i] % j
                                    if r > 0:
                                        paddings.append([0, j - r])
                                    else:
                                        paddings.append([0, 0])
                                for i in range(inp.shape.ndims -
                                               len(input_partition_dims)):
                                    paddings.append([0, 0])
                                padded_inputs.append(tf.pad(inp, paddings))
                            per_host_sharded_inputs.append(padded_inputs)
                        else:
                            per_host_sharded_inputs.append(flattened_inputs)

                    if input_partition_dims:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims == len(input_partition_dims):
                                flattened_input_dims.append(
                                    input_partition_dims)
                            elif i.shape.ndims > len(input_partition_dims):
                                flattened_input_dims.append(
                                    input_partition_dims + [1] *
                                    (i.shape.ndims - len(input_partition_dims))
                                )
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        self.infeed_op[
                            is_training] = tpu_feed._PartitionedInfeedQueue(
                                number_of_tuple_elements=len(
                                    per_host_sharded_inputs[0]),
                                host_id=host_id,
                                input_partition_dims=flattened_input_dims,
                                device_assignment=self.device_assignment)
                        with tf.control_dependencies(
                                self.infeed_op[is_training].
                                generate_enqueue_ops(per_host_sharded_inputs)):
                            return idx + 1
                    else:
                        self.infeed_op[is_training] = tpu_feed.InfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]))
                        per_host_enqueue_ops = (
                            self.infeed_op[is_training].generate_enqueue_ops(
                                per_host_sharded_inputs,
                                tpu_ordinal_function=_tpu_ordinal_fn))

                    self.maybe_add_embedding_enqueue_ops_int(
                        is_training, per_host_enqueue_ops)
                    with tf.control_dependencies(per_host_enqueue_ops):
                        return idx + 1