def call(self, inputs, **kwargs): input_shape = K.int_shape(inputs) tensor_input_shape = K.shape(inputs) # Prepare broadcasting shape. reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] // self.groups broadcast_shape.insert(1, self.groups) reshape_group_shape = K.shape(inputs) group_axes = [reshape_group_shape[i] for i in range(len(input_shape))] group_axes[self.axis] = input_shape[self.axis] // self.groups group_axes.insert(1, self.groups) # reshape inputs to new group shape group_shape = [group_axes[0], self.groups] + group_axes[2:] group_shape = K.stack(group_shape) inputs = K.reshape(inputs, group_shape) group_reduction_axes = list(range(len(group_axes))) mean, variance = KC.moments(inputs, group_reduction_axes[2:], keep_dims=True) inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon)) # prepare broadcast shape inputs = K.reshape(inputs, group_shape) outputs = inputs # In this case we must explicitly broadcast all parameters. if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) outputs = outputs * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) outputs = outputs + broadcast_beta # finally we reshape the output back to the input shape outputs = K.reshape(outputs, tensor_input_shape) return outputs
def test_moments(self, keep_dims): input_shape = (10, 10, 10, 10) x_0 = np.zeros(input_shape) x_1 = np.ones(input_shape) x_random = np.random.random(input_shape) th_axes = [0, 2, 3] tf_axes = [0, 1, 2] for ip in [x_0, x_1, x_random]: for axes in [th_axes, tf_axes]: K_mean, K_var = KC.moments(K.variable(ip), axes, keep_dims=keep_dims) np_mean, np_var = KCNP.moments(ip, axes, keep_dims=keep_dims) K_mean_val = K.eval(K_mean) K_var_val = K.eval(K_var) # absolute tolerance needed when working with zeros assert_allclose(K_mean_val, np_mean, rtol=1e-4, atol=1e-10) assert_allclose(K_var_val, np_var, rtol=1e-4, atol=1e-10)
def call(self, inputs, training=None): assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(inputs) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] mean_batch, var_batch = KC.moments(inputs, reduction_axes, shift=None, keep_dims=False) std_batch = (K.sqrt(var_batch + self.epsilon)) r = std_batch / (K.sqrt(self.running_variance + self.epsilon)) r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max)) d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance + self.epsilon) d = K.stop_gradient(K.clip(d, -self.d_max, self.d_max)) if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]: x_normed_batch = (inputs - mean_batch) / std_batch x_normed = (x_normed_batch * r + d) * self.gamma + self.beta else: # need broadcasting broadcast_mean = K.reshape(mean_batch, broadcast_shape) broadcast_std = K.reshape(std_batch, broadcast_shape) broadcast_r = K.reshape(r, broadcast_shape) broadcast_d = K.reshape(d, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_batch = (inputs - broadcast_mean) / broadcast_std x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta # explicit update to moving mean and standard deviation mean_update = K.moving_average_update(self.running_mean, mean_batch, self.momentum) variance_update = K.moving_average_update(self.running_variance, std_batch ** 2, self.momentum) self.add_update([mean_update, variance_update], inputs) # update r_max and d_max r_val = self.r_max_value / (1 + (self.r_max_value - 1) * K.exp(-self.t)) d_val = (self.d_max_value / (1 + ((self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t)))) self.add_update([K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update_add(self.t, self.t_delta_tensor)], inputs) if training in {0, False}: return x_normed else: def normalize_inference(): if sorted(reduction_axes) == list(range(K.ndim(inputs)))[:-1]: x_normed_running = K.batch_normalization( inputs, self.running_mean, self.running_variance, self.beta, self.gamma, epsilon=self.epsilon) return x_normed_running else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_variance, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( inputs, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) return x_normed_running # pick the normalized form of inputs corresponding to the training phase # for batch renormalization, inference time remains same as batchnorm x_normed = K.in_train_phase(x_normed, normalize_inference, training=training) return x_normed
def call(self, inputs, training=None): old_shape = (tf.shape(inputs)[0], ) + K.int_shape(inputs)[1:] if self.group: new_shape = (tf.shape(inputs)[0], ) + old_shape[1:3] + ( self.group, old_shape[self.axis] // self.group) inputs_reshape = K.reshape(inputs, shape=new_shape) inputs_reshape = K.permute_dimensions(inputs_reshape, pattern=(0, 1, 2, 4, 3)) input_shape = K.int_shape(inputs_reshape) else: inputs_reshape = inputs input_shape = old_shape # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalization(mean, variance): normed_training = (inputs_reshape - mean) / (K.sqrt(variance + self.epsilon)) # In this case we must explicitly broadcast all parameters. if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) normed_training = normed_training * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) normed_training = normed_training + broadcast_beta return normed_training def normalize_inference(): if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(self.moving_variance, broadcast_shape) outputs = normalization(broadcast_moving_mean, broadcast_moving_variance) return K.reshape( K.permute_dimensions(outputs, (0, 1, 2, 4, 3)), old_shape) else: outputs = normalization(self.moving_mean, self.moving_variance) return K.reshape( K.permute_dimensions(outputs, (0, 1, 2, 4, 3)), old_shape) # If the learning phase is *static* and set to inference: if training in {0, False}: return normalize_inference() # If the learning is either dynamic, or set to training: mean, variance = KC.moments(inputs_reshape, [1, 2, 3], keep_dims=True) normed_training = (inputs_reshape - mean) / (K.sqrt(variance + self.epsilon)) # In this case we must explicitly broadcast all parameters. if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) normed_training = normed_training * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) normed_training = normed_training + broadcast_beta mean = K.mean(mean, [0, 1, 2, 3]) variance = K.mean(variance, [0, 1, 2, 3]) if K.backend() != 'cntk': sample_size = K.prod( [K.shape(inputs_reshape)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs_reshape)) if K.backend() == 'tensorflow' and sample_size.dtype != 'float32': sample_size = K.cast(sample_size, dtype='float32') # sample variance - unbiased estimator of population variance variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum) ], inputs) normed_training = K.reshape( K.permute_dimensions(normed_training, (0, 1, 2, 4, 3)), old_shape) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)