예제 #1
0
 def build(self, input_shape):
     channels_in = input_shape[-1]
     self.deconv_1 = ConvBlockCondDown(channels_in // 4,
                                       self.use_bias,
                                       kernel=1,
                                       momentum=self.momentum,
                                       stride=1)
     self.deconv_2 = ConvBlockCondDown(channels_in // 4,
                                       self.use_bias,
                                       momentum=self.momentum,
                                       stride=1)
     self.deconv_3 = ConvBlockCondDown(channels_in // 4,
                                       self.use_bias,
                                       momentum=self.momentum,
                                       stride=1)
     self.deconv_4 = ConvBlockCondDown(self.channels,
                                       self.use_bias,
                                       kernel=1,
                                       momentum=self.momentum,
                                       stride=1,
                                       down=self.down)
     self.extra_conv = tf_utils.constant_value(channels_in < self.channels)
     if self.extra_conv:
         self.conv_extra = Conv2D(self.channels - channels_in,
                                  1,
                                  use_bias=self.use_bias)
예제 #2
0
    def call(self, inputs, training=None):
        W_shape = self.kernel.shape.as_list()
        # flatten
        W_reshaped = K.reshape( self.kernel, [-1, W_shape[-1]])
        _u, _v = power_iteration(W_reshaped, self.u)

        #calculate sigma
        sigma = K.dot(_v, W_reshaped)
        sigma = K.dot(sigma, K.transpose(_u))
        # normalize it

        w_bar = W_reshaped / sigma

        trainig_val = tf_utils.constant_value(training)
        if trainig_val == False:
            w_bar = K.reshape(w_bar, W_shape)

        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                w_bar = K.reshape(w_bar, W_shape)

        output = K.dot(inputs, w_bar)

        if self.use_bias:
            output = K.bias_add(output, self.bias, data_format='channels_last')

        if self.activation is not None:
            output = self.activation(output)

        print("DENSE: ", output.shape)
        return output
예제 #3
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = tf_utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = tf_utils.constant_value(training)
    if training_value is None:
      momentum = tf_utils.smart_cond(training,
                                     lambda: self.momentum,
                                     lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()
        mean_update = strategy.extended.update(
            self.moving_mean, self._assign_moving_average,
            (mean, self.momentum))
        variance_update = strategy.extended.update(
            self.moving_variance, self._assign_moving_average,
            (variance, self.momentum))
      else:
        mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                  momentum)
        variance_update = self._assign_moving_average(self.moving_variance,
                                                      variance, momentum)
      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    return output
예제 #4
0
    def call(self, inputs, training=None):
        training_value = tf_utils.constant_value(training)
        if len(inputs) != 2:
            raise ValueError('CondBatchNorm layer requires a list of inputs')
        norm_t, noise_t = inputs
        beta_t = self.beta(noise_t)
        beta_t = self.resphape_beta(beta_t)

        gamma_t = self.gamma(noise_t)
        gamma_t = self.resphape_gamma(gamma_t)

        if training_value:
            batch_mean, batch_var = tf.nn.moments(norm_t, [0, 1, 2],
                                                  name='batchMoments')
            test_mean = batch_mean * (1 - self.momentum) + (self.moving_mean *
                                                            self.momentum)
            test_var = batch_var * (1 - self.momentum) + (
                self.moving_variance * self.momentum)
            with tf.control_dependencies([
                    self.moving_mean.assign(test_mean),
                    self.moving_variance.assign(test_var)
            ]):
                return tf.nn.batch_normalization(norm_t, batch_mean, batch_var,
                                                 beta_t, gamma_t, self.epsilon)
        else:
            return tf.nn.batch_normalization(norm_t, self.moving_mean,
                                             self.moving_variance, beta_t,
                                             gamma_t, self.epsilon)
예제 #5
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = tf_utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = tf_utils.constant_value(training)
    if training_value is None:
      momentum = tf_utils.smart_cond(training,
                                     lambda: self.momentum,
                                     lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()
        mean_update = strategy.extended.update(
            self.moving_mean, self._assign_moving_average,
            (mean, self.momentum))
        variance_update = strategy.extended.update(
            self.moving_variance, self._assign_moving_average,
            (variance, self.momentum))
      else:
        mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                  momentum)
        variance_update = self._assign_moving_average(self.moving_variance,
                                                      variance, momentum)
      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    return output
