def testPartitionConfig(self): with self.assertRaises(ValueError): shampoo.PartitionConfig(-1, 2) with self.assertRaises(ValueError): shampoo.PartitionConfig(2, -1) with self.assertRaises(ValueError): shampoo.PartitionConfig(2, 3)
def testPartitionTensor(self): initial_value = np.ones((255, 255)) w1 = tf.Variable(initial_value, dtype=tf.float32) partition_info = shampoo.PartitionConfig(200, 128) partitioned_grad = shampoo.partition_tensor(w1, partition_info) partitioned_shape = [grad.get_shape() for grad in partitioned_grad] self.assertEqual(partitioned_shape, [[128, 128], [127, 128], [128, 127], [127, 127]])
def testPartitionMetadata(self): initial_value = np.ones((255, 255)) w1 = tf.Variable(initial_value, dtype=tf.float32) partition_info = shampoo.PartitionConfig(200, 128) metadata = shampoo.partition_metadata(w1, partition_info) self.assertAllEqual(metadata.split_sizes_per_dim, [[128, 127], [128, 127]]) self.assertAllEqual(metadata.num_splits_per_dim, [2, 2])
def testTensorPartitioner(self): initial_value = np.ones((255, 255)) w1 = tf.Variable(initial_value, dtype=tf.float32) partition_info = shampoo.PartitionConfig(200, 128) grad = tf.constant(initial_value) metadata = shampoo.partition_metadata(w1, partition_info) partitioned_grad = shampoo.partition_tensor(w1, partition_info) reformed_grad = shampoo.reform_tensor(partitioned_grad, metadata.num_splits_per_dim) self.assertAllCloseAccordingToType(reformed_grad, grad)