コード例 #1
0
 def normalize_in_training():
     if needs_broadcasting:
         return nn.batch_normalization(inputs, broadcast_mean,
                                       broadcast_variance,
                                       broadcast_beta, broadcast_gamma,
                                       self.epsilon)
     else:
         return nn.batch_normalization(
             inputs, mean, variance, self.beta if self.center else None,
             self.gamma if self.scale else None, self.epsilon)
コード例 #2
0
 def normalize_in_test():
     if needs_broadcasting:
         broadcast_moving_mean = array_ops.reshape(
             self.moving_mean, broadcast_shape)
         broadcast_moving_variance = array_ops.reshape(
             self.moving_variance, broadcast_shape)
         return nn.batch_normalization(inputs, broadcast_moving_mean,
                                       broadcast_moving_variance,
                                       broadcast_beta, broadcast_gamma,
                                       self.epsilon)
     else:
         return nn.batch_normalization(
             inputs, self.moving_mean, self.moving_variance,
             self.beta if self.center else None,
             self.gamma if self.scale else None, self.epsilon)
コード例 #3
0
 def normalize_in_training():
   if needs_broadcasting:
     return nn.batch_normalization(inputs,
                                   broadcast_mean,
                                   broadcast_variance,
                                   broadcast_beta,
                                   broadcast_gamma,
                                   self.epsilon)
   else:
     return nn.batch_normalization(inputs,
                                   mean,
                                   variance,
                                   self.beta if self.center else None,
                                   self.gamma if self.scale else None,
                                   self.epsilon)
コード例 #4
0
    def normalize(self, inputs):
        """Apply normalization to input.

    The shape must match the declared shape in the constructor.
    [This is copied from tf.contrib.rnn.LayerNormBasicLSTMCell.]

    Args:
      inputs: Input tensor

    Returns:
      Normalized version of input tensor.

    Raises:
      ValueError: if inputs has undefined rank.
    """
        inputs_shape = inputs.get_shape()
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        axis = range(1, inputs_rank)

        beta = self._component.get_variable('beta_%s' % self._name)
        gamma = self._component.get_variable('gamma_%s' % self._name)

        with tf.variable_scope('layer_norm_%s' % self._name):
            # Calculate the moments on the last axis (layer activations).
            mean, variance = nn.moments(inputs, axis, keep_dims=True)

            # Compute layer normalization using the batch_normalization function.
            variance_epsilon = 1E-12
            outputs = nn.batch_normalization(inputs, mean, variance, beta,
                                             gamma, variance_epsilon)
            outputs.set_shape(inputs_shape)
            return outputs
コード例 #5
0
  def __call__(self, inputs):
    """Run virtual batch normalization on inputs.

    Args:
      inputs: Tensor input.

    Returns:
       A virtual batch normalized version of `inputs`.

    Raises:
       ValueError: If `inputs` shape isn't compatible with the reference batch.
    """
    _validate_call_input([inputs, self._reference_batch], self._batch_axis)

    with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
      # Calculate the statistics on the current input on a per-example basis.
      vb_mean, vb_mean_sq = self._virtual_statistics(
          inputs, self._example_reduction_axes)
      vb_variance = vb_mean_sq - math_ops.square(vb_mean)

      # The exact broadcast shape of the input statistic Tensors depends on the
      # current batch, not the reference batch. The parameter broadcast shape
      # is independent of the shape of the input statistic Tensor dimensions.
      b_shape = self._broadcast_shape[:]  # deep copy
      b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
          inputs, self._batch_axis)
      return nn.batch_normalization(
          inputs,
          self._broadcast(vb_mean, b_shape),
          self._broadcast(vb_variance, b_shape),
          self._broadcast(self._beta, self._broadcast_shape),
          self._broadcast(self._gamma, self._broadcast_shape),
          self._epsilon)
コード例 #6
0
    def call(self, inputs):
        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)

        # Calculate the moments on the last axis (layer activations).
        mean, variance = nn.moments(inputs, self.norm_axis, keep_dims=True)

        # Broadcasting only necessary for norm where the params axes aren't just
        # the last dimension
        broadcast_shape = [1] * ndims
        for dim in self.params_axis:
            broadcast_shape[dim] = input_shape.dims[dim].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and self.params_axis != [ndims - 1]):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        # Compute layer normalization using the batch_normalization function.
        outputs = nn.batch_normalization(inputs,
                                         mean,
                                         variance,
                                         offset=offset,
                                         scale=scale,
                                         variance_epsilon=self.epsilon)

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs
コード例 #7
0
ファイル: tf_lstm.py プロジェクト: zhoukangg/TF_LSTM_seq_bn
def batch_norm(x, deterministic, alpha=0.9, shift=True, scope='bn'):
    with vs.variable_scope(scope):
        dtype = x.dtype
        input_shape = x.get_shape().as_list()
        feat_dim = input_shape[-1]
        axes = range(len(input_shape) - 1)

        if shift:
            beta = vs.get_variable(scope + "_beta",
                                   shape=[feat_dim],
                                   initializer=init_ops.zeros_initializer,
                                   dtype=dtype)
        else:
            beta = vs.get_variable(scope + "_beta",
                                   shape=[feat_dim],
                                   initializer=init_ops.zeros_initializer,
                                   dtype=dtype,
                                   trainable=False)

        gamma = vs.get_variable(scope + "_gamma",
                                shape=[feat_dim],
                                initializer=init_ops.constant_initializer(0.1),
                                dtype=dtype)

        mean = vs.get_variable(scope + "_mean",
                               shape=[feat_dim],
                               initializer=init_ops.zeros_initializer,
                               dtype=dtype,
                               trainable=False)

        var = vs.get_variable(scope + "_var",
                              shape=[feat_dim],
                              initializer=init_ops.ones_initializer,
                              dtype=dtype,
                              trainable=False)

        counter = vs.get_variable(scope + "_counter",
                                  shape=[],
                                  initializer=init_ops.constant_initializer(0),
                                  dtype=tf.int64,
                                  trainable=False)

        zero_cnt = vs.get_variable(
            scope + "_zero_cnt",
            shape=[],
            initializer=init_ops.constant_initializer(0),
            dtype=tf.int64,
            trainable=False)

        batch_mean, batch_var = moments(x, axes, name=scope + '_moments')

        mean, var = cond(math_ops.equal(counter, zero_cnt), lambda:
                         (batch_mean, batch_var), lambda: (mean, var))

        mean, var, counter = cond(
            deterministic, lambda: (mean, var, counter), lambda:
            ((1 - alpha) * batch_mean + alpha * mean,
             (1 - alpha) * batch_var + alpha * var, counter + 1))
        normed = batch_normalization(x, mean, var, beta, gamma, 1e-8)
    return normed
コード例 #8
0
    def __call__(self, inputs):
        """Run virtual batch normalization on inputs.

    Args:
      inputs: Tensor input.

    Returns:
       A virtual batch normalized version of `inputs`.

    Raises:
       ValueError: If `inputs` shape isn't compatible with the reference batch.
    """
        _validate_call_input([inputs, self._reference_batch], self._batch_axis)

        with ops.name_scope(self._vs.name,
                            values=[inputs, self._reference_batch]):
            # Calculate the statistics on the current input on a per-example basis.
            vb_mean, vb_mean_sq = self._virtual_statistics(
                inputs, self._example_reduction_axes)
            vb_variance = vb_mean_sq - math_ops.square(vb_mean)

            # The exact broadcast shape of the input statistic Tensors depends on the
            # current batch, not the reference batch. The parameter broadcast shape
            # is independent of the shape of the input statistic Tensor dimensions.
            b_shape = self._broadcast_shape[:]  # deep copy
            b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
                inputs, self._batch_axis)
            return nn.batch_normalization(
                inputs, self._broadcast(vb_mean, b_shape),
                self._broadcast(vb_variance, b_shape),
                self._broadcast(self._beta, self._broadcast_shape),
                self._broadcast(self._gamma, self._broadcast_shape),
                self._epsilon)
コード例 #9
0
  def call(self, inputs):
    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)

    # Calculate the moments on the last axis (layer activations).
    mean, variance = nn.moments(inputs, self.axis, keep_dims=True)

    # Broadcasting only necessary for norm where the axis is not just
    # the last dimension
    broadcast_shape = [1] * ndims
    for dim in self.axis:
      broadcast_shape[dim] = input_shape.dims[dim].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          self.axis != [ndims - 1]):
        return array_ops.reshape(v, broadcast_shape)
      return v
    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    # Compute layer normalization using the batch_normalization function.
    outputs = nn.batch_normalization(
        inputs,
        mean,
        variance,
        offset=offset,
        scale=scale,
        variance_epsilon=self.epsilon)

    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    return outputs
コード例 #10
0
 def reference_batch_normalization(self):
   """Return the reference batch, but batch normalized."""
   with ops.name_scope(self._vs.name):
     return nn.batch_normalization(self._reference_batch,
                                   self._broadcast(self._ref_mean),
                                   self._broadcast(self._ref_variance),
                                   self._broadcast(self._beta),
                                   self._broadcast(self._gamma), self._epsilon)
コード例 #11
0
 def reference_batch_normalization(self):
     """Return the reference batch, but batch normalized."""
     with ops.name_scope(self._vs.name):
         return nn.batch_normalization(self._reference_batch,
                                       self._broadcast(self._ref_mean),
                                       self._broadcast(self._ref_variance),
                                       self._broadcast(self._beta),
                                       self._broadcast(self._gamma),
                                       self._epsilon)
