Ejemplo n.º 1
0
 def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
                    scale_after_normalization, shift_after_normalization):
   """New implementation."""
   return nn_impl.batch_normalization(x, m, v, beta if
                                      shift_after_normalization else None,
                                      gamma if scale_after_normalization else
                                      None, epsilon)
Ejemplo n.º 2
0
 def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
                    scale_after_normalization, shift_after_normalization):
   """New implementation."""
   return nn_impl.batch_normalization(x, m, v, beta if
                                      shift_after_normalization else None,
                                      gamma if scale_after_normalization else
                                      None, epsilon)
 def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
   if data_format not in ['NHWC', 'NCHW']:
     raise ValueError('data_format must be NCHW or NHWC, '
                      'got %s.' % data_format)
   if data_format == 'NCHW':
     x = array_ops.transpose(x, [0, 2, 3, 1])
   y = nn_impl.batch_normalization(x, mean, var, offset, scale, epsilon)
   if data_format == 'NCHW':
     y = array_ops.transpose(y, [0, 3, 1, 2])
   return y.eval()
Ejemplo n.º 4
0
 def _training_ref(self, x, scale, offset, epsilon, data_format):
     if data_format not in ['NHWC', 'NCHW']:
         raise ValueError('data_format must be NCHW or NHWC, '
                          'got %s.' % data_format)
     if data_format == 'NCHW':
         x = array_ops.transpose(x, [0, 2, 3, 1])
     mean, var = nn_impl.moments(x, [0, 1, 2], keep_dims=False)
     y = nn_impl.batch_normalization(x, mean, var, offset, scale, epsilon)
     if data_format == 'NCHW':
         y = array_ops.transpose(y, [0, 3, 1, 2])
     return y.eval(), mean.eval(), var.eval()
Ejemplo n.º 5
0
 def _inference_ref(self, x, scale, offset, mean, var, epsilon,
                    data_format):
     if data_format not in ['NHWC', 'NCHW']:
         raise ValueError('data_format must be NCHW or NHWC, '
                          'got %s.' % data_format)
     if data_format == 'NCHW':
         x = array_ops.transpose(x, [0, 2, 3, 1])
     y = nn_impl.batch_normalization(x, mean, var, offset, scale, epsilon)
     if data_format == 'NCHW':
         y = array_ops.transpose(y, [0, 3, 1, 2])
     return y.eval()
 def _training_ref(self, x, scale, offset, epsilon, data_format):
   if data_format not in ['NHWC', 'NCHW']:
     raise ValueError('data_format must be NCHW or NHWC, '
                      'got %s.' % data_format)
   if data_format == 'NCHW':
     x = array_ops.transpose(x, [0, 2, 3, 1])
   mean, var = nn_impl.moments(x, [0, 1, 2], keep_dims=False)
   y = nn_impl.batch_normalization(x, mean, var, offset, scale, epsilon)
   if data_format == 'NCHW':
     y = array_ops.transpose(y, [0, 3, 1, 2])
   return y.eval(), mean.eval(), var.eval()
Ejemplo n.º 7
0
def batch_norm_py(tensor, mean, variance, beta, gamma, scale):
  """Python implementation of batch normalization."""
  return nn_impl.batch_normalization(tensor, mean, variance, beta, gamma if
                                     scale else None, 0.001)
