def test_partition(self): with self.cached_session(): partition_enabled_initializers = [ initializers.ZerosV2(), initializers.OnesV2(), initializers.RandomUniformV2(), initializers.RandomNormalV2(), initializers.TruncatedNormalV2(), initializers.LecunUniformV2(), initializers.GlorotUniformV2(), initializers.HeUniformV2() ] for initializer in partition_enabled_initializers: got = initializer( shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0)) self.assertEqual(got.shape, (2, 2)) partition_forbidden_initializers = [ initializers.OrthogonalV2(), initializers.IdentityV2() ] for initializer in partition_forbidden_initializers: with self.assertRaisesRegex( ValueError, "initializer doesn't support partition-related arguments"): initializer( shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
def test_zero(self): tensor_shape = (4, 5) with self.cached_session(): self._runner(initializers.ZerosV2(), tensor_shape, target_mean=0., target_max=0.)