예제 #6
0
    def call(self, inputs, training=None):
        training_value = tf_utils.constant_value(training)

        def _l2normalize(v):
            return v / (K.sum(v**2)**.5 + 1e-4)

        def power_iteration(W, u):
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v

        W_shape = self.kernel.shape.as_list()
        W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
        _u, _v = power_iteration(W_reshaped, self.u)

        sigma = K.dot(_v, W_reshaped)
        sigma = K.dot(sigma, K.transpose(_u))
        W_bar = W_reshaped / sigma

        if not training_value:
            W_bar = K.reshape(W_bar, W_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                W_bar = K.reshape(W_bar, W_shape)

        output = K.dot(inputs, W_bar)
        if self.use_bias:
            output = K.bias_add(output, self.bias, data_format='channels_last')
        if self.activation is not None:
            output = self.activation(output)
        return output
예제 #7
0
파일: metrics.py 프로젝트: skaurl/w2v-loss
def _resolve_training(layer, training):
    if training is None:
        training = K.learning_phase()
    if isinstance(training, int):
        training = bool(training)
    if not layer.trainable:
        training = False
    return tf_utils.constant_value(training)
    def call(self, inputs, training=None):
        if self.lr_mul == 1.0:
            W = self.coeff * self.kernel
        else:
            @custom_gradient
            def lr_multiplier(x):
                y = array_ops.identity(x)
                def grad(dy):
                    return dy * self.lr_mul
                return y, grad
            W = lr_multiplier(self.coeff * self.kernel)

        training = self._get_training_value(training)

        # Update singular vector by power iteration
        W_T = array_ops.transpose(W)
        u = array_ops.identity(self.u)
        for i in range(self.power_iter):
            v = nn_impl.l2_normalize(math_ops.matmul(u, W))  # 1 x filters
            u = nn_impl.l2_normalize(math_ops.matmul(v, W_T))
        # Spectral Normalization
        sigma_W = math_ops.matmul(math_ops.matmul(u, W), array_ops.transpose(v))
        # Backprop doesn't need in power iteration
        sigma_W = array_ops.stop_gradient(sigma_W)
        W_bar = W / array_ops.squeeze(sigma_W)

        # Assign new singular vector
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            def u_update():
                def true_branch():
                    return self._assign_singular_vector(self.u, u)
                def false_branch():
                    return self.u
                return tf_utils.smart_cond(training, true_branch, false_branch)
            self.add_update(u_update)

        # normal Dense using W_bar
        inputs = ops.convert_to_tensor(inputs)
        rank = common_shapes.rank(inputs)
        if rank > 2:
            # Broadcasting is required for the inputs.
            outputs = standard_ops.tensordot(inputs, W_bar, [[rank - 1], [0]])
            # Reshape the output back to the original ndim of the input.
            if not context.executing_eagerly():
                shape = inputs.shape.as_list()
                output_shape = shape[:-1] + [self.units]
                outputs.set_shape(output_shape)
        else:
            inputs = math_ops.cast(inputs, self._compute_dtype)
            outputs = math_ops.mat_mul(inputs, W_bar)
        if self.use_bias:
            outputs = nn.bias_add(outputs, self.bias)
        if self.activation is not None:
            return self.activation(outputs)  # pylint: disable=not-callable
        return outputs
예제 #9
0
    def call(self, inputs, training=None):
        training = self._get_training_value(training)
        latent1, latent2, lod = inputs

        training_value = tf_utils.constant_value(training)
        latent_avg_new = math_ops.reduce_mean(latent1[:, 0], axis=0)
        if training_value != False and self.update_latent_avg:
            latent_avg_new = self._interpolate(latent_avg_new, self.latent_avg,
                                               self.latent_avg_beta)

            def update_op():
                def true_branch():
                    return self._assign_latent_avg(self.latent_avg,
                                                   latent_avg_new)

                def false_branch():
                    return self.latent_avg

                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(update_op)

        if training_value != False and self.mix_latents:

            def true_branch():
                cur_layer = 2 * (1 + math_ops.cast(
                    array_ops.reshape(lod, [-1])[0], dtypes.int32))
                cutoff = tf_utils.smart_cond(
                    random_ops.random_uniform([], 0.0, 1.0) < self.mixing_prob,
                    lambda: random_ops.random_uniform([
                    ], 1, cur_layer, dtypes.int32), lambda: cur_layer)
                return array_ops.where(
                    array_ops.broadcast_to(self.layer_idx < cutoff,
                                           array_ops.shape(latent1)), latent1,
                    latent2)

            def false_branch():
                return latent1

            latent1 = tf_utils.smart_cond(training, true_branch, false_branch)

        if training_value != True and self.truncate_latent:

            def true_branch():
                return latent1

            def false_branch():
                return self._interpolate(latent_avg_new, latent1, self.coeff)

            latent1 = tf_utils.smart_cond(training, true_branch, false_branch)

        return latent1
예제 #10
0
    def call(self, inputs, training=None):
        """
        Call function will be called by __call__
        Arguments:
            inputs: activations into the layer
            training: Boolean to set training or inference mode
        Returns:
            normalized activations with multiplicative scale and additive bias
            corrections
        """
        if training is None:
            training = K.learning_phase()

        # Determine a boolean value for `training`: could be True, False, or None.
        training = tf_utils.constant_value(training)

        input_shape = inputs.get_shape()

        def _bcast(inputs):
            """
            broadcasts tensor for tensor operations with tensor of larger rank
            """
            if inputs is None:
                return None

            bcast_shape = [1] * len(input_shape)
            for a in self.axis:
                bcast_shape[a] = input_shape[a]
            return tf.reshape(inputs, bcast_shape)

        # cast fp16 to fp32
        precise_inputs = tf.cast(inputs, self.mp_type)

        # streaming / control normalization
        if training is not False:
            outputs = self.normalization(precise_inputs)
        else:
            mu = tf.cast(_bcast(self.mu), self.mp_type)
            denom = tf.cast(
                tf.math.sqrt(
                    self.var + self.epsilon
                ),
                self.mp_type
            )
            outputs = (inputs - mu) / _bcast(denom)
        outputs = tf.cast(outputs, self.mp_type)

        return outputs
예제 #11
0
    def compute_spectral_normal(self, training):
        # Spectrally Normalized Weight
        if self.spectral_normalization:
            W_mat = KB.reshape(self.kernel, [self.out_dim, -1])  # [out_channels, N]
            W_sn, u, _ = power_iteration(W_mat, self.u)

            def true_fn():
                self.u.assign(u)
                pass

            def false_fn():
                pass

            training_value = tf_utils.constant_value(training)
            if training_value is not None:
                tf_utils.smart_cond(training, true_fn, false_fn)
            return self.kernel/W_sn
        else:
            return self.kernel
예제 #12
0
    def call(self, inputs, training=None):
        def _l2normalize(v, eps=1e-12):
            return v / (K.sum(v**2)**0.5 + eps)

        def power_iteration(W, u):
            # Accroding the paper, we only need to do power iteration one time.
            _u = u
            _v = _l2normalize(K.dot(_u, K.transpose(W)))
            _u = _l2normalize(K.dot(_v, W))
            return _u, _v

        # Spectral Normalization
        W_shape = self.kernel.shape.as_list()
        # Flatten the Tensor
        W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
        _u, _v = power_iteration(W_reshaped, self.u)
        # Calculate Sigma
        sigma = K.dot(_v, W_reshaped)
        sigma = K.dot(sigma, K.transpose(_u))
        # normalize it
        W_bar = W_reshaped / sigma
        # reshape weight tensor
        trainig_val = tf_utils.constant_value(training)
        if trainig_val == False:
            W_bar = K.reshape(W_bar, W_shape)
        else:
            with tf.control_dependencies([self.u.assign(_u)]):
                W_bar = K.reshape(W_bar, W_shape)

        outputs = K.conv2d(inputs,
                           W_bar,
                           strides=self.strides,
                           padding=self.padding,
                           data_format=self.data_format,
                           dilation_rate=self.dilation_rate)
        if self.use_bias:
            outputs = K.bias_add(outputs,
                                 self.bias,
                                 data_format=self.data_format)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs
예제 #13
0
파일: metrics.py 프로젝트: skaurl/w2v-loss
 def call(self, inputs, training=None):
     x, y = inputs
     x = tf.nn.l2_normalize(x, axis=1)
     w = tf.nn.l2_normalize(self._w, axis=0)
     logits = tf.matmul(x, w)
     is_dynamic = tf_utils.constant_value(self._is_dynamic)
     if not is_dynamic:
         output = tf.multiply(self._init_s, logits)
         return output
     training = _resolve_training(self, training)
     if not training:
         return self._s * logits
     else:
         theta = tf.math.acos(K.clip(logits, -1.0 + K.epsilon(), 1.0 - K.epsilon()))
         b_avg = tf.where(y < 1.0, tf.exp(self._s * logits), tf.zeros_like(logits))
         b_avg = tf.reduce_mean(tf.reduce_sum(b_avg, axis=1))
         theta_class = tf.gather(theta, tf.cast(y, tf.int32))
         theta_med = tfp.stats.percentile(theta_class, q=50)
         self._s.assign(tf.math.log(b_avg) / tf.math.cos(tf.minimum(math.pi / 4, theta_med)))
         self._s.assign_add(-0.5)
         logits *= self._s
         out = tf.nn.sigmoid(logits)
         return out
예제 #14
0
    def compute_spectral_normal(self, training):
        # Spectrally Normalized Weight
        if self.spectral_normalization:
            # Get the kernel tensor shape
            #  W_shape = self.kernel.shape.as_list()
            # Flatten the Tensor
            # For transpose conv, the kernel shape is [H,W,Out,In]
            # out_dim=W_shape[-2]
            W_mat = KB.reshape(self.kernel, [self.out_dim, -1])  # [out_c, N]
            sigma, u, _ = power_iteration(W_mat, self.u)

            def true_fn():
                self.u.assign(u)
                pass

            def false_fn():
                pass

            training_value = tf_utils.constant_value(training)
            if training_value is not False:
                tf_utils.smart_cond(training, true_fn, false_fn)
            return self.kernel / sigma
        else:
            return self.kernel
예제 #15
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        # TODO(b/129279393): Support zero batch input in non DistributionStrategy
        # code as well.
        if self._support_zero_size_input():
            inputs_size = array_ops.size(inputs)
        else:
            inputs_size = None

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = tf_utils.smart_cond(
            training, _fused_batch_norm_training, _fused_batch_norm_inference)
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = tf_utils.constant_value(training)
        if training_value is None:
            momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                           lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)
        if training_value or training_value is None:

            def mean_update():
                return self._assign_moving_average(self.moving_mean, mean,
                                                   momentum, inputs_size)

            def variance_update():
                """Update self.moving_variance with the most recent data point."""
                if self.renorm:
                    # We apply epsilon as part of the moving_stddev to mirror the training
                    # code path.
                    moving_stddev = self._assign_moving_average(
                        self.moving_stddev,
                        math_ops.sqrt(variance + self.epsilon), momentum,
                        inputs_size)
                    return self._assign_new_value(
                        self.moving_variance,
                        # Apply relu in case floating point rounding causes it to go
                        # negative.
                        K.relu(moving_stddev * moving_stddev - self.epsilon))
                else:
                    return self._assign_moving_average(self.moving_variance,
                                                       variance, momentum,
                                                       inputs_size)

            self.add_update(mean_update)
            self.add_update(variance_update)

        return output
