def create_initializer(init_type, scale=None, fillvalue=None): if init_type == 'identity': return initializers.Identity() if scale is None else initializers.Identity(scale=scale) if init_type == 'constant': return initializers.Constant(fillvalue) if init_type == 'zero': return initializers.Zero() if init_type == 'one': return initializers.One() if init_type == 'normal': return initializers.Normal() if scale is None else initializers.Normal(scale) if init_type == 'glorotNormal': return initializers.GlorotNormal() if scale is None else initializers.GlorotNormal(scale) if init_type == 'heNormal': return initializers.HeNormal() if scale is None else initializers.HeNormal(scale) if init_type == 'orthogonal': return initializers.Orthogonal( scale) if scale is None else initializers.Orthogonal(scale) if init_type == 'uniform': return initializers.Uniform( scale) if scale is None else initializers.Uniform(scale) if init_type == 'leCunUniform': return initializers.LeCunUniform( scale) if scale is None else initializers.LeCunUniform(scale) if init_type == 'glorotUniform': return initializers.GlorotUniform( scale) if scale is None else initializers.GlorotUniform(scale) if init_type == 'heUniform': return initializers.HeUniform( scale) if scale is None else initializers.HeUniform(scale) raise ValueError("Unknown initializer type: {0}".format(init_type))
def check_shaped_initializer(self, xp): initializer = initializers.Identity(scale=self.scale, dtype=self.dtype) w = initializers.generate_array(initializer, self.shape, xp) self.assertIs(backend.get_array_module(w), xp) self.assertTupleEqual(w.shape, self.shape) self.assertEqual(w.dtype, self.dtype) testing.assert_allclose(w, self.scale * numpy.identity(len(self.shape)), **self.check_options)
def setUp(self): self.initializer = initializers.Identity()
def setUp(self): self.scale = 0.1 self.shape = (2, 2) self.initializer = initializers.Identity(scale=self.scale) self.w = numpy.empty((2, 2), dtype=numpy.float32)
def check_initializer(self, w): initializer = initializers.Identity(scale=self.scale) initializer(w) testing.assert_allclose(w, self.scale * numpy.identity(len(self.shape)), **self.check_options)