def apply_dense_on_grasssmann(grad_clip, grad_on_grassmann, grad_on_oblique, var, learning_rate, times, delta): a = tf.maximum(delta, 1 / tf.log((tf.log((times + 2))))) n = gutils.unit(gutils.grassmann_project( var, grad_on_oblique)) * gutils.norm(grad_on_grassmann) b_1 = 2 * (1 - a) * gutils.xTy(grad_on_grassmann, n) b_2 = gutils.norm(grad_on_grassmann) b = b_1 / (b_2 + 1e-5) if grad_clip != None: h = learning_rate * (a * grad_on_grassmann + b * n) h = -1 * gutils.clip_by_norm(h, grad_clip) else: h = -1 * learning_rate * (a * grad_on_grassmann + b * n) var_update = gutils.grassmann_retrction(var, h) return var_update
def _apply_dense_on_oblique(grad_clip, grad_on_grassmann, grad_on_oblique, var, learning_rate, times, delta): a = torch.max(delta, 1 / torch.log(times + 2)) # a=0.5 n = gutils.unit(gutils.oblique_project( var, grad_on_grassmann)) * gutils.norm(grad_on_oblique) b_1 = 2 * (1 - a) * gutils.xTy(grad_on_oblique, n) b_2 = gutils.norm(grad_on_oblique) b = b_1 / (b_2 + 1e-5) if grad_clip != None: h = -1 * learning_rate * (a * grad_on_oblique + b * n) h = gutils.clip_by_norm(h, grad_clip) else: h = -1 * learning_rate * (a * grad_on_oblique + b * n) var_update = gutils.oblique_retrction(var, h) return var_update
def _apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") unity, _ = unit(var) # for numerical stability h = gproj(unity, grad) if self._grad_clip_t != None: h_hat = clip_by_norm(h, self._grad_clip_t) else: h_hat = h mnew = self._beta1_t * m + (1.0 - self._beta1_t) * h_hat vnew = self._beta2_t * v + (1.0 - self._beta2_t) * xTy(h_hat, h_hat) alpha = tf.sqrt(1 - self._beta2_power) / (1. - self._beta1_power) deltas = (-alpha * self._lr_t) * mnew / tf.sqrt(vnew + self._epsilon_t) var_update = tf.assign(var, gexp(unity, deltas)) m_update = tf.assign(m, gpt2(unity, mnew, deltas)) v_update = tf.assign(v, vnew) return tf.group(*[var_update, m_update, v_update])
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: grassmann = group['grassmann'] if grassmann: beta1 = group['momentum'] beta2 = group['beta2'] epsilon = group['epsilon'] grad_clip = group['grad_clip'] omega = group['omega'] for p in group['params']: if p.grad is None: continue unity,_ = unit(p.data.view(p.size()[0],-1)) g = p.grad.data.view(p.size()[0],-1) if omega != 0: # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c # dL/dY=2(YY'Y-Y) g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity) h = gproj(unity, g) if grad_clip is not None: h_hat = clip_by_norm(h, grad_clip) else: h_hat = h param_state = self.state[p] if 'm_buffer' not in param_state: size=p.size() param_state['m_buffer'] = torch.zeros([size[0], int(np.prod(size[1:]))]) param_state['v_buffer'] = torch.zeros([size[0], 1]) if p.is_cuda: param_state['m_buffer'] = param_state['m_buffer'].cuda() param_state['v_buffer'] = param_state['v_buffer'].cuda() param_state['beta1_power'] = beta1 param_state['beta2_power'] = beta2 m = param_state['m_buffer'] v = param_state['v_buffer'] beta1_power = param_state['beta1_power'] beta2_power = param_state['beta2_power'] mnew = beta1*m + (1.0-beta1)*h_hat vnew = beta2*v + (1.0-beta2)*xTy(h_hat,h_hat) alpha = np.sqrt(1.-beta2_power) / (1.-beta1_power) deltas = mnew / vnew.add(epsilon).sqrt() deltas.mul_(-alpha*group['lr']) p.data.copy_(gexp(unity, deltas).view(p.size())) m.copy_(gpt2(unity, mnew, deltas)) v.copy_(vnew) param_state['beta1_power']*=beta1 param_state['beta2_power']*=beta2 else: momentum = group['momentum'] weight_decay = group['weight_decay'] dampening = group['dampening'] nesterov = group['nesterov'] for p in group['params']: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = d_p.clone() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf p.data.add_(-group['lr'], d_p) return loss