def test_partition(self):
    with self.cached_session():
      partition_enabled_initializers = [
          initializers.ZerosV2(),
          initializers.OnesV2(),
          initializers.RandomUniformV2(),
          initializers.RandomNormalV2(),
          initializers.TruncatedNormalV2(),
          initializers.LecunUniformV2(),
          initializers.GlorotUniformV2(),
          initializers.HeUniformV2()
      ]
      for initializer in partition_enabled_initializers:
        got = initializer(
            shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
        self.assertEqual(got.shape, (2, 2))

      partition_forbidden_initializers = [
          initializers.OrthogonalV2(),
          initializers.IdentityV2()
      ]
      for initializer in partition_forbidden_initializers:
        with self.assertRaisesRegex(
            ValueError,
            "initializer doesn't support partition-related arguments"):
          initializer(
              shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
Example #2
0
 def test_he_uniform(self):
     tensor_shape = (5, 6, 4, 2)
     with self.cached_session():
         fan_in, _ = _compute_fans(tensor_shape)
         std = np.sqrt(2. / fan_in)
         self._runner(initializers.HeUniformV2(seed=123),
                      tensor_shape,
                      target_mean=0.,
                      target_std=std)
Example #3
0
 def test_he_uniform(self):
     tensor_shape = (5, 6, 4, 2)
     with self.cached_session():
         self._runner(initializers.HeUniformV2(seed=123), tensor_shape)