Exemplo n.º 1
0
    def test_partition(self):
        with self.cached_session():
            partition_enabled_initializers = [
                initializers.ZerosV2(),
                initializers.OnesV2(),
                initializers.RandomUniformV2(),
                initializers.RandomNormalV2(),
                initializers.TruncatedNormalV2(),
                initializers.LecunUniformV2(),
                initializers.GlorotUniformV2(),
                initializers.HeUniformV2()
            ]
            for initializer in partition_enabled_initializers:
                got = initializer(shape=(4, 2),
                                  partition_shape=(2, 2),
                                  partition_offset=(0, 0))
                self.assertEqual(got.shape, (2, 2))

            partition_forbidden_initializers = [
                initializers.OrthogonalV2(),
                initializers.IdentityV2()
            ]
            for initializer in partition_forbidden_initializers:
                with self.assertRaisesRegex(
                        ValueError,
                        "initializer doesn't support partition-related arguments"
                ):
                    initializer(shape=(4, 2),
                                partition_shape=(2, 2),
                                partition_offset=(0, 0))
Exemplo n.º 2
0
 def test_zero(self):
     tensor_shape = (4, 5)
     with self.cached_session():
         self._runner(initializers.ZerosV2(),
                      tensor_shape,
                      target_mean=0.,
                      target_max=0.)