コード例 #1
0
 def create_dequeue_ops(host_id):
     """Create outfeed dequeue ops."""
     dequeue_ops = []
     tensor_dtypes = []
     tensor_shapes = []
     for v in self.outfeed_tensors:
         dequeue_ops.append([])
         tensor_dtypes.append(v.dtype)
         tensor_shapes.append(v.shape)
     for i in range(self.eval_params["replicas_per_worker"]):
         with tf.device(
                 runner_utils.device_for_host(self._get_host(host_id))):
             if self.use_spatial_partition:
                 replica_id = self.device_assignment.lookup_replicas(
                     host_id, 0)[i]
                 ordinal = self.device_assignment.tpu_ordinal(
                     replica=replica_id, logical_core=0)
             else:
                 ordinal = i
             outfeed_tensors = tf.contrib.tpu.outfeed_dequeue_tuple(
                 dtypes=tensor_dtypes,
                 shapes=tensor_shapes,
                 device_ordinal=ordinal)
             for j, item in enumerate(outfeed_tensors):
                 dequeue_ops[j].append(item)
     for j in range(len(outfeed_tensors)):
         dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
     return dequeue_ops
コード例 #2
0
    def get_enqueue_ops_fn():
      """Generate the enqueue ops graph function."""

      with tf.device(runner_utils.device_for_host(self._get_host(0))):
        dataset = input_fn(params)
        iterator = dataset.make_initializable_iterator()
        self.dataset_initializer.append(iterator.initializer)

        def enqueue_ops_fn():
          """Enqueue ops function for one host."""
          per_worker_sharded_inputs = []
          control_deps = []
          for _ in range(self.replicas_per_worker):
            with tf.control_dependencies(control_deps):
              features = iterator.get_next()
            flattened_inputs = self.input_flattener.flatten_features_and_labels(
                features, None)
            control_deps.extend(flattened_inputs)
            per_worker_sharded_inputs.append(flattened_inputs)

          infeed = tf.contrib.tpu.InfeedQueue(
              number_of_tuple_elements=len(per_worker_sharded_inputs[0]))
          self.infeed_queue.append(infeed)
          return infeed.generate_enqueue_ops(
              per_worker_sharded_inputs,
              tpu_ordinal_function=functools.partial(
                  runner_utils.tpu_ordinal_fn,
                  replicas_per_worker=self.replicas_per_worker))

        return enqueue_ops_fn
コード例 #3
0
        def get_enqueue_ops_fn(host_id):
            """Generate the enqueue ops graph function for training."""
            #  TODO(b/129084726): make dataset sharding also work for TPU Estimator.
            params["dataset_num_shards"] = num_hosts
            params["dataset_shard_id"] = host_id
            with tf.device(
                    runner_utils.device_for_host(self._get_host(host_id))):
                dataset = input_fn(params)
                iterator = dataset.make_initializable_iterator()
                if is_training:
                    self.dataset_initializer.append(iterator.initializer)
                else:
                    self.eval_dataset_initializer.append(iterator.initializer)

                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.train_params["replicas_per_worker"]):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        if self.use_spatial_partition:
                            self.input_dims_flattener.validate_and_flatten_input_dims(
                                features, labels)
                        flattened_inputs = (
                            self.input_flattener.flatten_features_and_labels(
                                features, labels))
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = (
                            self.input_dims_flattener.flattened_input_dims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tf.contrib.tpu.InfeedQueue(
                        number_of_tuple_elements=len(
                            per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=functools.partial(
                            runner_utils.tpu_ordinal_fn,
                            replicas_per_worker=self.
                            train_params["replicas_per_worker"]))

                return enqueue_ops_fn
コード例 #4
0
 def create_dequeue_ops():
   """Create outfeed dequeue ops."""
   dequeue_ops = []
   tensor_dtypes = []
   tensor_shapes = []
   for v in self.outfeed_tensors:
     dequeue_ops.append([])
     tensor_dtypes.append(v.dtype)
     tensor_shapes.append(v.shape)
   for i in range(params['num_shards']):
     with tf.device(runner_utils.device_for_host(
         self._get_host(0))):
       outfeed_tensors = tf.contrib.tpu.outfeed_dequeue_tuple(
           dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i)
       for j, item in enumerate(outfeed_tensors):
         dequeue_ops[j].append(item)
   for j in range(len(outfeed_tensors)):
     dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
   return dequeue_ops