Пример #1
0
 def testMerge(self):
     """Tests that merging works."""
     p1 = tpu_sharding.ShardingPolicy()
     p1.set_number_of_shards(17)
     p1.set_shard_dimension(23)
     p2 = tpu_sharding.ShardingPolicy()
     p2.merge(p1)
     self.assertEqual(p2.number_of_shards, 17)
     self.assertEqual(p2.shard_dimension, 23)
     p1 = tpu_sharding.ShardingPolicy()
     p1.set_shard_dimension(12)
     p2.merge(p1)
     self.assertEqual(p2.number_of_shards, 17)
     self.assertEqual(p2.shard_dimension, 12)
     p2.freeze()
     p2.merge(p1)
     self.assertEqual(p2.number_of_shards, 17)
     self.assertEqual(p2.shard_dimension, 12)
     p1.set_number_of_shards(1)
     with self.assertRaises(ValueError):
         p2.merge(p1)
     p1 = tpu_sharding.ShardingPolicy()
     p1.set_number_of_shards(17)
     p2.merge(p1)
     p1.set_shard_dimension(2)
     with self.assertRaises(ValueError):
         p2.merge(p1)
Пример #2
0
 def testFreeze(self):
     """Tests that freezing a policy applies default values."""
     p1 = tpu_sharding.ShardingPolicy()
     p1.freeze()
     self.assertEqual(p1.number_of_shards,
                      tpu_sharding._DEFAULT_NUMBER_OF_SHARDS)
     self.assertEqual(p1.shard_dimension,
                      tpu_sharding._DEFAULT_SHARD_DIMENSION)
     p2 = tpu_sharding.ShardingPolicy()
     p2.set_number_of_shards(17)
     p2.set_shard_dimension(23)
     p2.freeze()
     self.assertEqual(p2.number_of_shards, 17)
     self.assertEqual(p2.shard_dimension, 23)
Пример #3
0
 def testStr(self):
     """Tests the string representation."""
     p1 = tpu_sharding.ShardingPolicy()
     self.assertEqual(str(p1), "ShardingPolicy(unset)")
     p1.set_number_of_shards(17)
     self.assertEqual(str(p1), "ShardingPolicy(unset)")
     p1.set_shard_dimension(8)
     self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)")
Пример #4
0
 def testFrozen(self):
     """Tests that frozen policies can't be changed."""
     p1 = tpu_sharding.ShardingPolicy()
     p1.freeze()
     with self.assertRaises(ValueError):
         p1.set_number_of_shards(17)
     with self.assertRaises(ValueError):
         p1.set_shard_dimension(22)
Пример #5
0
 def testGetUnpartitionedShape(self):
     """Tests getting a sharded shape."""
     p = tpu_sharding.ShardingPolicy()
     p.set_number_of_shards(3)
     p.set_shard_dimension(1)
     p.set_number_of_partitions(4)
     self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20])
     p.freeze()
     with self.assertRaises(ValueError):
         _ = p.get_unpartitioned_shape([3, None])
Пример #6
0
 def testGetUnshardedShape(self):
     """Tests getting an unsharded shape."""
     p = tpu_sharding.ShardingPolicy()
     p.set_number_of_shards(2)
     p.set_shard_dimension(1)
     self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6])
     with self.assertRaises(ValueError):
         _ = p.get_unsharded_shape([[4, 3]])
     with self.assertRaises(ValueError):
         _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]])
     with self.assertRaises(ValueError):
         _ = p.get_unsharded_shape([[4, 3], [4, 2]])
     with self.assertRaises(TypeError):
         _ = p.get_unsharded_shape([[4, 3], "not_a_shape"])
     with self.assertRaises(ValueError):
         _ = p.get_unsharded_shape([None, [4, 3]])
     with self.assertRaises(ValueError):
         _ = p.get_unsharded_shape([[2], [4, 3]])
Пример #7
0
 def testGetShardedShape(self):
     """Tests getting a sharded shape."""
     p = tpu_sharding.ShardingPolicy()
     p.set_number_of_shards(3)
     p.set_shard_dimension(1)
     self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3])
     p.freeze()
     with self.assertRaises(ValueError):
         p.set_shard_dimension(0)
     with self.assertRaises(ValueError):
         _ = p.get_sharded_shape([4, 9], shard_index=4)
     with self.assertRaises(ValueError):
         _ = p.get_sharded_shape([4, 9], shard_index=-1)
     with self.assertRaises(TypeError):
         _ = p.get_sharded_shape("not_a_shape")
     with self.assertRaises(ValueError):
         _ = p.get_sharded_shape(tensor_shape.TensorShape(None))
     with self.assertRaises(ValueError):
         _ = p.get_sharded_shape([4, 10], shard_index=-1)