コード例 #12
0
def my_batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
    """Applies batch normalization on x given mean, var, beta and gamma.
    I.e. returns:
    `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
    Arguments:
        x: Input tensor or variable.
        mean: Mean of batch.
        var: Variance of batch.
        beta: Tensor with which to center the input.
        gamma: Tensor by which to scale the input.
        axis: Integer, the axis that should be normalized.
            (typically the features axis).
        epsilon: Fuzz factor.
    Returns:
        A tensor.
    """

    if K.ndim(x) == 4:
        print("hey")
        # The CPU implementation of `fused_batch_norm` only supports NHWC
        if axis == 1 or axis == -3:
            tf_data_format = 'NCHW'
        elif axis == 3 or axis == -1:
            tf_data_format = 'NHWC'
        else:
            tf_data_format = None

        if (tf_data_format == 'NHWC'
                or tf_data_format == 'NCHW' and _has_nchw_support()):
            # The mean / var / beta / gamma tensors may be broadcasted
            # so they may have extra axes of size 1, which should be squeezed.
            if K.ndim(mean) > 1:
                mean = array_ops.reshape(mean, [-1])
            if K.ndim(var) > 1:
                var = array_ops.reshape(var, [-1])
            if beta is None:
                beta = zeros_like(mean)
            elif K.ndim(beta) > 1:
                beta = array_ops.reshape(beta, [-1])
            if gamma is None:
                gamma = ones_like(mean)
            elif K.ndim(gamma) > 1:
                gamma = array_ops.reshape(gamma, [-1])
        y, _, _ = nn.fused_batch_norm(x,
                                      gamma,
                                      beta,
                                      epsilon=epsilon,
                                      mean=mean,
                                      variance=var,
                                      data_format=tf_data_format,
                                      is_training=False)
        return y

    return tf.map_fn(
        lambda xx: nn.batch_normalization(xx, mean, var, beta, gamma, epsilon),
        x)
コード例 #13
0
 def normalize_in_test():
   if needs_broadcasting:
     broadcast_moving_mean = array_ops.reshape(self.moving_mean,
                                               broadcast_shape)
     broadcast_moving_variance = array_ops.reshape(self.moving_variance,
                                                   broadcast_shape)
     return nn.batch_normalization(inputs,
                                   broadcast_moving_mean,
                                   broadcast_moving_variance,
                                   broadcast_beta,
                                   broadcast_gamma,
                                   self.epsilon)
   else:
     return nn.batch_normalization(inputs,
                                   self.moving_mean,
                                   self.moving_variance,
                                   self.beta if self.center else None,
                                   self.gamma if self.scale else None,
                                   self.epsilon)
コード例 #14
0
 def instance_norm(self, inputs, inputs_latent, name):
     inputs_rank = inputs.shape.ndims
     n_outputs = np.int(inputs.shape[-1])
     n_batch = np.int(inputs.shape[0])
     inputs_latent_flatten = tf.layers.flatten(inputs_latent)
     gamma = self.MLP(inputs_latent_flatten, n_outputs, name+"_gamma")
     beta = self.MLP(inputs_latent_flatten, n_outputs, name+"_beta")
     gamma = tf.reshape(gamma, [n_batch, 1, 1, n_outputs])
     beta = tf.reshape(beta, [n_batch, 1, 1, n_outputs])
     moments_axes = list(range(inputs_rank))
     mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
     outputs = nn.batch_normalization(
         inputs, mean, variance, beta, gamma, 1e-6, name=name)
     return outputs
コード例 #15
0
ファイル: tf_lstm.py プロジェクト: ScartleRoy/TF_LSTM_seq_bn
def batch_norm(x, deterministic, alpha=0.9, shift=True, scope='bn'):
    with vs.variable_scope(scope):
        dtype = x.dtype
        input_shape = x.get_shape().as_list()
        feat_dim = input_shape[-1]
        axes = range(len(input_shape)-1)
        
        if shift:
            beta = vs.get_variable(
                    scope+"_beta", shape=[feat_dim],
                    initializer=init_ops.zeros_initializer, dtype=dtype)
        else:
            beta = vs.get_variable(
                scope+"_beta", shape=[feat_dim],
                initializer=init_ops.zeros_initializer, 
                dtype=dtype, trainable=False)
        
        gamma = vs.get_variable(
                    scope+"_gamma", shape=[feat_dim],
                    initializer=init_ops.constant_initializer(0.1), dtype=dtype)
        
        mean = vs.get_variable(scope+"_mean", shape=[feat_dim],
                                       initializer=init_ops.zeros_initializer,
                                       dtype=dtype, trainable=False)
        
        var = vs.get_variable(scope+"_var", shape=[feat_dim],
                                          initializer=init_ops.ones_initializer,
                                          dtype=dtype, trainable=False)
        
        counter = vs.get_variable(scope+"_counter", shape=[],
                                          initializer=init_ops.constant_initializer(0),
                                          dtype=tf.int64, trainable=False)
        
        zero_cnt = vs.get_variable(scope+"_zero_cnt", shape=[],
                                          initializer=init_ops.constant_initializer(0),
                                          dtype=tf.int64, trainable=False)
        
        batch_mean, batch_var = moments(x, axes, name=scope+'_moments')
        
        mean, var = cond(math_ops.equal(counter, zero_cnt), lambda: (batch_mean, batch_var), 
                         lambda: (mean, var))
        
         
        mean, var, counter = cond(deterministic, lambda: (mean, var, counter), 
                                 lambda: ((1-alpha) * batch_mean + alpha * mean, 
                                         (1-alpha) * batch_var + alpha * var, 
                                         counter + 1))
        normed = batch_normalization(x, mean, var, beta, gamma, 1e-8)
    return normed
コード例 #16
0
 def normalize_in_training():
     arg_mean = broadcast_mean if needs_broadcasting else mean
     arg_variance = broadcast_variance if needs_broadcasting else variance
     arg_beta = broadcast_beta if needs_broadcasting else (
         self.beta if self.center else None)
     arg_gamma = broadcast_gamma if needs_broadcasting else (
         self.gamma if self.scale else None)
     if self.quantizer is None:
         return nn.batch_normalization(inputs, arg_mean, arg_variance,
                                       arg_beta, arg_gamma,
                                       self.epsilon)
     else:
         return qbatch_normalization(inputs, arg_mean, arg_variance,
                                     arg_beta, arg_gamma, self.epsilon,
                                     self.quantizer)
コード例 #17
0
 def normalize_in_test():
     if needs_broadcasting:
         broadcast_moving_mean = array_ops.reshape(
             self.moving_mean, broadcast_shape)
         broadcast_moving_variance = array_ops.reshape(
             self.moving_variance, broadcast_shape)
     arg_mean = broadcast_moving_mean if needs_broadcasting else self.moving_mean
     arg_variance = broadcast_moving_variance if needs_broadcasting else self.moving_variance
     arg_beta = broadcast_beta if needs_broadcasting else (
         self.beta if self.center else None)
     arg_gamma = broadcast_gamma if needs_broadcasting else (
         self.gamma if self.scale else None)
     if self.quantizer is None:
         return nn.batch_normalization(inputs, arg_mean, arg_variance,
                                       arg_beta, arg_gamma,
                                       self.epsilon)
     else:
         return qbatch_normalization(inputs, arg_mean, arg_variance,
                                     arg_beta, arg_gamma, self.epsilon,
                                     self.quantizer)
コード例 #18
0
    def call(self, inputs, training=True):
        layer_inputs = inputs[0]
        mix_weights = inputs[1]
        self.assign_mixture_value(name="template_beta",
                                  mixture_weights=mix_weights)
        self.assign_mixture_value(name="template_gamma",
                                  mixture_weights=mix_weights)
        norm_input = BatchNormalization.call(self, layer_inputs, training)

        input_shape = layer_inputs.shape
        reduction_axes = [
            i for i in range(len(input_shape)) if i not in self.axis
        ]
        mean, var = nn.moments(norm_input, reduction_axes, keep_dims=True)
        scale = self._broadcast(self.template_gamma, input_shape)
        offset = self._broadcast(self.template_beta, input_shape)
        output = nn.batch_normalization(norm_input, mean, var, offset, scale,
                                        self.epsilon)
        self.reset_all_values()
        return output
コード例 #19
0
            def my_graph(a):
                with ops.device("/device:IPU:0"):
                    with variable_scope.variable_scope("", use_resource=True):

                        beta = variable_scope.get_variable(
                            "x",
                            dtype=np.float16,
                            shape=[4],
                            initializer=init_ops.constant_initializer(0.0))
                        gamma = variable_scope.get_variable(
                            "y",
                            dtype=np.float16,
                            shape=[4],
                            initializer=init_ops.constant_initializer(1.0))

                        b_mean, b_var = nn.moments(a, [0, 1, 2],
                                                   name='moments')

                        normed = nn.batch_normalization(
                            a, b_mean, b_var, beta, gamma, 1e-3)
                        return normed
コード例 #20
0
    def testBatchNormalizeFp16(self):
        x = array_ops.placeholder(np.float16, [4, 64, 64, 4], name="a")

        with ops.device("/device:IPU:0"):
            with variable_scope.variable_scope("", use_resource=True):

                beta = variable_scope.get_variable(
                    "x",
                    dtype=np.float16,
                    shape=[4],
                    initializer=init_ops.constant_initializer(0.0))
                gamma = variable_scope.get_variable(
                    "y",
                    dtype=np.float16,
                    shape=[4],
                    initializer=init_ops.constant_initializer(1.0))

                b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments')

                normed = nn.batch_normalization(x, b_mean, b_var, beta, gamma,
                                                1e-3)

        with ops.device('cpu'):
            report = gen_ipu_ops.ipu_event_trace()

        tu.configure_ipu_system()

        with tu.ipu_session() as sess:
            sess.run(report)

            sess.run(variables.global_variables_initializer())
            result = sess.run(normed, {x: np.zeros([4, 64, 64, 4])})
            self.assertAllClose(result, np.zeros([4, 64, 64, 4]))

            rep = sess.run(report)
            s = tu.extract_all_strings_from_event_trace(rep)
            cs = tu.get_compute_sets_from_report(s)

            bl = ['*convert*/Cast*']
            self.assertTrue(tu.check_compute_sets_not_in_blacklist(cs, bl))
