示例#1
0
    def build(self, input_shape):
        self.axis = tf_utils.validate_axis(self.axis, input_shape)
        input_shape = tf.TensorShape(input_shape)
        rank = input_shape.rank

        param_shape = [input_shape[dim] for dim in self.axis]
        if self.scale:
            self.gamma = self.add_weight(name='gamma',
                                         shape=param_shape,
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint,
                                         trainable=True,
                                         experimental_autocast=False)
        else:
            self.gamma = None

        if self.center:
            self.beta = self.add_weight(name='beta',
                                        shape=param_shape,
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint,
                                        trainable=True,
                                        experimental_autocast=False)
        else:
            self.beta = None

        self._fused = self._fused_can_be_used(rank)
        self.built = True
示例#2
0
 def build(self, input_shape):
   self.axis = tf_utils.validate_axis(self.axis, input_shape)