예제 #16
0
    def call(self, inputs, training=None):
        """
        Call function will be called by __call__

        Arguments:
            inputs: activations into the layer
            training: Boolean to set training or inference mode

        Returns:
            normalized activations with multiplicative scale and additive bias
            corrections
        """
        original_training_value = training
        if training is None:
            training = K.learning_phase()

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)

        input_shape = inputs.get_shape()

        def _bcast(inputs):
            """
            broadcasts tensor for tensor operations with tensor of larger rank
            """
            if inputs is None:
                return None

            bcast_shape = [1] * len(input_shape)
            for a in self.axis:
                bcast_shape[a] = input_shape[a]
            return tf.reshape(inputs, bcast_shape)

        mixed_precision = (inputs.dtype == dtypes.float16
                           or inputs.dtype == dtypes.bfloat16)

        # cast fp16 to fp32
        precise_inputs = inputs
        if mixed_precision:
            precise_inputs = math_ops.cast(inputs, dtypes.float32)

        # streaming / control normalization
        if training_value is not False:
            x_norm = tf_utils.smart_cond(
                training, lambda: self.control_normalization(precise_inputs),
                lambda: tf.nn.batch_normalization(
                    precise_inputs,
                    tf.reshape(self.mu[-1], self.broadcast_shape),
                    tf.reshape(self.var[-1], self.broadcast_shape), None, None,
                    self.epsilon))
        else:
            x_norm = tf.nn.batch_normalization(
                precise_inputs, tf.reshape(self.mu[-1], self.broadcast_shape),
                tf.reshape(self.var[-1], self.broadcast_shape), None, None,
                self.epsilon)

        # scale and bias
        x_scaled = x_norm * _bcast(self.gamma) if self.scale else x_norm
        x_bias = x_scaled + _bcast(self.beta) if self.center else x_scaled

        outputs = self.layer_scaling(x_bias) if self.ls else x_bias

        # if needed, cast back to fp16
        if mixed_precision:
            outputs = math_ops.cast(outputs, inputs.dtype)

        return outputs