コード例 #21
0
def instance_norm(inputs,
                  center=True,
                  scale=True,
                  epsilon=1e-6,
                  activation_fn=None,
                  param_initializers=None,
                  reuse=None,
                  variables_collections=None,
                  outputs_collections=None,
                  trainable=True,
                  data_format=DATA_FORMAT_NHWC,
                  scope=None):
  """Functional interface for the instance normalization layer.

  Reference: https://arxiv.org/abs/1607.08022.

    "Instance Normalization: The Missing Ingredient for Fast Stylization"
    Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky

  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    scope: Optional scope for `variable_scope`.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
  inputs = ops.convert_to_tensor(inputs)
  inputs_shape = inputs.shape
  inputs_rank = inputs.shape.ndims

  if inputs_rank is None:
    raise ValueError('Inputs %s has undefined rank.' % inputs.name)
  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')

  with variable_scope.variable_scope(
      scope, 'InstanceNorm', [inputs], reuse=reuse) as sc:
    if data_format == DATA_FORMAT_NCHW:
      reduction_axis = 1
      # For NCHW format, rather than relying on implicit broadcasting, we
      # explicitly reshape the params to params_shape_broadcast when computing
      # the moments and the batch normalization.
      params_shape_broadcast = list(
          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
    else:
      reduction_axis = inputs_rank - 1
      params_shape_broadcast = None
    moments_axes = list(range(inputs_rank))
    del moments_axes[reduction_axis]
    del moments_axes[0]
    params_shape = inputs_shape[reduction_axis:reduction_axis + 1]
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
          inputs.name, params_shape))

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    dtype = inputs.dtype.base_dtype
    if param_initializers is None:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(
          variables_collections, 'beta')
      beta_initializer = param_initializers.get(
          'beta', init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
      if params_shape_broadcast:
        beta = array_ops.reshape(beta, params_shape_broadcast)
    if scale:
      gamma_collections = utils.get_variable_collections(
          variables_collections, 'gamma')
      gamma_initializer = param_initializers.get(
          'gamma', init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)
      if params_shape_broadcast:
        gamma = array_ops.reshape(gamma, params_shape_broadcast)

    # Calculate the moments (instance activations).
    mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)

    # Compute instance normalization.
    outputs = nn.batch_normalization(
        inputs, mean, variance, beta, gamma, epsilon, name='instancenorm')
    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
コード例 #22
0
ファイル: normalization.py プロジェクト: zys-123/tensorflow
  def call(self, inputs):
    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)

    # Broadcasting only necessary for norm where the axis is not just
    # the last dimension
    broadcast_shape = [1] * ndims
    for dim in self.axis:
      broadcast_shape[dim] = input_shape.dims[dim].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          self.axis != [ndims - 1]):
        return array_ops.reshape(v, broadcast_shape)
      return v

    if not self._fused:
      input_dtype = inputs.dtype
      if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32':
        # If mixed precision is used, cast inputs to float32 so that this is at
        # least as numerically stable as the fused version.
        inputs = math_ops.cast(inputs, 'float32')

      # Calculate the moments on the last axis (layer activations).
      mean, variance = nn.moments(inputs, self.axis, keep_dims=True)

      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

      # Compute layer normalization using the batch_normalization function.
      outputs = nn.batch_normalization(
          inputs,
          mean,
          variance,
          offset=offset,
          scale=scale,
          variance_epsilon=self.epsilon)
      outputs = math_ops.cast(outputs, input_dtype)
    else:
      # Collapse dims before self.axis, and dims in self.axis
      pre_dim, in_dim = (1, 1)
      axis = sorted(self.axis)
      tensor_shape = array_ops.shape(inputs)
      for dim in range(0, ndims):
        dim_tensor = tensor_shape[dim]
        if dim < axis[0]:
          pre_dim = pre_dim * dim_tensor
        else:
          assert dim in axis
          in_dim = in_dim * dim_tensor

      squeezed_shape = [1, pre_dim, in_dim, 1]
      # This fused operation requires reshaped inputs to be NCHW.
      data_format = 'NCHW'

      inputs = array_ops.reshape(inputs, squeezed_shape)

      def _set_const_tensor(val, dtype, shape):
        return array_ops.fill(shape, constant_op.constant(val, dtype=dtype))

      # self.gamma and self.beta have the wrong shape for fused_batch_norm, so
      # we cannot pass them as the scale and offset parameters. Therefore, we
      # create two constant tensors in correct shapes for fused_batch_norm and
      # later construct a separate calculation on the scale and offset.
      scale = _set_const_tensor(1.0, self.dtype, [pre_dim])
      offset = _set_const_tensor(0.0, self.dtype, [pre_dim])

      # Compute layer normalization using the fused_batch_norm function.
      outputs, _, _ = nn.fused_batch_norm(
          inputs,
          scale=scale,
          offset=offset,
          epsilon=self.epsilon,
          data_format=data_format)

      outputs = array_ops.reshape(outputs, tensor_shape)

      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

      if scale is not None:
        outputs = outputs * math_ops.cast(scale, outputs.dtype)
      if offset is not None:
        outputs = outputs + math_ops.cast(offset, outputs.dtype)

    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    return outputs
コード例 #23
0
ファイル: ops.py プロジェクト: carpedm20/NAF-tensorflow
def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               scope=None):
  """Code modification of tensorflow/contrib/layers/python/layers/layers.py
  """
  with variable_scope.variable_op_scope([inputs],
                                        scope, 'BatchNorm', reuse=reuse) as sc:
    inputs = ops.convert_to_tensor(inputs)
    inputs_shape = inputs.get_shape()
    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
      raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    dtype = inputs.dtype.base_dtype
    axis = list(range(inputs_rank - 1))
    params_shape = inputs_shape[-1:]
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined last dimension %s.' % (
          inputs.name, params_shape))
    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,
                                                        'beta')
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=init_ops.zeros_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,
                                                         'gamma')
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=init_ops.ones_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)
    # Create moving_mean and moving_variance variables and add them to the
    # appropiate collections.
    moving_mean_collections = utils.get_variable_collections(
        variables_collections, 'moving_mean')
    moving_mean = variables.model_variable(
        'moving_mean',
        shape=params_shape,
        dtype=dtype,
        initializer=init_ops.zeros_initializer,
        trainable=False,
        collections=moving_mean_collections)
    moving_variance_collections = utils.get_variable_collections(
        variables_collections, 'moving_variance')
    moving_variance = variables.model_variable(
        'moving_variance',
        shape=params_shape,
        dtype=dtype,
        initializer=init_ops.ones_initializer,
        trainable=False,
        collections=moving_variance_collections)

    # Calculate the moments based on the individual batch.
    mean, variance = nn.moments(inputs, axis, shift=moving_mean)
    # Update the moving_mean and moving_variance moments.
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, decay)
    if updates_collections is None:
      # Make sure the updates are computed here.
      with ops.control_dependencies([update_moving_mean,
                                      update_moving_variance]):
        outputs = nn.batch_normalization(
            inputs, mean, variance, beta, gamma, epsilon)
    else:
      # Collect the updates to be computed later.
      ops.add_to_collections(updates_collections, update_moving_mean)
      ops.add_to_collections(updates_collections, update_moving_variance)
      outputs = nn.batch_normalization(
          inputs, mean, variance, beta, gamma, epsilon)

    test_outputs = nn.batch_normalization(
        inputs, moving_mean, moving_variance, beta, gamma, epsilon)

    outputs = tf.cond(is_training, lambda: outputs, lambda: test_outputs)
    outputs.set_shape(inputs_shape)

    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
コード例 #24
0
  def call(self, inputs, training=False):
    if self.num_virtual_batches > 1:
      # Virtual batches (aka ghost batches) can be simulated by using some
      # reshape/transpose tricks on top of base batch normalization.
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [-1, self.num_virtual_batches] + original_shape[1:]

      # Will cause errors if num_virtual_batches does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      ndims = len(expanded_shape)
      if self.axis < 0:
        axis = ndims + self.axis
      else:
        axis = self.axis + 1      # Account for the added dimension

      # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis
      # TODO(b/66257056): when multi-axis batch normalization is implemented,
      # this permutation trick and the combined_dim reshape are no longer
      # necessary and can be reworked to simply use broadcasting.
      permutation = ([0] + list(range(2, axis)) + [1, axis] +
                     list(range(axis + 1, ndims)))
      inverse_permutation = [x[1] for x in
                             sorted(zip(permutation, range(ndims)))]
      inputs = array_ops.transpose(inputs, perm=permutation)

      # Combine the axis and num_virtual_batch dimension in order to take
      # advantage of fused batch normalization
      combined_dim = expanded_shape[1] * expanded_shape[axis]
      perm_shape = [-1] + inputs.shape.as_list()[1:]
      combined_shape = (perm_shape[:axis - 1] +
                        [combined_dim] +
                        perm_shape[axis + 1:])
      inputs = array_ops.reshape(inputs, combined_shape)
      # After the above reshape, the batch norm axis is the original self.axis

      # Undoes the reshaping and transposing tricks done above
      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, perm_shape)
        outputs = array_ops.transpose(outputs, perm=inverse_permutation)
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.num_virtual_batches > 1:
        return undo_virtual_batching(outputs)
      return outputs

    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)
      if context.in_graph_mode():
        self.add_update(mean_update, inputs=inputs)
        self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explicitly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     _broadcast(offset),
                                     _broadcast(scale),
                                     self.epsilon)

    if self.num_virtual_batches > 1:
      return undo_virtual_batching(outputs)

    return outputs
コード例 #25
0
ファイル: batch_norm.py プロジェクト: mkabra/poseTF
def batch_norm_mine_old(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               param_initializers=None,
               param_regularizers=None,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               batch_weights=None,
               fused=False,
               data_format=DATA_FORMAT_NHWC,
               zero_debias_moving_mean=False,
               scope=None,
               renorm=False,
               renorm_clipping=None,
               renorm_decay=0.99):
  """
  This earlier version of my modification to batch norm uses
current_mean and current_variance if is_training is True and
moving_mean and moving_variance otherwise. This was leading a large divergence between
the results depending upon whether the is_training set to True or not.

I think ideally it should always use moving_mean and moving_variance. batch_norm_mine
does this.

  Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
