Exemple #1
0
    def get_enqueue_ops_fn():
      """Generate the enqueue ops graph function."""

      with tf.device(utils.device_for_host(self._get_host(0))):
        dataset = input_fn(params)
        iterator = dataset.make_initializable_iterator()
        self.dataset_initializer.append(iterator.initializer)

        def enqueue_ops_fn():
          """Enqueue ops function for one host."""
          per_host_sharded_inputs = []
          control_deps = []
          for _ in range(FLAGS.num_shards_per_host):
            with tf.control_dependencies(control_deps):
              features = iterator.get_next()
            self.feature_structure["features"] = features
            flattened_inputs = data_nest.flatten(self.feature_structure)
            control_deps.extend(flattened_inputs)
            per_host_sharded_inputs.append(flattened_inputs)

          infeed = tpu.InfeedQueue(
              number_of_tuple_elements=len(per_host_sharded_inputs[0]))
          self.infeed_queue.append(infeed)
          return infeed.generate_enqueue_ops(
              per_host_sharded_inputs,
              tpu_ordinal_function=utils.tpu_ordinal_fn)

        return enqueue_ops_fn
 def create_dequeue_ops(host_id):
     """Create outfeed dequeue ops."""
     dequeue_ops = []
     tensor_dtypes = []
     tensor_shapes = []
     for v in self.outfeed_tensors:
         tensor_dtypes.append(v.dtype)
         tensor_shapes.append(v.shape)
     with tf.device(utils.device_for_host(self._get_host(host_id))):
         for i in range(FLAGS.num_shards_per_host):
             outfeed = tpu.outfeed_dequeue_tuple(dtypes=tensor_dtypes,
                                                 shapes=tensor_shapes,
                                                 device_ordinal=i)
             if len(outfeed) == 2:
                 if outfeed[0].shape.ndims == 3:
                     detections, is_pad = outfeed
                 else:
                     is_pad, detections = outfeed
                 num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum(
                     tf.cast(is_pad, tf.int32))
                 dequeue_ops.append(
                     tf.slice(detections, [0, 0, 0],
                              [num_non_pad, -1, -1]))
             else:
                 dequeue_ops.append(outfeed)
         dequeue_ops = tf.concat(dequeue_ops, axis=0)
     return dequeue_ops
        def get_eval_enqueue_ops_fn(host_id):
            """Generate the eval enqueue ops graph function."""

            params["dataset_num_shards"] = self.num_hosts
            params["dataset_index"] = host_id
            with tf.device(utils.device_for_host(self._get_host(host_id))):
                dataset = input_fn(params)
                iterator = dataset.make_initializable_iterator()
                self.eval_dataset_initializer.append(iterator.initializer)

                def eval_enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.replicas_per_worker):
                        with tf.control_dependencies(control_deps):
                            features = iterator.get_next()
                        if self.use_spatial_partition:
                            self.eval_input_dims_flattener.validate_and_flatten_input_dims(
                                features, None)
                        self.eval_feature_structure["features"] = features
                        flattened_inputs = data_nest.flatten(
                            self.eval_feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = (self.eval_input_dims_flattener.
                                                flattened_input_dims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.eval_infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.eval_infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)

                return eval_enqueue_ops_fn
Exemple #4
0
 def create_dequeue_ops():
   """Create outfeed dequeue ops."""
   dequeue_ops = []
   tensor_dtypes = []
   tensor_shapes = []
   for v in self.outfeed_tensors:
     dequeue_ops.append([])
     tensor_dtypes.append(v.dtype)
     tensor_shapes.append(v.shape)
   for i in range(FLAGS.num_shards):
     with tf.device(utils.device_for_host(self._get_host(0))):
       outfeed_tensors = tpu.outfeed_dequeue_tuple(
           dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i)
       for j, item in enumerate(outfeed_tensors):
         dequeue_ops[j].append(item)
   for j in range(len(outfeed_tensors)):
     dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
   return dequeue_ops
 def create_dequeue_ops(host_id):
     """Create outfeed dequeue ops."""
     dequeue_ops = []
     tensor_dtypes = []
     tensor_shapes = []
     for v in self.outfeed_tensors:
         tensor_dtypes.append(v.dtype)
         tensor_shapes.append(v.shape)
     with tf.device(utils.device_for_host(self._get_host(host_id))):
         for i in range(self.replicas_per_worker):
             if self.use_spatial_partition:
                 replica_id = self.device_assignment.lookup_replicas(
                     host_id, 0)[i]
                 ordinal = self.device_assignment.tpu_ordinal(
                     replica=replica_id, logical_core=0)
             else:
                 ordinal = i
             outfeed = tpu_ops.outfeed_dequeue_tuple(
                 dtypes=tensor_dtypes,
                 shapes=tensor_shapes,
                 device_ordinal=ordinal)
             if len(outfeed) == 2:
                 # 2 outfeed tensors
                 #   is_pad: [batch]
                 #   detections: [batch, 200, 7]
                 if outfeed[0].shape.ndims == 3:
                     detections, is_pad = outfeed
                 else:
                     is_pad, detections = outfeed
                 num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum(
                     tf.cast(is_pad, tf.int32))
                 dequeue_ops.append(
                     tf.slice(detections, [0, 0, 0],
                              [num_non_pad, -1, -1]))
             else:
                 # no padding, only detections are in the outfeed
                 dequeue_ops.append(outfeed)
         dequeue_ops = tf.concat(dequeue_ops, axis=0)
     return dequeue_ops
Exemple #6
0
        def get_enqueue_ops_fn(host_id):
            """Generate the enqueue ops graph function."""

            params["dataset_num_shards"] = self.num_hosts
            params["dataset_index"] = host_id
            with tf.device(utils.device_for_host(self._get_host(host_id))):
                dataset = input_fn(params)
                iterator = dataset.make_initializable_iterator()
                self.dataset_initializer.append(iterator.initializer)

                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.replicas_per_worker):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        if self.use_spatial_partition:
                            num_elements = []
                            for i, d in enumerate(ssd_constants.FEATURE_SIZES):
                                num_elements.append(
                                    d * d * ssd_constants.NUM_DEFAULTS[i])
                            gt_boxes = tf.split(labels[ssd_constants.BOXES],
                                                num_elements, 1)
                            gt_classes = tf.split(
                                labels[ssd_constants.CLASSES], num_elements, 1)

                            def transpose_gt_box(gt_box, i):
                                return tf.transpose(
                                    tf.reshape(gt_box, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i], 4
                                    ]), [0, 2, 3, 1, 4])

                            def transpose_gt_class(gt_class, i):
                                return tf.transpose(
                                    tf.reshape(gt_class, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i]
                                    ]), [0, 2, 3, 1])

                            labels[ssd_constants.BOXES] = {
                                i: transpose_gt_box(gt_boxes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                            labels[ssd_constants.CLASSES] = {
                                i: transpose_gt_class(gt_classes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                        self.feature_structure["features"] = features
                        self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims >= len(
                                    self.input_partition_dims[0]):
                                if i.shape.as_list() == self.feature_structure[
                                        "features"].shape.as_list():
                                    flattened_input_dims.append(
                                        self.input_partition_dims[0])
                                else:
                                    flattened_input_dims.append(
                                        FLAGS.input_partition_dims + [1] *
                                        (i.shape.ndims -
                                         len(self.input_partition_dims[0])))
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tpu.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)

                return enqueue_ops_fn