def call(self, inputs, training=None): if not self.trainable: training = False else: # The learning phase flag is a bool tensor (0 = test, 1 = train) training = K.learning_phase() if training is not False: K.update_add(self.iterations, 1) # compute current mean&var mini_mean, mini_variance = tf.nn.moments(inputs, axes=[0,1,2]) # affine the inputs x = (inputs - self.steps_mean) / K.sqrt(self.steps_variance + self.epsilon) x = self.gamma * x + self.beta # update the moving params K.moving_average_update(self.moving_mean, mini_mean, self.momentum) K.moving_average_update(self.moving_variance, mini_variance, self.momentum) # update the short-term params under specific condition cond = K.equal(self.iterations % self.steps_per_update, 0) K.switch(cond, lambda: self.steps_mean*0, K.update_add(self.steps_mean, mini_mean)) K.switch(cond, lambda: self.steps_variance*0, K.update_add(self.steps_variance, mini_mean)) else: # affine scale = self.gamma / K.sqrt(self.moving_variance + self.epsilon) x = inputs * scale + (self.beta - self.moving_mean * scale) return x
def call(self, x): mean, var = tf.nn.moments(x, [0]) self.add_update([ K.moving_average_update(self.mu, mean, self._mu_l), K.moving_average_update(self.sigma, tf.sqrt(var), self._sigma_l) ], x) return (x - self.mu) / (self.sigma + self._eps)
def call(self, x_cat, mask=None): # For some reason, we have to concatenate vectors to feed them using "merge" in keras x_cat = self.epsilon + ( 1 - 2. * self.epsilon ) * x_cat #K.clip(x_cat, self.epsilon, 1 - self.epsilon) # Avoid NANs z = x_cat[:, :self.size] x = x_cat[:, self.size:] batch_size = K.cast( K.shape(x)[0], x.dtype) # This is a node tensor, so we can't treat as integer div_n = Lambda( lambda v: v / batch_size ) # Dividing by batch size is an operation on unknown tensor # batch statistics px = K.expand_dims(K.mean(x, axis=0), 0) # p(xi = 1) py = K.expand_dims(K.mean(z, axis=0), 1) # mean of z_j V = div_n(K.dot(K.transpose(z), x)) # j i self.add_update([ K.moving_average_update(self.Vr, V, self.momentum), K.moving_average_update(self.pxr, px, self.momentum), K.moving_average_update(self.pyr, py, self.momentum) ], x_cat) V = K.in_train_phase(V, self.Vr) px = K.in_train_phase(px, self.pxr) py = K.in_train_phase(py, self.pyr) eta1 = V / px eta0 = (py - V) / (1 - px) W = K.log(eta1) - K.log(1 - eta1) + K.log(1 - eta0) - K.log(eta0) out = K.log(px) - K.log(1. - px) + K.dot(z, W) + K.sum( K.log(1. - eta1) - K.log(1. - eta0), 0, keepdims=True) return K.sigmoid(out)
def training_phase(): mean_batch = K.mean(mean_instance, axis=0, keepdims=True) variance_batch = K.mean(temp, axis=0, keepdims=True) - K.square(mean_batch) mean_batch_reshaped = K.flatten(mean_batch) variance_batch_reshaped = K.flatten(variance_batch) if K.backend() != 'cntk': sample_size = K.prod( [K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) # sample variance - unbiased estimator of population variance variance_batch_reshaped *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, mean_batch_reshaped, self.momentum), K.moving_average_update(self.moving_variance, variance_batch_reshaped, self.momentum) ], inputs) return normalize_func(mean_batch, variance_batch)
def call(self, inputs, training=None): inputs, spk_id = inputs spk_id = K.cast(K.flatten(spk_id)[0], 'int32') def normalize_inference(): return K.normalize_batch_in_training(inputs, self.gamma[spk_id], self.beta[spk_id], [0, 1], epsilon=self.epsilon)[0] normed_training, mean, variance = K.normalize_batch_in_training( inputs, self.gamma[spk_id], self.beta[spk_id], [0, 1], epsilon=self.epsilon) sample_size = K.shape(inputs)[1] sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum) ], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, x, mask=None): if self.mode == 0 or self.mode == 2: assert self.built, 'Layer must be built before being called' input_shape = self.input_spec[0].shape reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] if self.mode == 2: x_normed, mean, std = K.normalize_batch_in_training( x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) else: # mode 0 if self.called_with not in {None, x} and False: raise Exception('You are attempting to share a ' 'same `BatchNormalization` layer across ' 'different data flows. ' 'This is not possible. ' 'You should use `mode=2` in ' '`BatchNormalization`, which has ' 'a similar behavior but is shareable ' '(see docs for a description of ' 'the behavior).') self.called_with = x x_normed, mean, std = K.normalize_batch_in_training( x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) self.updates = [K.moving_average_update(self.running_mean, mean, self.momentum), K.moving_average_update(self.running_std, std, self.momentum)] if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_running = K.batch_normalization( x, self.running_mean, self.running_std, self.beta, self.gamma, epsilon=self.epsilon) else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_std, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( x, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) # pick the normalized form of x corresponding to the training phase x_normed = K.in_train_phase(x_normed, x_normed_running) elif self.mode == 1: # sample-wise normalization m = K.mean(x, axis=-1, keepdims=True) std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon) x_normed = (x - m) / (std + self.epsilon) x_normed = self.gamma * x_normed + self.beta return x_normed
def call(self, x, mask=None): if self.mode == 0 or self.mode == 2: assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(x) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] x_normed, mean, std = K.normalize_batch_in_training( x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) if self.mode == 0: self.add_update([ K.moving_average_update(self.running_mean, mean, self.momentum), K.moving_average_update(self.running_std, std, self.momentum) ], x) if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_running = K.batch_normalization( x, self.running_mean, self.running_std, self.beta, self.gamma, epsilon=self.epsilon) else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_std, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( x, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) # pick the normalized form of x corresponding to the training phase x_normed = K.in_train_phase(x_normed, x_normed_running) elif self.mode == 1: # sample-wise normalization m = K.mean(x, axis=-1, keepdims=True) std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon) x_normed = (x - m) / (std + self.epsilon) x_normed = self.gamma * x_normed + self.beta else: return None return x_normed
def call(self, inputs, training=None, **kwargs): G = self.groups # transpose:[ba,h,w,c] -> [bs,c,h,w] if self.axis in {-1, 3}: inputs = K.permute_dimensions(inputs, (0, 3, 1, 2)) input_shape = K.int_shape(inputs) N, C, H, W = input_shape inputs = K.reshape(inputs, (-1, G, C // G, H, W)) # inputs.assign_sub() # compute group-channel mean & variance gn_mean = K.mean(inputs, axis=[2, 3, 4], keepdims=True) gn_variance = K.var(inputs, axis=[2, 3, 4], keepdims=True) # compute group-normalization in different state def gn_inference(): # when in test phase, just return moving_mean & moving_var mean, variance = self.moving_mean, self.moving_variance outputs = (inputs - mean) / (K.sqrt(variance + self.epsilon)) outputs = K.reshape(outputs, [-1, C, H, W]) * self.gamma + self.beta # transpose: [bs,c,h,w] -> [ba,h,w,c] if self.axis in {-1, 3}: outputs = K.permute_dimensions(outputs, (0, 2, 3, 1)) return outputs if training in {0, False}: return gn_inference() outputs = (inputs - gn_mean) / (K.sqrt(gn_variance + self.epsilon)) outputs = K.reshape(outputs, [-1, C, H, W]) * self.gamma + self.beta # transpose: [bs,c,h,w] -> [ba,h,w,c] if self.axis in {-1, 3}: outputs = K.permute_dimensions(outputs, (0, 2, 3, 1)) self.add_update([K.moving_average_update(self.moving_mean, gn_mean, self.momentum), K.moving_average_update(self.moving_variance, gn_variance, self.momentum)], inputs) # print("moving_mean shape : ",K.int_shape(self.moving_mean)) # print("moving_mean: ",K.eval(self.moving_mean)) # print("moving_variance shape: ",K.int_shape(self.moving_variance)) # print("moving_variance: ",K.eval(self.moving_variance)) return K.in_train_phase(outputs, gn_inference, training=training)
def call(self, x, mask=None): output = K.conv2d(x, self.W, strides=self.subsample, border_mode=self.border_mode, dim_ordering=self.dim_ordering, filter_shape=self.W_shape) # added for batch normalization input_shape = K.int_shape(output) axis = 1 reduction_axes = list(range(len(input_shape))) del reduction_axes[axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[axis] = input_shape[axis] output_normed, mean, std = K.normalize_batch_in_training( output, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) self.add_update([K.moving_average_update(self.running_mean, mean, self.momentum), K.moving_average_update(self.running_std, std, self.momentum)], output) if sorted(reduction_axes) == range(K.ndim(output))[:-1]: output_normed_running = K.batch_normalization( output, self.running_mean, self.running_std, self.beta, self.gamma, epsilon=self.epsilon) else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_std, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) output_normed_running = K.batch_normalization( output, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) # pick the normalized form of output corresponding to the training phase output_normed = K.in_train_phase(output_normed, output_normed_running) if self.bias: if self.dim_ordering == 'th': output_normed += K.reshape(self.b, (1, self.nb_filter, 1, 1)) elif self.dim_ordering == 'tf': output_normed += K.reshape(self.b, (1, 1, 1, self.nb_filter)) else: raise ValueError('Invalid dim_ordering:', self.dim_ordering) output = self.activation(output_normed) return output
def normed_training(): mean_bn = K.mean(inputs, axis=reduction_axes_bn,keepdims=True) variance_bn = K.var(inputs, axis=reduction_axes_bn,keepdims=True) mean = [mean_in, mean_ln, mean_bn] variance = [variance_in, variance_ln, variance_bn] # If the learning is either dynamic, or set to training: self.add_update([K.moving_average_update(self.moving_mean, K.reshape(mean_bn,(input_shape[self.axis],)), self.momentum), K.moving_average_update(self.moving_variance, K.reshape(variance_bn,(input_shape[self.axis],)), self.momentum)], inputs) return norm(mean, variance)
def call(self, x_cat, mask=None): # For some reason, we have to concatenate vectors to feed them using "merge" in keras z = x_cat[:, :self.size] x = x_cat[:, self.size:] batch_size = K.cast( K.shape(x)[0], x.dtype) # This is a node tensor, so we can't treat as integer div_n = Lambda( lambda v: v / batch_size ) # Dividing by batch size is an operation on unknown tensor # batch statistics self.mi = K.expand_dims(K.mean(x, axis=0), 0) # mean of x_i self.mj = K.expand_dims(K.mean(z, axis=0), 1) # mean of z_j self.vj = K.expand_dims(K.var(z, axis=0) + self.epsilon, 1) # sigma_j^2 self.vi = K.expand_dims(K.var(x, axis=0) + self.epsilon, 0) # sigma_i^2 #CHANGE BACK #self.V = div_n(K.dot(K.transpose(z), x)) self.V = div_n( K.dot(K.transpose(z - K.transpose(self.mj)), x - self.mi)) # j i self.add_update([ K.moving_average_update(self.Vr, self.V, self.momentum), K.moving_average_update(self.mir, self.mi, self.momentum), K.moving_average_update(self.mjr, self.mj, self.momentum), K.moving_average_update(self.vjr, self.vj, self.momentum), K.moving_average_update(self.vir, self.vi, self.momentum) ], x_cat) V = K.in_train_phase(self.V, self.Vr) mi = K.in_train_phase(self.mi, self.mir) mj = K.in_train_phase(self.mj, self.mjr) vj = K.in_train_phase(self.vj, self.vjr) vi = K.in_train_phase(self.vi, self.vir) #CHANGE BACK #rho = (V - mi * mj) / K.sqrt(vi * vj) rho = V / K.sqrt(vi * vj) Q = rho / (1 - K.square(rho)) self.R = K.sum(rho * Q, axis=0, keepdims=True) Q = Q / (1 + self.R) if self.return_r: return self.R else: return mi + K.sqrt(vi) * K.dot( K.transpose((K.transpose(z) - mj) / K.sqrt(vj)), Q)
def inject(self): """添加更新算子到model.metrics_updates。 """ self.initialize() for w1, w2 in zip(self.ema_weights, self.model.weights): op = K.moving_average_update(w1, w2, self.momentum) self.model.metrics_updates.append(op)
def update_erm(): normed_training, mean, variance = K.normalize_batch_in_training( x=inputs, beta=None, gamma=None, reduction_axes=reduction_axes) self.add_update( [K.moving_average_update(self.values, mean, self.momentum)], inputs=inputs) return self.values
def call(self, inputs, training=None): if training is None: training = bk.learning_phase() training = bk.get_value(training) if training: bk.moving_average_update(self.moving_min, bk.min(inputs, axis=0), self.momentum) bk.moving_average_update(self.moving_max, bk.max(inputs, axis=0), self.momentum) scale = (self.max_val - self.min_val) / ( self.moving_max - self.moving_min + self.epsilon) output = bk.clip((inputs - self.moving_min) * scale + self.min_val, self.min_val, self.max_val) return output
def call(self, inputs, training=None): x = inputs assert not isinstance(x, list) # Compute the minibatch statistics mean, var = self._moments(x) sigma = K.sqrt(var + self.epsilon) # If in training phase set rmax, dmax large so that we use the moving # averages to do the normalization rmax = K.in_train_phase(self.rmax, K.constant(1e5), training) dmax = K.in_train_phase(self.dmax, K.constant(1e5), training) # Compute the corrections based on rmax, dmax r = K.stop_gradient( self._clip(sigma / self.moving_sigma, 1. / rmax, rmax)) d = K.stop_gradient( self._clip((mean - self.moving_mean) / self.moving_sigma, -dmax, dmax)) # Actually do the normalization and the rescaling xnorm = ((x - mean) / sigma) * r + d y = self.gamma * xnorm + self.beta # Add the moving average updates self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_sigma, sigma, self.momentum) ], x) # Add the r, d updates rmax_prog = K.minimum(1., self.steps / self.rmax_dur) dmax_prog = K.minimum(1., self.steps / self.dmax_dur) self.add_update([ K.update_add(self.steps, 1), K.update(self.rmax, self.rmax_0 + rmax_prog * (self.rmax_inf - self.rmax_0)), K.update(self.dmax, self.dmax_0 + dmax_prog * (self.dmax_inf - self.dmax_0)) ]) # Fix the output's uses learning phase y._uses_learning_phase = rmax._uses_learning_phase return y
def inject(self): """添加更新算子到model.metrics_updates。 """ self.initialize() for w1, w2 in zip(self.ema_weights, self.model.weights): op = K.moving_average_update(w1, w2, self.momentum) #self.model.metrics_updates.append(op) # 在 keras 2.2.4 有效 if not hasattr(self.model, '_other_metrics'): self.model._other_metrics = [] self.model._other_metrics.append(op)
def call(self, inputs, training=None): if len(inputs) == 3: params, trainable_params, x = inputs params = self.merge_params(params, trainable_params) elif len(inputs) == 2: params, x = inputs else: raise ValueError("Wrong number of inputs") offset = 0 for layer in self.layers: layer_params = params[:, offset:offset + layer["num_params"]] offset += layer["num_params"] if layer["type"] in ["standard-batchnorm", "batch-renorm"]: x = K.stack(x, 0) self.mean, self.variance = tf.nn.moments(x, [0, 1, 2]) if training: sample_size = K.prod( [K.shape(x)[axis] for axis in [0, 1, 2]]) sample_size = K.cast(sample_size, dtype='float32') unbiased_variance = self.variance * sample_size / ( sample_size - (1.0 + layer["epsilon"])) self.add_update([ K.moving_average_update( self.moving_means[layer["name"]], self.mean, layer["momentum"]), K.moving_average_update( self.moving_vars[layer["name"]], unbiased_variance, layer["momentum"]), ], inputs) x = [ self.evaluate_layer(layer, layer_params[i], x[i], training) for i in range(self.batch_size) ] output = K.stack(x, 0) output._uses_learning_phase = True return output
def call(self, x_cat, mask=None): # For some reason, we have to concatenate vectors to feed them using "merge" in keras z = x_cat[:, :self.size] x = K.clip(x_cat[:, self.size:], self.epsilon, 1. - self.epsilon) batch_size = K.cast( K.shape(x)[0], x.dtype) # This is a node tensor, so we can't treat as integer div_n = Lambda( lambda v: v / batch_size ) # Dividing by batch size is an operation on unknown tensor # batch statistics pi = K.expand_dims(K.mean(x, axis=0), 0) # p(xi = 1) mj = K.expand_dims(K.mean(z, axis=0), 1) # mean of z_j vj = K.expand_dims(K.mean(K.square(z), axis=0), 1) # expectation of z^2 V = div_n(K.dot(K.transpose(z), x)) # j i S = div_n(K.dot(K.transpose(K.square(z)), x)) # j i self.add_update([ K.moving_average_update(self.Vr, V, self.momentum), K.moving_average_update(self.Sr, S, self.momentum), K.moving_average_update(self.pir, pi, self.momentum), K.moving_average_update(self.mjr, mj, self.momentum), K.moving_average_update(self.vjr, vj, self.momentum) ], x_cat) V = K.in_train_phase(V, self.Vr) S = K.in_train_phase(S, self.Sr) pi = K.in_train_phase(pi, self.pir) mj = K.in_train_phase(mj, self.mjr) vj = K.in_train_phase(vj, self.vjr) mu0, mu1, sig0, sig1 = self.get_mean_sig(mj, vj, pi, V, S) out = (K.log(pi) - K.log(1. - pi) - 0.5 * K.sum(K.log(sig1) - K.log(sig0), 0) + 0.5 * K.sum(K.square(mu0) / sig0 - K.square(mu1) / sig1, 0) + K.dot(z, mu1 / sig1 - mu0 / sig0) + 0.5 * K.dot(K.square(z), 1. / sig0 - 1. / sig1)) return K.sigmoid(out)
def _update_embedding(self, x, y, seg_indices, seg_embeddings): dtype = self.embedding.dtype delta_embeddings = (1 - self.target_momentum) * (y - seg_embeddings) tmp_embedding, tmp_cnt = self._sum_seg_embeddings( seg_indices, delta_embeddings) bk.update_add( self.embedding, tmp_embedding / (tmp_cnt + bk.cast(0 == tmp_cnt, dtype=dtype))) bk.update_add(self.update_cnt, tmp_cnt) if self.mask_zero: min_val = bk.min(x + bk.constant(self.val_inf, dtype=dtype) * bk.cast(0 == x, dtype), axis=0) max_val = bk.max(x + bk.constant(-self.val_inf, dtype=dtype) * bk.cast(0 == x, dtype), axis=0) else: min_val, max_val = bk.min(x, axis=0), bk.max(x, axis=0) bk.moving_average_update(self.moving_min, min_val, self.val_momentum) bk.moving_average_update(self.moving_max, max_val, self.val_momentum)
def call(self, inputs, training=None): x = inputs assert not isinstance(x, list) # Do the normalization and the rescaling xnorm = K.batch_normalization(x, self.moving_mean, self.moving_variance, self.beta, self.gamma, epsilon=self.epsilon) # Compute and update the minibatch statistics if self.update_stats: mean, var = self._moments(x, axes=range(len(K.int_shape(x)) - 1)) self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, var, self.momentum) ], x) return xnorm
def batch_norm(inputs, gamma, beta, dims, ind): """ Normalize batch and update moving averages for mean and std Input: inputs: (batchsize, n_points, k, n_features * 2) - edge_features gamma: weight - gamma for batch normalization beta: weight - beta for batch normalization dims: list - dimensions along which to normalize ind: int - indicating which weights to use Returns: During training: normed: (batchsize, n_points, k, n_features * 2) - normalized batch of data using actual batch for normalization Else: normed_moving: same, but using the updated average values """ # Calculate normalized data, mean and std for batch normed, batch_mean, batch_var = K.normalize_batch_in_training( x=inputs, gamma=gamma, beta=beta, reduction_axes=dims) # Update the moving averages self.add_update([ K.moving_average_update(self.moving_mean[ind], batch_mean, 0.9), K.moving_average_update(self.moving_var[ind], batch_var, 0.9)]) # Calculate normalization using the averages normed_moving = K.batch_normalization( x=inputs, mean=self.moving_mean[ind], var=self.moving_var[ind], beta=beta, gamma=gamma) # If training return normed, else normed_moving return K.in_train_phase(normed, normed_moving)
def call(self, x_cat, mask=None): # For some reason, we have to concatenate vectors to feed them using "merge" in keras z = x_cat[:, :self.size] x = x_cat[:, self.size:] batch_size = K.cast( K.shape(x)[0], x.dtype) # This is a node tensor, so we can't treat as integer div_n = Lambda( lambda v: v / batch_size ) # Dividing by batch size is an operation on unknown tensor # batch statistics pi = K.expand_dims( K.clip(K.mean(x, axis=0), self.epsilon, 1. - self.epsilon), 0) # p(xi = 1) mj = K.expand_dims(K.mean(z, axis=0), 1) # mean of z_j vj = K.expand_dims(K.var(z, axis=0) + self.epsilon, 1) # sigma_j^2 V = div_n(K.dot(K.transpose(z), x)) # j i self.add_update([ K.moving_average_update(self.Vr, V, self.momentum), K.moving_average_update(self.pir, pi, self.momentum), K.moving_average_update(self.mjr, mj, self.momentum), K.moving_average_update(self.vjr, vj, self.momentum) ], x_cat) V = K.in_train_phase(V, self.Vr) pi = K.in_train_phase(pi, self.pir) mj = K.in_train_phase(mj, self.mjr) vj = K.in_train_phase(vj, self.vjr) mu_diff = (V - mj * pi) / ( pi * (1 - pi)) # difference between mu_xi=1^j - mu_xi=0^j mu_mean = 0.5 * (V / pi + (mj - V) / (1 - pi)) # average of means out = K.log(pi) - K.log(1. - pi) + K.dot(z, mu_diff / vj) - K.sum( mu_diff * mu_mean / vj, 0, keepdims=True) return K.sigmoid(out)
def train(): ff_apr = ktf.matmul(f, f, transpose_b=True) / ( ktf.cast(bs * w * h, ktf.float32) - 1.) if self.decomposition in ['pca-cor', 'zca-cor']: dinv = ktf.diag(ktf.sqrt(ktf.diag_part(ff_apr))) ff_apr = ktf.matmul(ktf.matmul(dinv, ff_apr), ktf.matrix_inverse(dinv), transpose_b=True) self.add_update([ K.moving_average_update(self.moving_mean, m, self.momentum), K.moving_average_update(self.moving_cov, ff_apr, self.momentum) ], inputs) ff_apr_shrinked = ( 1 - self.epsilon) * ff_apr + ktf.eye(c) * self.epsilon if self.renorm: l, l_inv = get_inv_sqrt(ff_apr_shrinked) ff_mov = (1 - self.epsilon ) * self.moving_cov + ktf.eye(c) * self.epsilon _, l_mov_inverse = get_inv_sqrt(ff_mov) l_ndiff = K.stop_gradient(l) return ktf.matmul(ktf.matmul(l_mov_inverse, l_ndiff), l_inv) return get_inv_sqrt(ff_apr_shrinked)[1]
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) # Prepare broadcasting shape. reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] # inference def normalize_inference(): return inputs - self.moving_mean if training in {0, False}: return normalize_inference() mean = K.mean(inputs, axis=reduction_axes) normed_training = inputs - mean self.add_update( K.moving_average_update(self.moving_mean, mean, self.momentum), inputs) return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) ndim = len(input_shape) reduction_axes = list(range(ndim)) del reduction_axes[self.axis] input_dim = input_shape[self.axis] // 4 mu = K.mean(inputs, axis=reduction_axes) broadcast_mu_shape = [1] * len(input_shape) broadcast_mu_shape[self.axis] = input_shape[self.axis] broadcast_mu = K.reshape(mu, broadcast_mu_shape) if self.center: input_centred = inputs - broadcast_mu else: input_centred = inputs centred_squared = input_centred ** 2 if (self.axis == 1 and ndim != 3) or ndim == 2: centred_squared_r = centred_squared[:, :input_dim] centred_squared_i = centred_squared[:, input_dim:input_dim*2] centred_squared_j = centred_squared[:, input_dim*2:input_dim*3] centred_squared_k = centred_squared[:, input_dim*3:] centred_r = input_centred[:, :input_dim] centred_i = input_centred[:, input_dim:input_dim*2] centred_j = input_centred[:, input_dim*2:input_dim*3] centred_k = input_centred[:, input_dim*3:] elif ndim == 3: centred_squared_r = centred_squared[:, :, :input_dim] centred_squared_i = centred_squared[:, :, input_dim:input_dim*2] centred_squared_j = centred_squared[:, :, input_dim*2:input_dim*3] centred_squared_k = centred_squared[:, :, input_dim*3:] centred_r = input_centred[:, :, :input_dim] centred_i = input_centred[:, :, input_dim:input_dim*2] centred_j = input_centred[:, :, input_dim*2:input_dim*3] centred_k = input_centred[:, :, input_dim*3:] elif self.axis == -1 and ndim == 4: centred_squared_r = centred_squared[:, :, :, :input_dim] centred_squared_i = centred_squared[:, :, :, input_dim:input_dim*2] centred_squared_j = centred_squared[:, :, :, input_dim*2:input_dim*3] centred_squared_k = centred_squared[:, :, :, input_dim*3:] centred_r = input_centred[:, :, :, :input_dim] centred_i = input_centred[:, :, :, input_dim:input_dim*2] centred_j = input_centred[:, :, :, input_dim*2:input_dim*3] centred_k = input_centred[:, :, :, input_dim*3:] elif self.axis == -1 and ndim == 5: centred_squared_r = centred_squared[:, :, :, :, :input_dim] centred_squared_i = centred_squared[:, :, :, :, input_dim:input_dim*2] centred_squared_j = centred_squared[:, :, :, :, input_dim*2:input_dim*3] centred_squared_k = centred_squared[:, :, :, :, input_dim*3:] centred_r = input_centred[:, :, :, :, :input_dim] centred_i = input_centred[:, :, :, :, input_dim:input_dim*2] centred_j = input_centred[:, :, :, :, input_dim*2:input_dim*3] centred_k = input_centred[:, :, :, :, input_dim*3:] else: raise ValueError( 'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. ' 'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.' ) if self.scale: Vrr = K.mean( centred_squared_r, axis=reduction_axes ) + self.epsilon Vii = K.mean( centred_squared_i, axis=reduction_axes ) + self.epsilon Vjj = K.mean( centred_squared_j, axis=reduction_axes ) + self.epsilon Vkk = K.mean( centred_squared_k, axis=reduction_axes ) + self.epsilon Vri = K.mean( centred_r * centred_i, axis=reduction_axes, ) Vrj = K.mean( centred_r * centred_j, axis=reduction_axes, ) Vrk = K.mean( centred_r * centred_k, axis=reduction_axes, ) Vij = K.mean( centred_i * centred_j, axis=reduction_axes, ) Vik = K.mean( centred_i * centred_k, axis=reduction_axes, ) Vjk = K.mean( centred_j * centred_k, axis=reduction_axes, ) elif self.center: Vrr = None Vii = None Vjj = None Vkk = None Vri = None Vrj = None Vrk = None Vij = None Vik = None Vjk = None else: raise ValueError('Error. Both scale and center in batchnorm are set to False.') input_bn = QuaternionBN( input_centred, Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_rj, self.gamma_rk, self.gamma_ii, self.gamma_ij, self.gamma_ik, self.gamma_jj, self.gamma_jk, self.gamma_kk, self.scale, self.center, axis=self.axis ) if training in {0, False}: return input_bn else: update_list = [] if self.center: update_list.append(K.moving_average_update(self.moving_mean, mu, self.momentum)) if self.scale: update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vjj, Vjj, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vkk, Vkk, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vrj, Vrj, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vrk, Vrk, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vij, Vij, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vik, Vik, self.momentum)) update_list.append(K.moving_average_update(self.moving_Vjk, Vjk, self.momentum)) self.add_update(update_list, inputs) def normalize_inference(): if self.center: inference_centred = inputs - K.reshape(self.moving_mean, broadcast_mu_shape) else: inference_centred = inputs return QuaternionBN( inference_centred, self.moving_Vrr, self.moving_Vri, self.moving_Vrj, self.moving_Vrk, self.moving_Vii, self.moving_Vij, self.moving_Vik, self.moving_Vjj, self.moving_Vjk, self.moving_Vkk, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_rj, self.gamma_rk, self.gamma_ii, self.gamma_ij, self.gamma_ik, self.gamma_jj, self.gamma_jk, self.gamma_kk, self.scale, self.center, axis=self.axis ) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(input_bn, normalize_inference, training=training)
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(self.moving_variance, broadcast_shape) if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None return K.batch_normalization( inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) else: return K.batch_normalization( inputs, self.moving_mean, self.moving_variance, self.beta, self.gamma, epsilon=self.epsilon) # If the learning phase is *static* and set to inference: if training in {0, False}: return normalize_inference() # If the learning is either dynamic, or set to training: normed_training, mean, variance = K.normalize_batch_in_training( inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) self.add_update([K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum)], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(self.moving_variance, broadcast_shape) if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None return tf.nn.batch_normalization(#K.batch_normalization( inputs, broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma, #axis=self.axis, self.epsilon)#epsilon=self.epsilon) else: return tf.nn.batch_normalization(#K.batch_normalization( inputs, self.moving_mean, self.moving_variance, self.beta, self.gamma, #axis=self.axis, self.epsilon)#epsilon=self.epsilon) # If the learning phase is *static* and set to inference: if training in {0, False}: return normalize_inference() # If the learning is either dynamic, or set to training: normed_training, mean, variance = _regular_normalize_batch_in_training(#K.normalize_batch_in_training( inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) if K.backend() != 'cntk': sample_size = K.prod([K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) # sample variance - unbiased estimator of population variance variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum)], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, x, mask=None): if self.mode == 0 or self.mode == 2: assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(x) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False) std_batch = (K.sqrt(var_batch + self.epsilon)) r_max_value = K.get_value(self.r_max) r = std_batch / (K.sqrt(self.running_std + self.epsilon)) r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) d_max_value = K.get_value(self.d_max) d = (mean_batch - self.running_mean) / K.sqrt(self.running_std + self.epsilon) d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_batch = (x - mean_batch) / std_batch x_normed = (x_normed_batch * r + d) * self.gamma + self.beta else: # need broadcasting broadcast_mean = K.reshape(mean_batch, broadcast_shape) broadcast_std = K.reshape(std_batch, broadcast_shape) broadcast_r = K.reshape(r, broadcast_shape) broadcast_d = K.reshape(d, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_batch = (x - broadcast_mean) / broadcast_std x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta # explicit update to moving mean and standard deviation self.add_update([ K.moving_average_update(self.running_mean, mean_batch, self.momentum), K.moving_average_update(self.running_std, std_batch**2, self.momentum) ], x) # update r_max and d_max t_val = K.get_value(self.t) r_val = self.r_max_value / ( 1 + (self.r_max_value - 1) * np.exp(-t_val)) d_val = self.d_max_value / (1 + ( (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val))) t_val += float(self.t_delta) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update(self.t, t_val) ], x) if self.mode == 0: if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_running = K.batch_normalization( x, self.running_mean, self.running_std, self.beta, self.gamma, epsilon=self.epsilon) else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_std, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( x, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) # pick the normalized form of x corresponding to the training phase # for batch renormalization, inference time remains same as batchnorm x_normed = K.in_train_phase(x_normed, x_normed_running) elif self.mode == 1: # sample-wise normalization m = K.mean(x, axis=self.axis, keepdims=True) std = K.sqrt( K.var(x, axis=self.axis, keepdims=True) + self.epsilon) x_normed_batch = (x - m) / (std + self.epsilon) r_max_value = K.get_value(self.r_max) r = std / (self.running_std + self.epsilon) r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) d_max_value = K.get_value(self.d_max) d = (m - self.running_mean) / (self.running_std + self.epsilon) d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta # update r_max and d_max t_val = K.get_value(self.t) r_val = self.r_max_value / ( 1 + (self.r_max_value - 1) * np.exp(-t_val)) d_val = self.d_max_value / (1 + ( (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val))) t_val += float(self.t_delta) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update(self.t, t_val) ], x) return x_normed
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) ndim = len(input_shape) reduction_axes = list(range(ndim)) del reduction_axes[self.axis] input_dim = input_shape[self.axis] // 2 mu = K.mean(inputs, axis=reduction_axes) broadcast_mu_shape = [1] * len(input_shape) broadcast_mu_shape[self.axis] = input_shape[self.axis] broadcast_mu = K.reshape(mu, broadcast_mu_shape) if self.center: input_centred = inputs - broadcast_mu else: input_centred = inputs centred_squared = input_centred**2 if (self.axis == 1 and ndim != 3) or ndim == 2: centred_squared_real = centred_squared[:, :input_dim] centred_squared_imag = centred_squared[:, input_dim:] centred_real = input_centred[:, :input_dim] centred_imag = input_centred[:, input_dim:] elif ndim == 3: centred_squared_real = centred_squared[:, :, :input_dim] centred_squared_imag = centred_squared[:, :, input_dim:] centred_real = input_centred[:, :, :input_dim] centred_imag = input_centred[:, :, input_dim:] elif self.axis == -1 and ndim == 4: centred_squared_real = centred_squared[:, :, :, :input_dim] centred_squared_imag = centred_squared[:, :, :, input_dim:] centred_real = input_centred[:, :, :, :input_dim] centred_imag = input_centred[:, :, :, input_dim:] elif self.axis == -1 and ndim == 5: centred_squared_real = centred_squared[:, :, :, :, :input_dim] centred_squared_imag = centred_squared[:, :, :, :, input_dim:] centred_real = input_centred[:, :, :, :, :input_dim] centred_imag = input_centred[:, :, :, :, input_dim:] else: raise ValueError( 'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. ' 'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.') if self.scale: Vrr = K.mean(centred_squared_real, axis=reduction_axes) + self.epsilon Vii = K.mean(centred_squared_imag, axis=reduction_axes) + self.epsilon # Vri contains the real and imaginary covariance for each feature map. Vri = K.mean( centred_real * centred_imag, axis=reduction_axes, ) elif self.center: Vrr = None Vii = None Vri = None else: raise ValueError( 'Error. Both scale and center in batchnorm are set to False.') input_bn = ComplexBN(input_centred, Vrr, Vii, Vri, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis=self.axis) if training in {0, False}: return input_bn else: update_list = [] if self.center: update_list.append( K.moving_average_update(self.moving_mean, mu, self.momentum)) if self.scale: update_list.append( K.moving_average_update(self.moving_Vrr, Vrr, self.momentum)) update_list.append( K.moving_average_update(self.moving_Vii, Vii, self.momentum)) update_list.append( K.moving_average_update(self.moving_Vri, Vri, self.momentum)) self.add_update(update_list, inputs) def normalize_inference(): if self.center: inference_centred = inputs - K.reshape( self.moving_mean, broadcast_mu_shape) else: inference_centred = inputs return ComplexBN(inference_centred, self.moving_Vrr, self.moving_Vii, self.moving_Vri, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis=self.axis) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(input_bn, normalize_inference, training=training)
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): def apply_mode_normalization_inference(moving_mean, moving_variance, beta, gamma): inputs_mul_gates_ = self.apply_gates(inputs, input_shape, reduction_axes[1:]) outputs = [] for k_ in range(self.k): outputs.append( K.batch_normalization(inputs_mul_gates_[:, k_], moving_mean[k_], moving_variance[k_], beta / self.k, gamma, axis=self.axis, epsilon=self.epsilon)) return K.sum(K.stack(outputs, axis=0), axis=0) if needs_broadcasting: # In this case we must explicitly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape(self.moving_variance, broadcast_shape) if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) else: broadcast_beta = None if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) else: broadcast_gamma = None return apply_mode_normalization_inference( broadcast_moving_mean, broadcast_moving_variance, broadcast_beta, broadcast_gamma) else: return apply_mode_normalization_inference( self.moving_mean, self.moving_variance, self.beta, self.gamma) # If the learning phase is *static* and set to inference: if training in {0, False}: return normalize_inference() inputs_mul_gates = self.apply_gates(inputs, input_shape, reduction_axes[1:]) # training. mean_list, variance_list, normed_training_list = [], [], [] norm_func = K.normalize_batch_in_training for k in range(self.k): normed_training, mean, variance = norm_func(inputs_mul_gates[:, k], self.gamma, self.beta / self.k, reduction_axes, epsilon=self.epsilon) normed_training_list.append(normed_training) mean_list.append(mean) variance_list.append(variance) mean = K.stack(mean_list, axis=0) variance = K.stack(variance_list, axis=0) normed_training = K.sum(normed_training_list, axis=0) if K.backend() != 'cntk': sample_size = K.prod( [K.shape(inputs)[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) # sample variance - unbiased estimator of population variance variance *= sample_size / (sample_size - (1.0 + self.epsilon)) self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum) ], inputs) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(normed_training, normalize_inference, training=training)
def call(self, inputs, training=None): assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(inputs) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] mean_batch, var_batch = _moments(inputs, reduction_axes, shift=None, keep_dims=False) std_batch = (K.sqrt(var_batch + self.epsilon)) r = std_batch / (K.sqrt(self.running_variance + self.epsilon)) r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max)) d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance + self.epsilon) d = K.stop_gradient(K.clip(d, -self.d_max, self.d_max)) if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]: x_normed_batch = (inputs - mean_batch) / std_batch x_normed = (x_normed_batch * r + d) * self.gamma + self.beta else: # need broadcasting broadcast_mean = K.reshape(mean_batch, broadcast_shape) broadcast_std = K.reshape(std_batch, broadcast_shape) broadcast_r = K.reshape(r, broadcast_shape) broadcast_d = K.reshape(d, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_batch = (inputs - broadcast_mean) / broadcast_std x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta # explicit update to moving mean and standard deviation mean_update = K.moving_average_update(self.running_mean, mean_batch, self.momentum) variance_update = K.moving_average_update(self.running_variance, std_batch**2, self.momentum) self.add_update([mean_update, variance_update], inputs) # update r_max and d_max r_val = self.r_max_value / (1 + (self.r_max_value - 1) * K.exp(-self.t)) d_val = (self.d_max_value / (1 + ((self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t)))) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update_add(self.t, self.t_delta_tensor) ], inputs) if training in {0, False}: return x_normed else: def normalize_inference(): if sorted(reduction_axes) == list(range(K.ndim(inputs)))[:-1]: x_normed_running = K.batch_normalization( inputs, self.running_mean, self.running_variance, self.beta, self.gamma, epsilon=self.epsilon) return x_normed_running else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_variance, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( inputs, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) return x_normed_running # pick the normalized form of inputs corresponding to the training phase # for batch renormalization, inference time remains same as batchnorm x_normed = K.in_train_phase(x_normed, normalize_inference, training=training) return x_normed