copy of tensorflow.contrib.layers
  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    decay: Decay for the moving average. Reasonable values for `decay` are close
      to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
      Lower `decay` value (recommend trying `decay`=0.9) if model experiences
      reasonably good training performance but poor validation and/or test
      performance. Try zero_debias_moving_mean=True for improved stability.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    param_regularizers: Optional regularizer for beta and gamma.
    updates_collections: Collections to collect the update ops for computation.
      The updates_ops need to be executed with the train_op.
      If None, a control dependency would be added to make sure the updates are
      computed in place.
    is_training: Whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    batch_weights: An optional tensor of shape `[batch_size]`,
      containing a frequency weight for each batch item. If present,
      then the batch normalization uses weighted mean and
      variance. (This can be used to correct for bias in training
      example selection.)
    fused:  Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
      pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
    scope: Optional scope for `variable_scope`.
    renorm: Whether to use Batch Renormalization
      (https://arxiv.org/abs/1702.03275). This adds extra variables during
      training. The inference is the same for either value of this parameter.
    renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
      scalar `Tensors` used to clip the renorm correction. The correction
      `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
      `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
      dmax are set to inf, 0, inf, respectively.
    renorm_decay: Momentum used to update the moving means and standard
      deviations with renorm. Unlike `momentum`, this affects training
      and should be neither too small (which would add noise) nor too large
      (which would give stale estimates). Note that `decay` is still applied
      to get the means and variances for inference.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `batch_weights` is not None and `fused` is True.
    ValueError: If `param_regularizers` is not None and `fused` is True.
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
  if fused:
    if batch_weights is not None:
      raise ValueError('Weighted mean and variance is not currently '
                       'supported for fused batch norm.')
    if param_regularizers is not None:
      raise ValueError('Regularizers are not currently '
                       'supported for fused batch norm.')
    if renorm:
      raise ValueError('Renorm is not supported for fused batch norm.')
    return _fused_batch_norm(
        inputs,
        decay=decay,
        center=center,
        scale=scale,
        epsilon=epsilon,
        activation_fn=activation_fn,
        param_initializers=param_initializers,
        updates_collections=updates_collections,
        is_training=is_training,
        reuse=reuse,
        variables_collections=variables_collections,
        outputs_collections=outputs_collections,
        trainable=trainable,
        data_format=data_format,
        zero_debias_moving_mean=zero_debias_moving_mean,
        scope=scope)

  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')

  layer_variable_getter = _build_variable_getter()
  with variable_scope.variable_scope(
      scope, 'BatchNorm', [inputs], reuse=reuse,
      custom_getter=layer_variable_getter) as sc:
    inputs = ops.convert_to_tensor(inputs)

    # Determine whether we can use the core layer class.
    if (batch_weights is None and
        updates_collections is ops.GraphKeys.UPDATE_OPS and
        not zero_debias_moving_mean):
      # Use the core layer class.
      axis = 1 if data_format == DATA_FORMAT_NCHW else -1
      if not param_initializers:
        param_initializers = {}
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      if not param_regularizers:
        param_regularizers = {}
      beta_regularizer = param_regularizers.get('beta')
      gamma_regularizer = param_regularizers.get('gamma')
      layer = normalization_layers.BatchNormalization(
          axis=axis,
          momentum=decay,
          epsilon=epsilon,
          center=center,
          scale=scale,
          beta_initializer=beta_initializer,
          gamma_initializer=gamma_initializer,
          moving_mean_initializer=moving_mean_initializer,
          moving_variance_initializer=moving_variance_initializer,
          beta_regularizer=beta_regularizer,
          gamma_regularizer=gamma_regularizer,
          trainable=trainable,
          renorm=renorm,
          renorm_clipping=renorm_clipping,
          renorm_momentum=renorm_decay,
          name=sc.name,
          _scope=sc,
          _reuse=reuse)
      outputs = layer.apply(inputs, training=is_training)

      # Add variables to collections.
      _add_variable_to_collections(
          layer.moving_mean, variables_collections, 'moving_mean')
      _add_variable_to_collections(
          layer.moving_variance, variables_collections, 'moving_variance')
      if layer.beta:
        _add_variable_to_collections(layer.beta, variables_collections, 'beta')
      if layer.gamma:
        _add_variable_to_collections(
            layer.gamma, variables_collections, 'gamma')

      if activation_fn is not None:
        outputs = activation_fn(outputs)
      return utils.collect_named_outputs(outputs_collections,
                                         sc.original_name_scope, outputs)

    # Not supported by layer class: batch_weights argument,
    # and custom updates_collections. In that case, use the legacy BN
    # implementation.
    # Custom updates collections are not supported because the update logic
    # is different in this case, in particular w.r.t. "forced updates" and
    # update op reuse.
    if renorm:
      raise ValueError('renorm is not supported with batch_weights, '
                       'updates_collections or zero_debias_moving_mean')
    inputs_shape = inputs.get_shape()
    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
      raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    dtype = inputs.dtype.base_dtype
    if batch_weights is not None:
      batch_weights = ops.convert_to_tensor(batch_weights)
      inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
      # Reshape batch weight values so they broadcast across inputs.
      nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
      batch_weights = array_ops.reshape(batch_weights, nshape)

    if data_format == DATA_FORMAT_NCHW:
      moments_axes = [0] + list(range(2, inputs_rank))
      params_shape = inputs_shape[1:2]
      # For NCHW format, rather than relying on implicit broadcasting, we
      # explicitly reshape the params to params_shape_broadcast when computing
      # the moments and the batch normalization.
      params_shape_broadcast = list(
          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
    else:
      moments_axes = list(range(inputs_rank - 1))
      params_shape = inputs_shape[-1:]
      params_shape_broadcast = None
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
          inputs.name, params_shape))

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if not param_initializers:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,
                                                        'beta')
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,
                                                         'gamma')
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)

    # Create moving_mean and moving_variance variables and add them to the
    # appropriate collections. We disable variable partitioning while creating
    # them, because assign_moving_average is not yet supported for partitioned
    # variables.
    partitioner = variable_scope.get_variable_scope().partitioner
    try:
      variable_scope.get_variable_scope().set_partitioner(None)
      moving_mean_collections = utils.get_variable_collections(
          variables_collections, 'moving_mean')
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_mean = variables.model_variable(
          'moving_mean',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_mean_initializer,
          trainable=False,
          collections=moving_mean_collections)
      moving_variance_collections = utils.get_variable_collections(
          variables_collections, 'moving_variance')
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      moving_variance = variables.model_variable(
          'moving_variance',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_variance_initializer,
          trainable=False,
          collections=moving_variance_collections)
    finally:
      variable_scope.get_variable_scope().set_partitioner(partitioner)

    # If `is_training` doesn't have a constant value, because it is a `Tensor`,
    # a `Variable` or `Placeholder` then is_training_value will be None and
    # `needs_moments` will be true.
    is_training_value = utils.constant_value(is_training)
    need_moments = is_training_value is None or is_training_value
    if need_moments:
      # Calculate the moments based on the individual batch.
      if batch_weights is None:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.moments(inputs, moments_axes, keep_dims=True)
          variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.moments(inputs, moments_axes)
          variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes)
      else:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights, keep_dims=True)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights)

      moving_vars_fn = lambda: (moving_mean, moving_variance)
      if updates_collections is None:
        def _force_updates():
          """Internal function forces updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          with ops.control_dependencies([update_moving_mean,
                                         update_moving_variance]):
            return array_ops.identity(mean), array_ops.identity(variance)
        mean, variance = utils.smart_cond(is_training,
                                          _force_updates,
                                          moving_vars_fn)
      else:
        def _delay_updates():
          """Internal function that delay updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          return update_moving_mean, update_moving_variance

        update_mean, update_variance = utils.smart_cond(is_training,
                                                        _delay_updates,
                                                        moving_vars_fn)
        ops.add_to_collections(updates_collections, update_mean)
        ops.add_to_collections(updates_collections, update_variance)
        # Use computed moments during training and moving_vars otherwise.
        vars_fn = lambda: (mean, variance)
        mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
    else:
      mean, variance = moving_mean, moving_variance
    if data_format == DATA_FORMAT_NCHW:
      mean = array_ops.reshape(mean, params_shape_broadcast)
      variance = array_ops.reshape(variance, params_shape_broadcast)
      beta = array_ops.reshape(beta, params_shape_broadcast)
      if gamma is not None:
        gamma = array_ops.reshape(gamma, params_shape_broadcast)

    # Compute batch_normalization.
    outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                     epsilon)
    outputs.set_shape(inputs_shape)
    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections,
                                       sc.original_name_scope, outputs)
コード例 #26
0
ファイル: normalization.py プロジェクト: zxs666/tensorflow
    def call(self, inputs, training=False):
        if self.num_virtual_batches > 1:
            # Virtual batches (aka ghost batches) can be simulated by using some
            # reshape/transpose tricks on top of base batch normalization.
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [-1, self.num_virtual_batches
                              ] + original_shape[1:]

            # Will cause errors if num_virtual_batches does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            ndims = len(expanded_shape)
            if self.axis < 0:
                axis = ndims + self.axis
            else:
                axis = self.axis + 1  # Account for the added dimension

            # Permute the num_virtual_batch dimension (dim 1) to be adjacent to axis
            # TODO(b/66257056): when multi-axis batch normalization is implemented,
            # this permutation trick and the combined_dim reshape are no longer
            # necessary and can be reworked to simply use broadcasting.
            permutation = ([0] + list(range(2, axis)) + [1, axis] +
                           list(range(axis + 1, ndims)))
            inverse_permutation = [
                x[1] for x in sorted(zip(permutation, range(ndims)))
            ]
            inputs = array_ops.transpose(inputs, perm=permutation)

            # Combine the axis and num_virtual_batch dimension in order to take
            # advantage of fused batch normalization
            combined_dim = expanded_shape[1] * expanded_shape[axis]
            perm_shape = [-1] + inputs.shape.as_list()[1:]
            combined_shape = (perm_shape[:axis - 1] + [combined_dim] +
                              perm_shape[axis + 1:])
            inputs = array_ops.reshape(inputs, combined_shape)

            # After the above reshape, the batch norm axis is the original self.axis

            # Undoes the reshaping and transposing tricks done above
            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, perm_shape)
                outputs = array_ops.transpose(outputs,
                                              perm=inverse_permutation)
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.num_virtual_batches > 1:
                return undo_virtual_batching(outputs)
            return outputs

        # First, compute the axes along which to reduce the mean / variance,
        # as well as the broadcast shape to be used for all parameters.
        input_shape = inputs.get_shape()
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis].value

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        scale, offset = self.gamma, self.beta

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = utils.constant_value(training)
        if training_value is not False:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            mean, variance = nn.moments(inputs, reduction_axes)
            mean = _smart_select(training, lambda: mean,
                                 lambda: self.moving_mean)
            variance = _smart_select(training, lambda: variance,
                                     lambda: self.moving_variance)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                scale = array_ops.stop_gradient(r, name='renorm_r')
                offset = array_ops.stop_gradient(d, name='renorm_d')
                if self.gamma is not None:
                    scale *= self.gamma
                    offset *= self.gamma
                if self.beta is not None:
                    offset += self.beta
            else:
                new_mean, new_variance = mean, variance

            # Update moving averages when training, and prevent updates otherwise.
            decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
            mean_update = moving_averages.assign_moving_average(
                self.moving_mean, new_mean, decay, zero_debias=False)
            variance_update = moving_averages.assign_moving_average(
                self.moving_variance, new_variance, decay, zero_debias=False)
            if context.in_graph_mode():
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        def _broadcast(v):
            if needs_broadcasting and v is not None:
                # In this case we must explicitly broadcast all parameters.
                return array_ops.reshape(v, broadcast_shape)
            return v

        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance),
                                         _broadcast(offset), _broadcast(scale),
                                         self.epsilon)

        if self.num_virtual_batches > 1:
            return undo_virtual_batching(outputs)

        return outputs
コード例 #27
0
  def call(self, inputs, training=False):
    if self.fused:
      return self._fused_batch_norm(inputs, training=training)

    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)

      self.add_update(mean_update, inputs=inputs)
      self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explicitly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    return nn.batch_normalization(inputs,
                                  _broadcast(mean),
                                  _broadcast(variance),
                                  _broadcast(offset),
                                  _broadcast(scale),
                                  self.epsilon)
コード例 #28
0
  def call(self, inputs, training=False):
    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        return undo_virtual_batching(outputs)
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.get_shape()
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and
          len(v.get_shape()) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = utils.smart_cond(training,
                                     lambda: adj_scale,
                                     lambda: array_ops.ones_like(adj_scale))
        adj_bias = utils.smart_cond(training,
                                    lambda: adj_bias,
                                    lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = utils.smart_cond(training,
                              lambda: mean,
                              lambda: moving_mean)
      variance = utils.smart_cond(training,
                                  lambda: variance,
                                  lambda: moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)
      else:
        new_mean, new_variance = mean, variance

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(new_mean,
                                        axis=1, keep_dims=True)
        new_variance = math_ops.reduce_mean(new_variance,
                                            axis=1, keep_dims=True)

      def _do_update(var, value):
        return moving_averages.assign_moving_average(
            var, value, self.momentum, zero_debias=False)

      mean_update = utils.smart_cond(
          training,
          lambda: _do_update(self.moving_mean, new_mean),
          lambda: self.moving_mean)
      variance_update = utils.smart_cond(
          training,
          lambda: _do_update(self.moving_variance, new_variance),
          lambda: self.moving_variance)
      if context.in_graph_mode():
        self.add_update(mean_update, inputs=inputs)
        self.add_update(variance_update, inputs=inputs)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      return undo_virtual_batching(outputs)

    return outputs
コード例 #29
0
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(inputs,
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training, lambda: mean,
                                       lambda: moving_mean)
            variance = tf_utils.smart_cond(training, lambda: variance,
                                           lambda: moving_variance)

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean, variance

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            if distribution_strategy_context.in_cross_replica_context():
                strategy = distribution_strategy_context.get_strategy()

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return strategy.extended.update(
                        var,
                        self._assign_moving_average, (value, self.momentum),
                        group=False)

                # We need to unwrap the moving_mean or moving_variance in the case of
                # training being false to match the output of true_fn and false_fn
                # in the smart cond.
                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: strategy.unwrap(self.moving_mean))
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: strategy.unwrap(self.moving_variance))
            else:

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return self._assign_moving_average(var, value,
                                                       self.momentum)

                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: self.moving_mean)
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
コード例 #30
0
def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               scope=None):
  """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
    "Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift"
    Sergey Ioffe, Christian Szegedy
  Can be used as a normalizer function for conv2d and fully_connected.
  Args:
    -inputs: a tensor of size `[batch_size, height, width, channels]`
            or `[batch_size, channels]`.
    -decay: decay for the moving average.
    -center: If True, subtract `beta`. If False, `beta` is ignored.
    -scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    -epsilon: small float added to variance to avoid dividing by zero.
    -activation_fn: Optional activation function.
    -updates_collections: collections to collect the update ops for computation.
      If None, a control dependency would be added to make sure the updates are
      computed.
    -is_training: whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    -reuse: whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    -variables_collections: optional collections for the variables.
    -outputs_collections: collections to add the outputs.
    -trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
    -scope: Optional scope for `variable_op_scope`.
  Returns:
    a tensor representing the output of the operation.
  """
  with variable_scope.variable_op_scope([inputs],scope, 'BatchNorm', reuse=reuse) as sc:
    inputs_shape = inputs.get_shape()
    dtype = inputs.dtype.base_dtype
    axis = list(range(len(inputs_shape) - 1))
    params_shape = inputs_shape[-1:]
    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,'beta')
      beta = variables.model_variable('beta',shape=params_shape,dtype=dtype,initializer=init_ops.zeros_initializer,collections=beta_collections,trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,'gamma')
      gamma = variables.model_variable('gamma',shape=params_shape,dtype=dtype,initializer=init_ops.ones_initializer,collections=gamma_collections,trainable=trainable)
    # Create moving_mean and moving_variance variables and add them to the
    # appropiate collections.
    moving_mean_collections = utils.get_variable_collections(variables_collections, 'moving_mean')
    moving_mean = variables.model_variable('moving_mean',shape=params_shape,dtype=dtype,initializer=init_ops.zeros_initializer,trainable=False,collections=moving_mean_collections)
    moving_variance_collections = utils.get_variable_collections(variables_collections, 'moving_variance')
    moving_variance = variables.model_variable('moving_variance',shape=params_shape,dtype=dtype,initializer=init_ops.ones_initializer,trainable=False,collections=moving_variance_collections)
    if is_training:
      # Calculate the moments based on the individual batch.
      mean, variance = nn.moments(inputs, axis, shift=moving_mean)
      # Update the moving_mean and moving_variance moments.
      update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay)
      update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay)
      if updates_collections is None:
        # Make sure the updates are computed here.
        with ops.control_dependencies([update_moving_mean,update_moving_variance]):
          outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
      else:
        # Collect the updates to be computed later.
        ops.add_to_collections(updates_collections, update_moving_mean)
        ops.add_to_collections(updates_collections, update_moving_variance)
        outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
    else:
      outputs = nn.batch_normalization(
          inputs, moving_mean, moving_variance, beta, gamma, epsilon)
    outputs.set_shape(inputs.get_shape())
    if activation_fn:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
コード例 #31
0
ファイル: ops.py プロジェクト: zhexiaozhe/NAF-tensorflow
def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               scope=None):
    """Code modification of tensorflow/contrib/layers/python/layers/layers.py
  """
    with variable_scope.variable_op_scope([inputs],
                                          scope,
                                          'BatchNorm',
                                          reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        inputs_shape = inputs.get_shape()
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        axis = list(range(inputs_rank - 1))
        params_shape = inputs_shape[-1:]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined last dimension %s.' %
                             (inputs.name, params_shape))
        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None
        if center:
            beta_collections = utils.get_variable_collections(
                variables_collections, 'beta')
            beta = variables.model_variable(
                'beta',
                shape=params_shape,
                dtype=dtype,
                initializer=init_ops.zeros_initializer,
                collections=beta_collections,
                trainable=trainable)
        if scale:
            gamma_collections = utils.get_variable_collections(
                variables_collections, 'gamma')
            gamma = variables.model_variable(
                'gamma',
                shape=params_shape,
                dtype=dtype,
                initializer=init_ops.ones_initializer,
                collections=gamma_collections,
                trainable=trainable)
        # Create moving_mean and moving_variance variables and add them to the
        # appropiate collections.
        moving_mean_collections = utils.get_variable_collections(
            variables_collections, 'moving_mean')
        moving_mean = variables.model_variable(
            'moving_mean',
            shape=params_shape,
            dtype=dtype,
            initializer=init_ops.zeros_initializer,
            trainable=False,
            collections=moving_mean_collections)
        moving_variance_collections = utils.get_variable_collections(
            variables_collections, 'moving_variance')
        moving_variance = variables.model_variable(
            'moving_variance',
            shape=params_shape,
            dtype=dtype,
            initializer=init_ops.ones_initializer,
            trainable=False,
            collections=moving_variance_collections)

        # Calculate the moments based on the individual batch.
        mean, variance = nn.moments(inputs, axis, shift=moving_mean)
        # Update the moving_mean and moving_variance moments.
        update_moving_mean = moving_averages.assign_moving_average(
            moving_mean, mean, decay)
        update_moving_variance = moving_averages.assign_moving_average(
            moving_variance, variance, decay)
        if updates_collections is None:
            # Make sure the updates are computed here.
            with ops.control_dependencies(
                [update_moving_mean, update_moving_variance]):
                outputs = nn.batch_normalization(inputs, mean, variance, beta,
                                                 gamma, epsilon)
        else:
            # Collect the updates to be computed later.
            ops.add_to_collections(updates_collections, update_moving_mean)
            ops.add_to_collections(updates_collections, update_moving_variance)
            outputs = nn.batch_normalization(inputs, mean, variance, beta,
                                             gamma, epsilon)

        test_outputs = nn.batch_normalization(inputs, moving_mean,
                                              moving_variance, beta, gamma,
                                              epsilon)

        outputs = tf.cond(is_training, lambda: outputs, lambda: test_outputs)
        outputs.set_shape(inputs_shape)

        return utils.collect_named_outputs(outputs_collections, sc.name,
                                           outputs)
