def set_model(self, model): self.model = model self.sess = K.get_session() if self.histogram_freq and self.merged is None: for layer in self.model.layers: for weight in layer.weights: mapped_weight_name = weight.name.replace(':', '_') tf_summary.histogram(mapped_weight_name, weight) if self.write_grads: grads = model.optimizer.get_gradients(model.total_loss, weight) def is_indexed_slices(grad): return type(grad).__name__ == 'IndexedSlices' grads = [ grad.values if is_indexed_slices(grad) else grad for grad in grads] tf_summary.histogram( '{}_grad'.format(mapped_weight_name), grads) if self.write_images: w_img = array_ops.squeeze(weight) shape = K.int_shape(w_img) if len(shape) == 2: # dense layer kernel case if shape[0] > shape[1]: w_img = array_ops.transpose(w_img) shape = K.int_shape(w_img) w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) elif len(shape) == 3: # convnet case if K.image_data_format() == 'channels_last': # switch to channels_first to display # every kernel as a separate image w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) shape = K.int_shape(w_img) w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1]) elif len(shape) == 1: # bias case w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) else: # not possible to handle 3D convnets etc. continue shape = K.int_shape(w_img) assert len(shape) == 4 and shape[-1] in [1, 3, 4] tf_summary.image(mapped_weight_name, w_img) if hasattr(layer, 'output'): tf_summary.histogram('{}_out'.format(layer.name), layer.output) self.merged = tf_summary.merge_all() if self.write_graph: self.writer = tf_summary.FileWriter(self.log_dir, self.sess.graph) else: self.writer = tf_summary.FileWriter(self.log_dir)
def sampling(args): """Reparameterization trick by sampling fr an isotropic unit Gaussian. # Arguments: args (tensor): mean and log of variance of Q(z|X) # Returns: z (tensor): sampled latent vector """ z_mean, z_log_var = args batch = K.shape(z_mean)[0] dim = K.int_shape(z_mean)[1] # by default, random_normal has mean=0 and std=1.0 epsilon = K.random_normal(shape=(batch, dim)) return z_mean + K.exp(0.5 * z_log_var) * epsilon
def call(self, inputs, **kwargs): input_shape = K.int_shape(inputs) tensor_input_shape = K.shape(inputs) # Prepare broadcasting shape. reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] // self.groups broadcast_shape.insert(1, self.groups) reshape_group_shape = K.shape(inputs) group_axes = [reshape_group_shape[i] for i in range(len(input_shape))] group_axes[self.axis] = input_shape[self.axis] // self.groups group_axes.insert(1, self.groups) # reshape inputs to new group shape group_shape = [group_axes[0], self.groups] + group_axes[2:] group_shape = K.stack(group_shape) inputs = K.reshape(inputs, group_shape) group_reduction_axes = list(range(len(group_axes))) mean, variance = _moments(inputs, group_reduction_axes[2:], keep_dims=True) inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon)) # prepare broadcast shape inputs = K.reshape(inputs, group_shape) outputs = inputs # In this case we must explicitly broadcast all parameters. if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) outputs = outputs * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) outputs = outputs + broadcast_beta # finally we reshape the output back to the input shape outputs = K.reshape(outputs, tensor_input_shape) return outputs
def call(self, inputs, training=None): input_shape = K.int_shape(inputs) reduction_axes = list(range(0, len(input_shape))) if (self.axis is not None): del reduction_axes[self.axis] del reduction_axes[0] mean = K.mean(inputs, reduction_axes, keepdims=True) stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon normed = (inputs - mean) / stddev broadcast_shape = [1] * len(input_shape) if self.axis is not None: broadcast_shape[self.axis] = input_shape[self.axis] if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) normed = normed * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) normed = normed + broadcast_beta return normed
def call(self, x, mask=None): if self.mode == 0 or self.mode == 2: assert self.built, 'Layer must be built before being called' input_shape = K.int_shape(x) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] mean_batch, var_batch = _moments(x, reduction_axes, shift=None, keep_dims=False) std_batch = (K.sqrt(var_batch + self.epsilon)) r_max_value = K.get_value(self.r_max) r = std_batch / (K.sqrt(self.running_std + self.epsilon)) r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) d_max_value = K.get_value(self.d_max) d = (mean_batch - self.running_mean) / K.sqrt(self.running_std + self.epsilon) d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_batch = (x - mean_batch) / std_batch x_normed = (x_normed_batch * r + d) * self.gamma + self.beta else: # need broadcasting broadcast_mean = K.reshape(mean_batch, broadcast_shape) broadcast_std = K.reshape(std_batch, broadcast_shape) broadcast_r = K.reshape(r, broadcast_shape) broadcast_d = K.reshape(d, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_batch = (x - broadcast_mean) / broadcast_std x_normed = (x_normed_batch * broadcast_r + broadcast_d) * broadcast_gamma + broadcast_beta # explicit update to moving mean and standard deviation self.add_update([ K.moving_average_update(self.running_mean, mean_batch, self.momentum), K.moving_average_update(self.running_std, std_batch**2, self.momentum) ], x) # update r_max and d_max r_val = self.r_max_value / ( 1 + (self.r_max_value - 1) * K.exp(-self.t)) d_val = self.d_max_value / (1 + ( (self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t))) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update_add(self.t, K.variable(np.array([self.t_delta]))) ], x) if self.mode == 0: if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_running = K.batch_normalization( x, self.running_mean, self.running_std, self.beta, self.gamma, epsilon=self.epsilon) else: # need broadcasting broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape) broadcast_running_std = K.reshape(self.running_std, broadcast_shape) broadcast_beta = K.reshape(self.beta, broadcast_shape) broadcast_gamma = K.reshape(self.gamma, broadcast_shape) x_normed_running = K.batch_normalization( x, broadcast_running_mean, broadcast_running_std, broadcast_beta, broadcast_gamma, epsilon=self.epsilon) # pick the normalized form of x corresponding to the training phase # for batch renormalization, inference time remains same as batchnorm x_normed = K.in_train_phase(x_normed, x_normed_running) elif self.mode == 1: # sample-wise normalization m = K.mean(x, axis=self.axis, keepdims=True) std = K.sqrt( K.var(x, axis=self.axis, keepdims=True) + self.epsilon) x_normed_batch = (x - m) / (std + self.epsilon) r_max_value = K.get_value(self.r_max) r = std / (self.running_std + self.epsilon) r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value)) d_max_value = K.get_value(self.d_max) d = (m - self.running_mean) / (self.running_std + self.epsilon) d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value)) x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta # update r_max and d_max t_val = K.get_value(self.t) r_val = self.r_max_value / ( 1 + (self.r_max_value - 1) * np.exp(-t_val)) d_val = self.d_max_value / (1 + ( (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val))) t_val += float(self.t_delta) self.add_update([ K.update(self.r_max, r_val), K.update(self.d_max, d_val), K.update(self.t, t_val) ], x) return x_normed
def __int_shape(self, x): return KC.int_shape(x) if self.backend == 'tensorflow' else KC.shape(x)