Пример #8
0
  def __init__(self,
               number_of_tuple_elements=None,
               tuple_types=None,
               tuple_shapes=None,
               shard_dimensions=None,
               name=None):
    """Creates a new InfeedQueue with the given configuration.

    The configuration need not be fully specified at creation since it
    can be modified subsequently by methods that set the values
    explicitly or infer them from the shapes of inputs.

    Args:
      number_of_tuple_elements: the number of Tensors fed atomically through the
        queue, must be present unless it can be inferred from other arguments.
      tuple_types: if not None, a list of types of the elements of the queue.
      tuple_shapes: if not None, a list of shapes of the elements of the queue.
      shard_dimensions: if not None, a list of dimensions on which the
        elements of the queue should be sharded during automatic
        parallelization.
      name: the name of the queue.

    Raises:
      ValueError: if number_of_tuple_elements <= 0; or
        number_of_tuple_arguments, tuple_types, tuple_shapes, and
        shard_dimensions are all None; or the length of tuple_types,
        tuple_shapes, or shard_dimensions is not equal to
        number_of_tuple_elements; or any element of shard_dimensions
        can't be converted to a Dimension.
      TypeError: if any element of tuple_types or tuple_shapes can't
        be converted to a dtype or TensorShape, respectively.
    """
    self._frozen = False
    self._generated_enqueue_ops = False
    self._generated_dequeue_op = False
    self._name = "InfeedQueue" if name is None else name
    if number_of_tuple_elements is None:
      if tuple_types is not None:
        number_of_tuple_elements = len(tuple_types)
      elif tuple_shapes is not None:
        number_of_tuple_elements = len(tuple_shapes)
      elif shard_dimensions is not None:
        number_of_tuple_elements = len(shard_dimensions)
      else:
        raise ValueError(
            "number of tuple elements cannot be inferred from InfeedQueue "
            "constructor")
    if number_of_tuple_elements <= 0:
      raise ValueError("number_of_tuple_elements %d must be > 0" %
                       number_of_tuple_elements)
    # Make an empty sharding policy for each tuple element.
    self._sharding_policies = [
        tpu_sharding.ShardingPolicy()
        for _ in xrange(number_of_tuple_elements)
    ]
    if tuple_types is not None:
      self.set_tuple_types(tuple_types)
    else:
      self._tuple_types = None
    if tuple_shapes is not None:
      self.set_tuple_shapes(tuple_shapes)
    else:
      self._tuple_shapes = None
    if shard_dimensions is not None:
      self.set_shard_dimensions(shard_dimensions)
    self._validate()
Пример #9
0
    def __init__(self,
                 number_of_tuple_elements=None,
                 tuple_types=None,
                 tuple_shapes=None,
                 shard_dimensions=None,
                 number_of_partitions=None,
                 name=None):
        """Creates a new InfeedQueue with the given configuration.

    The configuration need not be fully specified at creation since it
    can be modified subsequently by methods that set the values
    explicitly or infer them from the shapes of inputs.

    Args:
      number_of_tuple_elements: the number of Tensors fed atomically through the
        queue, must be present unless it can be inferred from other arguments.
      tuple_types: if not None, a list of types of the elements of the queue.
      tuple_shapes: if not None, a list of shapes of the elements of the queue.
      shard_dimensions: if not None, a list of dimensions on which the
        elements of the queue should be sharded during automatic
        parallelization.
      number_of_partitions: if > 1, the infeed dequeue shape will contain
        the full shape that includes all partitions and add corresponding XLA
        annotation on the infeed dequeue op. In this case, the infeed is still
        data parallel that feeds per-core batch size to each core while the XLA
        computation may be partitioned. As XLA requires infeed dequeue shape to
        be per-replica shape, thus we need number_of_partitions here to
        calculate the per-replica unpartitioned shape.
      name: the name of the queue.

    Raises:
      ValueError: if number_of_tuple_elements <= 0; or
        number_of_tuple_arguments, tuple_types, tuple_shapes, and
        shard_dimensions are all None; or the length of tuple_types,
        tuple_shapes, or shard_dimensions is not equal to
        number_of_tuple_elements; or any element of shard_dimensions
        can't be converted to a Dimension.
      TypeError: if any element of tuple_types or tuple_shapes can't
        be converted to a dtype or TensorShape, respectively.
    """
        self._frozen = False
        self._generated_enqueue_ops = False
        self._generated_dequeue_op = False
        self._name = "InfeedQueue" if name is None else name
        if number_of_partitions is None:
            self._number_of_partitions = 1
        else:
            self._number_of_partitions = number_of_partitions
        if number_of_tuple_elements is None:
            if tuple_types is not None:
                number_of_tuple_elements = len(tuple_types)
            elif tuple_shapes is not None:
                number_of_tuple_elements = len(tuple_shapes)
            elif shard_dimensions is not None:
                number_of_tuple_elements = len(shard_dimensions)
            else:
                raise ValueError(
                    "number of tuple elements cannot be inferred from InfeedQueue "
                    "constructor")
        if number_of_tuple_elements <= 0:
            raise ValueError("number_of_tuple_elements %d must be > 0" %
                             number_of_tuple_elements)
        # Make an empty sharding policy for each tuple element.
        self._sharding_policies = [
            tpu_sharding.ShardingPolicy()
            for _ in range(number_of_tuple_elements)
        ]
        if tuple_types is not None:
            self.set_tuple_types(tuple_types)
        else:
            self._tuple_types = None
        if tuple_shapes is not None:
            self.set_tuple_shapes(tuple_shapes)
        else:
            self._tuple_shapes = None
        if shard_dimensions is not None:
            self.set_shard_dimensions(shard_dimensions)
        self._validate()
Пример #10
0
 def testScalar(self):
     """Tests sharding and unsharding scalars."""
     p = tpu_sharding.ShardingPolicy()
     p.freeze()
     self.assertEqual(p.get_sharded_shape([]), [])
     self.assertEqual(p.get_unsharded_shape([[]]), [])