Ejemplo n.º 1
0
  def test_partition(self):
    with self.cached_session():
      partition_enabled_initializers = [
          initializers.ZerosV2(),
          initializers.OnesV2(),
          initializers.RandomUniformV2(),
          initializers.RandomNormalV2(),
          initializers.TruncatedNormalV2(),
          initializers.LecunUniformV2(),
          initializers.GlorotUniformV2(),
          initializers.HeUniformV2()
      ]
      for initializer in partition_enabled_initializers:
        got = initializer(
            shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
        self.assertEqual(got.shape, (2, 2))

      partition_forbidden_initializers = [
          initializers.OrthogonalV2(),
          initializers.IdentityV2()
      ]
      for initializer in partition_forbidden_initializers:
        with self.assertRaisesRegex(
            ValueError,
            "initializer doesn't support partition-related arguments"):
          initializer(
              shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
Ejemplo n.º 2
0
 def test_truncated_normal(self):
     tensor_shape = (12, 99, 7)
     with self.cached_session():
         self._runner(
             initializers.TruncatedNormalV2(mean=0, stddev=1, seed=126),
             tensor_shape,
         )