예제 #17
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    # TODO(b/129279393): Support zero batch input in non DistributionStrategy
    # code as well.
    if self._support_zero_size_input():
      inputs_size = array_ops.size(inputs)
    else:
      inputs_size = None

    # TODO(rmlarsen): Support using fused avg updates for non-eager execution
    # after fixing graph pattern matching and enabling fused_batch_norm to
    # take exponential_avg_factor as a tensor input.
    use_fused_avg_updates = (
        compat.forward_compatible(2020, 3, 6) and
        ops.executing_eagerly_outside_functions())
    if use_fused_avg_updates:
      exponential_avg_factor = 1.0 - self.momentum
    else:
      exponential_avg_factor = None

    def _maybe_add_or_remove_bessels_correction(variance, remove=True):
      r"""Add or remove Bessel's correction."""
      # Removes Bessel's correction if remove == True, adds it otherwise.
      # This is to be consistent with non-fused batch norm. Note that the
      # variance computed by fused batch norm is with Bessel's correction.
      # This is only used in legacy V1 batch norm tests.
      if self._bessels_correction_test_only:
        return variance
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      if remove:
        factor = (sample_size -
                  math_ops.cast(1.0, variance.dtype)) / sample_size
      else:
        factor = sample_size / (
            sample_size - math_ops.cast(1.0, variance.dtype))
      return variance * factor

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=_maybe_add_or_remove_bessels_correction(
              self.moving_variance, remove=False),
          epsilon=self.epsilon,
          is_training=True,
          data_format=self._data_format,
          exponential_avg_factor=exponential_avg_factor)

    def _fused_batch_norm_training_empty():
      return inputs, self.moving_mean, self.moving_variance

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    train_op = _fused_batch_norm_training
    if use_fused_avg_updates and inputs_size is not None:
      train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
                                             _fused_batch_norm_training,
                                             _fused_batch_norm_training_empty)

    output, mean, variance = tf_utils.smart_cond(training, train_op,
                                                 _fused_batch_norm_inference)
    variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)

    training_value = tf_utils.constant_value(training)
    if training_value or training_value is None:
      if not use_fused_avg_updates:
        if training_value is None:
          momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                         lambda: 1.0)
        else:
          momentum = ops.convert_to_tensor_v2(self.momentum)

      def mean_update():
        """Update self.moving_mean with the most recent data point."""
        if use_fused_avg_updates:
          return self._assign_new_value(self.moving_mean, mean)
        else:
          return self._assign_moving_average(self.moving_mean, mean, momentum,
                                             inputs_size)

      def variance_update():
        """Update self.moving_variance with the most recent data point."""
        if use_fused_avg_updates:
          return self._assign_new_value(self.moving_variance, variance)
        else:
          return self._assign_moving_average(self.moving_variance, variance,
                                             momentum, inputs_size)

      self.add_update(mean_update)
      self.add_update(variance_update)

    return output
예제 #18
0
    def call(self, inputs, training=None):
        training = self._get_training_value(training)

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims and
                    reduction_axes != list(range(ndims - 1))):
                return tf.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value == False:
            mean, variance = self.moving_mean, self.moving_variance
        else:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = len(self.axis) > 1
            mean, variance = self._moments(
                tf.cast(inputs, self._param_dtype),
                reduction_axes,
                keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training,
                                       lambda: mean,
                                       lambda: tf.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training,
                lambda: variance,
                lambda: tf.convert_to_tensor(moving_variance))
            new_mean, new_variance = mean, variance

            inputs_size = None

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                def true_branch(): return _do_update(self.moving_mean, new_mean)
                def false_branch(): return self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""

                def true_branch(): return _do_update(self.moving_variance, new_variance)
                def false_branch(): return self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        mean = tf.cast(mean, inputs.dtype)
        variance = tf.cast(variance, inputs.dtype)
        if offset is not None:
            offset = tf.cast(offset, inputs.dtype)
        if scale is not None:
            scale = tf.cast(scale, inputs.dtype)
        outputs = tf.nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs
예제 #19
0
    def call(self, inputs, params=None, training=None):

        if params[self.name + '/gamma:0'] is None:
            return super(layers.BatchNormalization, self).call(inputs)
        else:
            gamma = params.get(self.name + '/gamma:0')
            beta = params.get(self.name + '/beta:0')

        original_training_value = training
        if training is None:
            training = backend.learning_phase()

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(gamma), _broadcast(beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

        # Some of the computations here are not necessary when training==False
        # but not a constant. However, this makes the code simpler.
        keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
        mean, variance = nn.moments(inputs,
                                    reduction_axes,
                                    keep_dims=keep_dims)

        moving_mean = self.moving_mean
        moving_variance = self.moving_variance

        mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean)
        variance = tf_utils.smart_cond(training, lambda: variance,
                                       lambda: moving_variance)

        if self.virtual_batch_size is not None:
            # This isn't strictly correct since in ghost batch norm, you are
            # supposed to sequentially update the moving_mean and moving_variance
            # with each sub-batch. However, since the moving statistics are only
            # used during evaluation, it is more efficient to just update in one
            # step and should not make a significant difference in the result.
            new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
            new_variance = math_ops.reduce_mean(variance,
                                                axis=1,
                                                keepdims=True)
        else:
            new_mean, new_variance = mean, variance

        if self.renorm:
            r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                new_mean, new_variance, training)
            # When training, the normalized values (say, x) will be transformed as
            # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
            # = x * (r * gamma) + (d * gamma + beta) with renorm.
            r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
            d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
            scale, offset = _compose_transforms(r, d, scale, offset)

        def _do_update(var, value):
            return self._assign_moving_average(var, value, self.momentum)

        mean_update = tf_utils.smart_cond(
            training, lambda: _do_update(self.moving_mean, new_mean),
            lambda: self.moving_mean)
        variance_update = tf_utils.smart_cond(
            training, lambda: _do_update(self.moving_variance, new_variance),
            lambda: self.moving_variance)
        self.add_update(mean_update, inputs=True)
        self.add_update(variance_update, inputs=True)
        # mean, variance = self.moving_mean, self.moving_variance

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        if original_training_value is None:
            outputs._uses_learning_phase = True  # pylint: disable=protected-access
        return outputs
