예제 #1
0
def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config):
  """Utility to convert input_fn to enqueue and dequeue fns for TPU.

  Args:
    inputs_holder: An `_InputsHolder` holding features and labels.
    run_config: A `RunConfig` instance.

  Returns:
    A tuple of (dequeue_fn, enqueue_fn)
  """
  if inputs_holder.sharded:
    sharded_inputs = inputs_holder.as_sharded_flattened_inputs()

    infeed_queue = tpu_feed.InfeedQueue(
        number_of_tuple_elements=len(sharded_inputs[0]))
    infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs)
  else:
    unsharded_inputs = inputs_holder.as_flattened_inputs()
    infeed_queue = tpu_feed.InfeedQueue(
        tuple_types=[t.dtype for t in unsharded_inputs],
        tuple_shapes=[t.shape for t in unsharded_inputs])
    infeed_queue.set_number_of_shards(inputs_holder.num_shards)

  def dequeue_fn():
    """dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
    values = infeed_queue.generate_dequeue_op()
    return inputs_holder.unflatten_features_and_labels(values)

  def tpu_ordinal_function(index):
    """Return the TPU ordinal associated with a shard.

    Required because the enqueue ops are placed on CPU.

    Args:
      index: the shard index

    Returns:
      The ordinal of the TPU device the shard's infeed should be placed on.
    """
    return index % 8

  def enqueue_fn():
    """enqueue_fn is used to add ops to the graph to send tensors."""
    if inputs_holder.sharded:
      return infeed_queue.generate_enqueue_ops(
          sharded_inputs, tpu_ordinal_function=tpu_ordinal_function)
    else:
      job = _tpu_job(run_config)
      def placement_function(index):
        if job is None:
          return '/replica:0/task:0/device:CPU:0'
        else:
          return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)
      return infeed_queue.split_inputs_and_generate_enqueue_ops(
          unsharded_inputs, placement_function=placement_function)

  return (dequeue_fn, enqueue_fn)
예제 #2
0
    def testUsingInfeedQueueWithRegularizer(self):
        """Test that Layer regularizers can reference data created in loops."""
        def make_regularizer(scale):
            return lambda inputs: scale * math_ops.reduce_sum(
                math_ops.square(inputs))

        def training_step(inputs, scale):
            outputs = convolutional.conv2d(
                inputs,
                filters=16,
                kernel_size=(3, 3),
                data_format="channels_first",
                kernel_regularizer=make_regularizer(scale))
            loss = math_ops.reduce_mean(math_ops.square(outputs))
            return loss.op

        inputs = array_ops.zeros(shape=(128, 32, 32, 16))
        scale = array_ops.ones(shape=())
        infeed = tpu_feed.InfeedQueue(
            tuple_types=[dtypes.float32, dtypes.float32],
            tuple_shapes=[inputs.shape, scale.shape])

        def loop():
            return training_loop.repeat(5, training_step, infeed_queue=infeed)

        # This should not throw an error.
        tpu.rewrite(loop)
예제 #3
0
    def testVarArgsAndDefaults(self):
        """Tests that arg checker works for a function with varargs and defaults."""
        def func(x, y, z=17, *q):  # pylint: disable=keyword-arg-before-vararg
            return x + y + z + len(q)

        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 4, None))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 5, None))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 1, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 1, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 2, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 3, queue))
        self.assertEqual(None,
                         xla.check_function_argument_count(func, 4, queue))
        self.assertEqual('at least 2 arguments',
                         xla.check_function_argument_count(func, 0, queue))
    def testVarArgsAndDefaults(self):
        """Tests that arg checker works for a function with varargs and defaults."""
        def func(x, y, z=17, *q):
            return x + y + z + len(q)

        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 3, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 4, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 5, None))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 1, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 1, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 3, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 4, queue))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 0, queue))
    def testDefaultArgs(self):
        """Tests that arg checker works for a function with no varargs."""
        def func(x, y, z=17):
            return x + y + z

        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 3, None))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, None))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 1, None))
        self.assertEqual(
            "at most 3 arguments",
            tpu_function.check_function_argument_count(func, 4, None))
        queue = tpu_feed.InfeedQueue(1)
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 2, queue))
        self.assertEqual(
            None, tpu_function.check_function_argument_count(func, 1, queue))
        self.assertEqual(
            "at least 2 arguments",
            tpu_function.check_function_argument_count(func, 0, queue))
        self.assertEqual(
            "at most 3 arguments",
            tpu_function.check_function_argument_count(func, 4, queue))
예제 #6
0
def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
    """Utility to convert input_fn to enqueue and dequeue fns for TPU.

  Mainly, three things need to be done here.
  1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU
  2. Create a dequeue_fn used by the train_step inside TPU execution to
  dequeue the tensors.
  3. Sets up the input thread to infeed.

  Args:
    run_config: run_config
    features: features
    labels: labels

  Returns:
    A tuple of (dequeue_fn, and thread main function)
  """
    infeed_names = None
    infeed_tuple = []
    if isinstance(features, dict):
        # We need a fixed ordering for enqueueing and dequeueing.
        infeed_names = [name for name in features]
        infeed_tuple.extend([features[name] for name in infeed_names])
    else:
        infeed_tuple.append(features)
    # TODO(jhseu): Handle multi-head and None labels
    infeed_tuple.append(labels)
    # TODO(jhseu): Update when b/36470756 is settled.
    infeed_queue = tpu_feed.InfeedQueue(
        tuple_types=[t.dtype for t in infeed_tuple],
        tuple_shapes=[t.shape for t in infeed_tuple])
    infeed_queue.set_number_of_shards(run_config.tpu_config.num_shards)

    def dequeue_fn():
        """dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
        values = infeed_queue.generate_dequeue_op()
        if infeed_names is None:
            return values
        # Restore the feature dictionary and label.
        dequeued_features = {}
        for i in range(len(values) - 1):
            dequeued_features[infeed_names[i]] = values[i]
        label = values[-1]
        return dequeued_features, label

    def enqueue_fn():
        """enqueue_fn is used to add ops to the graph to send tensors."""
        job = _tpu_job(run_config)

        def placement_function(index):
            if job is None:
                return '/replica:0/task:0/device:CPU:0'
            else:
                return '/job:%s/replica:0/task:%d/device:CPU:0' % (job,
                                                                   index / 8)

        return infeed_queue.split_inputs_and_generate_enqueue_ops(
            infeed_tuple, placement_function=placement_function)

    return (dequeue_fn, enqueue_fn)
