def test_partition(self):
        with self.cached_session():
            partition_enabled_initializers = [
                initializers.ZerosV2(),
                initializers.OnesV2(),
                initializers.RandomUniformV2(),
                initializers.RandomNormalV2(),
                initializers.TruncatedNormalV2(),
                initializers.LecunUniformV2(),
                initializers.GlorotUniformV2(),
                initializers.HeUniformV2()
            ]
            for initializer in partition_enabled_initializers:
                got = initializer(shape=(4, 2),
                                  partition_shape=(2, 2),
                                  partition_offset=(0, 0))
                self.assertEqual(got.shape, (2, 2))

            partition_forbidden_initializers = [
                initializers.OrthogonalV2(),
                initializers.IdentityV2()
            ]
            for initializer in partition_forbidden_initializers:
                with self.assertRaisesRegex(
                        ValueError,
                        "initializer doesn't support partition-related arguments"
                ):
                    initializer(shape=(4, 2),
                                partition_shape=(2, 2),
                                partition_offset=(0, 0))
Exemple #2
0
 def test_uniform(self):
   tensor_shape = (9, 6, 7)
   with self.cached_session():
     self._runner(
         initializers.RandomUniformV2(minval=-1, maxval=1, seed=124),
         tensor_shape,
         target_mean=0.,
         target_max=1,
         target_min=-1)