def test_partition(self): with self.cached_session(): partition_enabled_initializers = [ initializers.ZerosV2(), initializers.OnesV2(), initializers.RandomUniformV2(), initializers.RandomNormalV2(), initializers.TruncatedNormalV2(), initializers.LecunUniformV2(), initializers.GlorotUniformV2(), initializers.HeUniformV2() ] for initializer in partition_enabled_initializers: got = initializer(shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) self.assertEqual(got.shape, (2, 2)) partition_forbidden_initializers = [ initializers.OrthogonalV2(), initializers.IdentityV2() ] for initializer in partition_forbidden_initializers: with self.assertRaisesRegex( ValueError, "initializer doesn't support partition-related arguments" ): initializer(shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
def test_orthogonal(self): tensor_shape = (20, 20) with self.cached_session(): self._runner(initializers.OrthogonalV2(seed=123), tensor_shape, target_mean=0.)