예제 #20
0
    def call(self, inputs, training=None):
        training = self._get_training_value(training)

        assert self.virtual_batch_size is None, "Disabled"
        assert self.fused is False, "Disabled"
        assert self.adjustment is None, "Disabled"
        assert self.renorm is False, "Disabled"

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            cond = (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1)))
            if cond:
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is False:
            mean, variance = self.moving_mean, self.moving_variance
        else:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(math_ops.cast(
                inputs, self._param_dtype),
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor_v2(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: ops.convert_to_tensor_v2(moving_variance))

            new_mean, new_variance = mean, variance

            if self._support_zero_size_input():
                # Keras assumes that batch dimension is the first dimension for Batch
                # Normalization.
                input_batch_size = array_ops.shape(inputs)[0]
            else:
                input_batch_size = None

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   input_batch_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""
                def true_branch_renorm():
                    # We apply epsilon as part of the moving_stddev to mirror the training
                    # code path.
                    moving_stddev = _do_update(
                        self.moving_stddev,
                        math_ops.sqrt(new_variance + self.epsilon))
                    return self._assign_new_value(
                        self.moving_variance,
                        # Apply relu in case floating point rounding causes it to go
                        # negative.
                        K.relu(moving_stddev * moving_stddev - self.epsilon))

                if self.renorm:
                    true_branch = true_branch_renorm
                else:
                    true_branch = lambda: _do_update(self.moving_variance,
                                                     new_variance)

                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
        # math in float16 hurts validation accuracy of popular models like resnet.
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        # bkeng: Flow loss/metric
        self.add_flow_loss(variance, scale)

        return outputs
예제 #21
0
    def call(self, inputs, training=None):
        if self.scale and self.gamma_quantizer:
            quantized_gamma = self.gamma_quantizer_internal(self.gamma)
        else:
            quantized_gamma = self.gamma

        if self.center and self.beta_quantizer:
            quantized_beta = self.beta_quantizer_internal(self.beta)
        else:
            quantized_beta = self.beta

        if self.mean_quantizer:
            quantized_moving_mean = self.mean_quantizer_internal(
                self.moving_mean)
        else:
            quantized_moving_mean = self.moving_mean

        if self.variance_quantizer:
            quantized_moving_variance = self.variance_quantizer_internal(
                self.moving_variance)
        else:
            quantized_moving_variance = self.moving_variance

        training = self._get_training_value(training)

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            quantized_mean, quantized_variance = (quantized_moving_mean,
                                                  quantized_moving_variance)
        else:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = len(self.axis) > 1
            mean, variance = self._moments(math_ops.cast(
                inputs, self._param_dtype),
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: ops.convert_to_tensor(moving_variance))

            new_mean, new_variance = mean, variance

            if self.mean_quantizer:
                quantized_mean = self.mean_quantizer_internal(mean)
            else:
                quantized_mean = mean

            if self.variance_quantizer:
                quantized_variance = self.variance_quantizer_internal(variance)
            else:
                quantized_variance = variance

            if self._support_zero_size_input():
                inputs_size = array_ops.size(inputs)
            else:
                inputs_size = None

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""
                true_branch = lambda: _do_update(self.moving_variance,
                                                 new_variance)
                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        quantized_mean = math_ops.cast(quantized_mean, inputs.dtype)
        quantized_variance = math_ops.cast(quantized_variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
        # math in float16 hurts validation accuracy of popular models like resnet.
        outputs = nn.batch_normalization(inputs, _broadcast(quantized_mean),
                                         _broadcast(quantized_variance),
                                         offset, scale, self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs
예제 #22
0
    def call(self, inputs, training=None, mask=None, **kwargs):
        training = self._get_training_value(training)
        x = inputs
        if mask is not None:
            x = tf.where(tf.expand_dims(mask, axis=-1), x, tf.zeros_like(x))
        orig_dtype = x.dtype
        x = tf.cast(x, tf.float32)
        inputs_size = array_ops.size(inputs)
        axes = list(range(len(shape_list(x))))[:-1]
        training_value = tf_utils.constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            mean, variance = self.moving_mean, self.moving_variance
        else:
            mean, variance = masked_moments(x,
                                            mask=mask,
                                            axes=axes,
                                            keepdims=False)
            mean = tf.squeeze(mean)
            variance = tf.squeeze(variance)
            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: tf.convert_to_tensor(moving_variance))

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""

                true_branch = lambda: _do_update(self.moving_variance, variance
                                                 )

                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        if self.scale:
            gamma = self.get_weight('gamma', training=training)
        else:
            gamma = None
        if self.center:
            beta = self.get_weight('beta', training=training)
        else:
            beta = None

        x = tf.nn.batch_normalization(x,
                                      mean=mean,
                                      variance=variance,
                                      scale=gamma,
                                      offset=beta,
                                      variance_epsilon=self.epsilon)
        x = tf.cast(x, orig_dtype)
        return x
예제 #23
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        # TODO(b/129279393): Support zero batch input in non DistributionStrategy
        # code as well.
        # TODO(b/130185866): Support zero batch input in graph mode.
        if ops.executing_eagerly_outside_functions(
        ) and distribution_strategy_context.has_strategy():
            inputs_size = array_ops.size(inputs)
        else:
            inputs_size = None

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = tf_utils.smart_cond(
            training, _fused_batch_norm_training, _fused_batch_norm_inference)
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = tf_utils.constant_value(training)
        if training_value is None:
            momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                           lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)
        if training_value or training_value is None:

            def mean_update():
                return self._assign_moving_average(self.moving_mean, mean,
                                                   momentum, inputs_size)

            def variance_update():
                return self._assign_moving_average(self.moving_variance,
                                                   variance, momentum,
                                                   inputs_size)

            self.add_update(mean_update)
            self.add_update(variance_update)

        return output
예제 #24
0
    def call(self, inputs, training=None):
        training = self._get_training_value(training)

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in [self.axis]]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis] = input_shape.dims[self.axis].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
        if self.axis != ndims - 1:
            trans = reduction_axes + [self.axis]
            transpose_recover = [i for i in range(self.axis)] + [ndims - 1] + [
                j for j in range(self.axis, ndims - 1)
            ]
            inputs = array_ops.transpose(inputs, perm=trans)
            transposed_shape = [-1] + inputs.get_shape().as_list()[1:]

        inputs = array_ops.reshape(
            inputs, shape=[-1, input_shape.dims[self.axis].value])

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        outputs = []
        height, width = input_shape.as_list()[1], input_shape.as_list()[2]
        for i in range(self.groups):
            start_index = i * self.m_per_group
            end_index = np.min(((i + 1) * self.m_per_group,
                                input_shape.dims[self.axis].value))
            group_input = inputs[:, start_index:end_index]

            if training_value is not False:
                mean = tf.reduce_mean(group_input, 0, keepdims=True)
                centered = group_input - mean

                # centered_ = tf.expand_dims(centered, -1)
                # sigma = tf.matmul(centered_, tf.linalg.matrix_transpose(centered_))
                # sigma = tf.reduce_mean(sigma, 0)

                sigma = tf.matmul(tf.linalg.matrix_transpose(centered),
                                  centered)
                sigma /= (cfg.training.batch_size * height * width)

                projection = self.get_projection(sigma, group_input)

                moving_mean = self.moving_means[i]
                moving_projection = self.moving_projections[i]

                mean = tf_utils.smart_cond(
                    training, lambda: mean,
                    lambda: ops.convert_to_tensor(moving_mean))
                projection = tf_utils.smart_cond(
                    training, lambda: projection,
                    lambda: ops.convert_to_tensor(moving_projection))

                new_mean, new_projection = mean, projection

                def _do_update(var, value):
                    return self._assign_moving_average(var, value,
                                                       self.momentum, None)

                def mean_update():
                    true_branch = lambda: _do_update(self.moving_means[i],
                                                     new_mean)
                    false_branch = lambda: self.moving_means[i]
                    return tf_utils.smart_cond(training, true_branch,
                                               false_branch)

                def projection_update():
                    true_branch = lambda: _do_update(
                        self.moving_projections[i], new_projection)
                    false_branch = lambda: self.moving_projections[i]
                    return tf_utils.smart_cond(training, true_branch,
                                               false_branch)

                self.add_update(mean_update)
                self.add_update(projection_update)

            else:
                mean, projection = self.moving_means[
                    i], self.moving_projections[i]
                centered = group_input - mean

            mean = math_ops.cast(mean, inputs.dtype)
            projection = math_ops.cast(projection, inputs.dtype)

            output = tf.matmul(centered, projection)
            outputs.append(output)

        outputs = tf.concat(outputs, 1)
        if self.axis != ndims - 1:
            outputs = tf.reshape(outputs, shape=transposed_shape)
            outputs = tf.transpose(outputs, perm=transpose_recover)
        else:
            outputs = tf.reshape(outputs,
                                 shape=[-1] + input_shape.as_list()[1:])

        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
            outputs = outputs * scale
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
            outputs += offset

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)
        return outputs
