def build(self, input_shape): channels_in = input_shape[-1] self.deconv_1 = ConvBlockCondDown(channels_in // 4, self.use_bias, kernel=1, momentum=self.momentum, stride=1) self.deconv_2 = ConvBlockCondDown(channels_in // 4, self.use_bias, momentum=self.momentum, stride=1) self.deconv_3 = ConvBlockCondDown(channels_in // 4, self.use_bias, momentum=self.momentum, stride=1) self.deconv_4 = ConvBlockCondDown(self.channels, self.use_bias, kernel=1, momentum=self.momentum, stride=1, down=self.down) self.extra_conv = tf_utils.constant_value(channels_in < self.channels) if self.extra_conv: self.conv_extra = Conv2D(self.channels - channels_in, 1, use_bias=self.use_bias)
def call(self, inputs, training=None): W_shape = self.kernel.shape.as_list() # flatten W_reshaped = K.reshape( self.kernel, [-1, W_shape[-1]]) _u, _v = power_iteration(W_reshaped, self.u) #calculate sigma sigma = K.dot(_v, W_reshaped) sigma = K.dot(sigma, K.transpose(_u)) # normalize it w_bar = W_reshaped / sigma trainig_val = tf_utils.constant_value(training) if trainig_val == False: w_bar = K.reshape(w_bar, W_shape) else: with tf.control_dependencies([self.u.assign(_u)]): w_bar = K.reshape(w_bar, W_shape) output = K.dot(inputs, w_bar) if self.use_bias: output = K.bias_add(output, self.bias, data_format='channels_last') if self.activation is not None: output = self.activation(output) print("DENSE: ", output.shape) return output
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() mean_update = strategy.extended.update( self.moving_mean, self._assign_moving_average, (mean, self.momentum)) variance_update = strategy.extended.update( self.moving_variance, self._assign_moving_average, (variance, self.momentum)) else: mean_update = self._assign_moving_average(self.moving_mean, mean, momentum) variance_update = self._assign_moving_average(self.moving_variance, variance, momentum) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output
def call(self, inputs, training=None): training_value = tf_utils.constant_value(training) if len(inputs) != 2: raise ValueError('CondBatchNorm layer requires a list of inputs') norm_t, noise_t = inputs beta_t = self.beta(noise_t) beta_t = self.resphape_beta(beta_t) gamma_t = self.gamma(noise_t) gamma_t = self.resphape_gamma(gamma_t) if training_value: batch_mean, batch_var = tf.nn.moments(norm_t, [0, 1, 2], name='batchMoments') test_mean = batch_mean * (1 - self.momentum) + (self.moving_mean * self.momentum) test_var = batch_var * (1 - self.momentum) + ( self.moving_variance * self.momentum) with tf.control_dependencies([ self.moving_mean.assign(test_mean), self.moving_variance.assign(test_var) ]): return tf.nn.batch_normalization(norm_t, batch_mean, batch_var, beta_t, gamma_t, self.epsilon) else: return tf.nn.batch_normalization(norm_t, self.moving_mean, self.moving_variance, beta_t, gamma_t, self.epsilon)
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() mean_update = strategy.extended.update( self.moving_mean, self._assign_moving_average, (mean, self.momentum)) variance_update = strategy.extended.update( self.moving_variance, self._assign_moving_average, (variance, self.momentum)) else: mean_update = self._assign_moving_average(self.moving_mean, mean, momentum) variance_update = self._assign_moving_average(self.moving_variance, variance, momentum) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output
def call(self, inputs, training=None): training_value = tf_utils.constant_value(training) def _l2normalize(v): return v / (K.sum(v**2)**.5 + 1e-4) def power_iteration(W, u): _u = u _v = _l2normalize(K.dot(_u, K.transpose(W))) _u = _l2normalize(K.dot(_v, W)) return _u, _v W_shape = self.kernel.shape.as_list() W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]]) _u, _v = power_iteration(W_reshaped, self.u) sigma = K.dot(_v, W_reshaped) sigma = K.dot(sigma, K.transpose(_u)) W_bar = W_reshaped / sigma if not training_value: W_bar = K.reshape(W_bar, W_shape) else: with tf.control_dependencies([self.u.assign(_u)]): W_bar = K.reshape(W_bar, W_shape) output = K.dot(inputs, W_bar) if self.use_bias: output = K.bias_add(output, self.bias, data_format='channels_last') if self.activation is not None: output = self.activation(output) return output
def _resolve_training(layer, training): if training is None: training = K.learning_phase() if isinstance(training, int): training = bool(training) if not layer.trainable: training = False return tf_utils.constant_value(training)
def call(self, inputs, training=None): if self.lr_mul == 1.0: W = self.coeff * self.kernel else: @custom_gradient def lr_multiplier(x): y = array_ops.identity(x) def grad(dy): return dy * self.lr_mul return y, grad W = lr_multiplier(self.coeff * self.kernel) training = self._get_training_value(training) # Update singular vector by power iteration W_T = array_ops.transpose(W) u = array_ops.identity(self.u) for i in range(self.power_iter): v = nn_impl.l2_normalize(math_ops.matmul(u, W)) # 1 x filters u = nn_impl.l2_normalize(math_ops.matmul(v, W_T)) # Spectral Normalization sigma_W = math_ops.matmul(math_ops.matmul(u, W), array_ops.transpose(v)) # Backprop doesn't need in power iteration sigma_W = array_ops.stop_gradient(sigma_W) W_bar = W / array_ops.squeeze(sigma_W) # Assign new singular vector training_value = tf_utils.constant_value(training) if training_value is not False: def u_update(): def true_branch(): return self._assign_singular_vector(self.u, u) def false_branch(): return self.u return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(u_update) # normal Dense using W_bar inputs = ops.convert_to_tensor(inputs) rank = common_shapes.rank(inputs) if rank > 2: # Broadcasting is required for the inputs. outputs = standard_ops.tensordot(inputs, W_bar, [[rank - 1], [0]]) # Reshape the output back to the original ndim of the input. if not context.executing_eagerly(): shape = inputs.shape.as_list() output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: inputs = math_ops.cast(inputs, self._compute_dtype) outputs = math_ops.mat_mul(inputs, W_bar) if self.use_bias: outputs = nn.bias_add(outputs, self.bias) if self.activation is not None: return self.activation(outputs) # pylint: disable=not-callable return outputs
def call(self, inputs, training=None): training = self._get_training_value(training) latent1, latent2, lod = inputs training_value = tf_utils.constant_value(training) latent_avg_new = math_ops.reduce_mean(latent1[:, 0], axis=0) if training_value != False and self.update_latent_avg: latent_avg_new = self._interpolate(latent_avg_new, self.latent_avg, self.latent_avg_beta) def update_op(): def true_branch(): return self._assign_latent_avg(self.latent_avg, latent_avg_new) def false_branch(): return self.latent_avg return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(update_op) if training_value != False and self.mix_latents: def true_branch(): cur_layer = 2 * (1 + math_ops.cast( array_ops.reshape(lod, [-1])[0], dtypes.int32)) cutoff = tf_utils.smart_cond( random_ops.random_uniform([], 0.0, 1.0) < self.mixing_prob, lambda: random_ops.random_uniform([ ], 1, cur_layer, dtypes.int32), lambda: cur_layer) return array_ops.where( array_ops.broadcast_to(self.layer_idx < cutoff, array_ops.shape(latent1)), latent1, latent2) def false_branch(): return latent1 latent1 = tf_utils.smart_cond(training, true_branch, false_branch) if training_value != True and self.truncate_latent: def true_branch(): return latent1 def false_branch(): return self._interpolate(latent_avg_new, latent1, self.coeff) latent1 = tf_utils.smart_cond(training, true_branch, false_branch) return latent1
def call(self, inputs, training=None): """ Call function will be called by __call__ Arguments: inputs: activations into the layer training: Boolean to set training or inference mode Returns: normalized activations with multiplicative scale and additive bias corrections """ if training is None: training = K.learning_phase() # Determine a boolean value for `training`: could be True, False, or None. training = tf_utils.constant_value(training) input_shape = inputs.get_shape() def _bcast(inputs): """ broadcasts tensor for tensor operations with tensor of larger rank """ if inputs is None: return None bcast_shape = [1] * len(input_shape) for a in self.axis: bcast_shape[a] = input_shape[a] return tf.reshape(inputs, bcast_shape) # cast fp16 to fp32 precise_inputs = tf.cast(inputs, self.mp_type) # streaming / control normalization if training is not False: outputs = self.normalization(precise_inputs) else: mu = tf.cast(_bcast(self.mu), self.mp_type) denom = tf.cast( tf.math.sqrt( self.var + self.epsilon ), self.mp_type ) outputs = (inputs - mu) / _bcast(denom) outputs = tf.cast(outputs, self.mp_type) return outputs
def compute_spectral_normal(self, training): # Spectrally Normalized Weight if self.spectral_normalization: W_mat = KB.reshape(self.kernel, [self.out_dim, -1]) # [out_channels, N] W_sn, u, _ = power_iteration(W_mat, self.u) def true_fn(): self.u.assign(u) pass def false_fn(): pass training_value = tf_utils.constant_value(training) if training_value is not None: tf_utils.smart_cond(training, true_fn, false_fn) return self.kernel/W_sn else: return self.kernel
def call(self, inputs, training=None): def _l2normalize(v, eps=1e-12): return v / (K.sum(v**2)**0.5 + eps) def power_iteration(W, u): # Accroding the paper, we only need to do power iteration one time. _u = u _v = _l2normalize(K.dot(_u, K.transpose(W))) _u = _l2normalize(K.dot(_v, W)) return _u, _v # Spectral Normalization W_shape = self.kernel.shape.as_list() # Flatten the Tensor W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]]) _u, _v = power_iteration(W_reshaped, self.u) # Calculate Sigma sigma = K.dot(_v, W_reshaped) sigma = K.dot(sigma, K.transpose(_u)) # normalize it W_bar = W_reshaped / sigma # reshape weight tensor trainig_val = tf_utils.constant_value(training) if trainig_val == False: W_bar = K.reshape(W_bar, W_shape) else: with tf.control_dependencies([self.u.assign(_u)]): W_bar = K.reshape(W_bar, W_shape) outputs = K.conv2d(inputs, W_bar, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if self.use_bias: outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs, training=None): x, y = inputs x = tf.nn.l2_normalize(x, axis=1) w = tf.nn.l2_normalize(self._w, axis=0) logits = tf.matmul(x, w) is_dynamic = tf_utils.constant_value(self._is_dynamic) if not is_dynamic: output = tf.multiply(self._init_s, logits) return output training = _resolve_training(self, training) if not training: return self._s * logits else: theta = tf.math.acos(K.clip(logits, -1.0 + K.epsilon(), 1.0 - K.epsilon())) b_avg = tf.where(y < 1.0, tf.exp(self._s * logits), tf.zeros_like(logits)) b_avg = tf.reduce_mean(tf.reduce_sum(b_avg, axis=1)) theta_class = tf.gather(theta, tf.cast(y, tf.int32)) theta_med = tfp.stats.percentile(theta_class, q=50) self._s.assign(tf.math.log(b_avg) / tf.math.cos(tf.minimum(math.pi / 4, theta_med))) self._s.assign_add(-0.5) logits *= self._s out = tf.nn.sigmoid(logits) return out
def compute_spectral_normal(self, training): # Spectrally Normalized Weight if self.spectral_normalization: # Get the kernel tensor shape # W_shape = self.kernel.shape.as_list() # Flatten the Tensor # For transpose conv, the kernel shape is [H,W,Out,In] # out_dim=W_shape[-2] W_mat = KB.reshape(self.kernel, [self.out_dim, -1]) # [out_c, N] sigma, u, _ = power_iteration(W_mat, self.u) def true_fn(): self.u.assign(u) pass def false_fn(): pass training_value = tf_utils.constant_value(training) if training_value is not False: tf_utils.smart_cond(training, true_fn, false_fn) return self.kernel / sigma else: return self.kernel
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if self._support_zero_size_input(): inputs_size = array_ops.size(inputs) else: inputs_size = None def _fused_batch_norm_training(): return nn.fused_batch_norm(inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: def mean_update(): return self._assign_moving_average(self.moving_mean, mean, momentum, inputs_size) def variance_update(): """Update self.moving_variance with the most recent data point.""" if self.renorm: # We apply epsilon as part of the moving_stddev to mirror the training # code path. moving_stddev = self._assign_moving_average( self.moving_stddev, math_ops.sqrt(variance + self.epsilon), momentum, inputs_size) return self._assign_new_value( self.moving_variance, # Apply relu in case floating point rounding causes it to go # negative. K.relu(moving_stddev * moving_stddev - self.epsilon)) else: return self._assign_moving_average(self.moving_variance, variance, momentum, inputs_size) self.add_update(mean_update) self.add_update(variance_update) return output
def call(self, inputs, training=None): """ Call function will be called by __call__ Arguments: inputs: activations into the layer training: Boolean to set training or inference mode Returns: normalized activations with multiplicative scale and additive bias corrections """ original_training_value = training if training is None: training = K.learning_phase() # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) input_shape = inputs.get_shape() def _bcast(inputs): """ broadcasts tensor for tensor operations with tensor of larger rank """ if inputs is None: return None bcast_shape = [1] * len(input_shape) for a in self.axis: bcast_shape[a] = input_shape[a] return tf.reshape(inputs, bcast_shape) mixed_precision = (inputs.dtype == dtypes.float16 or inputs.dtype == dtypes.bfloat16) # cast fp16 to fp32 precise_inputs = inputs if mixed_precision: precise_inputs = math_ops.cast(inputs, dtypes.float32) # streaming / control normalization if training_value is not False: x_norm = tf_utils.smart_cond( training, lambda: self.control_normalization(precise_inputs), lambda: tf.nn.batch_normalization( precise_inputs, tf.reshape(self.mu[-1], self.broadcast_shape), tf.reshape(self.var[-1], self.broadcast_shape), None, None, self.epsilon)) else: x_norm = tf.nn.batch_normalization( precise_inputs, tf.reshape(self.mu[-1], self.broadcast_shape), tf.reshape(self.var[-1], self.broadcast_shape), None, None, self.epsilon) # scale and bias x_scaled = x_norm * _bcast(self.gamma) if self.scale else x_norm x_bias = x_scaled + _bcast(self.beta) if self.center else x_scaled outputs = self.layer_scaling(x_bias) if self.ls else x_bias # if needed, cast back to fp16 if mixed_precision: outputs = math_ops.cast(outputs, inputs.dtype) return outputs
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if self._support_zero_size_input(): inputs_size = array_ops.size(inputs) else: inputs_size = None # TODO(rmlarsen): Support using fused avg updates for non-eager execution # after fixing graph pattern matching and enabling fused_batch_norm to # take exponential_avg_factor as a tensor input. use_fused_avg_updates = ( compat.forward_compatible(2020, 3, 6) and ops.executing_eagerly_outside_functions()) if use_fused_avg_updates: exponential_avg_factor = 1.0 - self.momentum else: exponential_avg_factor = None def _maybe_add_or_remove_bessels_correction(variance, remove=True): r"""Add or remove Bessel's correction.""" # Removes Bessel's correction if remove == True, adds it otherwise. # This is to be consistent with non-fused batch norm. Note that the # variance computed by fused batch norm is with Bessel's correction. # This is only used in legacy V1 batch norm tests. if self._bessels_correction_test_only: return variance sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) if remove: factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size else: factor = sample_size / ( sample_size - math_ops.cast(1.0, variance.dtype)) return variance * factor def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=_maybe_add_or_remove_bessels_correction( self.moving_variance, remove=False), epsilon=self.epsilon, is_training=True, data_format=self._data_format, exponential_avg_factor=exponential_avg_factor) def _fused_batch_norm_training_empty(): return inputs, self.moving_mean, self.moving_variance def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) train_op = _fused_batch_norm_training if use_fused_avg_updates and inputs_size is not None: train_op = lambda: tf_utils.smart_cond(inputs_size > 0, _fused_batch_norm_training, _fused_batch_norm_training_empty) output, mean, variance = tf_utils.smart_cond(training, train_op, _fused_batch_norm_inference) variance = _maybe_add_or_remove_bessels_correction(variance, remove=True) training_value = tf_utils.constant_value(training) if training_value or training_value is None: if not use_fused_avg_updates: if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor_v2(self.momentum) def mean_update(): """Update self.moving_mean with the most recent data point.""" if use_fused_avg_updates: return self._assign_new_value(self.moving_mean, mean) else: return self._assign_moving_average(self.moving_mean, mean, momentum, inputs_size) def variance_update(): """Update self.moving_variance with the most recent data point.""" if use_fused_avg_updates: return self._assign_new_value(self.moving_variance, variance) else: return self._assign_moving_average(self.moving_variance, variance, momentum, inputs_size) self.add_update(mean_update) self.add_update(variance_update) return output
def call(self, inputs, training=None): training = self._get_training_value(training) # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return tf.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value == False: mean, variance = self.moving_mean, self.moving_variance else: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = len(self.axis) > 1 mean, variance = self._moments( tf.cast(inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: tf.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: tf.convert_to_tensor(moving_variance)) new_mean, new_variance = mean, variance inputs_size = None def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): def true_branch(): return _do_update(self.moving_mean, new_mean) def false_branch(): return self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" def true_branch(): return _do_update(self.moving_variance, new_variance) def false_branch(): return self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) mean = tf.cast(mean, inputs.dtype) variance = tf.cast(variance, inputs.dtype) if offset is not None: offset = tf.cast(offset, inputs.dtype) if scale is not None: scale = tf.cast(scale, inputs.dtype) outputs = tf.nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def call(self, inputs, params=None, training=None): if params[self.name + '/gamma:0'] is None: return super(layers.BatchNormalization, self).call(inputs) else: gamma = params.get(self.name + '/gamma:0') beta = params.get(self.name + '/beta:0') original_training_value = training if training is None: training = backend.learning_phase() # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(gamma), _broadcast(beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) # mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) if original_training_value is None: outputs._uses_learning_phase = True # pylint: disable=protected-access return outputs
def call(self, inputs, training=None): training = self._get_training_value(training) assert self.virtual_batch_size is None, "Disabled" assert self.fused is False, "Disabled" assert self.adjustment is None, "Disabled" assert self.renorm is False, "Disabled" # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): cond = (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))) if cond: return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is False: mean, variance = self.moving_mean, self.moving_variance else: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = self._moments(math_ops.cast( inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor_v2(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: ops.convert_to_tensor_v2(moving_variance)) new_mean, new_variance = mean, variance if self._support_zero_size_input(): # Keras assumes that batch dimension is the first dimension for Batch # Normalization. input_batch_size = array_ops.shape(inputs)[0] else: input_batch_size = None def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, input_batch_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" def true_branch_renorm(): # We apply epsilon as part of the moving_stddev to mirror the training # code path. moving_stddev = _do_update( self.moving_stddev, math_ops.sqrt(new_variance + self.epsilon)) return self._assign_new_value( self.moving_variance, # Apply relu in case floating point rounding causes it to go # negative. K.relu(moving_stddev * moving_stddev - self.epsilon)) if self.renorm: true_branch = true_branch_renorm else: true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) # bkeng: Flow loss/metric self.add_flow_loss(variance, scale) return outputs
def call(self, inputs, training=None): if self.scale and self.gamma_quantizer: quantized_gamma = self.gamma_quantizer_internal(self.gamma) else: quantized_gamma = self.gamma if self.center and self.beta_quantizer: quantized_beta = self.beta_quantizer_internal(self.beta) else: quantized_beta = self.beta if self.mean_quantizer: quantized_moving_mean = self.mean_quantizer_internal( self.moving_mean) else: quantized_moving_mean = self.moving_mean if self.variance_quantizer: quantized_moving_variance = self.variance_quantizer_internal( self.moving_variance) else: quantized_moving_variance = self.moving_variance training = self._get_training_value(training) # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison quantized_mean, quantized_variance = (quantized_moving_mean, quantized_moving_variance) else: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = len(self.axis) > 1 mean, variance = self._moments(math_ops.cast( inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: ops.convert_to_tensor(moving_variance)) new_mean, new_variance = mean, variance if self.mean_quantizer: quantized_mean = self.mean_quantizer_internal(mean) else: quantized_mean = mean if self.variance_quantizer: quantized_variance = self.variance_quantizer_internal(variance) else: quantized_variance = variance if self._support_zero_size_input(): inputs_size = array_ops.size(inputs) else: inputs_size = None def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) quantized_mean = math_ops.cast(quantized_mean, inputs.dtype) quantized_variance = math_ops.cast(quantized_variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(quantized_mean), _broadcast(quantized_variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def call(self, inputs, training=None, mask=None, **kwargs): training = self._get_training_value(training) x = inputs if mask is not None: x = tf.where(tf.expand_dims(mask, axis=-1), x, tf.zeros_like(x)) orig_dtype = x.dtype x = tf.cast(x, tf.float32) inputs_size = array_ops.size(inputs) axes = list(range(len(shape_list(x))))[:-1] training_value = tf_utils.constant_value(training) if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison mean, variance = self.moving_mean, self.moving_variance else: mean, variance = masked_moments(x, mask=mask, axes=axes, keepdims=False) mean = tf.squeeze(mean) variance = tf.squeeze(variance) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: tf.convert_to_tensor(moving_variance)) def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" true_branch = lambda: _do_update(self.moving_variance, variance ) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) if self.scale: gamma = self.get_weight('gamma', training=training) else: gamma = None if self.center: beta = self.get_weight('beta', training=training) else: beta = None x = tf.nn.batch_normalization(x, mean=mean, variance=variance, scale=gamma, offset=beta, variance_epsilon=self.epsilon) x = tf.cast(x, orig_dtype) return x
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. # TODO(b/130185866): Support zero batch input in graph mode. if ops.executing_eagerly_outside_functions( ) and distribution_strategy_context.has_strategy(): inputs_size = array_ops.size(inputs) else: inputs_size = None def _fused_batch_norm_training(): return nn.fused_batch_norm(inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: def mean_update(): return self._assign_moving_average(self.moving_mean, mean, momentum, inputs_size) def variance_update(): return self._assign_moving_average(self.moving_variance, variance, momentum, inputs_size) self.add_update(mean_update) self.add_update(variance_update) return output
def call(self, inputs, training=None): training = self._get_training_value(training) # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in [self.axis]] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis] = input_shape.dims[self.axis].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) if self.axis != ndims - 1: trans = reduction_axes + [self.axis] transpose_recover = [i for i in range(self.axis)] + [ndims - 1] + [ j for j in range(self.axis, ndims - 1) ] inputs = array_ops.transpose(inputs, perm=trans) transposed_shape = [-1] + inputs.get_shape().as_list()[1:] inputs = array_ops.reshape( inputs, shape=[-1, input_shape.dims[self.axis].value]) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) outputs = [] height, width = input_shape.as_list()[1], input_shape.as_list()[2] for i in range(self.groups): start_index = i * self.m_per_group end_index = np.min(((i + 1) * self.m_per_group, input_shape.dims[self.axis].value)) group_input = inputs[:, start_index:end_index] if training_value is not False: mean = tf.reduce_mean(group_input, 0, keepdims=True) centered = group_input - mean # centered_ = tf.expand_dims(centered, -1) # sigma = tf.matmul(centered_, tf.linalg.matrix_transpose(centered_)) # sigma = tf.reduce_mean(sigma, 0) sigma = tf.matmul(tf.linalg.matrix_transpose(centered), centered) sigma /= (cfg.training.batch_size * height * width) projection = self.get_projection(sigma, group_input) moving_mean = self.moving_means[i] moving_projection = self.moving_projections[i] mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) projection = tf_utils.smart_cond( training, lambda: projection, lambda: ops.convert_to_tensor(moving_projection)) new_mean, new_projection = mean, projection def _do_update(var, value): return self._assign_moving_average(var, value, self.momentum, None) def mean_update(): true_branch = lambda: _do_update(self.moving_means[i], new_mean) false_branch = lambda: self.moving_means[i] return tf_utils.smart_cond(training, true_branch, false_branch) def projection_update(): true_branch = lambda: _do_update( self.moving_projections[i], new_projection) false_branch = lambda: self.moving_projections[i] return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(projection_update) else: mean, projection = self.moving_means[ i], self.moving_projections[i] centered = group_input - mean mean = math_ops.cast(mean, inputs.dtype) projection = math_ops.cast(projection, inputs.dtype) output = tf.matmul(centered, projection) outputs.append(output) outputs = tf.concat(outputs, 1) if self.axis != ndims - 1: outputs = tf.reshape(outputs, shape=transposed_shape) outputs = tf.transpose(outputs, perm=transpose_recover) else: outputs = tf.reshape(outputs, shape=[-1] + input_shape.as_list()[1:]) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) outputs = outputs * scale if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs += offset # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def call(self, inputs, training=None): # Extract weights of previous layer to compute proper scale. previous_weights = self.previous_layer.weights[0].value() original_training_value = training if training is None: training = K.learning_phase() in_eager_mode = tf.executing_eagerly() # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) # For dense layers, require a full reduction. if self.binary_dense: reduction_axes = [i for i in range(ndims)] # Otherwise, reduce all but the feature axis. else: reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]] def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return tf.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = len(self.axis) > 1 mean, variance = tf.compat.v1.nn.moments( inputs, reduction_axes, keep_dims=keep_dims) # When norming the output of a binary dense layer, # need to make sure shape is maintained. if self.binary_dense: mean = tf.reshape(mean, [1]) variance = tf.reshape(variance, [1]) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(tf.stop_gradient(r, name='renorm_r')) d = _broadcast(tf.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not tf.executing_eagerly(): self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = tf.cast(mean, inputs.dtype) variance = tf.cast(variance, inputs.dtype) if offset is not None: offset = tf.cast(offset, inputs.dtype) #outputs = nn.batch_normalization(inputs, _broadcast(mean), # _broadcast(variance), offset, scale, # self.epsilon) approximate_std, quantized_means = compute_quantized_shiftnorm( variance, mean, self.epsilon, previous_weights, self.extra_scale, self.bits, rescale=True) outputs = inputs - quantized_means outputs = outputs * approximate_std if scale: outputs = scale * outputs if offset: outputs = outputs + offset # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if not tf.executing_eagerly() and original_training_value is None: outputs._uses_learning_phase = True # pylint: disable=protected-access return outputs
def call(self, inputs, training=None): training = K.learning_phase() # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] scale, offset = self.gamma, self.beta # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = len(self.axis) > 1 mean, variance = self._moments( math_ops.cast(inputs, inputs.dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: ops.convert_to_tensor(moving_variance)) new_mean, new_variance = mean, variance if ops.executing_eagerly_outside_functions( ) and distribution_strategy_context.has_strategy(): inputs_size = array_ops.size(inputs) else: inputs_size = None if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() def _do_update(var, value): """Compute the updates for mean and variance.""" return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum, inputs_size), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: strategy.unwrap(self.moving_mean) return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): return tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) outputs = nn.batch_normalization(inputs,mean,variance, offset, scale, self.epsilon) return outputs
def call(self, inputs, training=None): training = self._get_training_value(training) if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison mean, variance = self.moving_mean, self.moving_variance else: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = self._moments(math_ops.cast( inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: ops.convert_to_tensor(moving_variance)) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self._support_zero_size_input(): inputs_size = array_ops.size(inputs) else: inputs_size = None if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training, inputs_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" def true_branch_renorm(): # We apply epsilon as part of the moving_stddev to mirror the training # code path. moving_stddev = _do_update( self.moving_stddev, math_ops.sqrt(new_variance + self.epsilon)) return self._assign_new_value( self.moving_variance, # Apply relu in case floating point rounding causes it to go # negative. K.relu(moving_stddev * moving_stddev - self.epsilon)) if self.renorm: true_branch = true_branch_renorm else: true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def call(self, inputs, training=None): if self.lr_mul == 1.0: kernel = self.coeff * self.kernel else: @custom_gradient def lr_multiplier(x): y = array_ops.identity(x) def grad(dy): return dy * self.lr_mul return y, grad kernel = lr_multiplier(self.coeff * self.kernel) training = self._get_training_value(training) # Update singular vector by power iteration if self.data_format == 'channels_first': W_T = array_ops.reshape(kernel, (self.filters, -1)) W = array_ops.transpose(W_T) else: W = array_ops.reshape(kernel, (-1, self.filters)) W_T = array_ops.transpose(W) u = array_ops.identity(self.u) for i in range(self.power_iter): v = nn_impl.l2_normalize(math_ops.matmul(u, W)) # 1 x filters u = nn_impl.l2_normalize(math_ops.matmul(v, W_T)) # Spectral Normalization sigma_W = math_ops.matmul(math_ops.matmul(u, W), array_ops.transpose(v)) # Backprop doesn't need in power iteration sigma_W = array_ops.stop_gradient(sigma_W) W_bar = kernel / array_ops.squeeze(sigma_W) # Assign new singular vector training_value = tf_utils.constant_value(training) if training_value is not False: def u_update(): def true_branch(): return self._assign_singular_vector(self.u, u) def false_branch(): return self.u return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(u_update) # normal convolution using W_bar outputs = self._convolution_op(inputs, W_bar) if self.use_bias: if self.data_format == 'channels_first': if self.rank == 1: # nn.bias_add does not accept a 1D input tensor. bias = array_ops.reshape(self.bias, (1, self.filters, 1)) outputs += bias else: outputs = nn.bias_add(outputs, self.bias, data_format='NCHW') else: outputs = nn.bias_add(outputs, self.bias, data_format='NHWC') if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs, training=None): if training is None: training = K.learning_phase() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond(training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond(training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 mean, variance = self._moments( math_ops.cast(inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() def _do_update(var, value): """Compute the updates for mean and variance.""" return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: strategy.unwrap(self.moving_mean) return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): return tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _cross_replica_non_fused_batch_norm_training(): # TODO(panos): assert the data format to be NHWC # TODO(panos): make a moments function with distributed synchronization def _merge_fn(strategy, per_replica_mean, per_replica_square_mean): # per_replica_mean: PerDevice # global_mean: Mirrored global_mean = strategy.reduce(tf.VariableAggregation.SUM, per_replica_mean, per_replica_mean) global_squared_mean = strategy.reduce( tf.VariableAggregation.SUM, per_replica_square_mean, per_replica_square_mean) return global_mean, global_squared_mean # dispatch as much computation to each replica per_replica_mean = tf.reduce_mean(inputs, axis=(0, 1, 2)) per_replica_square_mean = tf.reduce_mean(tf.square(inputs), axis=(0, 1, 2)) replica_context = tf.contrib.distribute.get_tower_context() global_mean, global_squared_mean = replica_context.merge_call( _merge_fn, per_replica_mean / replica_context.num_towers, per_replica_square_mean / replica_context.num_towers) global_variance = global_squared_mean - tf.square(global_mean) inputs_normalized = tf.nn.batch_normalization( inputs, global_mean, global_variance, beta, gamma, self.epsilon) return inputs_normalized, global_mean, global_variance def _fused_batch_norm_training(): return nn.fused_batch_norm(inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _cross_replica_non_fused_batch_norm_training, _fused_batch_norm_inference) if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: mean_update = self._assign_moving_average(self.moving_mean, mean, momentum) variance_update = self._assign_moving_average( self.moving_variance, variance, momentum) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output
def call(self, inputs, training=None): if training is None: training = K.learning_phase() in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 mean, variance = self._moments(inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond(training, lambda: mean, lambda: moving_mean) variance = tf_utils.smart_cond(training, lambda: variance, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) else: new_mean, new_variance = mean, variance if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: strategy.unwrap(self.moving_mean)) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not context.executing_eagerly(): self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def call(self, inputs, training=None): original_training_value = training if training is None: training = K.learning_phase() in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_switch_norm(inputs, training=training) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) if not context.executing_eagerly( ) and original_training_value is None: outputs._uses_learning_phase = True # pylint: disable=protected-access return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 # _, mean_bn, variance_bn = nn.fused_batch_norm(inputs, scale=tf.ones(shape=[inputs.shape[self.axis[0]]]), # offset=tf.zeros(shape=[inputs.shape[self.axis[0]]]), # epsilon=self.epsilon, data_format=self._data_format) # mean_bn, variance_bn = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) mean, variance, mean_bn, variance_bn = compute_stats( inputs, self.mean_weights, self.var_weights, self.hparams, training, axis=self.axis) moving_mean = self.moving_mean moving_variance = self.moving_variance mean_bn = tf_utils.smart_cond(training, lambda: mean_bn, lambda: moving_mean) variance_bn = tf_utils.smart_cond(training, lambda: variance_bn, lambda: moving_variance) if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean_bn, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance_bn, axis=1, keepdims=True) else: new_mean, new_variance = mean_bn, variance_bn if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) def _do_update(var, value): if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not context.executing_eagerly(): self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) print('input shape', inputs.shape) print('mean shape', mean.shape) print('variance shape', variance.shape) else: mean_bn, variance_bn = self.moving_mean, self.moving_variance mean, variance, _, _ = compute_stats(inputs, self.mean_weights, self.var_weights, self.hparams, training, mean_bn=mean_bn, variance_bn=variance_bn, axis=self.axis) # sn tf.summary.scalar('running_mean_0_', tf.squeeze(self.moving_mean[0])) tf.summary.scalar('running_var_0_', tf.squeeze(self.moving_variance[0])) tf.summary.scalar('batch_mean_0_', tf.squeeze(mean_bn[0])) outputs = nn.batch_normalization(inputs, mean, variance, offset, scale, self.epsilon) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) if not context.executing_eagerly() and original_training_value is None: outputs._uses_learning_phase = True # pylint: disable=protected-access return outputs