예제 #7
0
 def testModification(self):
     """Tests modification of the queue post-construction."""
     i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
     i.set_tuple_types([dtypes.float32, dtypes.int32])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     i.set_tuple_types([dtypes.float32, dtypes.float32])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32])
     with self.assertRaises(ValueError):
         i.set_tuple_types([dtypes.float32])
     i.set_tuple_shapes([[1], [2, 3]])
     self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
     i.set_tuple_shapes([[1, 2], [3, 4]])
     self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]])
     with self.assertRaises(ValueError):
         i.set_tuple_shapes([[1, 2]])
     i.set_number_of_shards(2)
     self.assertEqual(i.number_of_shards, 2)
     i.set_number_of_shards(3)
     self.assertEqual(i.number_of_shards, 3)
     t1 = constant_op.constant(1, dtypes.int32, shape=[6])
     t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18])
     i.set_configuration_from_input_tensors([t1, t2])
     self.assertEqual(i.tuple_shapes, [[6], [3, 18]])
     self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32])
     i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
     self.assertEqual(i.number_of_shards, 2)
     self.assertEqual(i.tuple_shapes, [[6, 18], [12]])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     i.set_shard_dimensions([1, 0])
     i.set_number_of_shards(3)
     with self.assertRaises(ValueError):
         i.set_number_of_shards(4)
예제 #8
0
 def testFreezing(self):
     """Tests freezing the queue."""
     i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
     t1 = constant_op.constant(1, dtypes.int32, shape=[2])
     t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4])
     i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
     self.assertEqual(i.number_of_shards, 2)
     self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     self.assertEqual(i.shard_dimensions, [0, 0])
     i.freeze()
     i.set_number_of_shards(2)
     i.set_tuple_shapes([[4, 4], [4]])
     i.set_tuple_types([dtypes.float32, dtypes.int32])
     i.set_shard_dimensions([0, 0])
     with self.assertRaises(ValueError):
         i.set_number_of_shards(1)
     with self.assertRaises(ValueError):
         i.set_tuple_shapes([[8, 8], [8]])
     with self.assertRaises(ValueError):
         i.set_tuple_types([dtypes.int32, dtypes.float32])
     with self.assertRaises(ValueError):
         i.set_shard_dimensions([1, 0])
     self.assertEqual(i.number_of_shards, 2)
     self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
     self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
     self.assertEqual(i.shard_dimensions, [0, 0])
예제 #9
0
  def testSimple(self):
    """Tests that arg checker works for functions with no varargs or defaults.
    """

    def func(x, y, z):
      return x + y + z

    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual('exactly 3 arguments',
                     xla.check_function_argument_count(func, 2, None))
    queue = tpu_feed.InfeedQueue(2)
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual('exactly 3 arguments',
                     xla.check_function_argument_count(func, 2, queue))
