def _generate_dropout_mask(ones, rate, training=None, count=1): def dropped_inputs(): return K.dropout(ones, rate) if count > 1: return [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(count) ] return K.in_train_phase(dropped_inputs, ones, training=training)
def call(self, inputs, training=None): if training is None: training = K.learning_phase() if self.use_mc_dropout: training = True def drop_inputs(): return K.dropout(inputs, self.unit_dropout) if 0. < self.unit_dropout < 1.: inputs = K.in_train_phase(drop_inputs, inputs, training=training) #kernel dropout ones = array_ops.ones_like(self.kernel) def dropped_weight_connections(): return K.dropout(ones, self.kernel_dropout) * (1 - self.kernel_dropout) if 0. < self.kernel_dropout < 1.: kern_dp_mask = K.in_train_phase(dropped_weight_connections, ones, training=training) else: kern_dp_mask = ones rank = len(inputs.shape) if rank > 2: # Broadcasting is required for the inputs. outputs = standard_ops.tensordot(inputs, self.kernel * kern_dp_mask, [[rank - 1], [0]]) # Reshape the output back to the original ndim of the input. if not context.executing_eagerly(): shape = inputs.shape.as_list() output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: inputs = math_ops.cast(inputs, self._compute_dtype) if K.is_sparse(inputs): outputs = sparse_ops.sparse_tensor_dense_matmul( inputs, self.kernel * kern_dp_mask) else: outputs = gen_math_ops.mat_mul(inputs, self.kernel * kern_dp_mask) if self.use_bias: outputs = nn.bias_add(outputs, self.bias) if self.activation is not None: return self.activation(outputs) # pylint: disable=not-callable return outputs
def call(self, x, training=None): if len(x) != 2: raise Exception('input layers must be a list: mean and logvar') if len(x[0].shape) != 2 or len(x[1].shape) != 2: raise Exception( 'input shape is not a vector [batchSize, latentSize]') mean = x[0] logvar = x[1] if mean.shape[0].value == None or logvar.shape[0].value == None: return mean + 0 * logvar if self.reg is not None: latent_loss = -0.5 * (1 + logvar - K.square(mean) - K.exp(logvar)) latent_loss = K.sum(latent_loss, axis=-1) latent_loss = K.mean(latent_loss, axis=0) latent_loss = self.beta * latent_loss self.add_loss(latent_loss, x) def reparameterization_trick(): epsilon = K.random_normal(shape=logvar.shape, mean=0., stddev=1.) stddev = K.exp(logvar * 0.5) return mean + stddev * epsilon return K.in_train_phase(reparameterization_trick, mean + 0 * logvar, training=training)
def call(self, inputs, training=None): def noised(): return inputs + K.random_normal( shape=array_ops.shape(inputs), mean=0., stddev=self.stddev) return K.in_train_phase(noised, inputs, training=training)
def call(self, inputs, training=None): if 0. < self.rate < 1.: noise_shape = self._get_noise_shape(inputs) def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): # pylint: disable=missing-docstring alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 alpha_p = -alpha * scale kept_idx = math_ops.greater_equal( K.random_uniform(noise_shape, seed=seed), rate) kept_idx = math_ops.cast(kept_idx, K.floatx()) # Get affine transformation params a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5 b = -a * alpha_p * rate # Apply mask x = inputs * kept_idx + alpha_p * (1 - kept_idx) # Do affine transformation return a * x + b return K.in_train_phase(dropped_inputs, inputs, training=training) return inputs
def call(self, x, training=None): if len(x) != 2: raise Exception('input layers must be a list: mean and logvar') if len(x[0].shape) != 2 or len(x[1].shape) != 2: raise Exception( 'input shape is not a vector [batchSize, latentSize]') mean = x[0] logvar = x[1] # trick to allow setting batch at train/eval time if mean.shape[0] is None or logvar.shape[0] is None: return mean + 0 * logvar # Keras needs the *0 so the gradinent is not None # kl divergence: latent_loss = -0.5 * (1 + logvar - K.square(mean) - K.exp(logvar)) latent_loss = K.sum(latent_loss, axis=-1) # sum over latent dimension latent_loss = K.mean(latent_loss, axis=0) # avg over batch # use beta to force less usage of vector space: # set beta latent_loss = 1.0 * latent_loss self.add_loss(latent_loss) #self.add_loss(latent_loss, x) def reparameterization_trick(): epsilon = K.random_normal(shape=logvar.shape, mean=0., stddev=1.) stddev = K.exp(logvar * 0.5) return mean + stddev * epsilon return K.in_train_phase( reparameterization_trick, mean + 0 * logvar, training=training ) # TODO figure out why this is not working in the specified tf version???
def call(self, inputs, **kwargs): main_input, embedding_matrix = inputs input_shape_tensor = K.shape(main_input) last_input_dim = K.int_shape(main_input)[-1] emb_input_dim, emb_output_dim = K.int_shape(embedding_matrix) projected = K.dot(K.reshape(main_input, (-1, last_input_dim)), self.embedding_weights['projection']) if self.add_biases: projected = K.bias_add(projected, self.embedding_weights['biases'], data_format='channels_last') if 0 < self.projection_dropout < 1: projected = K.in_train_phase( lambda: K.dropout(projected, self.projection_dropout), projected, training=kwargs.get('training')) attention = K.dot(projected, K.transpose(embedding_matrix)) if self.scaled_attention: # scaled dot-product attention, described in # "Attention is all you need" (https://arxiv.org/abs/1706.03762) sqrt_d = K.constant(math.sqrt(emb_output_dim), dtype=K.floatx()) attention = attention / sqrt_d result = K.reshape( self.activation(attention), (input_shape_tensor[0], input_shape_tensor[1], emb_input_dim)) return result
def call(self, inputs, training=None): def noised(): return inputs + K.random_normal( shape=array_ops.shape(inputs), mean=0., stddev=self.stddev) return K.in_train_phase(noised, inputs, training=training)
def call(self, inputs, training=None): if 0. < self.rate < 1.: noise_shape = self._get_noise_shape(inputs) def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): # pylint: disable=missing-docstring alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 alpha_p = -alpha * scale kept_idx = math_ops.greater_equal( K.random_uniform(noise_shape, seed=seed), rate) kept_idx = math_ops.cast(kept_idx, inputs.dtype) # Get affine transformation params a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5 b = -a * alpha_p * rate # Apply mask x = inputs * kept_idx + alpha_p * (1 - kept_idx) # Do affine transformation return a * x + b return K.in_train_phase(dropped_inputs, inputs, training=training) return inputs
def update_boost_strength(self): """ Update boost strength using given strength factor during training. """ factor = K.in_train_phase(self.boost_strength_factor, 1.0) self.add_update( self.boost_strength.assign(self.boost_strength * factor, read_value=False))
def call(self, inputs, training=None): stddev = K.sqrt(K.mean(K.square(inputs))) * 0.05 def noised(): return inputs + K.random_normal( shape=array_ops.shape(inputs), mean=0., stddev=stddev) return K.in_train_phase(noised, noised, training=training)
def apply_dropout_if_needed(self, attention_softmax, training=None): if 0.0 < self.dropout < 1.0: def dropped_softmax(): return K.dropout(attention_softmax, self.dropout) return K.in_train_phase(dropped_softmax, attention_softmax, training=training) return attention_softmax
def call(self, x, mask=None): if 0. < self.rate < 1.: noise_shape = self._get_noise_shape(x) if self.permanent: x = K.dropout(x, self.rate) else: x = K.in_train_phase(K.dropout(x, self.rate), x) return x
def call(self, inputs, training=None): def noised(): return inputs + tf.random.normal(array_ops.shape(inputs), mean=0.0, stddev=self.stddev, dtype=inputs.dtype, seed=None) return K.in_train_phase(noised, inputs, training=training)
def call(self, inputs, training=None): def noised(): return inputs + backend.random_normal( shape=array_ops.shape(inputs), mean=0., stddev=self.stddev, dtype=inputs.dtype) return backend.in_train_phase(noised, inputs, training=training)
def call(self, inputs, training=None): if 0 < self.rate < 1: def noised(): stddev = np.sqrt(self.rate / (1.0 - self.rate)) return inputs * K.random_normal( shape=array_ops.shape(inputs), mean=1.0, stddev=stddev) return K.in_train_phase(noised, inputs, training=training) return inputs
def call(self, inputs, training=None): if 0 < self.rate < 1: def noised(): stddev = np.sqrt(self.rate / (1.0 - self.rate)) return inputs * K.random_normal( shape=array_ops.shape(inputs), mean=1.0, stddev=stddev) return K.in_train_phase(noised, inputs, training=training) return inputs
def build(self): # Embedding-layer to transform input into 3D-space. input_embedding = Embedding(self.data_sequence.vocab_size(), self.embedding_dim) # Inputs encoder_inputs = Input(shape=(None,)) encoder_inputs_emb = input_embedding(encoder_inputs) # Encoder LSTM encoder = LSTM(self.lstm_hidden_dim, return_state=True) encoder_outputs, state_h, state_c = encoder(encoder_inputs_emb) state = [state_h, state_c] # state will be used to initialize the decoder # Start vars (emulates a constant input) def constant(input_batch, size): batch_size = K.shape(input_batch)[0] return K.tile(K.ones((1, size)), (batch_size, 1)) decoder_in = Lambda(constant, arguments={'size': self.embedding_dim})(encoder_inputs_emb) # "start word" # Definition of further layers to be used in the model (decoder and mapping to vocab-sized vector) decoder_lstm = LSTM(self.lstm_hidden_dim, return_sequences=False, return_state=True) decoder_dense = Dense(self.data_sequence.vocab_size(), activation='softmax') chars = [] # Container for single results during the loop for i in range(self.max_decoder_length): # Reshape necessary to match LSTMs interface, cell state will be reintroduced in the next iteration decoder_in = Reshape((1, self.embedding_dim))(decoder_in) decoder_in, hidden_state, cell_state = decoder_lstm(decoder_in, initial_state=state) state = [hidden_state, cell_state] # Mapping decoder_out = decoder_dense(decoder_in) # Reshaping and storing for later concatenation char = Reshape((1, self.data_sequence.vocab_size()))(decoder_out) chars.append(char) # Teacher forcing. During training the original input will be used as input to the decoder decoder_in_train = Lambda(lambda x, ii: x[:, -ii], arguments={'ii': i+1})(encoder_inputs_emb) decoder_in = Lambda(lambda x, y: K.in_train_phase(y, x), arguments={'y': decoder_in_train})(decoder_in) # Single results are joined together (axis 1 vanishes) decoded_seq = Concatenate(axis=1)(chars) self.model = Model(encoder_inputs, decoded_seq, name="enc_dec") self.model.compile(optimizer='adam', loss='categorical_crossentropy') self.model.summary() try: file_name = 'enc_dec_model' plot_model(self.model, to_file=f'{file_name}.png', show_shapes=True) print(f"Model built. Saved {file_name}.png\n") except (ImportError, FileNotFoundError): print(f"Skipping plotting of model due to missing dependencies.")
def call(self, inputs, mask=None): def sparse(): # number of dimensions in input might be < |k|. account for that actual_k = tf.minimum(K.shape(inputs)[-1] - 1, self.k) # multiply all values greater than the k smallest with 1, the rest with 0 kth_smallest = tf.sort(inputs)[..., K.shape(inputs)[-1] - 1 - actual_k] return inputs * K.cast(K.greater(inputs, kth_smallest[:, None]), K.floatx()) return K.in_train_phase(sparse, inputs)
def call(self, inputs, training=None, **kwargs): inputs = super().call(inputs, **kwargs) k = K.in_test_phase(x=self.k_inference, alt=self.k, training=training) kwinners = compute_kwinners( x=inputs, k=k, duty_cycles=self.duty_cycles, boost_strength=self.boost_strength, ) duty_cycles = K.in_train_phase( lambda: self.compute_duty_cycle(kwinners), self.duty_cycles, training=training, ) self.add_update(self.duty_cycles.assign(duty_cycles, read_value=False)) increment = K.in_train_phase(K.shape(inputs)[0], 0, training=training) self.add_update( self.learning_iterations.assign_add(increment, read_value=False)) return kwinners
def _generate_dropout_mask(self, inputs, training=None): if 0 < self.dropout < 1: ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) def dropped_inputs(): return K.dropout(ones, self.dropout) self._dropout_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(4) ] else: self._dropout_mask = None
def get_constants(self, inputs, training=None): constants = [] if self.implementation == 0 and 0 < self.dropout < 1: ones = K.zeros_like(inputs) ones = K.sum(ones, axis=1) ones += 1 def dropped_inputs(): return K.dropout(ones, self.dropout) dp_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(4) ] constants.append(dp_mask) else: constants.append([K.cast_to_floatx(1.) for _ in range(4)]) if 0 < self.recurrent_dropout < 1: depthwise_shape = list(self.depthwise_kernel_shape) pointwise_shape = list(self.pointwise_kernel_shape) ones = K.zeros_like(inputs) ones = K.sum(ones, axis=1) ones = self.input_conv(ones, K.zeros(depthwise_shape), K.zeros(pointwise_shape), padding=self.padding) ones += 1. def dropped_inputs(): # pylint: disable=function-redefined return K.dropout(ones, self.recurrent_dropout) rec_dp_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(4) ] constants.append(rec_dp_mask) else: constants.append([K.cast_to_floatx(1.) for _ in range(4)]) return constants
def call(self, inputs, training=None): def drop_connect(): keep_prob = 1.0 - self.drop_connect_rate # Compute drop_connect tensor batch_size = tf.shape(inputs)[0] random_tensor = keep_prob random_tensor += tf.random.uniform([batch_size, 1, 1, 1], dtype=inputs.dtype) binary_tensor = tf.floor(random_tensor) output = tf.divide(inputs, keep_prob) * binary_tensor return output return K.in_train_phase(drop_connect, inputs, training=training)
def call(self, inputs, **kwargs): is_training = kwargs.get('training', False) if self.dropconnect_prob > 0.0: def dropconnected(): return dropconnect(self.kernel, self.dropconnect_prob) # Apply dropconnect if in training # Fails when overwriting kernel, hence the "DC" self.kernelDC = K.in_train_phase(dropconnected, self.kernel, training=is_training) else: self.kernelDC = self.kernel # Apply kernel to inputs # Note: This part came from Dense() rank = len(inputs.shape) if rank > 2: # Broadcasting is required for the inputs. outputs = standard_ops.tensordot(inputs, self.kernelDC, [[rank - 1], [0]]) # Reshape the output back to the original ndim of the input. if not context.executing_eagerly(): shape = inputs.shape.as_list() output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: inputs = math_ops.cast(inputs, self._compute_dtype) if K.is_sparse(inputs): outputs = sparse_ops.sparse_tensor_dense_matmul( inputs, self.kernelDC) else: outputs = gen_math_ops.mat_mul(inputs, self.kernelDC) # Add bias if self.use_bias: outputs = nn.bias_add(outputs, self.bias) # Apply scaling factor if self.scale: outputs = self.scaler(outputs) # Apply activation function if self.activation is not None: outputs = self.activation(outputs) # pylint: disable=not-callable return outputs
def _time_distributed_dense(x, w, b=None, dropout=None, input_dim=None, output_dim=None, timesteps=None, training=None): """Apply `y . w + b` for every temporal slice y of x. # Arguments x: input tensor. w: weight matrix. b: optional bias vector. dropout: wether to apply dropout (same dropout mask for every temporal slice of the input). input_dim: integer; optional dimensionality of the input. output_dim: integer; optional dimensionality of the output. timesteps: integer; optional number of timesteps. training: training phase tensor or boolean. # Returns Output tensor. """ if not input_dim: input_dim = K.shape(x)[2] if not timesteps: timesteps = K.shape(x)[1] if not output_dim: output_dim = K.int_shape(w)[1] if dropout is not None and 0. < dropout < 1.: # apply the same dropout pattern at every timestep ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim))) dropout_matrix = K.dropout(ones, dropout) expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps) x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training) # collapse time dimension and batch dimension together x = K.reshape(x, (-1, input_dim)) x = K.dot(x, w) if b is not None: x = K.bias_add(x, b) # reshape to 3D tensor if K.backend() == 'tensorflow': x = K.reshape(x, K.stack([-1, timesteps, output_dim])) x.set_shape([None, None, output_dim]) else: x = K.reshape(x, (-1, timesteps, output_dim)) return x
def _generate_recurrent_dropout_mask(self, inputs, training=None): if 0 < self.recurrent_dropout < 1: ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) ones = K.tile(ones, (1, self.units)) def dropped_inputs(): return K.dropout(ones, self.dropout) self._recurrent_dropout_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(4) ] else: self._recurrent_dropout_mask = None
def dot_product_attention(self, x, mask=None, dropout=0.1, training=None): q, k, v = x logits = tf.matmul(q, k, transpose_b=True) # [bs, 8, len, len] if self.bias: logits += self.b if mask is not None: # [bs, len] mask = tf.expand_dims(mask, axis=1) mask = tf.expand_dims(mask, axis=1) # [bs,1,1,len] logits = self.mask_logits(logits, mask) weights = tf.nn.softmax(logits, name="attention_weights") weights = K.in_train_phase(K.dropout(weights, dropout), weights, training=training) x = tf.matmul(weights, v) return x
def center(inputs, moving_mean, w, h, c, instance_norm=False): if instance_norm: x_t = tf.transpose(inputs, (0, 3, 1, 2)) x_flat = tf.reshape(x_t, (-1, c, w * h)) # (bs, c, w*h) m = tf.reduce_mean(x_flat, axis=2, keepdims=True) # (bs, c, 1) else: x_t = tf.transpose(inputs, (3, 0, 1, 2)) x_flat = tf.reshape(x_t, (c, -1)) # (c, bs*w*h) m = tf.reduce_mean(x_flat, axis=1, keepdims=True) m = K.in_train_phase(m, moving_mean) # (c, 1) f = x_flat - m return m, f
def apply_dropout_if_needed(self, attention_softmax, training=None): """ apply dropout after attention softmax if desired :param attention_softmax: :param training: :return: """ if 0.0 < self.dropout < 1.0: def dropped_softmax(): return K.dropout(attention_softmax, self.dropout) return K.in_train_phase(dropped_softmax, attention_softmax, training=training) return attention_softmax
def W_bar(self): # Spectrally Normalized Weight W_mat = K.permute_dimensions( self.kernel, (3, 2, 0, 1)) # (h, w, i, o) => (o, i, h, w) W_mat = K.reshape(W_mat, [K.shape(W_mat)[0], -1]) # (o, i * h * w) if not self.Ip >= 1: raise ValueError( "The number of power iterations should be positive integer") _u = self.u _v = None for _ in range(self.Ip): _v = _l2normalize(K.dot(_u, W_mat)) _u = _l2normalize(K.dot(_v, K.transpose(W_mat))) sigma = K.sum(K.dot(_u, W_mat) * _v) K.update(self.u, K.in_train_phase(_u, self.u)) return self.kernel / sigma
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) reduction_axes = list(range(0, len(input_shape))) if self.axis is not None: del reduction_axes[self.axis] del reduction_axes[0] mean = K.mean(inputs, reduction_axes, keepdims=True) stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon normed = (inputs - mean) / stddev def noised(): eps = K.random_uniform(shape=[1], maxval=self.alpha) return inputs + K.random_normal( shape=K.shape(inputs), mean=0., stddev=eps) get_noised = K.in_train_phase(noised, normed, training=training) retrived = stddev * get_noised + mean return retrived
def call(self, x, training=None): def outputs_inference(): # Apply truncation trick according to cutoff. num_layers = K.int_shape(x)[1] if self.cutoff is not None: beta = Ke.where( np.arange(num_layers)[np.newaxis, :, np.newaxis] < self.cutoff, self.psi * np.ones(shape=(1, num_layers, 1), dtype=np.float32), np.ones(shape=(1, num_layers, 1), dtype=np.float32)) #? else: beta = np.ones(shape=(1, num_layers, 1), dtype=np.float32) return self.moving_mean + (x - self.moving_mean) * beta #? # Update moving average. mean = K.mean(x[:, 0], axis=0) #? x_moving_mean = K.moving_average_update(self.moving_mean, mean, self.momentum) #? add_update? # Apply truncation trick according to cutoff. num_layers = K.int_shape(x)[1] if self.cutoff is not None: beta = Ke.where( np.arange(num_layers)[np.newaxis, :, np.newaxis] < self.cutoff, self.psi * np.ones(shape=(1, num_layers, 1), dtype=np.float32), np.ones(shape=(1, num_layers, 1), dtype=np.float32)) #? else: beta = np.ones(shape=(1, num_layers, 1), dtype=np.float32) outputs = x_moving_mean + (x - self.moving_mean) * beta #? return K.in_train_phase(outputs, outputs_inference, training=training)
def call(self, inputs, training=None): class_labels = K.squeeze(inputs[1], axis=1) inputs = inputs[0] input_shape = K.int_shape(inputs) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) param_broadcast = [1] * len(input_shape) param_broadcast[self.axis] = input_shape[self.axis] param_broadcast[0] = K.shape(inputs)[0] if self.scale: broadcast_gamma = K.reshape(K.gather(self.gamma, class_labels), param_broadcast) else: broadcast_gamma = None if self.center: broadcast_beta = K.reshape(K.gather(self.beta, class_labels), param_broadcast) else: broadcast_beta = None normed, mean, variance = K.normalize_batch_in_training( inputs, gamma=None, beta=None, reduction_axes=reduction_axes, epsilon=self.epsilon) if training in {0, False}: return normed else: self.add_update([ K.moving_average_update(self.moving_mean, mean, self.momentum), K.moving_average_update(self.moving_variance, variance, self.momentum) ], inputs) def normalize_inference(): if needs_broadcasting: # In this case we must explictly broadcast all parameters. broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape) broadcast_moving_variance = K.reshape( self.moving_variance, broadcast_shape) return K.batch_normalization(inputs, broadcast_moving_mean, broadcast_moving_variance, beta=None, gamma=None, epsilon=self.epsilon) else: return K.batch_normalization(inputs, self.moving_mean, self.moving_variance, beta=None, gamma=None, epsilon=self.epsilon) # Pick the normalized form corresponding to the training phase. out = K.in_train_phase(normed, normalize_inference, training=training) return out * broadcast_gamma + broadcast_beta
def call(self, inputs, training=None): _, w, h, c = K.int_shape(inputs) bs = K.shape(inputs)[0] m, f = utils.center(inputs, self.moving_mean, self.instance_norm) get_inv_sqrt = utils.get_decomposition(self.decomposition, bs, self.group, self.instance_norm, self.iter_num, self.epsilon, self.device) def train(): ff_aprs = utils.get_group_cov(f, self.group, self.m_per_group, self.instance_norm, bs, w, h, c) if self.instance_norm: ff_aprs = tf.transpose(ff_aprs, (1, 0, 2, 3)) ff_aprs = (1 - self.epsilon) * ff_aprs + tf.expand_dims( tf.expand_dims(tf.eye(self.m_per_group) * self.epsilon, 0), 0) else: ff_aprs = (1 - self.epsilon) * ff_aprs + tf.expand_dims( tf.eye(self.m_per_group) * self.epsilon, 0) whitten_matrix = get_inv_sqrt(ff_aprs, self.m_per_group)[1] self.add_update([ K.moving_average_update(self.moving_mean, m, self.momentum), K.moving_average_update( self.moving_matrix, whitten_matrix if '_wm' in self.decomposition else ff_aprs, self.momentum) ], inputs) if self.renorm: l, l_inv = get_inv_sqrt(ff_aprs, self.m_per_group) ff_mov = (1 - self.epsilon) * self.moving_matrix + tf.eye( self.m_per_group) * self.epsilon _, l_mov_inverse = get_inv_sqrt(ff_mov, self.m_per_group) l_ndiff = K.stop_gradient(l) return tf.matmul(tf.matmul(l_mov_inverse, l_ndiff), l_inv) return whitten_matrix def test(): moving_matrix = (1 - self.epsilon) * self.moving_matrix + tf.eye( self.m_per_group) * self.epsilon if '_wm' in self.decomposition: return moving_matrix else: return get_inv_sqrt(moving_matrix, self.m_per_group)[1] if self.instance_norm == 1: inv_sqrt = train() f = tf.reshape(f, [-1, self.group, self.m_per_group, w * h]) f_hat = tf.matmul(inv_sqrt, f) decorelated = K.reshape(f_hat, [bs, c, w, h]) decorelated = tf.transpose(decorelated, [0, 2, 3, 1]) else: inv_sqrt = K.in_train_phase(train, test) f = tf.reshape(f, [self.group, self.m_per_group, -1]) f_hat = tf.matmul(inv_sqrt, f) decorelated = K.reshape(f_hat, [c, bs, w, h]) decorelated = tf.transpose(decorelated, [1, 2, 3, 0]) return decorelated