def _apply_dense_on_obilique_with_noise(self, grad, var, seed): g = gutils.obilique_project(var, grad) g_norm = gutils.norm(g) if g_norm >= 1 / (self._times): a = 1 - 1 / (tf.square(self._times) * tf.square(g_norm)) else: a = 1 / tf.square(self._times) b = 1 / tf.square(self._times) dim = grad.get_shape()[0] noise = tf.truncated_normal([dim, dim], mean=0.0, stddev=1.0, dtype=tf.float32, seed=seed, name="random_noise") if self._grad_clip == None: h = -self._learning_rate_t * (a * g + b * noise) else: h = -self._learning_rate_t * (a * g + b * noise) h = gutils.clip_by_norm(h, self._grad_clip_t) var_new = gutils.grassmann_retrction(var, h) return var_new
def _apply_dense_on_oblique_with_noise(grad_clip, grad, var, seed, learning_rate, times): g = gutils.oblique_project(var, grad) g_norm = gutils.norm(g) #a = tf.minimum(1 - 1 / (tf.square(times + 1) * tf.square(g_norm) + 1e-5), 1 / tf.square(times + 1)) a = 1.0 b = 1 / (tf.square(times + 1)) dim = tf.convert_to_tensor(grad.get_shape()[0], dtype=tf.int32) noise = tf.truncated_normal([dim, 1], mean=0.0, stddev=0.0001, dtype=tf.float32, seed=seed, name="random_noise") if grad_clip == None: h = -1 * learning_rate * (a * g + b * noise) else: h = -1 * learning_rate * (a * g + b * noise) h = gutils.clip_by_norm(h, grad_clip) var_new = gutils.grassmann_retrction(var, h) return var_new
def _apply_dense_on_oblique_o(grad_clip, grad_on_oblique, var, learning_rate, times, delta): a = tf.maximum(delta, 1) #/(tf.log(times+2)) if grad_clip != None: h = -1 * learning_rate * (a * grad_on_oblique) h = gutils.clip_by_norm(h, grad_clip) else: h = -1 * learning_rate * (a * grad_on_oblique) var_update = gutils.oblique_retrction(var, h) return var_update
def apply_dense_on_grasssmann_g(grad_clip, grad_on_grassmann, var, learning_rate, times, delta): a = tf.maximum(delta, 1) #/ (tf.log(times+2)) if grad_clip != None: h = learning_rate * (a * grad_on_grassmann) h = -1 * gutils.clip_by_norm(h, grad_clip) else: h = -1 * learning_rate * a * grad_on_grassmann var_update = gutils.grassmann_retrction(var, h) return var_update
def _apply_dense(self, grad, var): mom = self.get_slot(var, "momentum") unity, _ = unit(var) # for numerical stability h = gproj(unity, grad) if self._grad_clip_t != None: h_hat = clip_by_norm(h, self._grad_clip_t) else: h_hat = h mom_new = self._momentum_t * mom - self._learning_rate_t * h_hat var_update = tf.assign(var, gexp(unity, mom_new)) mom_update = tf.assign(mom, gpt(unity, mom_new)) return tf.group(*[var_update, mom_update])
def apply_dense_on_grasssmann(grad_clip, grad_on_grassmann, grad_on_oblique, var, learning_rate, times, delta): a = tf.maximum(delta, 1 / tf.log((tf.log((times + 2))))) n = gutils.unit(gutils.grassmann_project( var, grad_on_oblique)) * gutils.norm(grad_on_grassmann) b_1 = 2 * (1 - a) * gutils.xTy(grad_on_grassmann, n) b_2 = gutils.norm(grad_on_grassmann) b = b_1 / (b_2 + 1e-5) if grad_clip != None: h = learning_rate * (a * grad_on_grassmann + b * n) h = -1 * gutils.clip_by_norm(h, grad_clip) else: h = -1 * learning_rate * (a * grad_on_grassmann + b * n) var_update = gutils.grassmann_retrction(var, h) return var_update
def _apply_dense_on_oblique(grad_clip, grad_on_grassmann, grad_on_oblique, var, learning_rate, times, delta): a = torch.max(delta, 1 / torch.log(times + 2)) # a=0.5 n = gutils.unit(gutils.oblique_project( var, grad_on_grassmann)) * gutils.norm(grad_on_oblique) b_1 = 2 * (1 - a) * gutils.xTy(grad_on_oblique, n) b_2 = gutils.norm(grad_on_oblique) b = b_1 / (b_2 + 1e-5) if grad_clip != None: h = -1 * learning_rate * (a * grad_on_oblique + b * n) h = gutils.clip_by_norm(h, grad_clip) else: h = -1 * learning_rate * (a * grad_on_oblique + b * n) var_update = gutils.oblique_retrction(var, h) return var_update
def _apply_dense_on_oblique_with_noise(grad_clip, grad, var, learning_rate, times, variance): g = gutils.oblique_project(var, grad) #g_norm = gutils.norm(g) #a = tf.minimum(1 - 1 / (tf.square(times + 1) * tf.square(g_norm) + 1e-5), 1 / tf.square(times + 1)) a = 1.0 b = 1 / torch.square(times + 1) noise = variance * gutils.oblique_project(var, torch.randn(var.size()[0])) if grad_clip == None: h = -1 * learning_rate * (a * g + b * noise) else: h = -1 * learning_rate * (a * g + b * noise) h = gutils.clip_by_norm(h, grad_clip) var_new = gutils.grassmann_retrction(var, h) return var_new
def _apply_dense_on_obilique(self, grad_on_grassmann, grad_on_obilique, var): a = tf.maximum(self._delta_t, 1 / (tf.square(self._times))) b_1 = 2 * (1 - a) * tf.matmul( tf.transpose(grad_on_obilique), gutils.obilique_project(var, grad_on_grassmann)) b_2 = gutils.norm(gutils.obilique_project(grad_on_grassmann)) b = b_1 / b_2 if self._grad_clip != None: h = self._learning_rate_t * ( a * grad_on_obilique + b * gutils.obilique_project(var, grad_on_grassmann)) h = gutils.clip_by_norm(h, self._grad_clip_t) else: h = -self._learning_rate_t * ( a * grad_on_obilique + b * gutils.obilique_project(var, grad_on_grassmann)) var_update = gutils.obilique_retrction(var, h) return var_update
def _apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") unity, _ = unit(var) # for numerical stability h = gproj(unity, grad) if self._grad_clip_t != None: h_hat = clip_by_norm(h, self._grad_clip_t) else: h_hat = h mnew = self._beta1_t * m + (1.0 - self._beta1_t) * h_hat vnew = self._beta2_t * v + (1.0 - self._beta2_t) * xTy(h_hat, h_hat) alpha = tf.sqrt(1 - self._beta2_power) / (1. - self._beta1_power) deltas = (-alpha * self._lr_t) * mnew / tf.sqrt(vnew + self._epsilon_t) var_update = tf.assign(var, gexp(unity, deltas)) m_update = tf.assign(m, gpt2(unity, mnew, deltas)) v_update = tf.assign(v, vnew) return tf.group(*[var_update, m_update, v_update])
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: #momentum = group['momentum'] manifold = group['manifold'] if manifold != "None": grad_clip = group['grad_clip'] length = len(group['params']) for i in range(length): p_grassmann = group['params'][i] p_oblique = group['params'][i + length / 2] if p_grassmann.grad and p_oblique is None: continue unity_grassmann, _ = gutils.unit( p_grassmann.data.view(p_grassmann.size()[0], -1)) unity_oblique, _ = gutils.unit( p_oblique.data.view(p_grassmann.size()[0], -1)) grad_grassmann = p_grassmann.grad.data.view( p_grassmann.size()[0], -1) grad_oblique = p_grassmann.grad.data.view( p_oblique.size()[0], -1) # if omega != 0: # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c # dL/dY=2(YY'Y-Y) # g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity) h_grassmann = gutils.grassmann_project( unity_grassmann, grad_grassmann) h_oblique = gutils.oblique_project(unity_oblique, grad_oblique) if grad_clip is not None: h_hat_grassmann = gutils.clip_by_norm( h_grassmann, grad_clip) h_hat_oblique = gutils.clip_by_norm( h_oblique, grad_clip) else: h_hat_grassmann = h_grassmann h_hat_oblique = h_oblique # param_state = self.state[p] # if 'momentum_buffer' not in param_state: # param_state['momentum_buffer'] = torch.zeros(h_hat.size()) # if p.is_cuda: # param_state['momentum_buffer'] = param_state['momentum_buffer'].cuda() # mom = param_state['momentum_buffer'] # mom_new = momentum*mom - group['lr']*h_hat p_grassmann.data.copy_( gutils.grassmann_retrction( unity_grassmann, group['lr'] * h_hat_grassmann).view(p_grassmann.size())) p_oblique.data.copy_( gutils.oblique_retrction(unity_oblique, group['lr'] * h_hat_oblique).view( p_oblique.size())) elif manifold == "None": # This routine is from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py weight_decay = group['weight_decay'] #dampening = group['dampening'] #nesterov = group['nesterov'] for p in group['params']: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) #if momentum != 0: # param_state = self.state[p] # if 'momentum_buffer' not in param_state: # buf = param_state['momentum_buffer'] = d_p.clone() # else: # buf = param_state['momentum_buffer'] # buf.mul_(momentum).add_(1 - dampening, d_p) # if nesterov: # d_p = d_p.add(momentum, buf) # else: # d_p = buf p.data.add_(-group['lr'], d_p) else: raise ValueError("There is no such a manifold") return loss
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: momentum = group['momentum'] grassmann = group['grassmann'] if grassmann: grad_clip = group['grad_clip'] omega = group['omega'] for p in group['params']: if p.grad is None: continue unity,_ = unit(p.data.view(p.size()[0],-1)) g = p.grad.data.view(p.size()[0],-1) if omega != 0: # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c # dL/dY=2(YY'Y-Y) g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity) h = gproj(unity, g) if grad_clip is not None: h_hat = clip_by_norm(h, grad_clip) else: h_hat = h param_state = self.state[p] if 'momentum_buffer' not in param_state: param_state['momentum_buffer'] = torch.zeros(h_hat.size()) if p.is_cuda: param_state['momentum_buffer'] = param_state['momentum_buffer'].cuda() mom = param_state['momentum_buffer'] mom_new = momentum*mom - group['lr']*h_hat p.data.copy_(gexp(unity, mom_new).view(p.size())) mom.copy_(gpt(unity, mom_new)) else: # This routine is from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py weight_decay = group['weight_decay'] dampening = group['dampening'] nesterov = group['nesterov'] for p in group['params']: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = d_p.clone() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf p.data.add_(-group['lr'], d_p) return loss
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: grassmann = group['grassmann'] if grassmann: beta1 = group['momentum'] beta2 = group['beta2'] epsilon = group['epsilon'] grad_clip = group['grad_clip'] omega = group['omega'] for p in group['params']: if p.grad is None: continue unity,_ = unit(p.data.view(p.size()[0],-1)) g = p.grad.data.view(p.size()[0],-1) if omega != 0: # L=|Y'Y-I|^2/2=|YY'-I|^2/2+c # dL/dY=2(YY'Y-Y) g.add_(2*omega, torch.mm(torch.mm(unity, unity.t()), unity) - unity) h = gproj(unity, g) if grad_clip is not None: h_hat = clip_by_norm(h, grad_clip) else: h_hat = h param_state = self.state[p] if 'm_buffer' not in param_state: size=p.size() param_state['m_buffer'] = torch.zeros([size[0], int(np.prod(size[1:]))]) param_state['v_buffer'] = torch.zeros([size[0], 1]) if p.is_cuda: param_state['m_buffer'] = param_state['m_buffer'].cuda() param_state['v_buffer'] = param_state['v_buffer'].cuda() param_state['beta1_power'] = beta1 param_state['beta2_power'] = beta2 m = param_state['m_buffer'] v = param_state['v_buffer'] beta1_power = param_state['beta1_power'] beta2_power = param_state['beta2_power'] mnew = beta1*m + (1.0-beta1)*h_hat vnew = beta2*v + (1.0-beta2)*xTy(h_hat,h_hat) alpha = np.sqrt(1.-beta2_power) / (1.-beta1_power) deltas = mnew / vnew.add(epsilon).sqrt() deltas.mul_(-alpha*group['lr']) p.data.copy_(gexp(unity, deltas).view(p.size())) m.copy_(gpt2(unity, mnew, deltas)) v.copy_(vnew) param_state['beta1_power']*=beta1 param_state['beta2_power']*=beta2 else: momentum = group['momentum'] weight_decay = group['weight_decay'] dampening = group['dampening'] nesterov = group['nesterov'] for p in group['params']: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = d_p.clone() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf p.data.add_(-group['lr'], d_p) return loss