コード例 #32
0
    def call(self, inputs, training=None):
        training = self._get_training_value(training)

        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            mean, variance = self.moving_mean, self.moving_variance
        else:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(math_ops.cast(
                inputs, self._param_dtype),
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: ops.convert_to_tensor(moving_variance))

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean, variance

            if self._support_zero_size_input():
                inputs_size = array_ops.size(inputs)
            else:
                inputs_size = None
            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training, inputs_size)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""
                def true_branch_renorm():
                    # We apply epsilon as part of the moving_stddev to mirror the training
                    # code path.
                    moving_stddev = _do_update(
                        self.moving_stddev,
                        math_ops.sqrt(new_variance + self.epsilon))
                    return self._assign_new_value(
                        self.moving_variance,
                        # Apply relu in case floating point rounding causes it to go
                        # negative.
                        K.relu(moving_stddev * moving_stddev - self.epsilon))

                if self.renorm:
                    true_branch = true_branch_renorm
                else:
                    true_branch = lambda: _do_update(self.moving_variance,
                                                     new_variance)

                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
        # math in float16 hurts validation accuracy of popular models like resnet.
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
コード例 #33
0
input_channels = 8
fixed_size = 8
fixed_prec = 4
testdata_scale = 10

inputs_vals = np.random.normal(size=(batch_size, input_width, input_height,
                                     input_channels)) * testdata_scale // 1

inputs = tf.constant(inputs_vals, dtype=tf.float64)

means, variances = nn.moments(inputs, [0, 1, 2, 3])

quantizer = Quantizers.NoQuantizer()
output = QBN.qbatch_normalization(inputs, means, variances, None, None, 0.0001,
                                  quantizer)
gold_output = nn.batch_normalization(inputs, means, variances, None, None,
                                     0.0001)

with tf.Session() as sess:
    gold_result = gold_output.eval().flatten()
    result = output.eval().flatten()
    #print(sess.run(output))
    #print('------------')
    #print(sess.run(gold_output))
    print('mean: %f' % (sess.run(means)))
    print('variance: %f' % (sess.run(variances)))
    pass

failed = False
for i in range(len(result)):
    if result[i] != gold_result[i]:
        failed = True
コード例 #34
0
  def call(self, inputs, training=None):
    if training is None:
      training = K.learning_phase()

    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        outputs = undo_virtual_batching(outputs)
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = tf_utils.smart_cond(training,
                                        lambda: adj_scale,
                                        lambda: array_ops.ones_like(adj_scale))
        adj_bias = tf_utils.smart_cond(training,
                                       lambda: adj_bias,
                                       lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = self._moments(
          math_ops.cast(inputs, self._param_dtype),
          reduction_axes,
          keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: moving_mean)
      variance = tf_utils.smart_cond(training,
                                     lambda: variance,
                                     lambda: moving_variance)

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
      else:
        new_mean, new_variance = mean, variance

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            new_mean, new_variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)

      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()

        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return strategy.extended.update(
              var, self._assign_moving_average, (value, self.momentum),
              group=False)
        # We need to unwrap the moving_mean or moving_variance in the case of
        # training being false to match the output of true_fn and false_fn
        # in the smart cond.
        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: strategy.unwrap(self.moving_mean)
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          return tf_utils.smart_cond(
              training, lambda: _do_update(self.moving_variance, new_variance),
              lambda: strategy.unwrap(self.moving_variance))
      else:
        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return self._assign_moving_average(var, value, self.momentum)

        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: self.moving_mean
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          true_branch = lambda: _do_update(self.moving_variance, new_variance)
          false_branch = lambda: self.moving_variance
          return tf_utils.smart_cond(training, true_branch, false_branch)

      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    mean = math_ops.cast(mean, inputs.dtype)
    variance = math_ops.cast(variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    if scale is not None:
      scale = math_ops.cast(scale, inputs.dtype)
    # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
    # math in float16 hurts validation accuracy of popular models like resnet.
    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      outputs = undo_virtual_batching(outputs)
    return outputs
コード例 #35
0
def _RedoRestBatchnorms(graph, is_training):
    """Finds fused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, true if training.

  Raises:
    ValueError: When batch norm folding fails.
  """
    matches = _FindRestBatchNorms(graph)
    print("Replacing", len(matches), "BatchNorms (without a preceding Conv2D)")
    for match in matches:
        scope, sep, _ = match.bn_op.name.rpartition('/')
        # Make sure new ops are added to `graph` and put on the same device as
        # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
        # named `scope`. Otherwise, TF creates a unique scope whose name starts with
        # `scope`.
        with graph.as_default(), graph.name_scope(scope + sep):
            with graph.name_scope(scope + sep + '_psb' + sep):

                mean = match.mean_tensor
                variance = match.variance_tensor
                beta = match.beta_tensor
                gamma = match.gamma_tensor
                eps = match.batch_epsilon

                # new gamma = gamma / sqrt(variance + epsilon)
                # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
                multfac = gamma / math_ops.sqrt(variance + eps)
                gamma = multfac
                beta = -multfac * mean + beta
                mean = array_ops.zeros_like(mean)
                variance = array_ops.ones_like(variance)
                eps = array_ops.zeros_like(eps)

                gamma = variableFromSettings([], hiddenVar=gamma)[0]
                # gamma = fixed_point(gamma,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min"))
                # gamma = next_base2(gamma,strict_positive=False)
                # gamma = 1/variableFromSettings([],hiddenVar=1/gamma)[0]
                # variance = variableFromSettings([],hiddenVar=math_ops.sqrt(variance+eps))[0]**2
                # beta = variableFromSettings([],hiddenVar=beta)[0]
                if S("util.variable.fixed_point.use"):
                    beta = fixed_point(beta,
                                       S("util.variable.fixed_point.bits"),
                                       max=S("util.variable.fixed_point.max"),
                                       min=S("util.variable.fixed_point.min"))
                    # gamma = fixed_point(gamma,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min"))
                    # mean = fixed_point(mean,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min"))
                    # variance = fixed_point(variance,S("util.variable.fixed_point.bits"),max=S("util.variable.fixed_point.max"),min=S("util.variable.fixed_point.min"))

                # fixed_point division could be ok
                # silly silly_idiv(silly x, silly y) {
                #     uint64_t sign_bit = 1UL<<63;
                #     // unsetting the sign bit to ignore it
                #     silly res = ((x & ~sign_bit) / (y & sign_bit)) << 32;

                #     // setting the sign bit iff only one of sign bits is set
                #     res |= (x & sign_bit) ^ (y & sign_bit);
                #     return res;
                # }

            new_layer_tensor = nn.batch_normalization(
                match.input_tensor,
                mean,
                variance,
                beta,
                gamma,
                eps,
                name=match.bn_op.name.split("/")[-1] + "_psb")
            if S("util.variable.fixed_point.use"):
                new_layer_tensor = fixed_point(
                    new_layer_tensor,
                    S("util.variable.fixed_point.bits"),
                    max=S("util.variable.fixed_point.max"),
                    min=S("util.variable.fixed_point.min"))
            nodes_modified_count = common.RerouteTensor(
                new_layer_tensor, match.output_tensor)
            if nodes_modified_count == 0:
                raise ValueError(
                    'Folding batch norms failed, %s had no outputs.' %
                    match['output_tensor'].name)
コード例 #36
0
ファイル: normalization.py プロジェクト: wmpscc/tensorflow-1
    def call(self, inputs, training=False):
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                return undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        scale, offset = self.gamma, self.beta

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = utils.constant_value(training)
        if training_value is not False:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = utils.smart_cond(training, lambda: mean,
                                    lambda: moving_mean)
            variance = utils.smart_cond(training, lambda: variance,
                                        lambda: moving_variance)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                scale = array_ops.stop_gradient(r, name='renorm_r')
                offset = array_ops.stop_gradient(d, name='renorm_d')
                if self.gamma is not None:
                    scale *= self.gamma
                    offset *= self.gamma
                if self.beta is not None:
                    offset += self.beta
            else:
                new_mean, new_variance = mean, variance

            # Update moving averages when training, and prevent updates otherwise.
            decay = utils.smart_cond(training, lambda: self.momentum,
                                     lambda: 1.)
            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(new_mean,
                                                axis=1,
                                                keep_dims=True)
                new_variance = math_ops.reduce_mean(new_variance,
                                                    axis=1,
                                                    keep_dims=True)

            mean_update = moving_averages.assign_moving_average(
                self.moving_mean, new_mean, decay, zero_debias=False)
            variance_update = moving_averages.assign_moving_average(
                self.moving_variance, new_variance, decay, zero_debias=False)
            if context.in_graph_mode():
                self.add_update(mean_update, inputs=inputs)
                self.add_update(variance_update, inputs=inputs)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
        rank = len(inputs.get_shape())

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != rank
                    and reduction_axes != list(range(ndims))[:-1]):
                return array_ops.reshape(v, broadcast_shape)
            return v

        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance),
                                         _broadcast(offset), _broadcast(scale),
                                         self.epsilon)

        if self.virtual_batch_size is not None:
            return undo_virtual_batching(outputs)

        return outputs