예제 #25
0
    def call(self, inputs, training=None):
        # Extract weights of previous layer to compute proper scale.
        previous_weights = self.previous_layer.weights[0].value()
        original_training_value = training
        if training is None:
            training = K.learning_phase()

        in_eager_mode = tf.executing_eagerly()

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        # For dense layers, require a full reduction.
        if self.binary_dense:
            reduction_axes = [i for i in range(ndims)]
        # Otherwise, reduce all but the feature axis.
        else:
            reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]]

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return tf.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = len(self.axis) > 1

            mean, variance = tf.compat.v1.nn.moments(
                inputs, reduction_axes, keep_dims=keep_dims)

            # When norming the output of a binary dense layer,
            # need to make sure shape is maintained.
            if self.binary_dense:
                mean = tf.reshape(mean, [1])
                variance = tf.reshape(variance, [1])

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training, lambda: mean,
                                       lambda: moving_mean)
            variance = tf_utils.smart_cond(training, lambda: variance,
                                           lambda: moving_variance)

            new_mean, new_variance = mean, variance

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(tf.stop_gradient(r, name='renorm_r'))
                d = _broadcast(tf.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            def _do_update(var, value):
                if in_eager_mode and not self.trainable:
                    return

                return self._assign_moving_average(var, value, self.momentum)

            mean_update = tf_utils.smart_cond(
                training, lambda: _do_update(self.moving_mean, new_mean),
                lambda: self.moving_mean)
            variance_update = tf_utils.smart_cond(
                training,
                lambda: _do_update(self.moving_variance, new_variance),
                lambda: self.moving_variance)
            if not tf.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        mean = tf.cast(mean, inputs.dtype)
        variance = tf.cast(variance, inputs.dtype)
        if offset is not None:
            offset = tf.cast(offset, inputs.dtype)
        #outputs = nn.batch_normalization(inputs, _broadcast(mean),
        #                                 _broadcast(variance), offset, scale,
        #                                 self.epsilon)

        approximate_std, quantized_means = compute_quantized_shiftnorm(
            variance,
            mean,
            self.epsilon,
            previous_weights,
            self.extra_scale,
            self.bits,
            rescale=True)

        outputs = inputs - quantized_means
        outputs = outputs * approximate_std

        if scale:
            outputs = scale * outputs
        if offset:
            outputs = outputs + offset

        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if not tf.executing_eagerly() and original_training_value is None:
            outputs._uses_learning_phase = True  # pylint: disable=protected-access
        return outputs
예제 #26
0
  def call(self, inputs, training=None):
    training = K.learning_phase()

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]

    scale, offset = self.gamma, self.beta


    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = len(self.axis) > 1
      mean, variance = self._moments(
          math_ops.cast(inputs, inputs.dtype),
          reduction_axes,
          keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: ops.convert_to_tensor(moving_mean))
      variance = tf_utils.smart_cond(
          training,
          lambda: variance,
          lambda: ops.convert_to_tensor(moving_variance))

      new_mean, new_variance = mean, variance

      if ops.executing_eagerly_outside_functions(
      ) and distribution_strategy_context.has_strategy():
        inputs_size = array_ops.size(inputs)
      else:
        inputs_size = None

      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()

        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return strategy.extended.update(
              var,
              self._assign_moving_average, (value, self.momentum, inputs_size),
              group=False)
        # We need to unwrap the moving_mean or moving_variance in the case of
        # training being false to match the output of true_fn and false_fn
        # in the smart cond.
        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: strategy.unwrap(self.moving_mean)
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          return tf_utils.smart_cond(
              training, lambda: _do_update(self.moving_variance, new_variance),
              lambda: strategy.unwrap(self.moving_variance))
      else:
        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return self._assign_moving_average(var, value, self.momentum,
                                             inputs_size)


        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: self.moving_mean
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          true_branch = lambda: _do_update(self.moving_variance, new_variance)
          false_branch = lambda: self.moving_variance
          return tf_utils.smart_cond(training, true_branch, false_branch)

      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    mean = math_ops.cast(mean, inputs.dtype)
    variance = math_ops.cast(variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    if scale is not None:
      scale = math_ops.cast(scale, inputs.dtype)

    outputs = nn.batch_normalization(inputs,mean,variance,
                                     offset,
                                     scale,
                                     self.epsilon)

    return outputs
예제 #27
0
    def call(self, inputs, training=None):
        training = self._get_training_value(training)

        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            mean, variance = self.moving_mean, self.moving_variance
        else:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(math_ops.cast(
                inputs, self._param_dtype),
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: ops.convert_to_tensor(moving_variance))

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean, variance

            if self._support_zero_size_input():
                inputs_size = array_ops.size(inputs)
            else:
                inputs_size = None
            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training, inputs_size)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""
                def true_branch_renorm():
                    # We apply epsilon as part of the moving_stddev to mirror the training
                    # code path.
                    moving_stddev = _do_update(
                        self.moving_stddev,
                        math_ops.sqrt(new_variance + self.epsilon))
                    return self._assign_new_value(
                        self.moving_variance,
                        # Apply relu in case floating point rounding causes it to go
                        # negative.
                        K.relu(moving_stddev * moving_stddev - self.epsilon))

                if self.renorm:
                    true_branch = true_branch_renorm
                else:
                    true_branch = lambda: _do_update(self.moving_variance,
                                                     new_variance)

                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
        # math in float16 hurts validation accuracy of popular models like resnet.
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
예제 #28
0
    def call(self, inputs, training=None):
        if self.lr_mul == 1.0:
            kernel = self.coeff * self.kernel
        else:

            @custom_gradient
            def lr_multiplier(x):
                y = array_ops.identity(x)

                def grad(dy):
                    return dy * self.lr_mul

                return y, grad

            kernel = lr_multiplier(self.coeff * self.kernel)

        training = self._get_training_value(training)

        # Update singular vector by power iteration
        if self.data_format == 'channels_first':
            W_T = array_ops.reshape(kernel, (self.filters, -1))
            W = array_ops.transpose(W_T)
        else:
            W = array_ops.reshape(kernel, (-1, self.filters))
            W_T = array_ops.transpose(W)
        u = array_ops.identity(self.u)
        for i in range(self.power_iter):
            v = nn_impl.l2_normalize(math_ops.matmul(u, W))  # 1 x filters
            u = nn_impl.l2_normalize(math_ops.matmul(v, W_T))
        # Spectral Normalization
        sigma_W = math_ops.matmul(math_ops.matmul(u, W),
                                  array_ops.transpose(v))
        # Backprop doesn't need in power iteration
        sigma_W = array_ops.stop_gradient(sigma_W)
        W_bar = kernel / array_ops.squeeze(sigma_W)

        # Assign new singular vector
        training_value = tf_utils.constant_value(training)
        if training_value is not False:

            def u_update():
                def true_branch():
                    return self._assign_singular_vector(self.u, u)

                def false_branch():
                    return self.u

                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(u_update)

        # normal convolution using W_bar
        outputs = self._convolution_op(inputs, W_bar)

        if self.use_bias:
            if self.data_format == 'channels_first':
                if self.rank == 1:
                    # nn.bias_add does not accept a 1D input tensor.
                    bias = array_ops.reshape(self.bias, (1, self.filters, 1))
                    outputs += bias
                else:
                    outputs = nn.bias_add(outputs,
                                          self.bias,
                                          data_format='NCHW')
            else:
                outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
