def testMerge(self): """Tests that merging works.""" p1 = tpu_sharding.ShardingPolicy() p1.set_number_of_shards(17) p1.set_shard_dimension(23) p2 = tpu_sharding.ShardingPolicy() p2.merge(p1) self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 23) p1 = tpu_sharding.ShardingPolicy() p1.set_shard_dimension(12) p2.merge(p1) self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 12) p2.freeze() p2.merge(p1) self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 12) p1.set_number_of_shards(1) with self.assertRaises(ValueError): p2.merge(p1) p1 = tpu_sharding.ShardingPolicy() p1.set_number_of_shards(17) p2.merge(p1) p1.set_shard_dimension(2) with self.assertRaises(ValueError): p2.merge(p1)
def testFreeze(self): """Tests that freezing a policy applies default values.""" p1 = tpu_sharding.ShardingPolicy() p1.freeze() self.assertEqual(p1.number_of_shards, tpu_sharding._DEFAULT_NUMBER_OF_SHARDS) self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION) p2 = tpu_sharding.ShardingPolicy() p2.set_number_of_shards(17) p2.set_shard_dimension(23) p2.freeze() self.assertEqual(p2.number_of_shards, 17) self.assertEqual(p2.shard_dimension, 23)
def testStr(self): """Tests the string representation.""" p1 = tpu_sharding.ShardingPolicy() self.assertEqual(str(p1), "ShardingPolicy(unset)") p1.set_number_of_shards(17) self.assertEqual(str(p1), "ShardingPolicy(unset)") p1.set_shard_dimension(8) self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)")
def testFrozen(self): """Tests that frozen policies can't be changed.""" p1 = tpu_sharding.ShardingPolicy() p1.freeze() with self.assertRaises(ValueError): p1.set_number_of_shards(17) with self.assertRaises(ValueError): p1.set_shard_dimension(22)
def testGetUnpartitionedShape(self): """Tests getting a sharded shape.""" p = tpu_sharding.ShardingPolicy() p.set_number_of_shards(3) p.set_shard_dimension(1) p.set_number_of_partitions(4) self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20]) p.freeze() with self.assertRaises(ValueError): _ = p.get_unpartitioned_shape([3, None])
def testGetUnshardedShape(self): """Tests getting an unsharded shape.""" p = tpu_sharding.ShardingPolicy() p.set_number_of_shards(2) p.set_shard_dimension(1) self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[4, 3]]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[4, 3], [4, 2]]) with self.assertRaises(TypeError): _ = p.get_unsharded_shape([[4, 3], "not_a_shape"]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([None, [4, 3]]) with self.assertRaises(ValueError): _ = p.get_unsharded_shape([[2], [4, 3]])
def testGetShardedShape(self): """Tests getting a sharded shape.""" p = tpu_sharding.ShardingPolicy() p.set_number_of_shards(3) p.set_shard_dimension(1) self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3]) p.freeze() with self.assertRaises(ValueError): p.set_shard_dimension(0) with self.assertRaises(ValueError): _ = p.get_sharded_shape([4, 9], shard_index=4) with self.assertRaises(ValueError): _ = p.get_sharded_shape([4, 9], shard_index=-1) with self.assertRaises(TypeError): _ = p.get_sharded_shape("not_a_shape") with self.assertRaises(ValueError): _ = p.get_sharded_shape(tensor_shape.TensorShape(None)) with self.assertRaises(ValueError): _ = p.get_sharded_shape([4, 10], shard_index=-1)
def __init__(self, number_of_tuple_elements=None, tuple_types=None, tuple_shapes=None, shard_dimensions=None, name=None): """Creates a new InfeedQueue with the given configuration. The configuration need not be fully specified at creation since it can be modified subsequently by methods that set the values explicitly or infer them from the shapes of inputs. Args: number_of_tuple_elements: the number of Tensors fed atomically through the queue, must be present unless it can be inferred from other arguments. tuple_types: if not None, a list of types of the elements of the queue. tuple_shapes: if not None, a list of shapes of the elements of the queue. shard_dimensions: if not None, a list of dimensions on which the elements of the queue should be sharded during automatic parallelization. name: the name of the queue. Raises: ValueError: if number_of_tuple_elements <= 0; or number_of_tuple_arguments, tuple_types, tuple_shapes, and shard_dimensions are all None; or the length of tuple_types, tuple_shapes, or shard_dimensions is not equal to number_of_tuple_elements; or any element of shard_dimensions can't be converted to a Dimension. TypeError: if any element of tuple_types or tuple_shapes can't be converted to a dtype or TensorShape, respectively. """ self._frozen = False self._generated_enqueue_ops = False self._generated_dequeue_op = False self._name = "InfeedQueue" if name is None else name if number_of_tuple_elements is None: if tuple_types is not None: number_of_tuple_elements = len(tuple_types) elif tuple_shapes is not None: number_of_tuple_elements = len(tuple_shapes) elif shard_dimensions is not None: number_of_tuple_elements = len(shard_dimensions) else: raise ValueError( "number of tuple elements cannot be inferred from InfeedQueue " "constructor") if number_of_tuple_elements <= 0: raise ValueError("number_of_tuple_elements %d must be > 0" % number_of_tuple_elements) # Make an empty sharding policy for each tuple element. self._sharding_policies = [ tpu_sharding.ShardingPolicy() for _ in xrange(number_of_tuple_elements) ] if tuple_types is not None: self.set_tuple_types(tuple_types) else: self._tuple_types = None if tuple_shapes is not None: self.set_tuple_shapes(tuple_shapes) else: self._tuple_shapes = None if shard_dimensions is not None: self.set_shard_dimensions(shard_dimensions) self._validate()
def __init__(self, number_of_tuple_elements=None, tuple_types=None, tuple_shapes=None, shard_dimensions=None, number_of_partitions=None, name=None): """Creates a new InfeedQueue with the given configuration. The configuration need not be fully specified at creation since it can be modified subsequently by methods that set the values explicitly or infer them from the shapes of inputs. Args: number_of_tuple_elements: the number of Tensors fed atomically through the queue, must be present unless it can be inferred from other arguments. tuple_types: if not None, a list of types of the elements of the queue. tuple_shapes: if not None, a list of shapes of the elements of the queue. shard_dimensions: if not None, a list of dimensions on which the elements of the queue should be sharded during automatic parallelization. number_of_partitions: if > 1, the infeed dequeue shape will contain the full shape that includes all partitions and add corresponding XLA annotation on the infeed dequeue op. In this case, the infeed is still data parallel that feeds per-core batch size to each core while the XLA computation may be partitioned. As XLA requires infeed dequeue shape to be per-replica shape, thus we need number_of_partitions here to calculate the per-replica unpartitioned shape. name: the name of the queue. Raises: ValueError: if number_of_tuple_elements <= 0; or number_of_tuple_arguments, tuple_types, tuple_shapes, and shard_dimensions are all None; or the length of tuple_types, tuple_shapes, or shard_dimensions is not equal to number_of_tuple_elements; or any element of shard_dimensions can't be converted to a Dimension. TypeError: if any element of tuple_types or tuple_shapes can't be converted to a dtype or TensorShape, respectively. """ self._frozen = False self._generated_enqueue_ops = False self._generated_dequeue_op = False self._name = "InfeedQueue" if name is None else name if number_of_partitions is None: self._number_of_partitions = 1 else: self._number_of_partitions = number_of_partitions if number_of_tuple_elements is None: if tuple_types is not None: number_of_tuple_elements = len(tuple_types) elif tuple_shapes is not None: number_of_tuple_elements = len(tuple_shapes) elif shard_dimensions is not None: number_of_tuple_elements = len(shard_dimensions) else: raise ValueError( "number of tuple elements cannot be inferred from InfeedQueue " "constructor") if number_of_tuple_elements <= 0: raise ValueError("number_of_tuple_elements %d must be > 0" % number_of_tuple_elements) # Make an empty sharding policy for each tuple element. self._sharding_policies = [ tpu_sharding.ShardingPolicy() for _ in range(number_of_tuple_elements) ] if tuple_types is not None: self.set_tuple_types(tuple_types) else: self._tuple_types = None if tuple_shapes is not None: self.set_tuple_shapes(tuple_shapes) else: self._tuple_shapes = None if shard_dimensions is not None: self.set_shard_dimensions(shard_dimensions) self._validate()
def testScalar(self): """Tests sharding and unsharding scalars.""" p = tpu_sharding.ShardingPolicy() p.freeze() self.assertEqual(p.get_sharded_shape([]), []) self.assertEqual(p.get_unsharded_shape([[]]), [])