コード例 #37
0
ファイル: modules.py プロジェクト: athicha/metapoison
    def call(self, inputs, params=None, training=None):

        if params[self.name + '/gamma:0'] is None:
            return super(layers.BatchNormalization, self).call(inputs)
        else:
            gamma = params.get(self.name + '/gamma:0')
            beta = params.get(self.name + '/beta:0')

        original_training_value = training
        if training is None:
            training = backend.learning_phase()

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(gamma), _broadcast(beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

        # Some of the computations here are not necessary when training==False
        # but not a constant. However, this makes the code simpler.
        keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
        mean, variance = nn.moments(inputs,
                                    reduction_axes,
                                    keep_dims=keep_dims)

        moving_mean = self.moving_mean
        moving_variance = self.moving_variance

        mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean)
        variance = tf_utils.smart_cond(training, lambda: variance,
                                       lambda: moving_variance)

        if self.virtual_batch_size is not None:
            # This isn't strictly correct since in ghost batch norm, you are
            # supposed to sequentially update the moving_mean and moving_variance
            # with each sub-batch. However, since the moving statistics are only
            # used during evaluation, it is more efficient to just update in one
            # step and should not make a significant difference in the result.
            new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
            new_variance = math_ops.reduce_mean(variance,
                                                axis=1,
                                                keepdims=True)
        else:
            new_mean, new_variance = mean, variance

        if self.renorm:
            r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                new_mean, new_variance, training)
            # When training, the normalized values (say, x) will be transformed as
            # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
            # = x * (r * gamma) + (d * gamma + beta) with renorm.
            r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
            d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
            scale, offset = _compose_transforms(r, d, scale, offset)

        def _do_update(var, value):
            return self._assign_moving_average(var, value, self.momentum)

        mean_update = tf_utils.smart_cond(
            training, lambda: _do_update(self.moving_mean, new_mean),
            lambda: self.moving_mean)
        variance_update = tf_utils.smart_cond(
            training, lambda: _do_update(self.moving_variance, new_variance),
            lambda: self.moving_variance)
        self.add_update(mean_update, inputs=True)
        self.add_update(variance_update, inputs=True)
        # mean, variance = self.moving_mean, self.moving_variance

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        if original_training_value is None:
            outputs._uses_learning_phase = True  # pylint: disable=protected-access
        return outputs
コード例 #38
0
    def _subdiv_batch_norm(self, inputs, training=None):
        # tf.print('bn', self.local_count)
        training = self._get_training_value(training)

        inputs_dtype = inputs.dtype.base_dtype
        if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
            # Do all math in float32 if given 16-bit inputs for numeric stability.
            # In particular, it's very easy for variance to overflow in float16 and
            # for safety we also choose to cast bfloat16 to float32.
            inputs = math_ops.cast(inputs, dtypes.float32)

        params_dtype = self._param_dtype

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        # what does this do...
        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # is training value true false or None
        training_value = control_flow_util.constant_value(training)
        update_value = (self.local_count + 1) % self.subdivisions == 0
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            mean, variance = self.moving_mean, self.moving_variance
        else:
            # training_value could be True or None -> None means determine at runtime
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = control_flow_util.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = control_flow_util.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1

            # normalization stats for the current batch important = mean and squared_mean
            mean, net_sum, variance, squared_mean, input_batch_size = self.subdiv_moments(
                math_ops.cast(inputs, self._param_dtype),
                reduction_axes,
                keep_dims=keep_dims)

            # aggregate the things
            def _update_aggragate_sum():
                return self._assign_subdiv_rotating_sum(
                    self.aggregated_sum_batch, net_sum, self.subdivisions,
                    self.local_count, input_batch_size)

            def _update_aggragate_squared_sum():
                return self._assign_subdiv_rotating_sum(
                    self.aggregated_square_sum_batch, squared_mean,
                    self.subdivisions, self.local_count, input_batch_size)

            def _update_aggragate_batch_size():
                return self._assign_subdiv_rotating_sum(
                    self.aggregated_batch_size, input_batch_size,
                    self.subdivisions, self.local_count, input_batch_size)

            self.add_update(_update_aggragate_sum)
            self.add_update(_update_aggragate_squared_sum)
            self.add_update(_update_aggragate_batch_size)

            aggregated_mean = self.aggregated_sum_batch / math_ops.cast(
                self.aggregated_batch_size, params_dtype)
            aggregated_squared_mean = self.aggregated_square_sum_batch / math_ops.cast(
                self.aggregated_batch_size, params_dtype)
            aggregated_variance = aggregated_squared_mean - math_ops.square(
                aggregated_mean)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            # if we are training use the stats for this batch for normalizing this
            # value other wise use the moving average

            # should only happen when we update the moving values
            mean = control_flow_util.smart_cond(
                training,
                true_fn=lambda: mean,
                false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    moving_mean))
            variance = control_flow_util.smart_cond(
                training,
                true_fn=lambda: variance,
                false_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    moving_variance))

            # circular update of the mean and variance
            new_mean = control_flow_util.smart_cond(
                update_value,
                true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    aggregated_mean),
                false_fn=lambda: moving_mean)

            new_variance = control_flow_util.smart_cond(
                update_value,
                true_fn=lambda: ops.convert_to_tensor_v2_with_dispatch(
                    aggregated_variance),
                false_fn=lambda: moving_variance)

            # # should only be done when the moving mean is updated
            # tf.print(new_variance, self.local_count, update_value, self.aggregated_batch_size, self.aggregated_sum_batch)

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training, input_batch_size)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   self.aggregated_batch_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return control_flow_util.smart_cond(training, true_branch,
                                                    false_branch)

            def variance_update():
                """Update the moving variance."""
                def true_branch_renorm():
                    # We apply epsilon as part of the moving_stddev to mirror the training
                    # code path.
                    moving_stddev = _do_update(
                        self.moving_stddev,
                        math_ops.sqrt(new_variance + self.epsilon))
                    return self._assign_new_value(
                        self.moving_variance,
                        # Apply relu in case floating point rounding causes it to go
                        # negative.
                        K.relu(moving_stddev * moving_stddev - self.epsilon))

                if self.renorm:
                    true_branch = true_branch_renorm
                else:
                    true_branch = lambda: _do_update(self.moving_variance,
                                                     new_variance)

                false_branch = lambda: self.moving_variance
                return control_flow_util.smart_cond(training, true_branch,
                                                    false_branch)

            def update_count():
                with K.name_scope('update_count') as scope:
                    # update the local count
                    return state_ops.assign_add(self.local_count,
                                                tf.cast(
                                                    1, self.local_count.dtype),
                                                name=scope)

            self.add_update(mean_update)
            self.add_update(variance_update)
            self.add_update(update_count)

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
            outputs = math_ops.cast(outputs, inputs_dtype)

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
コード例 #39
0
def fused_layer_norm(inputs,
                     center=True,
                     scale=True,
                     activation_fn=None,
                     reuse=None,
                     variables_collections=None,
                     outputs_collections=None,
                     trainable=True,
                     begin_norm_axis=1,
                     begin_params_axis=-1,
                     scope=None,
                     use_fused_batch_norm=False):
    with tf.compat.v1.variable_scope(scope, 'LayerNorm', [inputs],
                                     reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        inputs_shape = inputs.shape
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        if begin_norm_axis < 0:
            begin_norm_axis = inputs_rank + begin_norm_axis
        if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
            raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
                             'must be < rank(inputs) (%d)' %
                             (begin_params_axis, begin_norm_axis, inputs_rank))
        params_shape = inputs_shape[begin_params_axis:]
        if not params_shape.is_fully_defined():
            raise ValueError(
                'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
                (inputs.name, begin_params_axis, inputs_shape))
        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None
        if center:
            beta_collections = utils.get_variable_collections(
                variables_collections, 'beta')
            beta = variables.model_variable(
                'beta',
                shape=params_shape,
                dtype=dtype,
                initializer=init_ops.zeros_initializer(),
                collections=beta_collections,
                trainable=trainable)
        if scale:
            gamma_collections = utils.get_variable_collections(
                variables_collections, 'gamma')
            gamma = variables.model_variable(
                'gamma',
                shape=params_shape,
                dtype=dtype,
                initializer=init_ops.ones_initializer(),
                collections=gamma_collections,
                trainable=trainable)
        if use_fused_batch_norm:
            # get static TensorShape if fully defined,
            # otherwise retrieve shape tensor
            norm_shape = inputs.shape[begin_norm_axis:]
            if norm_shape.is_fully_defined():
                bn_shape = [1, -1, 1, numpy.prod(norm_shape.as_list())]
            else:
                norm_shape = tf.shape(input=inputs)[begin_norm_axis:]
                bn_shape = [1, -1, 1, tf.reduce_prod(input_tensor=norm_shape)]
            if inputs.get_shape().is_fully_defined():
                outputs_shape = inputs.get_shape()
            else:
                outputs_shape = tf.shape(input=inputs)
            inputs = array_ops.reshape(inputs, bn_shape)
            if inputs.get_shape().is_fully_defined():
                # static inputs TensorShape fully defined after reshape.
                ones = array_ops.ones(inputs.get_shape()[1],
                                      dtype=dtypes.float32)
                zeros = array_ops.zeros(inputs.get_shape()[1],
                                        dtype=dtypes.float32)
            else:
                # static inputs TensorShape NOT fully defined after reshape.
                # must use dynamic shape, which means these input tensors
                # have to be created at runtime, which causes a slowdown.
                scale_shape = tf.shape(input=inputs)[1]
                ones = array_ops.ones(scale_shape, dtype=dtypes.float32)
                zeros = array_ops.zeros(scale_shape, dtype=dtypes.float32)
            outputs, mean, variance = nn.fused_batch_norm(inputs,
                                                          ones,
                                                          zeros,
                                                          epsilon=1e-4,
                                                          data_format="NCHW")
            outputs = array_ops.reshape(outputs, outputs_shape)
            if center and scale:
                outputs = outputs * gamma + beta
            elif center:
                outputs = outputs + beta
            elif scale:
                outputs = outputs * gamma
        else:
            # Calculate the moments on the last axis (layer activations).
            norm_axes = list(range(begin_norm_axis, inputs_rank))
            mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
            # Compute layer normalization using the batch_normalization function.
            variance_epsilon = 1e-4
            outputs = nn.batch_normalization(inputs,
                                             mean,
                                             variance,
                                             offset=beta,
                                             scale=gamma,
                                             variance_epsilon=variance_epsilon)
            outputs.set_shape(inputs_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections, sc.name,
                                           outputs)
