def get_gradients(model, data_x): gs = [] for i in range(len(data_x)): i = torch.tensor(data_x[i], requires_grad=True, device=data_x.device) g = gradient(model(i)[0], i).detach() gs.append(g) return torch.stack(gs)
def it(): for i in x: a = torch.tensor(i, requires_grad=True) fx = f(a) for fxo in fx.view(-1): for k in range(n_derive): u = i.clone().normal_() fxo = gradient(fxo, a, create_graph=True) @ u if fxo.grad_fn is None: break # the derivative is strictly zero yield gradient(fxo, f.parameters(), retain_graph=(k < n_derive - 1))
def n_effective(f, x, n_derive=1): assert x.dtype == torch.float64 basis = expand_basis(None, (gradient(o, f.parameters(), retain_graph=True) for o in f(x).view(-1))) if n_derive <= 0: return basis.size(0) def it(): for i in x: a = torch.tensor(i, requires_grad=True) fx = f(a) for fxo in fx.view(-1): for k in range(n_derive): u = i.clone().normal_() fxo = gradient(fxo, a, create_graph=True) @ u if fxo.grad_fn is None: break # the derivative is strictly zero yield gradient(fxo, f.parameters(), retain_graph=(k < n_derive - 1)) while True: ws = expand_basis(basis, it()) if basis.size(0) == ws.size(0): return basis.size(0) basis = ws
def compute_h0(model, deltas, out=None): ''' Compute extensive H0 ''' parameters = [p for p in model.parameters() if p.requires_grad] Ntot = sum(p.numel() for p in parameters) if out is None: out = deltas.new_zeros(Ntot, Ntot) # da Delta_i db Delta_i for delta in deltas: g = gradient(delta, parameters, retain_graph=True) out.add_(g.view(-1, 1) * g.view(1, -1)) return out
def compute_kernels(f, xtr, xte): from hessian import gradient ktrtr = xtr.new_zeros(len(xtr), len(xtr)) ktetr = xtr.new_zeros(len(xte), len(xtr)) ktete = xtr.new_zeros(len(xte), len(xte)) params = [] current = [] for p in sorted(f.parameters(), key=lambda p: p.numel(), reverse=True): current.append(p) if sum(p.numel() for p in current) > 2e9 // (8 * (len(xtr) + len(xte))): if len(current) > 1: params.append(current[:-1]) current = current[-1:] else: params.append(current) current = [] if len(current) > 0: params.append(current) for i, p in enumerate(params): print("[{}/{}] [len={} numel={}]".format(i, len(params), len(p), sum(x.numel() for x in p)), flush=True) jtr = xtr.new_empty(len(xtr), sum(u.numel() for u in p)) # (P, N~) jte = xte.new_empty(len(xte), sum(u.numel() for u in p)) # (P, N~) for j, x in enumerate(xtr): jtr[j] = gradient(f(x[None]), p) # (N~) for j, x in enumerate(xte): jte[j] = gradient(f(x[None]), p) # (N~) ktrtr.add_(jtr @ jtr.t()) ktetr.add_(jte @ jtr.t()) ktete.add_(jte @ jte.t()) del jtr, jte return ktrtr, ktetr, ktete
def train_regular(f0, x, y, tau, max_walltime, alpha, loss, subf0, max_dgrad=math.inf, max_dout=math.inf): f = copy.deepcopy(f0) with torch.no_grad(): out0 = f0(x) if subf0 else 0 dt = 1 step_change_dt = 0 optimizer = ContinuousMomentum(f.parameters(), dt=dt, tau=tau) checkpoint_generator = loglinspace(0.01, 100) checkpoint = next(checkpoint_generator) wall = perf_counter() t = 0 converged = False out = f(x) grad = gradient(loss((out - out0) * y).mean(), f.parameters()) for step in itertools.count(): state = copy.deepcopy((f.state_dict(), optimizer.state_dict(), t)) while True: make_step(f, optimizer, dt, grad) t += dt current_dt = dt new_out = f(x) new_grad = gradient(loss((new_out - out0) * y).mean(), f.parameters()) dout = (out - new_out).mul(alpha).abs().max().item() if grad.norm() == 0 or new_grad.norm() == 0: dgrad = 0 else: dgrad = (grad - new_grad).norm().pow(2).div(grad.norm() * new_grad.norm()).item() if dgrad < max_dgrad and dout < max_dout: if dgrad < 0.5 * max_dgrad and dout < 0.5 * max_dout: dt *= 1.1 break dt /= 10 print("[{} +{}] [dt={:.1e} dgrad={:.1e} dout={:.1e}]".format(step, step - step_change_dt, dt, dgrad, dout), flush=True) step_change_dt = step f.load_state_dict(state[0]) optimizer.load_state_dict(state[1]) t = state[2] out = new_out grad = new_grad save = False if step == checkpoint: checkpoint = next(checkpoint_generator) assert checkpoint > step save = True if (alpha * (out - out0) * y >= 1).all() and not converged: converged = True save = True if save: state = { 'step': step, 'wall': perf_counter() - wall, 't': t, 'dt': current_dt, 'dgrad': dgrad, 'dout': dout, 'norm': sum(p.norm().pow(2) for p in f.parameters()).sqrt().item(), 'dnorm': sum((p0 - p).norm().pow(2) for p0, p in zip(f0.parameters(), f.parameters())).sqrt().item(), 'grad_norm': grad.norm().item(), } yield f, state, converged if converged: break if perf_counter() > wall + max_walltime: break if torch.isnan(out).any(): break
def run_kernel(args, ktrtr, ktetr, ktete, f, xtr, ytr, xte, yte): assert args.f0 == 1 dynamics = [] tau = args.tau_over_h * args.h if args.tau_alpha_crit is not None: tau *= min(1, args.tau_alpha_crit / args.alpha) for otr, _velo, _grad, state, _converged in train_kernel(ktrtr, ytr, tau, args.train_time, args.alpha, partial(loss_func_prime, args), args.max_dgrad, args.max_dout): state['train'] = { 'loss': loss_func(args, otr * ytr).mean().item(), 'aloss': args.alpha * loss_func(args, otr * ytr).mean().item(), 'err': (otr * ytr <= 0).double().mean().item(), 'nd': (args.alpha * otr * ytr < 1).long().sum().item(), 'dfnorm': otr.pow(2).mean().sqrt(), 'outputs': otr if args.save_outputs else None, 'labels': ytr if args.save_outputs else None, } print("[i={d[step]:d} t={d[t]:.2e} wall={d[wall]:.0f}] [dt={d[dt]:.1e} dgrad={d[dgrad]:.1e} dout={d[dout]:.1e}] [train aL={d[train][aloss]:.2e} err={d[train][err]:.2f} nd={d[train][nd]}]".format(d=state), flush=True) dynamics.append(state) c = torch.lstsq(otr.view(-1, 1), ktrtr).solution.flatten() if len(xte) > len(xtr): from hessian import gradient a = gradient(f(xtr) @ c, f.parameters()) ote = torch.stack([gradient(f(x[None]), f.parameters()) @ a for x in xte]) else: ote = ktetr @ c out = { 'dynamics': dynamics, 'train': { 'outputs': otr, 'labels': ytr, }, 'test': { 'outputs': ote, 'labels': yte, }, 'kernel': { 'train': { 'value': ktrtr.cpu() if args.store_kernel == 1 else None, 'diag': ktrtr.diag(), 'mean': ktrtr.mean(), 'std': ktrtr.std(), 'norm': ktrtr.norm(), }, 'test': { 'value': ktete.cpu() if args.store_kernel == 1 else None, 'diag': ktete.diag(), 'mean': ktete.mean(), 'std': ktete.std(), 'norm': ktete.norm(), }, }, } return out