def test_identity(self):
        with self.cached_session():
            tensor_shape = (3, 4, 5)
            with self.assertRaises(ValueError):
                self._runner(initializers.IdentityV2(),
                             tensor_shape,
                             target_mean=1. / tensor_shape[0],
                             target_max=1.)

            tensor_shape = (3, 3)
            self._runner(initializers.IdentityV2(),
                         tensor_shape,
                         target_mean=1. / tensor_shape[0],
                         target_max=1.)
    def test_partition(self):
        with self.cached_session():
            partition_enabled_initializers = [
                initializers.ZerosV2(),
                initializers.OnesV2(),
                initializers.RandomUniformV2(),
                initializers.RandomNormalV2(),
                initializers.TruncatedNormalV2(),
                initializers.LecunUniformV2(),
                initializers.GlorotUniformV2(),
                initializers.HeUniformV2()
            ]
            for initializer in partition_enabled_initializers:
                got = initializer(shape=(4, 2),
                                  partition_shape=(2, 2),
                                  partition_offset=(0, 0))
                self.assertEqual(got.shape, (2, 2))

            partition_forbidden_initializers = [
                initializers.OrthogonalV2(),
                initializers.IdentityV2()
            ]
            for initializer in partition_forbidden_initializers:
                with self.assertRaisesRegex(
                        ValueError,
                        "initializer doesn't support partition-related arguments"
                ):
                    initializer(shape=(4, 2),
                                partition_shape=(2, 2),
                                partition_offset=(0, 0))