コード例 #40
0
  def call(self, inputs, training=False):
    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

    scale, offset = self.gamma, self.beta

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = utils.constant_value(training)
    if training_value is not False:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      mean, variance = nn.moments(inputs, reduction_axes)
      mean = _smart_select(training,
                           lambda: mean,
                           lambda: self.moving_mean)
      variance = _smart_select(training,
                               lambda: variance,
                               lambda: self.moving_variance)

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            mean, variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        scale = array_ops.stop_gradient(r, name='renorm_r')
        offset = array_ops.stop_gradient(d, name='renorm_d')
        if self.gamma is not None:
          scale *= self.gamma
          offset *= self.gamma
        if self.beta is not None:
          offset += self.beta
      else:
        new_mean, new_variance = mean, variance

      # Update moving averages when training, and prevent updates otherwise.
      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, new_mean, decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, new_variance, decay, zero_debias=False)

      if not self.updates:
        self.add_update(mean_update)
        self.add_update(variance_update)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    def _broadcast(v):
      if needs_broadcasting and v is not None:
        # In this case we must explictly broadcast all parameters.
        return array_ops.reshape(v, broadcast_shape)
      return v

    return nn.batch_normalization(inputs,
                                  _broadcast(mean),
                                  _broadcast(variance),
                                  _broadcast(offset),
                                  _broadcast(scale),
                                  self.epsilon)
コード例 #41
0
def layer_norm(inputs,
               center=True,
               scale=True,
               activation_fn=None,
               reuse=None,
               trainable=True,
               begin_norm_axis=1,
               begin_params_axis=-1,
               scope=None):
    """Adds a Layer Normalization layer.
  Based on the paper:
    "Layer Normalization"
    Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
    https://arxiv.org/abs/1607.06450.
  Can be used as a normalizer function for conv2d and fully_connected.
  Given a tensor `inputs` of rank `R`, moments are calculated and normalization
  is performed over axes `begin_norm_axis ... R - 1`.  Scaling and centering,
  if requested, is performed over axes `begin_params_axis .. R - 1`.
  By default, `begin_norm_axis = 1` and `begin_params_axis = -1`,
  meaning that normalization is performed over all but the first axis
  (the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable
  parameters are calculated for the rightmost axis (the `C` if `inputs` is
  `NHWC`).  Scaling and recentering is performed via broadcast of the
  `beta` and `gamma` parameters with the normalized tensor.
  The shapes of `beta` and `gamma` are `inputs.shape[begin_params_axis:]`,
  and this part of the inputs' shape must be fully defined.
  Args:
    inputs: A tensor having rank `R`. The normalization is performed over
      axes `begin_norm_axis ... R - 1` and centering and scaling parameters
      are calculated over `begin_params_axis ... R - 1`.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
    begin_norm_axis: The first normalization dimension: normalization will be
      performed along dimensions `begin_norm_axis : rank(inputs)`
    begin_params_axis: The first parameter (beta, gamma) dimension: scale
      and centering parameters will have dimensions
      `begin_params_axis : rank(inputs)` and will be broadcast with the
      normalized inputs accordingly.
    scope: Optional scope for `variable_scope`.
  Returns:
    A `Tensor` representing the output of the operation, having the same
    shape and dtype as `inputs`.
  Raises:
    ValueError: If the rank of `inputs` is not known at graph build time,
      or if `inputs.shape[begin_params_axis:]` is not fully defined at
      graph build time.
  """
    with variable_scope.variable_scope(scope,
                                       'LayerNorm', [inputs],
                                       reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        inputs_shape = inputs.shape
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        if begin_norm_axis < 0:
            begin_norm_axis = inputs_rank + begin_norm_axis
        if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
            raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
                             'must be < rank(inputs) (%d)' %
                             (begin_params_axis, begin_norm_axis, inputs_rank))
        params_shape = inputs_shape[begin_params_axis:]
        if not params_shape.is_fully_defined():
            raise ValueError(
                'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
                (inputs.name, begin_params_axis, inputs_shape))
        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None
        if center:
            beta = tf.get_variable(name='beta',
                                   shape=params_shape,
                                   dtype=dtype,
                                   initializer=tf.zeros_initializer(),
                                   trainable=trainable)
        if scale:
            gamma = tf.get_variable(name='gamma',
                                    shape=params_shape,
                                    dtype=dtype,
                                    initializer=tf.zeros_initializer(),
                                    trainable=trainable)
        # Calculate the moments on the last axis (layer activations).
        norm_axes = list(range(begin_norm_axis, inputs_rank))
        mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
        # Compute layer normalization using the batch_normalization function.
        variance_epsilon = 1e-12
        outputs = nn.batch_normalization(inputs,
                                         mean,
                                         variance,
                                         offset=beta,
                                         scale=gamma,
                                         variance_epsilon=variance_epsilon)
        outputs.set_shape(inputs_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return collect_named_outputs(None, sc.name, outputs)
コード例 #42
0
ファイル: qnormalization.py プロジェクト: google/qkeras
    def call(self, inputs, training=None):
        if self.scale and self.gamma_quantizer:
            quantized_gamma = self.gamma_quantizer_internal(self.gamma)
        else:
            quantized_gamma = self.gamma

        if self.center and self.beta_quantizer:
            quantized_beta = self.beta_quantizer_internal(self.beta)
        else:
            quantized_beta = self.beta

        if self.mean_quantizer:
            quantized_moving_mean = self.mean_quantizer_internal(
                self.moving_mean)
        else:
            quantized_moving_mean = self.moving_mean

        if self.variance_quantizer:
            quantized_moving_variance = self.variance_quantizer_internal(
                self.moving_variance)
        else:
            quantized_moving_variance = self.moving_variance

        training = self._get_training_value(training)

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.smart_constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            quantized_mean, quantized_variance = (quantized_moving_mean,
                                                  quantized_moving_variance)
        else:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = len(self.axis) > 1
            mean, variance = self._moments(math_ops.cast(
                inputs, self._param_dtype),
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: ops.convert_to_tensor(moving_variance))

            new_mean, new_variance = mean, variance

            if self.mean_quantizer:
                quantized_mean = self.mean_quantizer_internal(mean)
            else:
                quantized_mean = mean

            if self.variance_quantizer:
                quantized_variance = self.variance_quantizer_internal(variance)
            else:
                quantized_variance = variance

            if self._support_zero_size_input():
                inputs_size = array_ops.size(inputs)
            else:
                inputs_size = None

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""
                true_branch = lambda: _do_update(self.moving_variance,
                                                 new_variance)
                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        quantized_mean = math_ops.cast(quantized_mean, inputs.dtype)
        quantized_variance = math_ops.cast(quantized_variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
        # math in float16 hurts validation accuracy of popular models like resnet.
        outputs = nn.batch_normalization(inputs, _broadcast(quantized_mean),
                                         _broadcast(quantized_variance),
                                         offset, scale, self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs
コード例 #43
0
ファイル: normalization.py プロジェクト: BloodD/tensorflow
  def call(self, inputs, training=False):
    # First, compute the axes along which to reduce the mean / variance,
    # as well as the broadcast shape to be used for all parameters.
    input_shape = inputs.get_shape()
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]
    broadcast_shape = [1] * len(input_shape)
    broadcast_shape[self.axis] = input_shape[self.axis].value

    # Determines whether broadcasting is needed.
    needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])

    # Determine boolean training boolean value. May be False, True, None.
    # If None, it is assumed that `training` is a variable to be used in `cond`.
    if isinstance(training, bool):
      training_bool = training
    else:
      try:
        training_bool = tensor_util.constant_value(training)
      except TypeError:
        training_bool = None

    # Obtain current current batch mean, variance, if necessary.
    if training_bool is not False:
      # Use a copy of moving_mean as a shift to compute more reliable moments.
      shift = math_ops.add(self.moving_mean, 0)
      if needs_broadcasting:
        shift = array_ops.reshape(shift, broadcast_shape)
        broadcast_mean, broadcast_variance = nn.moments(
            inputs, reduction_axes, shift=shift, keep_dims=True)
        mean = array_ops.reshape(broadcast_mean, [-1])
        variance = array_ops.reshape(broadcast_variance, [-1])
      else:
        mean, variance = nn.moments(inputs, reduction_axes, shift=shift)

    # Prepare updates if necessary.
    if training_bool is not False and not self.updates:
      mean_update = moving_averages.assign_moving_average(
          self.moving_mean, mean, self.momentum, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          self.moving_variance, variance, self.momentum, zero_debias=False)
      # In the future this should be refactored into a self.add_update
      # methods in order to allow for instance-based BN layer sharing
      # across unrelated input streams (e.g. like in Keras).
      self.updates.append(mean_update)
      self.updates.append(variance_update)

    # Normalize batch.
    if needs_broadcasting:
      # In this case we must explictly broadcast all parameters.
      broadcast_moving_mean = array_ops.reshape(self.moving_mean,
                                                broadcast_shape)
      broadcast_moving_variance = array_ops.reshape(self.moving_variance,
                                                    broadcast_shape)
      if self.center:
        broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
      else:
        broadcast_beta = None
      if self.scale:
        broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape)
      else:
        broadcast_gamma = None

      if training_bool is not False:
        normed_inputs_training = nn.batch_normalization(inputs,
                                                        broadcast_mean,
                                                        broadcast_variance,
                                                        broadcast_beta,
                                                        broadcast_gamma,
                                                        self.epsilon)
      normed_inputs = nn.batch_normalization(inputs,
                                             broadcast_moving_mean,
                                             broadcast_moving_variance,
                                             broadcast_beta,
                                             broadcast_gamma,
                                             self.epsilon)
    else:
      # No need for broadcasting.
      if training_bool is not False:
        normed_inputs_training = nn.batch_normalization(
            inputs,
            mean,
            variance,
            self.beta if self.center else None,
            self.gamma if self.scale else None,
            self.epsilon)
      normed_inputs = nn.batch_normalization(inputs,
                                             self.moving_mean,
                                             self.moving_variance,
                                             self.beta if self.center else None,
                                             self.gamma if self.scale else None,
                                             self.epsilon)

    # Return the proper output depending on the boolean training phase.
    if training_bool is True:
      return normed_inputs_training
    if training_bool is False:
      return normed_inputs
    return control_flow_ops.cond(training,
                                 lambda: normed_inputs_training,
                                 lambda: normed_inputs)