def get_enqueue_ops_fn(host_id):
            """Generate the enqueue ops graph function."""
            with tf.device(
                    low_level_utils.device_for_host(self._get_host(host_id))):
                dataset = input_fn(params, config)
                iterator = dataset.make_initializable_iterator()
                dataset_initializer.append(iterator.initializer)

                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(FLAGS.tpu_num_shards_per_host):
                        if "eval" in task:
                            with tf.control_dependencies(control_deps):
                                features = iterator.get_next()
                            feature_structure["features"] = features
                        else:
                            with tf.control_dependencies(control_deps):
                                features, labels = iterator.get_next()
                            feature_structure["features"] = features
                            feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=low_level_utils.tpu_ordinal_fn)

                return enqueue_ops_fn
    def get_enqueue_ops_fn():
      """Generate the enqueue ops graph function."""

      with tf.device(low_level_utils.device_for_host(self._get_host(0))):
        dataset = input_fn(params, config)
        iterator = dataset.make_initializable_iterator()
        self.dataset_initializer.append(iterator.initializer)

        def enqueue_ops_fn():
          """Enqueue ops function for one host."""
          per_host_sharded_inputs = []
          control_deps = []
          # Currently working only on a donut, change this later to support
          # distibuted eval.
          for _ in range(FLAGS.tpu_num_shards_per_host):
            with tf.control_dependencies(control_deps):
              features = iterator.get_next()
            self.feature_structure["features"] = features
            flattened_inputs = data_nest.flatten(self.feature_structure)
            control_deps.extend(flattened_inputs)
            per_host_sharded_inputs.append(flattened_inputs)

          infeed = tpu.InfeedQueue(
              number_of_tuple_elements=len(per_host_sharded_inputs[0]))
          self.infeed_queue.append(infeed)
          return infeed.generate_enqueue_ops(
              per_host_sharded_inputs,
              tpu_ordinal_function=low_level_utils.tpu_ordinal_fn)

        return enqueue_ops_fn
예제 #3
0
 def create_dequeue_ops(host_id):
   """Create outfeed dequeue ops."""
   dequeue_ops = []
   tensor_dtypes = []
   tensor_shapes = []
   for v in self.outfeed_tensors:
     dequeue_ops.append([])
     tensor_dtypes.append(v.dtype)
     tensor_shapes.append(v.shape)
   for i in range(FLAGS.tpu_num_shards_per_host):
     with tf.device(
         low_level_utils.device_for_host(self._get_host(host_id))):
       outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
           dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i)
       for j, item in enumerate(outfeed_tensors):
         dequeue_ops[j].append(item)
   for j in range(len(outfeed_tensors)):
     dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
   return dequeue_ops
 def create_dequeue_ops():
   """Create outfeed dequeue ops."""
   dequeue_ops = []
   tensor_dtypes = []
   tensor_shapes = []
   for v in self.outfeed_tensors:
     dequeue_ops.append([])
     tensor_dtypes.append(v.dtype)
     tensor_shapes.append(v.shape)
   # Currently working only on a donut, change this later to support
   # distibuted eval.
   for i in range(FLAGS.tpu_num_shards_per_host):
     with tf.device(low_level_utils.device_for_host(self._get_host(0))):
       outfeed_tensors = tpu.outfeed_dequeue_tuple(
           dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i)
       for j, item in enumerate(outfeed_tensors):
         dequeue_ops[j].append(item)
   for j in range(len(outfeed_tensors)):
     dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
   return dequeue_ops