Ejemplo n.º 1
0
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.learning_rate
        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay *
                             K.cast(self.iterations, K.dtype(self.decay))))

        t = K.cast(self.iterations, K.floatx()) + 1
        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                     (1. - K.pow(self.beta_1, t)))

        ms = [
            K.zeros(K.int_shape(p), dtype=K.dtype(p), name='m_' + str(i))
            for (i, p) in enumerate(params)
        ]
        vs = [
            K.zeros(K.int_shape(p), dtype=K.dtype(p), name='v_' + str(i))
            for (i, p) in enumerate(params)
        ]

        if self.amsgrad:
            vhats = [
                K.zeros(K.int_shape(p),
                        dtype=K.dtype(p),
                        name='vhat_' + str(i)) for (i, p) in enumerate(params)
            ]
        else:
            vhats = [
                K.zeros(1, name='vhat_' + str(i)) for i in range(len(params))
            ]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g - m_t)
            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, vhat_t))
            else:
                p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

            self.updates.append(K.update(m, m_t))
            self.updates.append(K.update(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates
Ejemplo n.º 2
0
        def get_updates(self, loss, params):
            # 是否更新
            cond = K.equal(self.iterations % self.grad_accum_steps, 0)
            cond = K.cast(cond, K.floatx())
            # 获取梯度
            grads = self.get_gradients(loss, params)
            self.accum_grads = [
                K.zeros(shape=K.int_shape(p),
                        dtype=K.dtype(p),
                        name='accum_grad_{}'.format(i))
                for i, p in enumerate(params)
            ]

            old_update = K.update

            def new_update(x, new_x):
                new_x = cond * new_x + (1 - cond) * x
                return old_update(x, new_x)

            K.update = new_update
            updates = super(NewOptimizer, self).get_updates(loss, params)
            K.update = old_update

            # 累计更新
            with K.control_dependencies(updates):
                acc_updates = [
                    K.update(ag, g + (1 - cond) * ag)
                    for ag, g in zip(self.accum_grads, grads)
                ]

            return acc_updates