def test_kron_1x2_conv(): """Minimal example of a 1x2 convolution whose Hessian/grad covariance doesn't factor as Kronecker. Two convolutional layers stacked on top of each other, followed by least squares loss. Outputs: 0 tensor([[[[0., 1., 1., 1.]]]]) 1 tensor([[[[2., 3.]]]]) 2 tensor([[[[8.]]]]) Activations/backprops: layerA 0 tensor([[[[0., 1.], [1., 1.]]]]) layerB 0 tensor([[[[1., 2.]]]]) layerA 1 tensor([[[[2.], [3.]]]]) layerB 1 tensor([[[[1.]]]]) layer 0 discrepancy: 0.6597963571548462 layer 1 discrepancy: 0.0 """ u.seed_random(1) n, Xh, Xw = 1, 1, 4 Kh, Kw = 1, 2 dd = [1, 1, 1] o = dd[-1] model: u.SimpleModel = u.StridedConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.tensor([0, 1., 1, 1]).reshape((n, dd[0], Xh, Xw)) model.layers[0].bias.data.zero_() model.layers[0].weight.data.copy_(torch.tensor([1, 2])) model.layers[1].bias.data.zero_() model.layers[1].weight.data.copy_(torch.tensor([1, 2])) sample_output = model(data) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() for i in range(len(model.layers)): layer = model.layers[i] H = layer.weight.hess Hk = layer.weight.hess_kron Hk = Hk.expand() print(u.symsqrt_dist(H, Hk))
def test_main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=1, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/temp/runs/curv_train_tiny/run') parser.add_argument('--train_batch_size', type=int, default=100) parser.add_argument('--stats_batch_size', type=int, default=60000) parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--method', type=str, choices=['gradient', 'newton'], default='gradient', help="descent method, newton or gradient") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument( '--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian' ) parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=1, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection') parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--weight_decay', type=float, default=1e-4) #args = parser.parse_args() args = AttrDict() args.lmb = 1e-3 args.compute_rho = 1 args.weight_decay = 1e-4 args.method = 'gradient' args.logdir = '/tmp' args.data_width = 2 args.targets_width = 2 args.train_batch_size = 10 args.full_batch = False args.skip_stats = False args.autograd_check = False u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) #gl.event_writer = SummaryWriter(logdir) gl.event_writer = u.NoOp() # print(f"Logging to {run_name}") # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width**2 d2 = 2 d3 = args.targets_width**2 d1 = args.data_width**2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size) + 1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") # optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam( model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) # print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i] > 50 or d[i + 1] > 50: print( f'layer {i} is too big ({d[i], d[i + 1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel( ) # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma) / s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1] * d[i]).to( gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H + lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm()**2)) with u.timeit(f"pinvH-{i}"): pinvH = H.pinverse() with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_newton = u.to_python_scalar( g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_spectral(H, sigma) discrepancy = torch.max(abs(p_sigma - p_sigma.t()) / p_sigma) s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check( u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_( param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close() assert val_losses[0] > 2.4 # 2.4828238487243652 assert val_losses[-1] < 2.25 # 2.20609712600708
def main(): u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") d1 = args.data_width ** 2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width ** 2 d2 = 2 d3 = args.targets_width ** 2 if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size)+1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam(model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i]>50 or d[i+1]>50: print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma)/s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H+lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2)) with u.timeit(f"pinvH-{i}"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm() # frobenius norm of Newton update s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_svd(H, sigma) if u.has_nan(p_sigma) and args.compute_rho: # use expensive method print('using expensive method') import pdb; pdb.set_trace() H0, sigma0 = u.to_numpys(H, sigma) p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0) p_sigma = torch.tensor(p_sigma).to(gl.device) if u.has_nan(p_sigma): # import pdb; pdb.set_trace() s.psigma_erank = H.shape[0] s.rho = 1 else: s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1-args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run') parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument( '--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian' ) parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=1, help='skip all stats collection') parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=1000, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=2e-5) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--dropout', type=int, default=0) parser.add_argument('--swa', type=int, default=0) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument('--train_batch_size', type=int, default=64) parser.add_argument('--stats_batch_size', type=int, default=10000) parser.add_argument('--uniform', type=int, default=0, help='use uniform architecture (all layers same size)') parser.add_argument('--run_name', type=str, default='noname') gl.args = parser.parse_args() args = gl.args u.seed_random(1) gl.project_name = 'train_ciresan' u.setup_logdir_and_event_writer(args.run_name) print(f"Logging to {gl.logdir}") d1 = 28 * 28 if args.uniform: d = [784, 784, 784, 784, 784, 784, 10] else: d = [784, 2500, 2000, 1500, 1000, 500, 10] o = 10 n = args.stats_batch_size model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout) model = model.to(gl.device) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True) train_iter = u.infinite_iter(train_loader) assert not args.full_batch, "fixme: validation still uses stats_iter" if not args.full_batch: stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) else: stats_iter = None test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, original_targets=True, dataset_size=args.dataset_size) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) loss_fn = torch.nn.CrossEntropyLoss() autograd_lib.add_hooks(model) autograd_lib.disable_hooks() gl.token_count = 0 last_outer = 0 for step in range(args.stats_steps): epoch = gl.token_count // 60000 print(gl.token_count) if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() # compute validation loss if args.swa: model.eval() with u.timeit('swa'): base_opt = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) opt = torchcontrib.optim.SWA(base_opt, swa_start=0, swa_freq=1, swa_lr=args.lr) for _ in range(100): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() opt.step() opt.swap_swa_sgd() with u.timeit("validate"): val_accuracy, val_loss = validate(model, test_loader, f'test (epoch {epoch})') train_accuracy, train_loss = validate(model, stats_loader, f'train (epoch {epoch})') # save log metrics = { 'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss, 'train_loss': train_loss, 'train_accuracy': train_accuracy, 'lr': optimizer.param_groups[0]['lr'], 'momentum': optimizer.param_groups[0].get('momentum', 0) } u.log_scalars(metrics) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) if not args.skip_stats: autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type='CrossEntropy') autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model, method='kron', attr_name='hess2') autograd_lib.compute_stats_factored(model) for (i, layer) in enumerate(model.layers): param_names = {layer.weight: "weight", layer.bias: "bias"} for param in [layer.weight, layer.bias]: if param is None: continue if not hasattr(param, 'stats'): continue s = param.stats param_name = param_names[param] u.log_scalars(u.nest_stats(f"{param_name}", s)) # gradient steps model.train() last_inner = 0 for i in range(args.train_steps): if last_inner: u.log_scalars( {"time/inner": 1000 * (time.perf_counter() - last_inner)}) last_inner = time.perf_counter() optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) gl.token_count += data.shape[0] gl.event_writer.close()
def main(): u.install_pdb_handler() u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {logdir}") loss_type = 'CrossEntropy' d1 = args.data_width ** 2 args.stats_batch_size = min(args.stats_batch_size, args.dataset_size) args.train_batch_size = min(args.train_batch_size, args.dataset_size) n = args.stats_batch_size o = 10 d = [d1, 60, 60, 60, o] # dataset_size = args.dataset_size model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin, last_layer_linear=True) model = model.to(gl.device) u.mark_expensive(model.layers[0]) # to stop grad1/hess calculations on this layer print(model) try: if args.wandb: wandb.init(project='curv_train_tiny', name=run_name, dir='/tmp/wandb.runs') wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) # optimizer = torch.optim.Adam(model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) test_dataset = u.TinyMNIST(data_width=args.data_width, train=False, dataset_size=args.dataset_size, loss_type=loss_type) test_batch_size = min(args.dataset_size, 1000) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) with u.timeit("validate"): if loss_type == 'CrossEntropy': val_accuracy, val_loss = validate(model, test_loader, f'test (stats_step {step})') # train_accuracy, train_loss = validate(model, train_loader, f'train (stats_step {step})') metrics = {'stats_step': step, 'val_accuracy': val_accuracy, 'val_loss': val_loss} u.log_scalars(metrics) data, targets = stats_data, stats_targets if not args.skip_stats: # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): if hasattr(layer, 'expensive'): continue param_names = {layer.weight: "weight", layer.bias: "bias"} for param in [layer.weight, layer.bias]: # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i]*d[i+1] > 8000: print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats') continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations B_t = layer.backprops_list[0] * n s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher G = param.grad1.reshape((n, -1)) g = G.mean(dim=0, keepdim=True) u.nan_check(G) with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g # sigma_spectrum = s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma)/s.sigma_l2 H = param.hess lambda_regularizer = args.lmb * torch.eye(H.shape[0]).to(gl.device) u.nan_check(H) with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H+lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = param.data.flatten().norm() def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2)) with u.timeit(f"pinvH-{i}"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) # curvature (eigenvalue) in direction g ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 1234567 s.step_div_inf = 2 / s.H_l2 # divegent step size for batch_size=infinity s.step_div_1 = torch.tensor(2) / torch.trace(H) # divergent step for batch_size=1 s.newton_fro = ndir.flatten().norm() # frobenius norm of Newton update s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): s.rho, s.lyap_erank, lyap_evals = u.truncated_lyapunov_rho(H, sigma) s.step_div_1_adjusted = s.step_div_1/s.rho with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n # Gradient diversity / n s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank param_name = f"{layer.name}={param_names[param]}" u.log_scalars(u.nest_stats(f"{param_name}", s)) H_evals = u.symeig_pos_evals(H) sigma_evals = u.symeig_pos_evals(sigma) u.log_spectrum(f'{param_name}/hess', H_evals) u.log_spectrum(f'{param_name}/sigma', sigma_evals) u.log_spectrum(f'{param_name}/lyap', lyap_evals) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1-args.weight_decay) gl.increment_global_step(data.shape[0]) gl.event_writer.close()
def _test_kron_conv_golden(): """Hardcoded error values to detect unexpected numeric changes.""" u.seed_random(1) n, Xh, Xw = 2, 8, 8 Kh, Kw = 2, 2 dd = [3, 3, 3, 3] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) # print(model) # print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron') autograd_lib.compute_hess(model, method='mean_kron', attr_name='hess_mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() errors1 = [] errors2 = [] for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_kron Hk_bias = layer.bias.hess_kron Hk = Hk.expand() Hk_bias = Hk_bias.expand() Hk2 = layer.weight.hess_mean_kron Hk2_bias = layer.bias.hess_mean_kron Hk2 = Hk2.expand() Hk2_bias = Hk2_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) errors1.extend([u.symsqrt_dist(H, Hk), u.symsqrt_dist(H_bias, Hk_bias)]) errors2.extend([u.symsqrt_dist(H, Hk2), u.symsqrt_dist(H_bias, Hk2_bias)]) errors1 = torch.tensor(errors1) errors2 = torch.tensor(errors2) golden_errors1 = torch.tensor([0.09458080679178238, 0.0, 0.13416489958763123, 0.0, 0.0003909761435352266, 0.0]) golden_errors2 = torch.tensor([0.0945773795247078, 0.0, 0.13418318331241608, 0.0, 4.478318658129865e-07, 0.0]) u.check_close(golden_errors1, errors1) u.check_close(golden_errors2, errors2)
def test_kron_conv_exact(): """Test per-example gradient computation for conv layer. Kronecker factoring is exact for 1x1 convolutions and linear activations. """ u.seed_random(1) n, Xh, Xw = 2, 2, 2 Kh, Kw = 1, 1 dd = [2, 2, 2] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) #print(model) #print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_factored Hk_bias = layer.bias.hess_factored Hk = Hk.expand() Hk_bias = Hk_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) u.check_close(H_bias, Hk_bias) u.check_close(H, Hk)
def test_hessian_multibatch(): """Test that Kronecker-factored computations still work when splitting work over batches.""" u.seed_random(1) # torch.set_default_dtype(torch.float64) gl.project_name = 'test' gl.logdir_base = '/tmp/runs' run_name = 'test_hessian_multibatch' u.setup_logdir_and_event_writer(run_name=run_name) loss_type = 'CrossEntropy' data_width = 2 n = 4 d1 = data_width ** 2 o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=False) model = model.to(gl.device) dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False) stats_iter = u.infinite_iter(stats_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) data, targets = stats_data, stats_targets # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) layer = model.layers[0] autograd_lib.clear_hess_backprops(model) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # compute Hessian using direct method, compare against PyTorch autograd hess0 = u.hessian(loss, layer.weight) autograd_lib.compute_hess(model) hess1 = layer.weight.hess u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-8, rtol=1e-6) # compute Hessian using factored method. Because Hessian depends on examples for cross entropy, factoring is not exact, raise tolerance autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True) hess2 = layer.weight.hess2 u.check_close(hess1, hess2, atol=1e-3, rtol=1e-1) # compute Hessian using multibatch # restart iterators dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) assert n % 2 == 0 stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n//2, shuffle=False) stats_iter = u.infinite_iter(stats_loader) autograd_lib.compute_cov(model, loss_fn, stats_iter, batch_size=n//2, steps=2) cov: autograd_lib.LayerCov = layer.cov hess2: u.Kron = hess2.commute() # get back into AA x BB order u.check_close(cov.H.value(), hess2)
def test_factored_hessian(): """"Simple test to ensure Hessian computation is working. In a linear neural network with squared loss, Newton step will converge in one step. Compute stats after minimizing, pass sanity checks. """ u.seed_random(1) loss_type = 'LeastSquares' data_width = 2 n = 5 d1 = data_width ** 2 o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=False) model = model.to(gl.device) print(model) dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False) stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 data, targets = stats_data, stats_targets # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) print(loss) loss.backward(retain_graph=True) layer = model.layers[0] autograd_lib.clear_hess_backprops(model) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # compute Hessian using direct method, compare against PyTorch autograd hess0 = u.hessian(loss, layer.weight) autograd_lib.compute_hess(model) hess1 = layer.weight.hess print(hess1) u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-9, rtol=1e-6) # compute Hessian using factored method autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True) # s.regret_newton = vecG.t() @ pinvH.commute() @ vecG.t() / 2 # TODO(y): figure out why needed transposes hess2 = layer.weight.hess2 u.check_close(hess1, hess2, atol=1e-9, rtol=1e-6) # Newton step in regular notation g1 = layer.weight.grad.flatten() newton1 = hess1 @ g1 g2 = u.Vecr(layer.weight.grad) newton2 = g2 @ hess2 u.check_close(newton1, newton2, atol=1e-9, rtol=1e-6) # compute regret in factored notation, compare against actual drop in loss regret1 = g1 @ hess1.pinverse() @ g1 / 2 regret2 = g2 @ hess2.pinv() @ g2 / 2 u.check_close(regret1, regret2) current_weight = layer.weight.detach().clone() param: torch.nn.Parameter = layer.weight # param.data.sub_((hess1.pinverse() @ g1).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 1", loss) # param.data.sub_((hess1.pinverse() @ u.vec(layer.weight.grad)).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 2", loss) # param.data.sub_((u.vec(layer.weight.grad).t() @ hess1.pinverse()).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 3", loss) # del layer.weight.grad output = model(data) loss = loss_fn(output, targets) loss.backward() param.data.sub_(u.unvec(hess1.pinverse() @ u.vec(layer.weight.grad), layer.weight.shape[0])) output = model(data) loss = loss_fn(output, targets) print("result 4", loss) # param.data.sub_((g1 @ hess1.pinverse() @ g1).reshape(param.shape)) print(loss)
def test_factored_stats_golden_values(): """Test stats from values generated by non-factored version""" u.seed_random(1) u.install_pdb_handler() torch.set_default_dtype(torch.float32) parser = argparse.ArgumentParser(description='PyTorch MNIST Example') args = parser.parse_args() logdir = u.create_local_logdir('/temp/runs/factored_test') run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print('logging to ', logdir) loss_type = 'LeastSquares' args.data_width = 2 args.dataset_size = 5 args.stats_batch_size = 5 d1 = args.data_width**2 args.stats_batch_size = args.dataset_size args.stats_steps = 1 n = args.stats_batch_size o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=0) model = model.to(gl.device) print(model) dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False) stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 for step in range(args.stats_steps): if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() data, targets = stats_data, stats_targets # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) autograd_lib.compute_hess(model, method='kron', attr_name='hess2') autograd_lib.compute_stats_factored(model) params = list(model.parameters()) assert len(params) == 1 new_values = params[0].stats golden_values = torch.load('test/factored.pt') for valname in new_values: print("Checking ", valname) if valname == 'sigma_l2': u.check_close(new_values[valname], golden_values[valname], atol=1e-2) # sigma is approximate elif valname == 'sigma_erank': u.check_close(new_values[valname], golden_values[valname], atol=0.11) # 1.0 vs 1.1 elif valname in ['rho', 'step_div_1_adjusted', 'batch_jain_full']: continue # lyapunov stats weren't computed correctly in golden set elif valname in ['batch_openai']: continue # batch sizes depend on sigma which is approximate elif valname in ['noise_variance_pinv']: pass # went from 0.22 to 0.014 after kron factoring (0.01 with full centering, 0.3 with no centering) elif valname in ['sparsity']: pass # had a bug in old calc (using integer arithmetic) else: u.check_close(new_values[valname], golden_values[valname], rtol=1e-4, atol=1e-6, label=valname) gl.event_writer.close()
def compute_hess(n: int = 1, image_size: int = 1, kernel_size: int = 1, num_channels: int = 1, num_layers: int = 1, nonlin: bool = False, loss: str = 'CrossEntropy', method='exact', param_name='weight') -> List[torch.Tensor]: """ Compute Hessians for all layers for given architecture Args: param_name: which parameter to compute ('weight' or 'bias') n: number of examples image_size: width of image (square image) kernel_size: kernel size num_channels: num_layers: nonlin loss: LeastSquares or CrossEntropy method: 'kron', 'mean_kron' num_layers: number of layers in the network Returns: list of num_layers Hessian matrices. """ assert param_name in ['weight', 'bias'] assert loss in autograd_lib._supported_losses assert method in autograd_lib._supported_methods u.seed_random(1) Xh, Xw = 1, image_size Kh, Kw = 1, kernel_size dd = [num_channels] * (num_layers + 1) model: u.SimpleModel2 = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=nonlin, bias=True) # model: u.SimpleModel2 = u.StridedConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=nonlin, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss) autograd_lib.compute_hess(model, method=method) autograd_lib.disable_hooks() result = [] for i in range(len(model.layers)): param = getattr(model.layers[i], param_name) if method == 'exact' or method == 'autograd': result.append(param.hess) else: result.append(param.hess_factored.expand()) return result
def main(): u.install_pdb_handler() u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {logdir}") loss_type = 'CrossEntropy' d1 = args.data_width ** 2 args.stats_batch_size = min(args.stats_batch_size, args.dataset_size) args.train_batch_size = min(args.train_batch_size, args.dataset_size) n = args.stats_batch_size o = 10 d = [d1, 60, 60, 60, o] # dataset_size = args.dataset_size model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin, last_layer_linear=True) model = model.to(gl.device) u.mark_expensive(model.layers[0]) # to stop grad1/hess calculations on this layer print(model) try: if args.wandb: wandb.init(project='curv_train_tiny', name=run_name, dir='/tmp/wandb.runs') wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) # optimizer = torch.optim.Adam(model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) test_dataset = u.TinyMNIST(data_width=args.data_width, train=False, dataset_size=args.dataset_size, loss_type=loss_type) test_batch_size = min(args.dataset_size, 1000) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) with u.timeit("validate"): if loss_type == 'CrossEntropy': val_accuracy, val_loss = validate(model, test_loader, f'test (stats_step {step})') # train_accuracy, train_loss = validate(model, train_loader, f'train (stats_step {step})') metrics = {'stats_step': step, 'val_accuracy': val_accuracy, 'val_loss': val_loss} u.log_scalars(metrics) data, targets = stats_data, stats_targets if not args.skip_stats: # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type='CrossEntropy') autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model, method='kron', attr_name='hess2') autograd_lib.compute_stats_factored(model) for (i, layer) in enumerate(model.layers): param_names = {layer.weight: "weight", layer.bias: "bias"} for param in [layer.weight, layer.bias]: if param is None: continue if not hasattr(param, 'stats'): continue s = param.stats param_name = param_names[param] u.log_scalars(u.nest_stats(f"{param_name}", s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1-args.weight_decay) gl.increment_global_step(data.shape[0]) gl.event_writer.close()