예제 #29
0
  def call(self, inputs, training=None):
    if training is None:
      training = K.learning_phase()

    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        outputs = undo_virtual_batching(outputs)
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = tf_utils.smart_cond(training,
                                        lambda: adj_scale,
                                        lambda: array_ops.ones_like(adj_scale))
        adj_bias = tf_utils.smart_cond(training,
                                       lambda: adj_bias,
                                       lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = self._moments(
          math_ops.cast(inputs, self._param_dtype),
          reduction_axes,
          keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: moving_mean)
      variance = tf_utils.smart_cond(training,
                                     lambda: variance,
                                     lambda: moving_variance)

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
      else:
        new_mean, new_variance = mean, variance

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            new_mean, new_variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)

      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()

        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return strategy.extended.update(
              var, self._assign_moving_average, (value, self.momentum),
              group=False)
        # We need to unwrap the moving_mean or moving_variance in the case of
        # training being false to match the output of true_fn and false_fn
        # in the smart cond.
        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: strategy.unwrap(self.moving_mean)
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          return tf_utils.smart_cond(
              training, lambda: _do_update(self.moving_variance, new_variance),
              lambda: strategy.unwrap(self.moving_variance))
      else:
        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return self._assign_moving_average(var, value, self.momentum)

        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: self.moving_mean
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          true_branch = lambda: _do_update(self.moving_variance, new_variance)
          false_branch = lambda: self.moving_variance
          return tf_utils.smart_cond(training, true_branch, false_branch)

      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    mean = math_ops.cast(mean, inputs.dtype)
    variance = math_ops.cast(variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    if scale is not None:
      scale = math_ops.cast(scale, inputs.dtype)
    # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
    # math in float16 hurts validation accuracy of popular models like resnet.
    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      outputs = undo_virtual_batching(outputs)
    return outputs
예제 #30
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _cross_replica_non_fused_batch_norm_training():
            # TODO(panos): assert the data format to be NHWC
            # TODO(panos): make a moments function with distributed synchronization

            def _merge_fn(strategy, per_replica_mean, per_replica_square_mean):
                # per_replica_mean: PerDevice
                # global_mean: Mirrored
                global_mean = strategy.reduce(tf.VariableAggregation.SUM,
                                              per_replica_mean,
                                              per_replica_mean)
                global_squared_mean = strategy.reduce(
                    tf.VariableAggregation.SUM, per_replica_square_mean,
                    per_replica_square_mean)
                return global_mean, global_squared_mean

            # dispatch as much computation to each replica
            per_replica_mean = tf.reduce_mean(inputs, axis=(0, 1, 2))
            per_replica_square_mean = tf.reduce_mean(tf.square(inputs),
                                                     axis=(0, 1, 2))

            replica_context = tf.contrib.distribute.get_tower_context()
            global_mean, global_squared_mean = replica_context.merge_call(
                _merge_fn, per_replica_mean / replica_context.num_towers,
                per_replica_square_mean / replica_context.num_towers)
            global_variance = global_squared_mean - tf.square(global_mean)

            inputs_normalized = tf.nn.batch_normalization(
                inputs, global_mean, global_variance, beta, gamma,
                self.epsilon)

            return inputs_normalized, global_mean, global_variance

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = tf_utils.smart_cond(
            training, _cross_replica_non_fused_batch_norm_training,
            _fused_batch_norm_inference)
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = tf_utils.constant_value(training)
        if training_value is None:
            momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                           lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)
        if training_value or training_value is None:
            mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                      momentum)
            variance_update = self._assign_moving_average(
                self.moving_variance, variance, momentum)
            self.add_update(mean_update, inputs=True)
            self.add_update(variance_update, inputs=True)

        return output
예제 #31
0
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(inputs,
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training, lambda: mean,
                                       lambda: moving_mean)
            variance = tf_utils.smart_cond(training, lambda: variance,
                                           lambda: moving_variance)

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean, variance

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            if distribution_strategy_context.in_cross_replica_context():
                strategy = distribution_strategy_context.get_strategy()

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return strategy.extended.update(
                        var,
                        self._assign_moving_average, (value, self.momentum),
                        group=False)

                # We need to unwrap the moving_mean or moving_variance in the case of
                # training being false to match the output of true_fn and false_fn
                # in the smart cond.
                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: strategy.unwrap(self.moving_mean))
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: strategy.unwrap(self.moving_variance))
            else:

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return self._assign_moving_average(var, value,
                                                       self.momentum)

                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: self.moving_mean)
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
예제 #32
0
    def call(self, inputs, training=None):
        original_training_value = training
        if training is None:
            training = K.learning_phase()

        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_switch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            if not context.executing_eagerly(
            ) and original_training_value is None:
                outputs._uses_learning_phase = True  # pylint: disable=protected-access
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.

        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            # _, mean_bn, variance_bn = nn.fused_batch_norm(inputs, scale=tf.ones(shape=[inputs.shape[self.axis[0]]]),
            #                                               offset=tf.zeros(shape=[inputs.shape[self.axis[0]]]),
            #                                               epsilon=self.epsilon, data_format=self._data_format)
            # mean_bn, variance_bn = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
            mean, variance, mean_bn, variance_bn = compute_stats(
                inputs,
                self.mean_weights,
                self.var_weights,
                self.hparams,
                training,
                axis=self.axis)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean_bn = tf_utils.smart_cond(training, lambda: mean_bn,
                                          lambda: moving_mean)
            variance_bn = tf_utils.smart_cond(training, lambda: variance_bn,
                                              lambda: moving_variance)

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean_bn, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance_bn,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean_bn, variance_bn

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            def _do_update(var, value):
                if in_eager_mode and not self.trainable:
                    return

                return self._assign_moving_average(var, value, self.momentum)

            mean_update = tf_utils.smart_cond(
                training, lambda: _do_update(self.moving_mean, new_mean),
                lambda: self.moving_mean)
            variance_update = tf_utils.smart_cond(
                training,
                lambda: _do_update(self.moving_variance, new_variance),
                lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)
            print('input shape', inputs.shape)
            print('mean shape', mean.shape)
            print('variance shape', variance.shape)

        else:
            mean_bn, variance_bn = self.moving_mean, self.moving_variance
            mean, variance, _, _ = compute_stats(inputs,
                                                 self.mean_weights,
                                                 self.var_weights,
                                                 self.hparams,
                                                 training,
                                                 mean_bn=mean_bn,
                                                 variance_bn=variance_bn,
                                                 axis=self.axis)
        # sn
        tf.summary.scalar('running_mean_0_', tf.squeeze(self.moving_mean[0]))
        tf.summary.scalar('running_var_0_',
                          tf.squeeze(self.moving_variance[0]))
        tf.summary.scalar('batch_mean_0_', tf.squeeze(mean_bn[0]))

        outputs = nn.batch_normalization(inputs, mean, variance, offset, scale,
                                         self.epsilon)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        if not context.executing_eagerly() and original_training_value is None:
            outputs._uses_learning_phase = True  # pylint: disable=protected-access
        return outputs