Example #1
0
    def test_identity(self):
        with self.cached_session():
            tensor_shape = (3, 4, 5)
            with self.assertRaises(ValueError):
                self._runner(initializers.IdentityV2(), tensor_shape)

            tensor_shape = (3, 3)
            self._runner(initializers.IdentityV2(), tensor_shape)
  def test_partition(self):
    with self.cached_session():
      partition_enabled_initializers = [
          initializers.ZerosV2(),
          initializers.OnesV2(),
          initializers.RandomUniformV2(),
          initializers.RandomNormalV2(),
          initializers.TruncatedNormalV2(),
          initializers.LecunUniformV2(),
          initializers.GlorotUniformV2(),
          initializers.HeUniformV2()
      ]
      for initializer in partition_enabled_initializers:
        got = initializer(
            shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
        self.assertEqual(got.shape, (2, 2))

      partition_forbidden_initializers = [
          initializers.OrthogonalV2(),
          initializers.IdentityV2()
      ]
      for initializer in partition_forbidden_initializers:
        with self.assertRaisesRegex(
            ValueError,
            "initializer doesn't support partition-related arguments"):
          initializer(
              shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
Example #3
0
    def test_identity(self):
        with self.cached_session():
            tensor_shape = (3, 4, 5)
            with self.assertRaises(ValueError):
                self._runner(
                    initializers.IdentityV2(),
                    tensor_shape,
                    target_mean=1.0 / tensor_shape[0],
                    target_max=1.0,
                )

            tensor_shape = (3, 3)
            self._runner(
                initializers.IdentityV2(),
                tensor_shape,
                target_mean=1.0 / tensor_shape[0],
                target_max=1.0,
            )