def on_train_batch_end(self, batch, logs=None): self.count += 1 if self.slow_weights is None: with tf.control_dependencies(self.model.trainable_weights): self.slow_weights = [] for fast_param in self.model.trainable_weights: with ops.control_dependencies([fast_param]): slow_param = tf.Variable(fast_param.initialized_value(), dtype=fast_param.dtype, trainable=False, name=fast_param.name.split(":")[0]) self.slow_weights.append(slow_param) K.track_variable(slow_param) else: if self.count % self.k == 0: slow_ups, fast_ups = [], [] for fast, slow in zip(self.model.trainable_weights, self.slow_weights): slow_ups.append(K.update(slow, slow + self.alpha * (fast - slow))) with tf.control_dependencies(slow_ups): for fast, slow in zip(self.model.trainable_weights, self.slow_weights): fast_ups.append(K.update(fast, slow)) K.batch_get_value(slow_ups) K.batch_get_value(fast_ups)
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] if self.amsgrad: vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] else: vhats = [K.zeros(1) for _ in params] self.weights = [self.iterations] + ms + vs + vhats for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): # Learning rate multipliers if self.multipliers: multiplier = [ mult for mult in self.multipliers if mult in p.name ] else: multiplier = None if multiplier: new_lr_t = lr_t * self.multipliers[multiplier[0]] if self.debug_verbose: print('Setting {} to learning rate {}'.format( multiplier[0], new_lr_t)) print(K.get_value(new_lr_t)) else: new_lr_t = lr_t if self.debug_verbose: print('No change in learning rate {}'.format(p.name)) print(K.get_value(new_lr_t)) m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) if self.amsgrad: vhat_t = K.maximum(vhat, v_t) p_t = p - new_lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) self.updates.append(K.update(vhat, vhat_t)) else: p_t = p - new_lr_t * m_t / (K.sqrt(v_t) + self.epsilon) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 # Applies bounds on actual learning rate step_size = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) final_lr = self.final_lr * lr / self.base_lr lower_bound = final_lr * (1. - 1. / (self.gamma * t + 1.)) upper_bound = final_lr * (1. + 1. / (self.gamma * t)) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] if self.amsbound: vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] else: vhats = [K.zeros(1) for _ in params] self.weights = [self.iterations] + ms + vs + vhats for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): # apply weight decay if self.weight_decay != 0.: g += self.weight_decay * K.stop_gradient(p) m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) if self.amsbound: vhat_t = K.maximum(vhat, v_t) denom = (K.sqrt(vhat_t) + self.epsilon) self.updates.append(K.update(vhat, vhat_t)) else: denom = (K.sqrt(v_t) + self.epsilon) # Compute the bounds step_size_p = step_size * K.ones_like(denom) step_size_p_bound = step_size_p / denom bounded_lr_t = m_t * K.minimum( K.maximum(step_size_p_bound, lower_bound), upper_bound) p_t = p - bounded_lr_t self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] shapes = [K.int_shape(p) for p in params] prev_grads = [ K.zeros(shape, name='prev_grad_' + str(i)) for (i, shape) in enumerate(shapes) ] ds = [ K.zeros(shape, name='d_' + str(i)) for (i, shape) in enumerate(shapes) ] vs = [ K.zeros(shape, name='v_' + str(i)) for (i, shape) in enumerate(shapes) ] self.weights = [self.iterations] + ds + vs + prev_grads for p, g, pg, v, d in zip(params, grads, prev_grads, vs, ds): v_t = self.momentum * v - self.lr * g self.updates.append(K.update(v, v_t)) d_t = self.momentum * d + (1 - self.momentum) * (g - pg) self.updates.append(K.update(d, d_t)) self.updates.append(K.update(pg, g)) new_p = p + v_t + self.kd * d_t self.updates.append(K.update(p, new_p)) return self.updates
def _update_s_matrix_stats(self, num_of_complex_params_t, s): if not self.add_s_matrix_stats: return tf.no_op(), tf.no_op() abs_eigvals = tf.math.abs(tf.linalg.eigvalsh(s)) tol = K.epsilon() * tf.cast(num_of_complex_params_t, abs_eigvals.dtype) * tf.math.reduce_max( tf.math.abs(s)) # see https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.matrix_rank.html filtered_eigvals = tf.boolean_mask(abs_eigvals, abs_eigvals > tol) updated_s_matrix_rank = K.update(self.s_matrix_rank, tf.count_nonzero(filtered_eigvals)) updated_s_matrix_min_eigval = K.update(self.s_matrix_min_eigval, tf.math.reduce_min(filtered_eigvals)) return updated_s_matrix_min_eigval, updated_s_matrix_rank
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 '''Bias corrections according to the Adam paper ''' lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] self.weights = [self.iterations] + ms + vs for p, g, m, v in zip(params, grads, ms, vs): #################################################### # Add a lr multiplier for vars outside excluded_vars if p.name in self.excluded_vars: multiplied_lr_t = lr_t else: multiplied_lr_t = lr_t * self.lr_mult ################################################### m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) '''Schedule multiplier eta_t = 1 for simple AdamW According to the AdamW paper, eta_t can be fixed, decay, or also be used for warm restarts (AdamWR to come). ''' eta_t = 1. p_t = p - eta_t * (multiplied_lr_t * m_t / (K.sqrt(v_t) + self.epsilon)) if self.weight_decay != 0: '''Normalized weight decay according to the AdamW paper ''' w_d = self.weight_decay * K.sqrt(self.batch_size / (self.samples_per_epoch * self.epochs)) p_t = p_t - eta_t * (w_d * p) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def get_updates_Padam(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] base_lr = self._optimizer.learning_rate if self.initial_decay > 0: base_lr = base_lr * (1. / (1. + self.decay * K.cast( self.iterations, K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 lr_t = base_lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] if self.amsgrad: vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] else: vhats = [K.zeros(1) for _ in params] self.weights = [self.iterations] + ms + vs + vhats for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): if self._get_multiplier(p) is None: multiplier = 1.0 else: multiplier = self._get_multiplier(p) m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) if self.amsgrad: vhat_t = K.maximum(vhat, v_t) denom = (K.sqrt(vhat_t) + self.epsilon) self.updates.append(K.update(vhat, vhat_t)) else: denom = (K.sqrt(v_t) + self.epsilon) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) # Partial momentum adaption. new_p = p - (lr_t * multiplier * (m_t / (denom**(self.partial * 2)))) # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def compile_discriminator_train_op(self): loss_list = [] adversarial_loss = self.get_discriminator_adversarial_loss( self.generator_adversarial_objective) loss_list += adversarial_loss loss_list += self.get_gradient_penalty_loss() loss_list += self.additional_discriminator_losses() updates = [] updates += self.collect_updates(self.discriminator) updates += self.collect_updates(self.generator) print(updates) updates += self.discriminator_optimizer.get_updates( params=self.discriminator.trainable_weights, loss=sum(loss_list)) inputs = self.discriminator_input + self.additional_inputs_for_discriminator_train +\ self.generator_input + self.additional_inputs_for_generator_train lr_update = (self.lr_decay_schedule_discriminator( self.discriminator_optimizer.iterations) * K.get_value(self.discriminator_optimizer.lr)) updates.append(K.update(self.discriminator_optimizer.lr, lr_update)) train_op = function(inputs + [K.learning_phase()], [sum(loss_list)] + loss_list, updates=updates) return train_op
def call(self, inputs, **kwargs): inputs, memory_length = inputs memory_length = K.cast(memory_length[0][0], 'int32') batch_size = K.cast(K.shape(inputs)[0], 'int32') seq_len = K.cast(K.shape(inputs)[1], 'int32') # Build new memory pad = K.tile(inputs[0:1, ...], (self.batch_size - batch_size, 1, 1)) padded = K.concatenate([inputs, pad], axis=0) # (self.batch_size, seq_len, output_dim) new_memory = K.concatenate([self.memory, padded], axis=1) # (self.batch_size, self.memory_len + self.target_len + seq_len, ...) new_memory = tf.slice( # (self.batch_size, self.memory_len + self.target_len, output_dim) new_memory, (0, seq_len, 0), (self.batch_size, self.memory_len + self.target_len, self.output_dim), ) self.add_update(K.update(self.memory, new_memory), inputs) # Build output old_memory = tf.slice( # (batch_size, memory_length, output_dim) new_memory, (0, K.maximum(0, self.memory_len + self.target_len - seq_len - memory_length), 0), (batch_size, K.minimum(self.memory_len, memory_length), self.output_dim), ) return old_memory
def call(self, inputs): w = self.kernel kernel_shape = K.int_shape(self.kernel) if self.renormalize: w = K.reshape(w, [-1, kernel_shape[-1]]) sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) else: sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) sigma = K.reshape(sigma, (self.number_of_classes, 1, 1)) self.add_update(K.update(self.u, u_bar)) kernel = self.kernel self.kernel = self.kernel / sigma outputs = super(SNCondtionalDense, self).call(inputs) self.kernel = kernel return outputs
def call(self, inputs): kernel_shape = K.int_shape(self.kernel) if not self.renormalize: w = K.reshape(self.kernel, (kernel_shape[0], kernel_shape[1] * kernel_shape[2] * kernel_shape[3], kernel_shape[-1])) sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) sigma = K.reshape(sigma, (self.number_of_classes, 1, 1, 1, 1)) else: w = K.reshape(self.kernel, (-1, kernel_shape[-1])) sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) self.add_update(K.update(self.u, u_bar)) kernel = self.kernel self.kernel = self.kernel / sigma outputs = super(SNConditionalConv2D, self).call(inputs) self.kernel = kernel return outputs
def call(self, inputs): kernel_shape = K.int_shape(self.kernel) if self.renormalize: w = K.reshape(self.kernel, (-1, kernel_shape[-1])) sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) else: w = tf.transpose(self.kernel, (0, 3, 1, 2)) w = K.reshape(w, [-1, kernel_shape[1] * kernel_shape[2]]) w = K.expand_dims(w, axis=-1) sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) sigma = K.reshape(sigma, [kernel_shape[0], 1, 1, kernel_shape[-1]]) self.add_update(K.update(self.u, u_bar)) kernel = self.kernel self.kernel = self.kernel / sigma outputs = super(SNConditionalDepthwiseConv2D, self).call(inputs) self.kernel = kernel return outputs
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] shapes = [K.int_shape(p) for p in params] prev_grads = [ K.zeros(shape, name='prev_grad_' + str(i)) for (i, shape) in enumerate(shapes) ] self.weights = [self.iterations] + prev_grads for p, g, pg in zip(params, grads, prev_grads): new_p = p - self.lr * g + self.kd * (g - pg) self.updates.append(K.update(pg, g)) self.updates.append(K.update(p, new_p)) return self.updates
def inject(self, model): """Inject the Lookahead algorithm for the given model. The following code is modified from keras's _make_train_function method. See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497 """ if not hasattr(model, 'train_function'): raise RuntimeError('You must compile your model before using it.') model._check_trainable_weights_consistency() if model.train_function is None: inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights) if model._uses_dynamic_learning_phase(): inputs += [K.learning_phase()] fast_params = model._collected_trainable_weights with K.name_scope('training'): with K.name_scope(model.optimizer.__class__.__name__): training_updates = model.optimizer.get_updates( params=fast_params, loss=model.total_loss) slow_params = [K.variable(p) for p in fast_params] fast_updates = (model.updates + training_updates + model.metrics_updates) slow_updates, copy_updates = [], [] for p, q in zip(fast_params, slow_params): slow_updates.append(K.update(q, q + self.alpha * (p - q))) copy_updates.append(K.update(p, q)) # Gets loss and metrics. Updates weights at each call. fast_train_function = K.function(inputs, [model.total_loss] + model.metrics_tensors, updates=fast_updates, name='fast_train_function', **model._function_kwargs) def F(inputs): self.count += 1 R = fast_train_function(inputs) if self.count % self.k == 0: K.batch_get_value(slow_updates) K.batch_get_value(copy_updates) return R model.train_function = F
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 beta_1_t = K.pow(self.beta_1, t) beta_2_t = K.pow(self.beta_2, t) rho = 2 / (1 - self.beta_2) - 1 rho_t = rho - 2 * t * beta_2_t / (1 - beta_2_t) r_t = K.sqrt( K.relu(rho_t - 4) * K.relu(rho_t - 2) * rho / ((rho - 4) * (rho - 2) * rho_t)) flag = K.cast(rho_t > 4, K.floatx()) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] self.weights = [self.iterations] + ms + vs for p, g, m, v in zip(params, grads, ms, vs): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) mhat_t = m_t / (1 - beta_1_t) vhat_t = K.sqrt(v_t / (1 - beta_2_t)) p_t = p - lr * mhat_t * (flag * r_t / (vhat_t + self.epsilon) + (1 - flag)) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def W_bar(self): # Spectrally Normalized Weight W_mat = K.permute_dimensions( self.kernel, (3, 2, 0, 1)) # (h, w, i, o) => (o, i, h, w) W_mat = K.reshape(W_mat, [K.shape(W_mat)[0], -1]) # (o, i * h * w) if not self.Ip >= 1: raise ValueError( "The number of power iterations should be positive integer") _u = self.u _v = None for _ in range(self.Ip): _v = _l2normalize(K.dot(_u, W_mat)) _u = _l2normalize(K.dot(_v, K.transpose(W_mat))) sigma = K.sum(K.dot(_u, W_mat) * _v) K.update(self.u, K.in_train_phase(_u, self.u)) return self.kernel / sigma
def compute_wave_function_gradient_covariance_inverse_multiplication_with_iterative_solver( self, complex_vector, wave_function_jacobian_minus_mean=None): complex_vector = tf.squeeze(complex_vector) num_of_complex_params_t = tf.shape(complex_vector)[:1] if wave_function_jacobian_minus_mean is None: def wave_function_gradient_covariance_vector_product( complex_vector): return self.get_stochastic_reconfiguration_matrix_vector_product_via_jvp( complex_vector) else: def wave_function_gradient_covariance_vector_product(v): return tf.matmul( wave_function_jacobian_minus_mean, tf.matmul(wave_function_jacobian_minus_mean, v), adjoint_a=True) / self.batch_size + self.diag_shift * v operator = Operator( shape=tf.concat([num_of_complex_params_t] * 2, axis=0), dtype=self.predictions_keras_model.output.dtype, apply=wave_function_gradient_covariance_vector_product) conjugate_gradient_res = conjugate_gradient( operator, complex_vector, tol=self.conjugate_gradient_tol, max_iter=self.iterative_solver_max_iterations) updated_conjugate_gradient_iterations = K.update( self.conjugate_gradient_iterations, conjugate_gradient_res.i) updated_conjugate_gradient_residual_norm = K.update( self.conjugate_gradient_residual_norm, float_norm(conjugate_gradient_res.r)) with tf.control_dependencies([ updated_conjugate_gradient_iterations, updated_conjugate_gradient_residual_norm ]): flat_gradient = tf.stop_gradient( tf.reshape(conjugate_gradient_res.x, (-1, 1))) return flat_gradient
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] wd = self.wd * self.wd_normalizer # decoupled weight decay (4/6) lr = self.lr if self.initial_decay > 0: lr = lr * (1. / (1. + self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) eta_t = lr / self.init_lr # decoupled weight decay (5/6) with ops.control_dependencies( [state_ops.assign_add(self.iterations, 1)]): t = math_ops.cast(self.iterations, K.floatx()) """Bias corrections according to the Adam paper.""" lr_t = lr * (K.sqrt(1. - math_ops.pow(self.beta_2, t)) / (1. - math_ops.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] self.weights = [self.iterations] + ms + vs for p, g, m, v in zip(params, grads, ms, vs): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g) p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) p_t -= eta_t * wd * p # decoupled weight decay (6/6) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def apply_complex_gradient(self, flat_gradient): conj_flat_gradient = tf.conj(flat_gradient) real_gradients = column_to_tensors(self.model_real_weights, tf.math.real(conj_flat_gradient)) imag_gradients = column_to_tensors(self.model_imag_weights, tf.math.imag(conj_flat_gradient)) updates = [] for p, g in zip(self.model_real_weights + self.model_imag_weights, real_gradients + imag_gradients): new_p = p + self.lr * g if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) updates.append(K.update(p, new_p)) return updates
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] # decoupled weight decay (4/6) wd = self.wd lr = self.lr if self.initial_decay > 0: lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay)))) # decoupled weight decay (5/6) eta_t = lr / self.init_lr t = K.cast(self.iterations, K.floatx()) + 1 lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] self.weights = [self.iterations] + ms + vs for p, g, m, v in zip(params, grads, ms, vs): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) # decoupled weight decay (6/6) p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) - eta_t * wd * p self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def call(self, inputs): if self.conv_singular: sigma, u_bar = max_singular_val_for_convolution( self.kernel, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations, padding=self.padding, strides=self.strides, data_format=self.data_format) kernel_sn = self.kernel / sigma self.add_update(K.update(self.u, u_bar)) else: kernel_shape = K.int_shape(self.kernel) w = K.reshape(self.kernel, (kernel_shape[0] * kernel_shape[1] * kernel_shape[2], kernel_shape[3])) sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) w_sn = w / sigma kernel_sn = K.reshape(w_sn, kernel_shape) self.add_update(K.update(self.u, u_bar)) kernel = self.kernel self.kernel = kernel_sn outputs = super(SNConv2D, self).call(inputs) self.kernel = kernel return outputs
def get_updates(self, loss, params): self.updates = [ K.update_add(self.iterations, 1), K.update_add(self.optimizer.iterations, K.constant(self.cond, "int64")) ] # accumulate gradients self.accum_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] grads = self.get_gradients(loss, params) for g, ag in zip(grads, self.accum_grads): self.updates.append(K.update(ag, K.switch(self.cond, ag * 0, ag + g))) self.updates.extend(self.optimizer.get_updates()[1:]) self.weights.extend(self.optimizer.weights) return self.updates
def call(self, inputs): w = self.embeddings sigma, u_bar = max_singular_val( w, self.u, fully_differentiable=self.fully_diff_spectral, ip=self.spectral_iterations) w_sn = w / sigma kernel_sn = w_sn self.add_update(K.update(self.u, u_bar)) embeddings = self.embeddings self.embeddings = kernel_sn outputs = super(SNEmbeding, self).call(inputs) self.embeddings = embeddings return outputs
def call(self, inputs, mask=None, training=None, initial_state=None, constants=None): if isinstance(inputs, list): inputs = inputs[0] if initial_state is not None: pass elif self.stateful: initial_state = self.states else: initial_state = self.get_initial_state(inputs) if isinstance(mask, list): mask = mask[0] if len(initial_state) != len(self.states): raise ValueError('Layer has ' + str(len(self.states)) + ' states but was passed ' + str(len(initial_state)) + ' initial states.') timesteps = self.niter kwargs = {} if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') def step(inputs, states): constants = states[-self._num_constants:] states = states[:-self._num_constants] return self.cell.call(inputs, states, constants=constants, **kwargs) else: def step(inputs, states): return self.cell.call(inputs, states, **kwargs) # Augment the RNN cell with the likelihood gradient def augmented_step(x, states): prediction = tf.stop_gradient(self.output_layer.call(states[0])) grad = tf.gradients(self.likelihood_fn(x, prediction), x)[0] return step(tf.concat([x, grad], axis=-1), states) last_output, outputs, states = rim(augmented_step, inputs, initial_state, constants=constants, niter=timesteps) if self.stateful: updates = [] for i in range(len(states)): updates.append(K.update(self.states[i], states[i])) self.add_update(updates, inputs=True) if self.return_sequences: output = outputs ni, nb, a, b, c, d = output.shape shape = [-1, a, b, c, d] output = self.output_layer.call(tf.reshape(output, shape)) output = tf.reshape(output, [-1, nb, a, b, c, output.shape[-1]]) else: output = last_output output = self.output_layer.call(output) if self.return_state: if not isinstance(states, (list, tuple)): states = [states] else: states = list(states) return [output] + states else: return output
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay)))) t = K.cast(self.iterations, K.floatx()) + 1 ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] self.weights = [self.iterations] + ms + vs for p, g, m, v in zip(params, grads, ms, vs): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) beta2_t = self.beta_2 ** t N_sma_max = 2 / (1 - self.beta_2) - 1 N_sma = N_sma_max - 2 * t * beta2_t / (1 - beta2_t) # apply weight decay if self.weight_decay != 0.: p_wd = p - self.weight_decay * lr * p else: p_wd = None if p_wd is None: p_ = p else: p_ = p_wd def gt_path(): step_size = lr * K.sqrt( (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - self.beta_1 ** t) denom = K.sqrt(v_t) + self.epsilon p_t = p_ - step_size * (m_t / denom) return p_t def lt_path(): step_size = lr / (1 - self.beta_1 ** t) p_t = p_ - step_size * m_t return p_t p_t = K.switch(N_sma > 5, gt_path, lt_path) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates
def call(self, inputs, mask=None, training=None, initial_state=None, constants=None): # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. if isinstance(inputs, list): inputs = inputs[0] if initial_state is not None: pass elif self.stateful: initial_state = self.states else: initial_state = self.get_initial_state(inputs) if isinstance(mask, list): mask = mask[0] if len(initial_state) != len(self.states): raise ValueError('Layer has ' + str(len(self.states)) + ' states but was passed ' + str(len(initial_state)) + ' initial states.') timesteps = K.int_shape(inputs)[1] kwargs = {} if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') def step(inputs, states): constants = states[-self._num_constants:] states = states[:-self._num_constants] return self.cell.call(inputs, states, constants=constants, **kwargs) else: def step(inputs, states): return self.cell.call(inputs, states, **kwargs) last_output, outputs, states = K.rnn(step, inputs, initial_state, constants=constants, go_backwards=self.go_backwards, mask=mask, input_length=timesteps) if self.stateful: updates = [] for i in range(len(states)): updates.append(K.update(self.states[i], states[i])) self.add_update(updates) if self.return_sequences: output = outputs else: output = last_output if self.return_state: if not isinstance(states, (list, tuple)): states = [states] else: states = list(states) return [output] + states else: return output
def call(self, inputs, mask=None, training=None, initial_state=None, constants=None): # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. inputs, initial_state, constants = self._process_inputs( inputs, initial_state, constants) if isinstance(mask, list): mask = mask[0] timesteps = K.int_shape(inputs)[1] kwargs = {} if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') def step(inputs, states): constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type return self.cell.call(inputs, states, constants=constants, **kwargs) else: def step(inputs, states): return self.cell.call(inputs, states, **kwargs) last_output, outputs, states = K.rnn(step, inputs, initial_state, constants=constants, go_backwards=self.go_backwards, mask=mask, input_length=timesteps) if self.stateful: updates = [ K.update(self_state, state) for self_state, state in zip(self.states, states) ] self.add_update(updates) if self.return_sequences: output = outputs else: output = last_output if self.return_state: if not isinstance(states, (list, tuple)): states = [states] else: states = list(states) return [output] + states else: return output
def call(self, inputs, mask=None, training=None, initial_state=None, constants=None): # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. if isinstance(inputs, list): inputs = inputs[0] if initial_state is not None: pass elif self.stateful: initial_state = self.states else: initial_state = self.get_initial_state(inputs) if isinstance(mask, list): mask = mask[0] if len(initial_state) != len(self.states): raise ValueError('Layer has ' + str(len(self.states)) + ' states but was passed ' + str(len(initial_state)) + ' initial states.') timesteps = K.int_shape(inputs)[1] kwargs = {} if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') def step(inputs, states): constants = states[-self._num_constants:] states = states[:-self._num_constants] return self.cell.call(inputs, states, constants=constants, **kwargs) else: def step(inputs, states): return self.cell.call(inputs, states, **kwargs) last_output, outputs, states = K.rnn(step, inputs, initial_state, constants=constants, go_backwards=self.go_backwards, mask=mask, input_length=timesteps) if self.stateful: updates = [] for i in range(len(states)): updates.append(K.update(self.states[i], states[i])) self.add_update(updates, inputs=True) if self.return_sequences: output = outputs else: output = last_output # Properly set learning phase if getattr(last_output, '_uses_learning_phase', False): output._uses_learning_phase = True if self.return_state: if not isinstance(states, (list, tuple)): states = [states] else: states = list(states) return [output] + states else: return output
def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] t = K.cast(self.iterations, K.floatx()) + 1 lr = K.switch( t <= self.warmup_steps, self.lr * (t / self.warmup_steps), self.min_lr + (self.lr - self.min_lr) * (1.0 - K.minimum(t, self.decay_steps) / self.decay_steps), ) lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))) ms = [ K.zeros(K.int_shape(p), dtype=K.dtype(p), name='m_{}'.format(i)) for i, p in enumerate(params) ] vs = [ K.zeros(K.int_shape(p), dtype=K.dtype(p), name='v_{}'.format(i)) for i, p in enumerate(params) ] if self.amsgrad: vhats = [ K.zeros(K.int_shape(p), dtype=K.dtype(p), name='vh_{}'.format(i)) for i, p in enumerate(params) ] else: vhats = [ K.zeros(1, dtype=K.dtype(p), name='vh_{}'.format(i)) for i, p in enumerate(params) ] self.weights = [self.iterations] + ms + vs + vhats for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) if self.amsgrad: vhat_t = K.maximum(vhat, v_t) p_t = m_t / (K.sqrt(vhat_t) + self.epsilon) self.updates.append(K.update(vhat, vhat_t)) else: p_t = m_t / (K.sqrt(v_t) + self.epsilon) if self.initial_weight_decay > 0.0: if self.weight_decay_pattern is None: p_t += self.weight_decay * p else: for pattern in self.weight_decay_pattern: if pattern in p.name: p_t += self.weight_decay * p break p_t = p - lr_t * p_t self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) new_p = p_t if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates