def training_phase(): mean_batch = K.mean(mean_instance, axis=0, keepdims=True) variance_batch = K.mean(temp, axis=0, keepdims=True) - K.square(mean_batch) mean_batch_reshaped = K.flatten(mean_batch) variance_batch_reshaped = K.flatten(variance_batch) if K.backend() != 'cntk': sample_size = K.prod( [K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) # sample variance - unbiased estimator of population variance variance_batch_reshaped *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, mean_batch_reshaped, self.momentum), K.moving_average_update(self.moving_variance, variance_batch_reshaped, self.momentum) ], ) return normalize_func(mean_batch, variance_batch)
def call(self, inputs, training=None): x = inputs assert not isinstance(x, list) # Compute the minibatch statistics mean, var = self._moments(x) sigma = K.sqrt(var + self.epsilon) # If in training phase set rmax, dmax large so that we use the moving # averages to do the normalization rmax = K.in_train_phase(self.rmax, K.constant(1e5), training) dmax = K.in_train_phase(self.dmax, K.constant(1e5), training) # Compute the corrections based on rmax, dmax r = K.stop_gradient( self._clip(sigma / self.moving_sigma, 1. / rmax, rmax)) d = K.stop_gradient( self._clip((mean - self.moving_mean) / self.moving_sigma, -dmax, dmax)) # Actually do the normalization and the rescaling xnorm = ((x - mean) / sigma) * r + d y = self.gamma * xnorm + self.beta # Add the moving average updates self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_sigma, sigma, self.momentum) ], x) # Add the r, d updates rmax_prog = K.minimum(1., self.steps / self.rmax_dur) dmax_prog = K.minimum(1., self.steps / self.dmax_dur) self.add_update([ K.update_add(self.steps, 1), K.update(self.rmax, self.rmax_0 + rmax_prog * (self.rmax_inf - self.rmax_0)), K.update(self.dmax, self.dmax_0 + dmax_prog * (self.dmax_inf - self.dmax_0)) ]) # Fix the output's uses learning phase y._uses_learning_phase = rmax._uses_learning_phase return y
def inject(self): """ add moving average update op to model.metrics_updates """ self.initialize() for w1, w2 in zip(self.ema_weights, self.model.weights): op = K.moving_average_update(w1, w2, self.momentum) self.model.metrics_updates.append(op)
def inject(self): """添加更新算子到model.metrics_updates。 """ self.initialize() for w1, w2 in zip(self.ema_weights, self.model.weights): op = K.moving_average_update(w1, w2, self.momentum) #self.model.metrics_updates.append(op) # 在 keras 2.2.4 有效 if not hasattr(self.model, '_other_metrics'): self.model._other_metrics = [] self.model._other_metrics.append(op)
def update_branch(): """ Update the moving average when is_ema_training is True.""" # Set the qnoise factor to 0 to update the EMA using the unquantized input prev_qnoise_factor = tf.identity(self.quantizer.qnoise_factor) self.quantizer.update_qnoise_factor(tf.constant(0.0)) # Update the EMA act_x = self.quantizer( x) # act_x is the input after the activation # function, but before the quantizer. This is # done by using a qnoise_factor of 0 new_min = tf.squeeze(K.min(act_x, axis=axis, keepdims=True)) K.moving_average_update(self.ema_min, new_min, self.ema_decay) new_max = tf.squeeze(K.max(act_x, axis=axis, keepdims=True)) K.moving_average_update(self.ema_max, new_max, self.ema_decay) # Reset the qnoise factor to the previous value self.quantizer.update_qnoise_factor(prev_qnoise_factor)
def set_model(self, model): """绑定模型,并初始化参数 """ super(ExponentialMovingAverage, self).set_model(model) self.ema_weights = [K.zeros(K.shape(w)) for w in model.weights] self.old_weights = K.batch_get_value(model.weights) K.batch_set_value(zip(self.ema_weights, self.old_weights)) self.updates = [] for w1, w2 in zip(self.ema_weights, model.weights): op = K.moving_average_update(w1, w2, self.momentum) self.updates.append(op)
def call(self, inputs, training=None): x = inputs assert not isinstance(x, list) # Do the normalization and the rescaling xnorm = K.batch_normalization(x, self.moving_mean, self.moving_variance, self.beta, self.gamma, epsilon=self.epsilon) # Compute and update the minibatch statistics if self.update_stats: mean, var = self._moments(x, axes=range(len(K.int_shape(x)) - 1)) self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, var, self.momentum) ], x) return xnorm
def training_phase(): # Depthwise-Conv mit Soft-Relu dconvs = K.depthwise_conv2d(inputs, self.depthwise_kernel, strides=self.strides, padding=self.padding, data_format='channels_last') # dconvs = tf.where(dconvs<=2**(-self.L_A[1]-1), tf.zeros_like(dconvs), dconvs) # factor2 = 0.9*self.max_activity # dconvs = K.minimum(dconvs, 0.1*dconvs+factor2) factor2 = 0.9 * self.max_activity_signed dconvs = K.minimum(dconvs, 0.1 * dconvs + factor2) dconvs = K.maximum(dconvs, 0.1 * dconvs - factor2) # Pointwise-Conv convs = K.conv2d(dconvs, self.kernel, strides=(1, 1), padding=self.padding, data_format='channels_last', dilation_rate=self.dilation_rate) convs = K.bias_add(convs, self.bias, data_format='channels_last') # Skalierung scale1 = K.abs(self.max_activity_x / (K.max(K.abs(convs), axis=(0, 1, 2)) + 1e-6)) indizes = K.greater(scale1, self.max_scale) scale1 = self.w_scale * tf.to_float(indizes) + tf.to_float( ~indizes) * scale1 scale2 = self.max_weight / (K.maximum( tf.abs(self.bias), tf.reduce_max(tf.abs(self.kernel), axis=(0, 1, 2))) + 1e-6) scale = K.minimum(scale1, scale2) self.add_update( K.moving_average_update(self.w_scale, scale, self.momentum), inputs) # Softclipped-linear outputs = convs * self.w_scale # outputs = K.clip(outputs, min_value=-self.max_activity_signed, max_value=self.max_activity_signed) outputs = tf.where(outputs <= 2**(-self.L_A[1] - 1), tf.zeros_like(outputs), outputs) outputs = K.minimum(outputs, 0.1 * outputs + factor2) return outputs
def training_phase(): convs = K.conv2d(inputs, self.kernel, data_format='channels_last', strides=self.strides, padding=self.padding, dilation_rate=self.dilation_rate) if self.use_bias: if self.data_format == 'channels_last': convs = K.bias_add(convs, self.bias, data_format='channels_last') scale2 = self.max_weight / (K.maximum( tf.abs(self.bias), tf.reduce_max(tf.abs(self.kernel), axis=(0, 1, 2))) + 1e-6) else: scale2 = self.max_weight / ( tf.reduce_max(tf.abs(self.kernel), axis=(0, 1, 2)) + 1e-6) indizes = K.greater(K.max(convs, axis=(0, 1, 2)), 0.01) scale1 = self.w_scale * tf.cast(~indizes, tf.float32) + tf.cast( indizes, tf.float32) * K.abs(self.max_activity_x / (K.max(convs, axis=(0, 1, 2)) + 1e-6)) scale = K.minimum(K.minimum(scale1, scale2), self.max_scale) self.add_update( K.moving_average_update(self.w_scale, scale, self.momentum)) outputs = convs * self.w_scale if self.data_format == 'channels_last': outputs = tf.transpose(outputs, [0, 3, 1, 2]) outputs = tf.where(outputs <= 2**(-self.L_A[1] - 1), tf.zeros_like(outputs), outputs) outputs = tf.transpose(outputs, [0, 2, 3, 1]) else: outputs = tf.where(outputs <= 2**(-self.L_A[1] - 1), tf.zeros_like(outputs), outputs) outputs = K.minimum(outputs, self.max_activity) #0.1*outputs+self.factor) return outputs
def call(self, inputs, training=False): x = inputs training = training and self.trainable self.will_ema_freeze = self.will_ema_freeze and self.trainable # Update the step count if the optimizer step count is unknown self.step.assign_add( K.switch( tf.math.logical_and(self.is_estimating_step_count, training), tf.constant(1, tf.int64), tf.constant(0, tf.int64))) # Perform the quantization if training: # Calculate the qnoise, a scalar from 0 to 1 that represents the level of # quantization noise to use. At training start, we want no quantization, # so qnoise_factor = 0.0. After quantization_delay steps, we want normal # quantization, so qnoise_factor = 1.0. qnoise_factor = K.switch( tf.greater_equal(self.step, self.quantization_delay), lambda: tf.constant(1.0), lambda: tf.constant(0.0)) qx = self.quantizer(x, qnoise_factor=qnoise_factor) else: # If not training, we always want to use full quantization qx = self.quantizer(x, qnoise_factor=tf.constant(1.0)) # Calculate the axis along where to find the min and max EMAs len_axis = len(x.shape) if len_axis > 1: if self.per_channel: if K.image_data_format() == "channels_last": axis = list(range(len_axis - 1)) else: axis = list(range(1, len_axis)) else: axis = list(range(len_axis)) else: axis = [0] # Determine if freezing the EMA is_ema_training = tf.constant(training, dtype=tf.bool) if self.will_ema_freeze: is_ema_training = tf.cond( tf.greater(self.step, self.ema_freeze_delay), lambda: tf.constant(False), lambda: tf.constant(True)) # Update the moving average if is_ema_training: new_min = tf.squeeze(K.min(qx, axis=axis, keepdims=True)) K.moving_average_update(self.ema_min, new_min, self.ema_decay) new_max = tf.squeeze(K.max(qx, axis=axis, keepdims=True)) K.moving_average_update(self.ema_max, new_max, self.ema_decay) # Set the integer bits for the quantizer integer_bits = _get_integer_bits(min_value=self.ema_min, max_value=self.ema_max, bits=self.total_bits, symmetric=self.symmetric, keep_negative=self.keep_negative, is_clipping=self.po2_rounding) self.quantizer.integer.assign(integer_bits) return qx
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(self.moving_variance, broadcast_shape) if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None return tf.nn.batch_normalization( #K.batch_normalization( inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, #axis=self.axis, self.epsilon) #epsilon=self.epsilon) else: return tf.nn.batch_normalization( #K.batch_normalization( inputs, self.moving_mean, self.moving_variance, self.beta, self.gamma, #axis=self.axis, self.epsilon) #epsilon=self.epsilon) # If the learning phase is *static* and set to inference: if training in {0, False}: return normalize_inference() # If the learning is either dynamic, or set to training: normed_training, mean, variance = _regular_normalize_batch_in_training( #K.normalize_batch_in_training( inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) if K.backend() != 'cntk': sample_size = K.prod( [K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) # sample variance - unbiased estimator of population variance variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum) ], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, inputs, training=None): if self.quant_mode not in [None, 'extrinsic', 'hybrid', 'intrinsic']: raise ValueError( 'Invalid quantization mode. The \'quant_mode\' argument must be one of \'extrinsic\' , \'intrinsic\' , \'hybrid\' or None.' ) if isinstance(self.quantizer, list) and len(self.quantizer) == 3: quantizer_input = self.quantizer[0] quantizer_weight = self.quantizer[1] quantizer_output = self.quantizer[2] else: quantizer_input = self.quantizer quantizer_weight = self.quantizer quantizer_output = self.quantizer input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(self.moving_variance, broadcast_shape) if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None if self.quant_mode in ['hybrid', 'intrinsic']: broadcast_moving_mean = quantizer_weight.quantize( broadcast_moving_mean) broadcast_moving_variance = quantizer_weight.quantize( broadcast_moving_variance) if self.center: broadcast_beta = quantizer_weight.quantize( broadcast_beta) if self.scale: broadcast_gamma = quantizer_weight.quantize( broadcast_gamma) if self.quant_mode in ['hybrid', 'intrinsic']: quantized_inputs = quantizer_input.quantize(inputs) if self.quant_mode == 'intrinsic': return QuantizedBatchNormalizationCore( quantized_inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, self.epsilon, quantizer_output) elif self.quant_mode == 'hybrid': output = K.batch_normalization(quantized_inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, axis=self.axis, epsilon=self.epsilon) return quantizer_output.quantize(output) elif self.quant_mode == 'extrinsic': output = K.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, axis=self.axis, epsilon=self.epsilon) return quantizer_output.quantize(output) elif self.quant_mode is None: return K.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, axis=self.axis, epsilon=self.epsilon) else: if self.quant_mode in ['hybrid', 'intrinsic']: moving_mean = quantizer_weight.quantize(self.moving_mean) moving_variance = quantizer_weight.quantize( self.moving_variance) if self.center: beta = quantizer_weight.quantize(self.beta) else: beta = self.beta if self.scale: gamma = quantizer_weight.quantize(self.gamma) else: gamma = self.gamma if self.quant_mode in ['hybrid', 'intrinsic']: quantized_inputs = quantizer_input.quantize(inputs) if self.quant_mode == 'intrinsic': return QuantizedBatchNormalizationCore( quantized_inputs, moving_mean, moving_variance, beta, gamma, self.epsilon, quantizer_output) elif self.quant_mode == 'hybrid': output = K.batch_normalization(quantized_inputs, moving_mean, moving_variance, beta, gamma, axis=self.axis, epsilon=self.epsilon) return quantizer_output.quantize(output) elif self.quant_mode == 'extrinsic': output = K.batch_normalization(inputs, self.moving_mean, self.moving_variance, self.beta, self.gamma, axis=self.axis, epsilon=self.epsilon) return quantizer_output.quantize(output) elif self.quant_mode == None: return K.batch_normalization(inputs, self.moving_mean, self.moving_variance, self.beta, self.gamma, axis=self.axis, epsilon=self.epsilon) # If the learning phase is *static* and set to inference: if not training: return normalize_inference() # If the learning is either dynamic, or set to training: normed_training, mean, variance = K.normalize_batch_in_training( inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) if K.backend() != 'cntk': sample_size = K.prod( [K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) # sample variance - unbiased estimator of population variance variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum) ], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, inputs, training=None): assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(inputs) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] mean_batch, var_batch = K.moments(inputs, reduction_axes, shift=None, keep_dims=False) std_batch = (K.sqrt(var_batch + self.epsilon)) r_max_value = K.get_value(self.r_max) r = std_batch / (K.sqrt(self.running_variance + self.epsilon)) r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) d_max_value = K.get_value(self.d_max) d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance + self.epsilon) d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]: x_normed_batch = (inputs - mean_batch) / std_batch x_normed = (x_normed_batch * r + d) * self.gamma + self.beta else: # need broadcasting broadcast_mean = K.reshape(mean_batch, broadcast_shape) broadcast_std = K.reshape(std_batch, broadcast_shape) broadcast_r = K.reshape(r, broadcast_shape) broadcast_d = K.reshape(d, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_batch = (inputs - broadcast_mean) / broadcast_std x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta # explicit update to moving mean and standard deviation self.add_update([ K.moving_average_update(self.running_mean, mean_batch, self.momentum), K.moving_average_update(self.running_variance, std_batch**2, self.momentum) ], inputs) # update r_max and d_max t_val = K.get_value(self.t) r_val = self.r_max_value / (1 + (self.r_max_value - 1) * np.exp(-t_val)) d_val = self.d_max_value / (1 + ( (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val))) t_val += float(self.t_delta) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update(self.t, t_val) ], inputs) if training in {0, False}: return x_normed else: def normalize_inference(): if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]: x_normed_running = K.batch_normalization( inputs, self.running_mean, self.running_variance, self.beta, self.gamma, epsilon=self.epsilon) return x_normed_running else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_variance, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( inputs, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) return x_normed_running # pick the normalized form of inputs corresponding to the training phase # for batch renormalization, inference time remains same as batchnorm x_normed = K.in_train_phase(x_normed, normalize_inference, training=training) return x_normed
def call(self, inputs, training = None): input_shape = K.int_shape(inputs) # .shape ndim = len(input_shape) # 4 reduction_axes = list(range(ndim)) # If ndim == 4, list(range(ndim)) == [0, 1, 2, 3] del reduction_axes[self.axis] # --> [0, 1, 2], self.axis == -1 input_dim = input_shape[self.axis] // 2 mu = K.mean(inputs, axis = reduction_axes) # real mu, imag mu broadcast_mu_shape = [1] * len(input_shape) # [1, 1, 1, 1] broadcast_mu_shape[self.axis] = input_shape[self.axis] # [1, 1, 1, input_shape[self.axis]] broadcast_mu = K.reshape(mu, broadcast_mu_shape) # mu shape is [1, 1, 1, 2] """ real parts에는 real mean을 빼고 imag parts에는 imag mean을 뺀다 centred_squared == (x - E(x))^2 """ if self.center: input_centred = inputs - broadcast_mu else: input_centred = inputs centred_squared = input_centred ** 2 'for Conv2D' centred_squared_real = centred_squared[:, :, :, :input_dim] # real centred_squared_imag = centred_squared[:, :, :, input_dim:] # imag centred_real = input_centred[:, :, :, :input_dim] # real centred_imag = input_centred[:, :, :, input_dim:] # imag if self.scale: Vrr = K.mean(centred_squared_real, axis=reduction_axes) + self.epsilon Vii = K.mean(centred_squared_imag, axis=reduction_axes) + self.epsilon Vri = K.mean(centred_real * centred_imag, axis=reduction_axes,) # Vri contains the real and imaginary covariance for each feature map. elif self.center: Vrr = None Vii = None Vri = None else: raise ValueError('Error. Both scale and center in batchnorm are set to False.') """ 1. Calcultae BatchNormalization for real parts, imag parts of complex numbers 2. If Training == True, Under self.center and self.scale condition, Update parameter moving mean, moving_Vrr, moving_Vii, moving_Vri """ input_bn = complex_batchnorm(input_centred, Vrr, Vii, Vri, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis = self.axis) if training in {0, False}: return input_bn else: # traning is True!!! update_list = [] if self.center: update_list.append(K.moving_average_update(self.moving_mean, mu, self.momentum)) if self.scale: update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum)) self.add_update(update_list, inputs) def normalize_inference(): if self.center: inference_centred = inputs - K.reshape(self.moving_mean, broadcast_mu_shape) else: inference_centred = inputs return complex_batchnorm(inference_centred, self.moving_Vrr, self.moving_Vii, self.moving_Vri, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis = self.axis) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(input_bn, normalize_inference, training = training)
def _moving_average(self, var, value, momentum): if self._tf1: return self._assign(var, var * momentum + value * (1 - momentum)) result = K.moving_average_update(var, value, momentum) self._updates.append(result) return result
def average_op(itself, var, average_var): decay = tf.constant(self.hull.decay_fn(self.hull.step), dtype=tf.float32) return backend.moving_average_update(average_var, var, decay)
def call(self, x, mask=None): if self.mode == 0 or self.mode == 2: assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(x) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False) normed, mean_batch, var_batch = K.normalize_batch_in_training( x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) std_batch = (K.sqrt(var_batch + self.epsilon)) r_max_value = K.get_value(self.r_max) r = std_batch / (K.sqrt(self.running_std + self.epsilon)) r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) d_max_value = K.get_value(self.d_max) d = (mean_batch - self.running_mean) / K.sqrt(self.running_std + self.epsilon) d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_batch = (x - mean_batch) / std_batch x_normed = (x_normed_batch * r + d) * self.gamma + self.beta else: # need broadcasting broadcast_mean = K.reshape(mean_batch, broadcast_shape) broadcast_std = K.reshape(std_batch, broadcast_shape) broadcast_r = K.reshape(r, broadcast_shape) broadcast_d = K.reshape(d, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_batch = (x - broadcast_mean) / broadcast_std x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta # explicit update to moving mean and standard deviation self.add_update([ K.moving_average_update(self.running_mean, mean_batch, self.momentum), K.moving_average_update(self.running_std, std_batch**2, self.momentum) ], x) # update r_max and d_max t_val = K.get_value(self.t) r_val = self.r_max_value / ( 1 + (self.r_max_value - 1) * np.exp(-t_val)) d_val = self.d_max_value / (1 + ( (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val))) t_val += float(self.t_delta) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update(self.t, t_val) ], x) if self.mode == 0: if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_running = K.batch_normalization( x, self.running_mean, self.running_std, self.beta, self.gamma, epsilon=self.epsilon) else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_std, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( x, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) # pick the normalized form of x corresponding to the training phase # for batch renormalization, inference time remains same as batchnorm x_normed = K.in_train_phase(x_normed, x_normed_running) elif self.mode == 1: # sample-wise normalization m = K.mean(x, axis=self.axis, keepdims=True) std = K.sqrt( K.var(x, axis=self.axis, keepdims=True) + self.epsilon) x_normed_batch = (x - m) / (std + self.epsilon) r_max_value = K.get_value(self.r_max) r = std / (self.running_std + self.epsilon) r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) d_max_value = K.get_value(self.d_max) d = (m - self.running_mean) / (self.running_std + self.epsilon) d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta # update r_max and d_max t_val = K.get_value(self.t) r_val = self.r_max_value / ( 1 + (self.r_max_value - 1) * np.exp(-t_val)) d_val = self.d_max_value / (1 + ( (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val))) t_val += float(self.t_delta) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update(self.t, t_val) ], x) return x_normed
def call(self, inputs, training=None): # These were moved here from build() because tf2 eager was not # tracking gradients: repeated_gamma = K.reshape( K.tile(K.expand_dims(self.gamma, -1), [1, self.n]), [-1], ) repeated_beta = K.reshape( K.tile(K.expand_dims(self.beta, -1), [1, self.n]), [-1], ) repeated_moving_mean = K.reshape( K.tile(K.expand_dims(self.moving_mean, -1), [1, self.n]), [-1], ) repeated_moving_variance = K.reshape( K.tile(K.expand_dims(self.moving_variance, -1), [1, self.n]), [-1], ) def unrepeat(w): n = 1 if self.h == 'C4': n *= 4 elif self.h == 'D4': n *= 8 elif self.h == 'Z2': n *= 1 else: raise ValueError('Wrong h: %s' % self.h) return K.mean(K.reshape(w, (K.int_shape(w)[0] // n, n)), -1) input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(repeated_moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(repeated_moving_variance, broadcast_shape) broadcast_beta = K.reshape(repeated_beta, broadcast_shape) broadcast_gamma = K.reshape(repeated_gamma, broadcast_shape) return K.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) else: return K.batch_normalization(inputs, repeated_moving_mean, repeated_moving_variance, repeated_beta, repeated_gamma, epsilon=self.epsilon) def _get_training_value(training, trainable_flag): """ Return a flag indicating whether a layer should be called in training or inference mode. Modified from https://git.io/JUGHX training: the setting used when layer is called for inference. trainable: flag indicating whether the layer is trainable. """ if training is None: training = K.learning_phase() if isinstance(training, int): training = bool(training) # If layer not trainable, override value passed from model. if trainable_flag is False: training = False return training # If the learning phase is *static* and set to inference: training_val = _get_training_value(training, self.trainable) if training_val is False: return normalize_inference() # If the learning is either dynamic, or set to training: normed_training, mean, variance = K.normalize_batch_in_training( inputs, repeated_gamma, repeated_beta, reduction_axes, epsilon=self.epsilon) if K.backend() != 'cntk': sample_size = K.prod( [K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) # sample variance - unbiased estimator of population variance variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, unrepeat(mean), self.momentum), K.moving_average_update(self.moving_variance, unrepeat(variance), self.momentum) ], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) ndim = len(input_shape) reduction_axes = list(range(ndim)) del reduction_axes[self.axis] input_dim = input_shape[self.axis] // 2 mu = K.mean(inputs, axis=reduction_axes) broadcast_mu_shape = [1] * len(input_shape) broadcast_mu_shape[self.axis] = input_shape[self.axis] broadcast_mu = K.reshape(mu, broadcast_mu_shape) if self.center: input_centred = inputs - broadcast_mu else: input_centred = inputs centred_squared = input_centred**2 if (self.axis == 1 and ndim != 3) or ndim == 2: centred_squared_real = centred_squared[:, :input_dim] centred_squared_imag = centred_squared[:, input_dim:] centred_real = input_centred[:, :input_dim] centred_imag = input_centred[:, input_dim:] elif ndim == 3: centred_squared_real = centred_squared[:, :, :input_dim] centred_squared_imag = centred_squared[:, :, input_dim:] centred_real = input_centred[:, :, :input_dim] centred_imag = input_centred[:, :, input_dim:] elif self.axis == -1 and ndim == 4: centred_squared_real = centred_squared[:, :, :, :input_dim] centred_squared_imag = centred_squared[:, :, :, input_dim:] centred_real = input_centred[:, :, :, :input_dim] centred_imag = input_centred[:, :, :, input_dim:] elif self.axis == -1 and ndim == 5: centred_squared_real = centred_squared[:, :, :, :, :input_dim] centred_squared_imag = centred_squared[:, :, :, :, input_dim:] centred_real = input_centred[:, :, :, :, :input_dim] centred_imag = input_centred[:, :, :, :, input_dim:] else: raise ValueError( 'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. ' 'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.') if self.scale: Vrr = K.mean(centred_squared_real, axis=reduction_axes) + self.epsilon Vii = K.mean(centred_squared_imag, axis=reduction_axes) + self.epsilon # Vri contains the real and imaginary covariance for each feature map. Vri = K.mean( centred_real * centred_imag, axis=reduction_axes, ) elif self.center: Vrr = None Vii = None Vri = None else: raise ValueError( 'Error. Both scale and center in batchnorm are set to False.') input_bn = ComplexBN(input_centred, Vrr, Vii, Vri, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis=self.axis) if training in {0, False}: return input_bn else: update_list = [] if self.center: update_list.append( K.moving_average_update(self.moving_mean, mu, self.momentum)) if self.scale: update_list.append( K.moving_average_update(self.moving_Vrr, Vrr, self.momentum)) update_list.append( K.moving_average_update(self.moving_Vii, Vii, self.momentum)) update_list.append( K.moving_average_update(self.moving_Vri, Vri, self.momentum)) self.add_update(update_list) def normalize_inference(): if self.center: inference_centred = inputs - K.reshape( self.moving_mean, broadcast_mu_shape) else: inference_centred = inputs return ComplexBN(inference_centred, self.moving_Vrr, self.moving_Vii, self.moving_Vri, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis=self.axis) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(input_bn, normalize_inference, training=training)
def _get_update_list(self, kernel): self.moving_heatmap.assign(self.heatmap_momentum * self.moving_heatmap + (1. - self.heatmap_momentum) * K.sign(kernel)) update_list = [ K.moving_average_update(self.moving_heatmap, K.sign(kernel), self.heatmap_momentum), ] return update_list