Ejemplo n.º 8
0
    def _batchnorm(self, input_x, scope, \
                   gamma_value, beta_value,\
                   moving_mean_value, moving_variance_value,\
                   is_training):
        """ Wrapper function for batch normalization. """
        with variable_scope.variable_scope(scope):
            gamma_initial = init_ops.constant_initializer(
                gamma_value, dtypes.float32)
            gamma = gap_finetune.get_variable(\
                      name='gamma',
                      shape=gamma_value.shape,
                      dtype=dtypes.float32,
                      initializer=gamma_initial,
                      gap=self.gap,
                      gap_vars=self.gap_vars)
            beta_initial = init_ops.constant_initializer(
                beta_value, dtypes.float32)
            beta = gap_finetune.get_variable(\
                      name='beta',
                      shape=beta_value.shape,
                      dtype=dtypes.float32,
                      initializer=beta_initial,
                      gap=self.gap,
                      gap_vars=self.gap_vars)
            moving_mean_initial = init_ops.constant_initializer(\
                      moving_mean_value,
                      dtypes.float32)
            moving_mean = gap_finetune.get_variable(\
                      name='moving_mean',
                      shape=moving_mean_value.shape,
                      dtype=dtypes.float32,
                      initializer=moving_mean_initial,
                      gap=self.gap,
                      gap_vars=self.gap_vars)
            moving_variance_initial = init_ops.constant_initializer(\
                      moving_variance_value,
                      dtypes.float32)
            moving_variance = gap_finetune.get_variable(\
                      name='moving_variance',
                      shape=moving_variance_value.shape,
                      dtype=dtypes.float32,
                      initializer=moving_variance_initial,
                      gap=self.gap,
                      gap_vars=self.gap_vars)

            def mean_var_with_update():
                mean, variance = nn_impl.moments(input_x, [0, 1, 2],
                                                 name='moments')
                with ops.control_dependencies([\
                    moving_averages.assign_moving_average(\
                                  moving_mean, mean, 0.9),
                    moving_averages.assign_moving_average(\
                                  moving_variance, variance, 0.9)]):
                    return array_ops.identity(mean), array_ops.identity(
                        variance)

            mean, variance = control_flow_ops.cond(is_training, \
                                  mean_var_with_update, \
                                  lambda: (moving_mean, moving_variance))

            out = nn_impl.batch_normalization(input_x, mean, variance, beta,
                                              gamma, 0.001)
            return out
Ejemplo n.º 9
0
def _batch_norm(_input,
                trnFlag,
                eps=1e-3,
                name="batch_norm",
                ema_decay=0.5,
                dtype=dtypes.float32):
    """
    A wrapped BN operation used for 2D or 3D convolution as described in:

        https://gist.github.com/tomokishii/0ce3bdac1588b5cca9fa5fbdf6e1c412
        https://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow?answertab=votes#tab-top

    :param _input: tensor, always convolution result before Relu
    :param eps: scalar,
    :param trnFlag: bool, whether training or not
    :param ema_decay: scalar, moving average used of BN's beta and gamma
    :param dtype: tf.dtype, data type
    :return:
        tensor, BN reuslt which has the same shape as _input
    """
    shape = _input.get_shape().as_list()
    with variable_scope.variable_scope(name) as scope:
        beta = variable_scope.get_variable(
            "beta", [shape[-1]],
            dtype=dtype,
            initializer=init_ops.constant_initializer(0., dtype=dtype),
            trainable=True)
        gamma = variable_scope.get_variable(
            "gamma", [shape[-1]],
            dtype=dtype,
            initializer=init_ops.random_normal_initializer(1.,
                                                           0.01,
                                                           dtype=dtype,
                                                           seed=20170705),
            trainable=True)

        if shape.__len__() == 2:  # fc, [batch_size, num_dim]
            batch_mean, batch_var = nn_impl.moments(_input, [0],
                                                    name="moments")
        elif shape.__len__(
        ) == 4:  # conv, [batch_size, width, heigth, channel]
            batch_mean, batch_var = nn_impl.moments(_input, [0, 1, 2],
                                                    name="moments")
        elif shape.__len__(
        ) == 5:  # conv, [batch_size, depth, width, heigth, channel]
            batch_mean, batch_var = nn_impl.moments(_input, [0, 1, 2, 3],
                                                    name="moments")
        else:
            raise 'wrong _input shape, it must have dim of 2 or 4 or 5'

        ema = training.ExponentialMovingAverage(decay=ema_decay)

        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean, batch_var])
            with ops.control_dependencies([ema_apply_op]):
                return array_ops.identity(batch_mean), array_ops.identity(
                    batch_var)

        mean, var = control_flow_ops.cond(
            trnFlag, mean_var_with_update, lambda:
            (ema.average(batch_mean), ema.average(batch_var)))

        bn_out = nn_impl.batch_normalization(_input, mean, var, beta, gamma,
                                             eps)
    return bn_out