def main(): attemp_count = 0 while os.path.exists(f"{args.logdir}{attemp_count:02d}"): attemp_count += 1 logdir = f"{args.logdir}{attemp_count:02d}" run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") u.seed_random(1) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method except Exception as e: print(f"wandb crash with {e}") # data_width = 4 # targets_width = 2 d1 = args.data_width**2 d2 = 10 d3 = args.targets_width**2 o = d3 n = args.stats_batch_size d = [d1, d2, d3] model = u.SimpleFullyConnected(d, nonlin=args.nonlin) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) def capture_activations(module, input, _output): if skip_forward_hooks: return assert gl.backward_idx == 0 # no need to forward-prop on Hessian computation assert not hasattr( module, 'activations' ), "Seeing activations from previous forward, call util.zero_grad to clear" assert len(input) == 1, "this works for single input layers only" setattr(module, "activations", input[0].detach()) def capture_backprops(module: nn.Module, _input, output): if skip_backward_hooks: return assert len(output) == 1, "this works for single variable layers only" if gl.backward_idx == 0: assert not hasattr( module, 'backprops' ), "Seeing results of previous autograd, call util.zero_grad to clear" setattr(module, 'backprops', []) assert gl.backward_idx == len(module.backprops) module.backprops.append(output[0]) def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]: """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient if the value has not been set, call util.zero_grad to clear it.""" def save_grad_fn(grad): if not hasattr(param, 'saved_grad'): setattr(param, 'saved_grad', grad) return save_grad_fn for layer in model.layers: layer.register_forward_hook(capture_activations) layer.register_backward_hook(capture_backprops) layer.weight.register_hook(save_grad(layer.weight)) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == len(targets) return torch.sum(err * err) / 2 / len(data) gl.token_count = 0 for step in range(args.stats_steps): data, targets = next(stats_iter) skip_forward_hooks = False skip_backward_hooks = False # get gradient values gl.backward_idx = 0 u.zero_grad(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) print("loss", loss.item()) # get Hessian values skip_forward_hooks = True id_mat = torch.eye(o) u.log_scalars({'loss': loss.item()}) # o = 0 for out_idx in range(o): model.zero_grad() # backprop to get section of batch output jacobian for output at position out_idx output = model( data ) # opt: using autograd.grad means I don't have to zero_grad ei = id_mat[out_idx] bval = torch.stack([ei] * n) gl.backward_idx = out_idx + 1 output.backward(bval) skip_backward_hooks = True # for (i, layer) in enumerate(model.layers): s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops[0] * n assert B_t.shape == (n, d[i + 1]) G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) # empirical Fisher efisher = G.t() @ G / n sigma = efisher - g.t() @ g # u.dump(sigma, f'/tmp/sigmas/{step}-{i}') s.sigma_l2 = u.l2_norm(sigma) ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)] Amat_t = torch.cat([A_t] * o, dim=0) Bmat_t = torch.cat(Bh_t, dim=0) assert Amat_t.shape == (n * o, d[i]) assert Bmat_t.shape == (n * o, d[i + 1]) Jb = u.khatri_rao_t(Bmat_t, Amat_t) # batch Jacobian, in row-vec format H = Jb.t() @ Jb / n pinvH = u.pinv(H) s.hess_l2 = u.l2_norm(H) s.invhess_l2 = u.l2_norm(pinvH) s.hess_fro = H.flatten().norm() s.invhess_fro = pinvH.flatten().norm() s.jacobian_l2 = u.l2_norm(Jb) s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / dd.flatten().norm()**2) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) s.grad_curv = curv_direction(g) ndir = g @ u.pinv(H) # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', u.pinv(H)) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 999 s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_gradient = loss_direction(g, s.step_openai) u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def compute_layer_stats(layer): refreeze = False if hasattr(layer, 'frozen') and layer.frozen: u.unfreeze(layer) refreeze = True s = AttrDefault(str, {}) n = args.stats_batch_size param = u.get_param(layer) _d = len(param.flatten()) # dimensionality of parameters layer_idx = model.layers.index(layer) # TODO: get layer type, include it in name assert layer_idx >= 0 assert stats_data.shape[0] == n def backprop_loss(): model.zero_grad() output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) loss.backward() return loss, output def backprop_output(): model.zero_grad() output = model(stats_data) output.backward(gradient=torch.ones_like(output)) return output # per-example gradients, n, d _loss, _output = backprop_loss() At = layer.data_input Bt = layer.grad_output * n G = u.khatri_rao_t(At, Bt) g = G.sum(dim=0, keepdim=True) / n u.check_close(g, u.vec(param.grad).t()) s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2 s.grad_fro = g.flatten().norm() s.param_fro = param.data.flatten().norm() pos_activations = torch.sum(layer.data_output > 0) neg_activations = torch.sum(layer.data_output <= 0) s.a_sparsity = neg_activations.float() / ( pos_activations + neg_activations) # 1 sparsity means all 0's activation_size = len(layer.data_output.flatten()) s.a_magnitude = torch.sum(layer.data_output) / activation_size _output = backprop_output() B2t = layer.grad_output J = u.khatri_rao_t(At, B2t) # batch output Jacobian H = J.t() @ J / n s.hessian_l2 = u.l2_norm(H) s.jacobian_l2 = u.l2_norm(J) J1 = J.sum(dim=0) / n # single output Jacobian s.J1_l2 = J1.norm() # newton decrement def loss_direction(direction, eps): """loss improvement if we take step eps in direction dir""" return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 * eps**2 * direction @ H @ direction.t()) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) # TODO: gradient diversity is stuck at 1 # TODO: newton/gradient angle # TODO: newton step magnitude s.grad_curvature = u.to_python_scalar( g @ H @ g.t()) # curvature in direction of g s.step_openai = u.to_python_scalar( s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999 s.regret_gradient = loss_direction(g, s.step_openai) if refreeze: u.freeze(layer) return s