Пример #1
0
 def __init__(self, units,
              kernel_initializer=init.GlorotUniformInitializer(),
              bias_initializer=init.RandomNormalInitializer(1e-6)):
   super(Dense, self).__init__()
   self._units = units
   self._kernel_initializer = kernel_initializer
   self._bias_initializer = bias_initializer
Пример #2
0
 def __init__(self,
              filters,
              kernel_width=3,
              kernel_initializer=None,
              bias_initializer=init.RandomNormalInitializer(1e-6)):
     super(CausalConv,
           self).__init__(filters=filters,
                          kernel_size=(kernel_width, ),
                          strides=None,
                          padding='VALID',
                          dimension_numbers=('NWC', 'WIO', 'NWC'),
                          kernel_initializer=kernel_initializer,
                          bias_initializer=bias_initializer)
Пример #3
0
 def __init__(self, filters, kernel_size, strides=None, padding='VALID',
              dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
              kernel_initializer=None,
              bias_initializer=init.RandomNormalInitializer(1e-6)):
   super(Conv, self).__init__()
   self._filters = filters
   self._kernel_size = kernel_size
   self._padding = padding
   self._dimension_numbers = dimension_numbers
   self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers
   self._one = (1,) * len(kernel_size)
   self._strides = strides or self._one
   self._bias_initializer = bias_initializer
   rhs_spec = self._rhs_spec
   self._kernel_initializer = kernel_initializer
   if kernel_initializer is None:
     self._kernel_initializer = init.GlorotNormalInitializer(
         rhs_spec.index('O'), rhs_spec.index('I'))
Пример #4
0
 def __init__(self, initializer=init.RandomNormalInitializer(0.01)):
     super(ShiftRightLearned, self).__init__()
     self._initializer = initializer
Пример #5
0
 def test_random_normal(self):
     initializer = initializers.RandomNormalInitializer()
     input_shape = (29, 5, 7, 20)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual(tuple(init_value.shape), input_shape)