Exemple #1
0
    def testPartitionConfig(self):
        with self.assertRaises(ValueError):
            shampoo.PartitionConfig(-1, 2)

        with self.assertRaises(ValueError):
            shampoo.PartitionConfig(2, -1)

        with self.assertRaises(ValueError):
            shampoo.PartitionConfig(2, 3)
Exemple #2
0
 def testPartitionTensor(self):
     initial_value = np.ones((255, 255))
     w1 = tf.Variable(initial_value, dtype=tf.float32)
     partition_info = shampoo.PartitionConfig(200, 128)
     partitioned_grad = shampoo.partition_tensor(w1, partition_info)
     partitioned_shape = [grad.get_shape() for grad in partitioned_grad]
     self.assertEqual(partitioned_shape,
                      [[128, 128], [127, 128], [128, 127], [127, 127]])
Exemple #3
0
 def testPartitionMetadata(self):
     initial_value = np.ones((255, 255))
     w1 = tf.Variable(initial_value, dtype=tf.float32)
     partition_info = shampoo.PartitionConfig(200, 128)
     metadata = shampoo.partition_metadata(w1, partition_info)
     self.assertAllEqual(metadata.split_sizes_per_dim,
                         [[128, 127], [128, 127]])
     self.assertAllEqual(metadata.num_splits_per_dim, [2, 2])
Exemple #4
0
 def testTensorPartitioner(self):
     initial_value = np.ones((255, 255))
     w1 = tf.Variable(initial_value, dtype=tf.float32)
     partition_info = shampoo.PartitionConfig(200, 128)
     grad = tf.constant(initial_value)
     metadata = shampoo.partition_metadata(w1, partition_info)
     partitioned_grad = shampoo.partition_tensor(w1, partition_info)
     reformed_grad = shampoo.reform_tensor(partitioned_grad,
                                           metadata.num_splits_per_dim)
     self.assertAllCloseAccordingToType(reformed_grad, grad)