Example #1
0
 def on_train_batch_end(self, batch, logs=None):
     self.count += 1
     if self.slow_weights is None:
         with tf.control_dependencies(self.model.trainable_weights):
             self.slow_weights = []
             for fast_param in self.model.trainable_weights:
                 with ops.control_dependencies([fast_param]):
                     slow_param = tf.Variable(fast_param.initialized_value(),
                                              dtype=fast_param.dtype,
                                              trainable=False,
                                              name=fast_param.name.split(":")[0])
                 self.slow_weights.append(slow_param)
                 K.track_variable(slow_param)
     else:
         if self.count % self.k == 0:
             slow_ups, fast_ups = [], []
             for fast, slow in zip(self.model.trainable_weights,
                                   self.slow_weights):
                 slow_ups.append(K.update(slow, slow + self.alpha * (fast - slow)))
             with tf.control_dependencies(slow_ups):
                 for fast, slow in zip(self.model.trainable_weights,
                                       self.slow_weights):
                     fast_ups.append(K.update(fast, slow))
             K.batch_get_value(slow_ups)
             K.batch_get_value(fast_ups)
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay *
                             K.cast(self.iterations, K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        if self.amsgrad:
            vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        else:
            vhats = [K.zeros(1) for _ in params]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):

            # Learning rate multipliers
            if self.multipliers:
                multiplier = [
                    mult for mult in self.multipliers if mult in p.name
                ]
            else:
                multiplier = None
            if multiplier:
                new_lr_t = lr_t * self.multipliers[multiplier[0]]
                if self.debug_verbose:
                    print('Setting {} to learning rate {}'.format(
                        multiplier[0], new_lr_t))
                    print(K.get_value(new_lr_t))
            else:
                new_lr_t = lr_t
                if self.debug_verbose:
                    print('No change in learning rate {}'.format(p.name))
                    print(K.get_value(new_lr_t))
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                p_t = p - new_lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                p_t = p - new_lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay *
                             K.cast(self.iterations, K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1

        # Applies bounds on actual learning rate
        step_size = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                          (1. - K.pow(self.beta_1, t)))

        final_lr = self.final_lr * lr / self.base_lr
        lower_bound = final_lr * (1. - 1. / (self.gamma * t + 1.))
        upper_bound = final_lr * (1. + 1. / (self.gamma * t))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        if self.amsbound:
            vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        else:
            vhats = [K.zeros(1) for _ in params]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            # apply weight decay
            if self.weight_decay != 0.:
                g += self.weight_decay * K.stop_gradient(p)

            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)

            if self.amsbound:
                vhat_t = K.maximum(vhat, v_t)
                denom = (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                denom = (K.sqrt(v_t) + self.epsilon)

            # Compute the bounds
            step_size_p = step_size * K.ones_like(denom)
            step_size_p_bound = step_size_p / denom
            bounded_lr_t = m_t * K.minimum(
                K.maximum(step_size_p_bound, lower_bound), upper_bound)

            p_t = p - bounded_lr_t

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
Example #4
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        shapes = [K.int_shape(p) for p in params]
        prev_grads = [
            K.zeros(shape, name='prev_grad_' + str(i))
            for (i, shape) in enumerate(shapes)
        ]
        ds = [
            K.zeros(shape, name='d_' + str(i))
            for (i, shape) in enumerate(shapes)
        ]
        vs = [
            K.zeros(shape, name='v_' + str(i))
            for (i, shape) in enumerate(shapes)
        ]
        self.weights = [self.iterations] + ds + vs + prev_grads

        for p, g, pg, v, d in zip(params, grads, prev_grads, vs, ds):
            v_t = self.momentum * v - self.lr * g
            self.updates.append(K.update(v, v_t))

            d_t = self.momentum * d + (1 - self.momentum) * (g - pg)
            self.updates.append(K.update(d, d_t))
            self.updates.append(K.update(pg, g))

            new_p = p + v_t + self.kd * d_t
            self.updates.append(K.update(p, new_p))

        return self.updates
Example #5
0
 def _update_s_matrix_stats(self, num_of_complex_params_t, s):
     if not self.add_s_matrix_stats:
         return tf.no_op(), tf.no_op()
     abs_eigvals = tf.math.abs(tf.linalg.eigvalsh(s))
     tol = K.epsilon() * tf.cast(num_of_complex_params_t, abs_eigvals.dtype) * tf.math.reduce_max(
         tf.math.abs(s))  # see https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.matrix_rank.html
     filtered_eigvals = tf.boolean_mask(abs_eigvals, abs_eigvals > tol)
     updated_s_matrix_rank = K.update(self.s_matrix_rank, tf.count_nonzero(filtered_eigvals))
     updated_s_matrix_min_eigval = K.update(self.s_matrix_min_eigval, tf.math.reduce_min(filtered_eigvals))
     return updated_s_matrix_min_eigval, updated_s_matrix_rank
Example #6
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
                                                      K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        '''Bias corrections according to the Adam paper
        '''
        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):

            ####################################################
            # Add a lr multiplier for vars outside excluded_vars
            if p.name in self.excluded_vars:
                multiplied_lr_t = lr_t
            else:
                multiplied_lr_t = lr_t * self.lr_mult
            ###################################################

            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)

            '''Schedule multiplier eta_t = 1 for simple AdamW
            According to the AdamW paper, eta_t can be fixed, decay, or 
            also be used for warm restarts (AdamWR to come). 
            '''
            eta_t = 1.
            p_t = p - eta_t * (multiplied_lr_t * m_t / (K.sqrt(v_t) + self.epsilon))
            if self.weight_decay != 0:
                '''Normalized weight decay according to the AdamW paper
                '''
                w_d = self.weight_decay * K.sqrt(self.batch_size / (self.samples_per_epoch * self.epochs))
                p_t = p_t - eta_t * (w_d * p)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
    def get_updates_Padam(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        base_lr = self._optimizer.learning_rate
        if self.initial_decay > 0:
            base_lr = base_lr * (1. / (1. + self.decay * K.cast(
                self.iterations, K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        lr_t = base_lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                          (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        if self.amsgrad:
            vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        else:
            vhats = [K.zeros(1) for _ in params]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            if self._get_multiplier(p) is None:
                multiplier = 1.0
            else:
                multiplier = self._get_multiplier(p)
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                denom = (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                denom = (K.sqrt(v_t) + self.epsilon)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))

            # Partial momentum adaption.
            new_p = p - (lr_t * multiplier * (m_t /
                                              (denom**(self.partial * 2))))

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
Example #8
0
    def compile_discriminator_train_op(self):
        loss_list = []
        adversarial_loss = self.get_discriminator_adversarial_loss(
            self.generator_adversarial_objective)
        loss_list += adversarial_loss
        loss_list += self.get_gradient_penalty_loss()
        loss_list += self.additional_discriminator_losses()

        updates = []

        updates += self.collect_updates(self.discriminator)
        updates += self.collect_updates(self.generator)

        print(updates)
        updates += self.discriminator_optimizer.get_updates(
            params=self.discriminator.trainable_weights, loss=sum(loss_list))

        inputs = self.discriminator_input + self.additional_inputs_for_discriminator_train +\
                 self.generator_input + self.additional_inputs_for_generator_train

        lr_update = (self.lr_decay_schedule_discriminator(
            self.discriminator_optimizer.iterations) *
                     K.get_value(self.discriminator_optimizer.lr))
        updates.append(K.update(self.discriminator_optimizer.lr, lr_update))

        train_op = function(inputs + [K.learning_phase()],
                            [sum(loss_list)] + loss_list,
                            updates=updates)
        return train_op
Example #9
0
    def call(self, inputs, **kwargs):
        inputs, memory_length = inputs
        memory_length = K.cast(memory_length[0][0], 'int32')
        batch_size = K.cast(K.shape(inputs)[0], 'int32')
        seq_len = K.cast(K.shape(inputs)[1], 'int32')

        # Build new memory
        pad = K.tile(inputs[0:1, ...], (self.batch_size - batch_size, 1, 1))
        padded = K.concatenate([inputs, pad], axis=0)              # (self.batch_size, seq_len, output_dim)
        new_memory = K.concatenate([self.memory, padded], axis=1)  # (self.batch_size, self.memory_len + self.target_len + seq_len, ...)
        new_memory = tf.slice(                                     # (self.batch_size, self.memory_len + self.target_len, output_dim)
            new_memory,
            (0, seq_len, 0),
            (self.batch_size, self.memory_len + self.target_len, self.output_dim),
        )
        self.add_update(K.update(self.memory, new_memory), inputs)

        # Build output
        old_memory = tf.slice(                                     # (batch_size, memory_length, output_dim)
            new_memory,
            (0, K.maximum(0, self.memory_len + self.target_len - seq_len - memory_length), 0),
            (batch_size, K.minimum(self.memory_len, memory_length), self.output_dim),
        )

        return old_memory
    def call(self, inputs):
        w = self.kernel
        kernel_shape = K.int_shape(self.kernel)
        if self.renormalize:
            w = K.reshape(w, [-1, kernel_shape[-1]])
            sigma, u_bar = max_singular_val(
                w,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations)
        else:
            sigma, u_bar = max_singular_val(
                w,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations)
            sigma = K.reshape(sigma, (self.number_of_classes, 1, 1))

        self.add_update(K.update(self.u, u_bar))

        kernel = self.kernel
        self.kernel = self.kernel / sigma
        outputs = super(SNCondtionalDense, self).call(inputs)
        self.kernel = kernel

        return outputs
    def call(self, inputs):
        kernel_shape = K.int_shape(self.kernel)
        if not self.renormalize:
            w = K.reshape(self.kernel,
                          (kernel_shape[0], kernel_shape[1] * kernel_shape[2] *
                           kernel_shape[3], kernel_shape[-1]))
            sigma, u_bar = max_singular_val(
                w,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations)
            sigma = K.reshape(sigma, (self.number_of_classes, 1, 1, 1, 1))
        else:
            w = K.reshape(self.kernel, (-1, kernel_shape[-1]))
            sigma, u_bar = max_singular_val(
                w,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations)

        self.add_update(K.update(self.u, u_bar))

        kernel = self.kernel
        self.kernel = self.kernel / sigma
        outputs = super(SNConditionalConv2D, self).call(inputs)
        self.kernel = kernel

        return outputs
    def call(self, inputs):
        kernel_shape = K.int_shape(self.kernel)

        if self.renormalize:
            w = K.reshape(self.kernel, (-1, kernel_shape[-1]))

            sigma, u_bar = max_singular_val(
                w,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations)
        else:
            w = tf.transpose(self.kernel, (0, 3, 1, 2))
            w = K.reshape(w, [-1, kernel_shape[1] * kernel_shape[2]])
            w = K.expand_dims(w, axis=-1)
            sigma, u_bar = max_singular_val(
                w,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations)

            sigma = K.reshape(sigma, [kernel_shape[0], 1, 1, kernel_shape[-1]])

        self.add_update(K.update(self.u, u_bar))

        kernel = self.kernel
        self.kernel = self.kernel / sigma
        outputs = super(SNConditionalDepthwiseConv2D, self).call(inputs)
        self.kernel = kernel

        return outputs
Example #13
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        shapes = [K.int_shape(p) for p in params]
        prev_grads = [
            K.zeros(shape, name='prev_grad_' + str(i))
            for (i, shape) in enumerate(shapes)
        ]
        self.weights = [self.iterations] + prev_grads

        for p, g, pg in zip(params, grads, prev_grads):
            new_p = p - self.lr * g + self.kd * (g - pg)
            self.updates.append(K.update(pg, g))
            self.updates.append(K.update(p, new_p))

        return self.updates
Example #14
0
    def inject(self, model):
        """Inject the Lookahead algorithm for the given model.
        The following code is modified from keras's _make_train_function method.
        See: https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L497
        """
        if not hasattr(model, 'train_function'):
            raise RuntimeError('You must compile your model before using it.')

        model._check_trainable_weights_consistency()

        if model.train_function is None:
            inputs = (model._feed_inputs + model._feed_targets +
                      model._feed_sample_weights)
            if model._uses_dynamic_learning_phase():
                inputs += [K.learning_phase()]
            fast_params = model._collected_trainable_weights

            with K.name_scope('training'):
                with K.name_scope(model.optimizer.__class__.__name__):
                    training_updates = model.optimizer.get_updates(
                        params=fast_params, loss=model.total_loss)
                    slow_params = [K.variable(p) for p in fast_params]
                fast_updates = (model.updates + training_updates +
                                model.metrics_updates)

                slow_updates, copy_updates = [], []
                for p, q in zip(fast_params, slow_params):
                    slow_updates.append(K.update(q, q + self.alpha * (p - q)))
                    copy_updates.append(K.update(p, q))

                # Gets loss and metrics. Updates weights at each call.
                fast_train_function = K.function(inputs, [model.total_loss] +
                                                 model.metrics_tensors,
                                                 updates=fast_updates,
                                                 name='fast_train_function',
                                                 **model._function_kwargs)

                def F(inputs):
                    self.count += 1
                    R = fast_train_function(inputs)
                    if self.count % self.k == 0:
                        K.batch_get_value(slow_updates)
                        K.batch_get_value(copy_updates)
                    return R

                model.train_function = F
Example #15
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay *
                             K.cast(self.iterations, K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        beta_1_t = K.pow(self.beta_1, t)
        beta_2_t = K.pow(self.beta_2, t)
        rho = 2 / (1 - self.beta_2) - 1
        rho_t = rho - 2 * t * beta_2_t / (1 - beta_2_t)
        r_t = K.sqrt(
            K.relu(rho_t - 4) * K.relu(rho_t - 2) * rho / ((rho - 4) *
                                                           (rho - 2) * rho_t))
        flag = K.cast(rho_t > 4, K.floatx())

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            mhat_t = m_t / (1 - beta_1_t)
            vhat_t = K.sqrt(v_t / (1 - beta_2_t))
            p_t = p - lr * mhat_t * (flag * r_t / (vhat_t + self.epsilon) +
                                     (1 - flag))

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
Example #16
0
    def W_bar(self):
        # Spectrally Normalized Weight
        W_mat = K.permute_dimensions(
            self.kernel, (3, 2, 0, 1))  # (h, w, i, o) => (o, i, h, w)
        W_mat = K.reshape(W_mat, [K.shape(W_mat)[0], -1])  # (o, i * h * w)

        if not self.Ip >= 1:
            raise ValueError(
                "The number of power iterations should be positive integer")

        _u = self.u
        _v = None

        for _ in range(self.Ip):
            _v = _l2normalize(K.dot(_u, W_mat))
            _u = _l2normalize(K.dot(_v, K.transpose(W_mat)))

        sigma = K.sum(K.dot(_u, W_mat) * _v)

        K.update(self.u, K.in_train_phase(_u, self.u))
        return self.kernel / sigma
Example #17
0
    def compute_wave_function_gradient_covariance_inverse_multiplication_with_iterative_solver(
            self, complex_vector, wave_function_jacobian_minus_mean=None):
        complex_vector = tf.squeeze(complex_vector)
        num_of_complex_params_t = tf.shape(complex_vector)[:1]
        if wave_function_jacobian_minus_mean is None:

            def wave_function_gradient_covariance_vector_product(
                    complex_vector):
                return self.get_stochastic_reconfiguration_matrix_vector_product_via_jvp(
                    complex_vector)
        else:

            def wave_function_gradient_covariance_vector_product(v):
                return tf.matmul(
                    wave_function_jacobian_minus_mean,
                    tf.matmul(wave_function_jacobian_minus_mean, v),
                    adjoint_a=True) / self.batch_size + self.diag_shift * v

        operator = Operator(
            shape=tf.concat([num_of_complex_params_t] * 2, axis=0),
            dtype=self.predictions_keras_model.output.dtype,
            apply=wave_function_gradient_covariance_vector_product)
        conjugate_gradient_res = conjugate_gradient(
            operator,
            complex_vector,
            tol=self.conjugate_gradient_tol,
            max_iter=self.iterative_solver_max_iterations)
        updated_conjugate_gradient_iterations = K.update(
            self.conjugate_gradient_iterations, conjugate_gradient_res.i)
        updated_conjugate_gradient_residual_norm = K.update(
            self.conjugate_gradient_residual_norm,
            float_norm(conjugate_gradient_res.r))
        with tf.control_dependencies([
                updated_conjugate_gradient_iterations,
                updated_conjugate_gradient_residual_norm
        ]):
            flat_gradient = tf.stop_gradient(
                tf.reshape(conjugate_gradient_res.x, (-1, 1)))
        return flat_gradient
Example #18
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]
        wd = self.wd * self.wd_normalizer  # decoupled weight decay (4/6)

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (1. /
                       (1. + self.decay *
                        math_ops.cast(self.iterations, K.dtype(self.decay))))
        eta_t = lr / self.init_lr  # decoupled weight decay (5/6)

        with ops.control_dependencies(
            [state_ops.assign_add(self.iterations, 1)]):
            t = math_ops.cast(self.iterations, K.floatx())
        """Bias corrections according to the Adam paper."""
        lr_t = lr * (K.sqrt(1. - math_ops.pow(self.beta_2, t)) /
                     (1. - math_ops.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
            p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
            p_t -= eta_t * wd * p  # decoupled weight decay (6/6)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
Example #19
0
 def apply_complex_gradient(self, flat_gradient):
     conj_flat_gradient = tf.conj(flat_gradient)
     real_gradients = column_to_tensors(self.model_real_weights,
                                        tf.math.real(conj_flat_gradient))
     imag_gradients = column_to_tensors(self.model_imag_weights,
                                        tf.math.imag(conj_flat_gradient))
     updates = []
     for p, g in zip(self.model_real_weights + self.model_imag_weights,
                     real_gradients + imag_gradients):
         new_p = p + self.lr * g
         if getattr(p, 'constraint', None) is not None:
             new_p = p.constraint(new_p)
         updates.append(K.update(p, new_p))
     return updates
Example #20
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]
        # decoupled weight decay (4/6)
        wd = self.wd

        lr = self.lr
        if self.initial_decay > 0:
            lr *= (1. / (1. + self.decay * K.cast(self.iterations,
                                                  K.dtype(self.decay))))
        # decoupled weight decay (5/6)
        eta_t = lr / self.init_lr

        t = K.cast(self.iterations, K.floatx()) + 1
        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            # decoupled weight decay (6/6)
            p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) - eta_t * wd * p

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
    def call(self, inputs):
        if self.conv_singular:
            sigma, u_bar = max_singular_val_for_convolution(
                self.kernel,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations,
                padding=self.padding,
                strides=self.strides,
                data_format=self.data_format)
            kernel_sn = self.kernel / sigma
            self.add_update(K.update(self.u, u_bar))
        else:
            kernel_shape = K.int_shape(self.kernel)
            w = K.reshape(self.kernel,
                          (kernel_shape[0] * kernel_shape[1] * kernel_shape[2],
                           kernel_shape[3]))

            sigma, u_bar = max_singular_val(
                w,
                self.u,
                fully_differentiable=self.fully_diff_spectral,
                ip=self.spectral_iterations)

            w_sn = w / sigma

            kernel_sn = K.reshape(w_sn, kernel_shape)

            self.add_update(K.update(self.u, u_bar))

        kernel = self.kernel
        self.kernel = kernel_sn
        outputs = super(SNConv2D, self).call(inputs)
        self.kernel = kernel

        return outputs
Example #22
0
    def get_updates(self, loss, params):
        self.updates = [
            K.update_add(self.iterations, 1),
            K.update_add(self.optimizer.iterations, K.constant(self.cond, "int64"))
        ]

        # accumulate gradients
        self.accum_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        grads = self.get_gradients(loss, params)
        for g, ag in zip(grads, self.accum_grads):
            self.updates.append(K.update(ag, K.switch(self.cond, ag * 0, ag + g)))

        self.updates.extend(self.optimizer.get_updates()[1:])
        self.weights.extend(self.optimizer.weights)

        return self.updates
    def call(self, inputs):
        w = self.embeddings
        sigma, u_bar = max_singular_val(
            w,
            self.u,
            fully_differentiable=self.fully_diff_spectral,
            ip=self.spectral_iterations)
        w_sn = w / sigma
        kernel_sn = w_sn
        self.add_update(K.update(self.u, u_bar))

        embeddings = self.embeddings
        self.embeddings = kernel_sn
        outputs = super(SNEmbeding, self).call(inputs)
        self.embeddings = embeddings

        return outputs
    def call(self,
             inputs,
             mask=None,
             training=None,
             initial_state=None,
             constants=None):

        if isinstance(inputs, list):
            inputs = inputs[0]
        if initial_state is not None:
            pass
        elif self.stateful:
            initial_state = self.states
        else:
            initial_state = self.get_initial_state(inputs)

        if isinstance(mask, list):
            mask = mask[0]

        if len(initial_state) != len(self.states):
            raise ValueError('Layer has ' + str(len(self.states)) +
                             ' states but was passed ' +
                             str(len(initial_state)) + ' initial states.')

        timesteps = self.niter

        kwargs = {}
        if generic_utils.has_arg(self.cell.call, 'training'):
            kwargs['training'] = training

        if constants:
            if not generic_utils.has_arg(self.cell.call, 'constants'):
                raise ValueError('RNN cell does not support constants')

            def step(inputs, states):
                constants = states[-self._num_constants:]
                states = states[:-self._num_constants]
                return self.cell.call(inputs,
                                      states,
                                      constants=constants,
                                      **kwargs)
        else:

            def step(inputs, states):
                return self.cell.call(inputs, states, **kwargs)

        # Augment the RNN cell with the likelihood gradient
        def augmented_step(x, states):
            prediction = tf.stop_gradient(self.output_layer.call(states[0]))
            grad = tf.gradients(self.likelihood_fn(x, prediction), x)[0]
            return step(tf.concat([x, grad], axis=-1), states)

        last_output, outputs, states = rim(augmented_step,
                                           inputs,
                                           initial_state,
                                           constants=constants,
                                           niter=timesteps)
        if self.stateful:
            updates = []
            for i in range(len(states)):
                updates.append(K.update(self.states[i], states[i]))
            self.add_update(updates, inputs=True)

        if self.return_sequences:
            output = outputs
            ni, nb, a, b, c, d = output.shape
            shape = [-1, a, b, c, d]
            output = self.output_layer.call(tf.reshape(output, shape))
            output = tf.reshape(output, [-1, nb, a, b, c, output.shape[-1]])
        else:
            output = last_output
            output = self.output_layer.call(output)

        if self.return_state:
            if not isinstance(states, (list, tuple)):
                states = [states]
            else:
                states = list(states)
            return [output] + states
        else:
            return output
Example #25
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
                                                      K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)

            beta2_t = self.beta_2 ** t
            N_sma_max = 2 / (1 - self.beta_2) - 1
            N_sma = N_sma_max - 2 * t * beta2_t / (1 - beta2_t)

            # apply weight decay
            if self.weight_decay != 0.:
                p_wd = p - self.weight_decay * lr * p
            else:
                p_wd = None

            if p_wd is None:
                p_ = p
            else:
                p_ = p_wd

            def gt_path():
                step_size = lr * K.sqrt(
                    (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max /
                    (N_sma_max - 2)) / (1 - self.beta_1 ** t)

                denom = K.sqrt(v_t) + self.epsilon
                p_t = p_ - step_size * (m_t / denom)

                return p_t

            def lt_path():
                step_size = lr / (1 - self.beta_1 ** t)
                p_t = p_ - step_size * m_t

                return p_t

            p_t = K.switch(N_sma > 5, gt_path, lt_path)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
Example #26
0
  def call(self,
           inputs,
           mask=None,
           training=None,
           initial_state=None,
           constants=None):
    # note that the .build() method of subclasses MUST define
    # self.input_spec and self.state_spec with complete input shapes.
    if isinstance(inputs, list):
      inputs = inputs[0]
    if initial_state is not None:
      pass
    elif self.stateful:
      initial_state = self.states
    else:
      initial_state = self.get_initial_state(inputs)

    if isinstance(mask, list):
      mask = mask[0]

    if len(initial_state) != len(self.states):
      raise ValueError('Layer has ' + str(len(self.states)) +
                       ' states but was passed ' +
                       str(len(initial_state)) +
                       ' initial states.')
    timesteps = K.int_shape(inputs)[1]

    kwargs = {}
    if generic_utils.has_arg(self.cell.call, 'training'):
      kwargs['training'] = training

    if constants:
      if not generic_utils.has_arg(self.cell.call, 'constants'):
        raise ValueError('RNN cell does not support constants')

      def step(inputs, states):
        constants = states[-self._num_constants:]
        states = states[:-self._num_constants]
        return self.cell.call(inputs, states, constants=constants,
                              **kwargs)
    else:
      def step(inputs, states):
        return self.cell.call(inputs, states, **kwargs)
    
    last_output, outputs, states = K.rnn(step,
                                         inputs,
                                         initial_state,
                                         constants=constants,
                                         go_backwards=self.go_backwards,
                                         mask=mask,
                                         input_length=timesteps)
    if self.stateful:
      updates = []
      for i in range(len(states)):
        updates.append(K.update(self.states[i], states[i]))
      self.add_update(updates)

    if self.return_sequences:
      output = outputs
    else:
      output = last_output

    if self.return_state:
      if not isinstance(states, (list, tuple)):
        states = [states]
      else:
        states = list(states)
      return [output] + states
    else:
      return output
Example #27
0
    def call(self,
             inputs,
             mask=None,
             training=None,
             initial_state=None,
             constants=None):
        # note that the .build() method of subclasses MUST define
        # self.input_spec and self.state_spec with complete input shapes.
        inputs, initial_state, constants = self._process_inputs(
            inputs, initial_state, constants)

        if isinstance(mask, list):
            mask = mask[0]
        timesteps = K.int_shape(inputs)[1]

        kwargs = {}
        if generic_utils.has_arg(self.cell.call, 'training'):
            kwargs['training'] = training

        if constants:
            if not generic_utils.has_arg(self.cell.call, 'constants'):
                raise ValueError('RNN cell does not support constants')

            def step(inputs, states):
                constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
                states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
                return self.cell.call(inputs,
                                      states,
                                      constants=constants,
                                      **kwargs)
        else:

            def step(inputs, states):
                return self.cell.call(inputs, states, **kwargs)

        last_output, outputs, states = K.rnn(step,
                                             inputs,
                                             initial_state,
                                             constants=constants,
                                             go_backwards=self.go_backwards,
                                             mask=mask,
                                             input_length=timesteps)
        if self.stateful:
            updates = [
                K.update(self_state, state)
                for self_state, state in zip(self.states, states)
            ]
            self.add_update(updates)

        if self.return_sequences:
            output = outputs
        else:
            output = last_output

        if self.return_state:
            if not isinstance(states, (list, tuple)):
                states = [states]
            else:
                states = list(states)
            return [output] + states
        else:
            return output
  def call(self,
           inputs,
           mask=None,
           training=None,
           initial_state=None,
           constants=None):
    # note that the .build() method of subclasses MUST define
    # self.input_spec and self.state_spec with complete input shapes.
    if isinstance(inputs, list):
      inputs = inputs[0]
    if initial_state is not None:
      pass
    elif self.stateful:
      initial_state = self.states
    else:
      initial_state = self.get_initial_state(inputs)

    if isinstance(mask, list):
      mask = mask[0]

    if len(initial_state) != len(self.states):
      raise ValueError('Layer has ' + str(len(self.states)) +
                       ' states but was passed ' +
                       str(len(initial_state)) +
                       ' initial states.')
    timesteps = K.int_shape(inputs)[1]

    kwargs = {}
    if generic_utils.has_arg(self.cell.call, 'training'):
      kwargs['training'] = training

    if constants:
      if not generic_utils.has_arg(self.cell.call, 'constants'):
        raise ValueError('RNN cell does not support constants')

      def step(inputs, states):
        constants = states[-self._num_constants:]
        states = states[:-self._num_constants]
        return self.cell.call(inputs, states, constants=constants,
                              **kwargs)
    else:
      def step(inputs, states):
        return self.cell.call(inputs, states, **kwargs)

    last_output, outputs, states = K.rnn(step,
                                         inputs,
                                         initial_state,
                                         constants=constants,
                                         go_backwards=self.go_backwards,
                                         mask=mask,
                                         input_length=timesteps)
    if self.stateful:
      updates = []
      for i in range(len(states)):
        updates.append(K.update(self.states[i], states[i]))
      self.add_update(updates, inputs=True)

    if self.return_sequences:
      output = outputs
    else:
      output = last_output

    # Properly set learning phase
    if getattr(last_output, '_uses_learning_phase', False):
      output._uses_learning_phase = True

    if self.return_state:
      if not isinstance(states, (list, tuple)):
        states = [states]
      else:
        states = list(states)
      return [output] + states
    else:
      return output
Example #29
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        t = K.cast(self.iterations, K.floatx()) + 1

        lr = K.switch(
            t <= self.warmup_steps,
            self.lr * (t / self.warmup_steps),
            self.min_lr + (self.lr - self.min_lr) *
            (1.0 - K.minimum(t, self.decay_steps) / self.decay_steps),
        )

        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [
            K.zeros(K.int_shape(p), dtype=K.dtype(p), name='m_{}'.format(i))
            for i, p in enumerate(params)
        ]
        vs = [
            K.zeros(K.int_shape(p), dtype=K.dtype(p), name='v_{}'.format(i))
            for i, p in enumerate(params)
        ]
        if self.amsgrad:
            vhats = [
                K.zeros(K.int_shape(p),
                        dtype=K.dtype(p),
                        name='vh_{}'.format(i)) for i, p in enumerate(params)
            ]
        else:
            vhats = [
                K.zeros(1, dtype=K.dtype(p), name='vh_{}'.format(i))
                for i, p in enumerate(params)
            ]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                p_t = m_t / (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                p_t = m_t / (K.sqrt(v_t) + self.epsilon)

            if self.initial_weight_decay > 0.0:
                if self.weight_decay_pattern is None:
                    p_t += self.weight_decay * p
                else:
                    for pattern in self.weight_decay_pattern:
                        if pattern in p.name:
                            p_t += self.weight_decay * p
                            break
            p_t = p - lr_t * p_t

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates