Exemplo n.º 1
0
  def test_partition(self):
    with self.cached_session():
      partition_enabled_initializers = [
          initializers.ZerosV2(),
          initializers.OnesV2(),
          initializers.RandomUniformV2(),
          initializers.RandomNormalV2(),
          initializers.TruncatedNormalV2(),
          initializers.LecunUniformV2(),
          initializers.GlorotUniformV2(),
          initializers.HeUniformV2()
      ]
      for initializer in partition_enabled_initializers:
        got = initializer(
            shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
        self.assertEqual(got.shape, (2, 2))

      partition_forbidden_initializers = [
          initializers.OrthogonalV2(),
          initializers.IdentityV2()
      ]
      for initializer in partition_forbidden_initializers:
        with self.assertRaisesRegex(
            ValueError,
            "initializer doesn't support partition-related arguments"):
          initializer(
              shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
Exemplo n.º 2
0
 def test_uniform(self):
     tensor_shape = (3, 2, 3)
     with self.cached_session():
         self._runner(
             initializers.RandomUniformV2(minval=-1, maxval=1, seed=124),
             tensor_shape,
         )
Exemplo n.º 3
0
 def test_uniform(self):
     tensor_shape = (9, 6, 7)
     with self.cached_session():
         self._runner(initializers.RandomUniformV2(minval=-1,
                                                   maxval=1,
                                                   seed=124),
                      tensor_shape,
                      target_mean=0.,
                      target_max=1,
                      target_min=-1)