예제 #10
0
 def testConstructor(self):
     """Tests that the constructor can be called with different arguments."""
     i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
     self.assertEqual(i.number_of_tuple_elements, 2)
     self.assertEqual(i.tuple_types, None)
     self.assertEqual(i.tuple_shapes, None)
     self.assertEqual(i.number_of_shards, None)
     i = tpu_feed.InfeedQueue(
         tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32])
     self.assertEqual(i.number_of_tuple_elements, 3)
     self.assertEqual(i.tuple_types,
                      [dtypes.float32, dtypes.int32, dtypes.int32])
     self.assertEqual(i.tuple_shapes, None)
     self.assertEqual(i.number_of_shards, None)
     i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]])
     self.assertEqual(i.number_of_tuple_elements, 2)
     self.assertEqual(i.tuple_types, None)
     self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
     self.assertEqual(i.number_of_shards, None)
     i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7])
     self.assertEqual(i.number_of_tuple_elements, 3)
     self.assertEqual(i.tuple_types, None)
     self.assertEqual(i.tuple_shapes, None)
     self.assertEqual([p.shard_dimension for p in i.sharding_policies],
                      [1, 0, 7])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue()
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(number_of_tuple_elements=2,
                                  tuple_types=[dtypes.float32])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(number_of_tuple_elements=2,
                                  tuple_shapes=[[1]])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(number_of_tuple_elements=2,
                                  shard_dimensions=[1])
     with self.assertRaises(ValueError):
         i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]],
                                  shard_dimensions=[1])
예제 #11
0
  def testVarArgs(self):
    """Tests that arg checker works for a function with varargs."""

    def func(x, y, *z):
      return x + y + len(z)

    self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
    self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 1, None))
    queue = tpu_feed.InfeedQueue(1)
    self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
    self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
    self.assertEqual('at least 2 arguments',
                     xla.check_function_argument_count(func, 0, queue))
예제 #12
0
def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder):
    """Utility to convert input_fn to enqueue and dequeue fns for TPU.

  Args:
    inputs_holder: An `_InputsHolder` holding features and labels.

  Returns:
    A tuple of (dequeue_fn, enqueue_fn)
  """
    sharded_inputs = inputs_holder.as_sharded_flattened_inputs()

    infeed_queue = tpu_feed.InfeedQueue(
        number_of_tuple_elements=len(sharded_inputs[0]))
    infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs)

    def dequeue_fn():
        """dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
        values = infeed_queue.generate_dequeue_op()
        return inputs_holder.unflatten_features_and_labels(values)

    def tpu_ordinal_function(index):
        """Return the TPU ordinal associated with a shard.

    Required because the enqueue ops are placed on CPU.

    Args:
      index: the shard index

    Returns:
      The ordinal of the TPU device the shard's infeed should be placed on.
    """
        return index % 8

    def enqueue_fn():
        """enqueue_fn is used to add ops to the graph to send tensors."""
        return infeed_queue.generate_enqueue_ops(
            sharded_inputs, tpu_ordinal_function=tpu_ordinal_function)

    return (dequeue_fn, enqueue_fn)
                def eval_enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.replicas_per_worker):
                        with tf.control_dependencies(control_deps):
                            features = iterator.get_next()
                        if self.use_spatial_partition:
                            self.eval_input_dims_flattener.validate_and_flatten_input_dims(
                                features, None)
                        self.eval_feature_structure["features"] = features
                        flattened_inputs = data_nest.flatten(
                            self.eval_feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = (self.eval_input_dims_flattener.
                                                flattened_input_dims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.eval_infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.eval_infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)
예제 #14
0
def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
    """Utility to convert input_fn to enqueue and dequeue fns for TPU.

  Mainly, three things need to be done here.
  1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU
  2. Create a dequeue_fn used by the train_step inside TPU execution to
  dequeue the tensors.
  3. Sets up the input thread to infeed.

  Args:
    run_config: run_config
    features: features
    labels: labels

  Returns:
    A tuple of (dequeue_fn, enqueue_fn)
  """
    infeed_names = None
    sharded_inputs = []
    if isinstance(features[0], dict):
        # We need a fixed ordering for enqueueing and dequeueing.
        infeed_names = [name for name in features[0]]

    for shard in range(run_config.tpu_config.num_shards):
        inputs = []
        if infeed_names is None:
            inputs.append(features[shard])
        else:
            for name in infeed_names:
                inputs.append(features[shard][name])
        if labels is not None:
            inputs.append(labels[shard])
        sharded_inputs.append(inputs)

    infeed_queue = tpu_feed.InfeedQueue(
        number_of_tuple_elements=len(sharded_inputs[0]))
    infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs)

    def dequeue_fn():
        """dequeue_fn is used by the train_step in TPU to retrieve the tensors."""
        values = infeed_queue.generate_dequeue_op()

        expected_num_tensors = 0
        if labels is not None:
            expected_num_tensors += 1
        if infeed_names is None:
            expected_num_tensors += 1
        else:
            expected_num_tensors += len(infeed_names)
        assert len(values) == expected_num_tensors

        dequeue_label = None
        if labels is not None:
            dequeue_label = values[-1]
        if infeed_names is None:
            return values[0], dequeue_label
        # Restore the feature dictionary and label.
        dequeued_features = {}
        for i in range(len(infeed_names)):
            dequeued_features[infeed_names[i]] = values[i]
        return dequeued_features, dequeue_label

    def tpu_ordinal_function(index):
        """Return the TPU ordinal associated with a shard.

    Required because the enqueue ops are placed on CPU.

    Args:
      index: the shard index

    Returns:
      The ordinal of the TPU device the shard's infeed should be placed on.
    """
        return index % 8

    def enqueue_fn():
        """enqueue_fn is used to add ops to the graph to send tensors."""
        return infeed_queue.generate_enqueue_ops(
            sharded_inputs, tpu_ordinal_function=tpu_ordinal_function)

    return (dequeue_fn, enqueue_fn)
                def enqueue_ops_fn():
                    """Enqueue ops function for one host."""
                    per_host_sharded_inputs = []
                    control_deps = []
                    for _ in range(self.replicas_per_worker):
                        with tf.control_dependencies(control_deps):
                            features, labels = iterator.get_next()
                        if self.use_spatial_partition:
                            num_elements = []
                            for i, d in enumerate(ssd_constants.FEATURE_SIZES):
                                num_elements.append(
                                    d * d * ssd_constants.NUM_DEFAULTS[i])
                            gt_boxes = tf.split(labels[ssd_constants.BOXES],
                                                num_elements, 1)
                            gt_classes = tf.split(
                                labels[ssd_constants.CLASSES], num_elements, 1)

                            def transpose_gt_box(gt_box, i):
                                return tf.transpose(
                                    tf.reshape(gt_box, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i], 4
                                    ]), [0, 2, 3, 1, 4])

                            def transpose_gt_class(gt_class, i):
                                return tf.transpose(
                                    tf.reshape(gt_class, [
                                        -1, ssd_constants.NUM_DEFAULTS[i],
                                        ssd_constants.FEATURE_SIZES[i],
                                        ssd_constants.FEATURE_SIZES[i]
                                    ]), [0, 2, 3, 1])

                            # TODO(dehao): This causes 3s overhead in startup, fix it.
                            labels[ssd_constants.BOXES] = {
                                i: transpose_gt_box(gt_boxes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                            labels[ssd_constants.CLASSES] = {
                                i: transpose_gt_class(gt_classes[i], i)
                                for i in range(len(ssd_constants.NUM_DEFAULTS))
                            }
                        self.feature_structure["features"] = features
                        self.feature_structure["labels"] = labels
                        flattened_inputs = data_nest.flatten(
                            self.feature_structure)
                        control_deps.extend(flattened_inputs)
                        per_host_sharded_inputs.append(flattened_inputs)

                    if self.use_spatial_partition:
                        flattened_input_dims = []
                        for i in per_host_sharded_inputs[0]:
                            if i.shape.ndims >= len(
                                    self.input_partition_dims[0]):
                                if i.shape.as_list() == self.feature_structure[
                                        "features"].shape.as_list():
                                    flattened_input_dims.append(
                                        self.input_partition_dims[0])
                                else:
                                    flattened_input_dims.append(
                                        FLAGS.input_partition_dims + [1] *
                                        (i.shape.ndims -
                                         len(self.input_partition_dims[0])))
                            else:
                                flattened_input_dims.append([1] *
                                                            i.shape.ndims)
                        # pylint: disable=protected-access
                        infeed = tpu_feed._PartitionedInfeedQueue(
                            number_of_tuple_elements=len(
                                per_host_sharded_inputs[0]),
                            host_id=host_id,
                            input_partition_dims=flattened_input_dims,
                            device_assignment=self.device_assignment)
                        self.infeed_queue.append(infeed)
                        return infeed.generate_enqueue_ops(
                            per_host_sharded_inputs)

                    infeed = tpu_feed.InfeedQueue(number_of_tuple_elements=len(
                        per_host_sharded_inputs[0]))
                    self.infeed_queue.append(infeed)
                    return infeed.generate_enqueue_ops(
                        per_host_sharded_inputs,
                        tpu_ordinal_function=utils.tpu_ordinal_fn)