def test_kfac_jacobian_mnist(): u.seed_random(1) data_width = 3 d = [data_width**2, 8, 10] model: u.SimpleMLP = u.SimpleMLP(d, nonlin=False) autograd_lib.register(model) batch_size = 4 stats_steps = 2 n = batch_size * stats_steps dataset = u.TinyMNIST(dataset_size=n, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} jacobians = defaultdict(lambda: AttrDefault(float)) total_data = [] # sum up statistics over n examples for train_step in range(stats_steps): data, targets = next(train_iter) total_data.append(data) activations = {} def save_activations(layer, A, _): activations[layer] = A jacobians[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_jacobian(layer, _, B): A = activations[layer] jacobians[layer].BB += torch.einsum("ni,nj->ij", B, B) jacobians[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_jacobian): autograd_lib.backward_jacobian(output) for layer in model.layers: jacobian0 = jacobians[layer] jacobian_full = torch.einsum('kl,ij->kilj', jacobian0.BB / n, jacobian0.AA / n) jacobian_diag = jacobian0.diag / n J = u.jacobian(model(torch.cat(total_data)), layer.weight) J_autograd = torch.einsum('noij,nokl->ijkl', J, J) / n u.check_equal(jacobian_full, J_autograd) u.check_equal(jacobian_diag, torch.einsum('ikik->ik', J_autograd))
def test_full_hessian_xent_mnist_multilayer(): """Test regular and diagonal hessian computation.""" u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 6, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() hess = defaultdict(float) hess_diag = defaultdict(float) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) u.check_close(hess[model.layers[0]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[0]] / batch_size) H_autograd = u.hessian(loss, model.layers[1].weight) u.check_close(hess[model.layers[1]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[1]] / batch_size)
def test_kfac_fisher_mnist(): u.seed_random(1) data_width = 3 d = [data_width**2, 8, 10] model: u.SimpleMLP = u.SimpleMLP(d, nonlin=False) autograd_lib.register(model) batch_size = 4 stats_steps = 2 n = batch_size * stats_steps dataset = u.TinyMNIST(dataset_size=n, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} fishers = defaultdict(lambda: AttrDefault(float)) total_data = [] # sum up statistics over n examples for train_step in range(stats_steps): data, targets = next(train_iter) total_data.append(data) activations = {} def save_activations(layer, A, _): activations[layer] = A fishers[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) * len( data) # remove data normalization def compute_fisher(layer, _, B): A = activations[layer] fishers[layer].BB += torch.einsum("ni,nj->ij", B, B) fishers[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_fisher): autograd_lib.backward_jacobian(output) for layer in model.layers: fisher0 = fishers[layer] fisher_full = torch.einsum('kl,ij->kilj', fisher0.BB / n, fisher0.AA / n) fisher_diag = fisher0.diag / n u.check_equal(torch.einsum('ikik->ik', fisher_full), fisher_diag)
def _test_kfac_hessian_xent_mnist(): u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} hess = defaultdict(lambda: AttrDefault(float)) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] hess[layer].AA += torch.einsum("ni,nj->ij", A, A) hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) rel_error = torch.norm( (hess0 - H_autograd).flatten()) / torch.norm(H_autograd.flatten()) assert rel_error < 0.01 # 0.0057
def test_full_hessian_xent_kfac2(): """Test with uneven layers.""" u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [3, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=False) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.2, 0.1]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(lambda: AttrDefault(float)) for i in range(n): def save_activations(layer, A, _): activations[layer] = A hess[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) o = Y.shape[1] loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # expand hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) u.check_equal(hess_autograd, hess0)
def _test_refactored_stats(): gl.project_name = 'test' gl.logdir_base = '/tmp/runs' run_name = 'test_hessian_multibatch' u.setup_logdir_and_event_writer(run_name=run_name) loss_type = 'CrossEntropy' data_width = 2 n = 4 d1 = data_width ** 2 o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=False) model = model.to(gl.device) dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False) stats_iter = u.infinite_iter(stats_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) data, targets = stats_data, stats_targets covG = autograd_lib.layer_cov_dict() covH = autograd_lib.layer_cov_dict() covJ = autograd_lib.layer_cov_dict() autograd_lib.register(model) A = {} with autograd_lib.save_activations(A): output = model(data) loss = loss_fn(output, targets) Acov = autograd_lib.ModuleDict(autograd_lib.SecondOrder) for layer, activations in A.items(): Acov[layer].accumulate(activations) autograd_lib.set_default_activations(A) # set activations to use by default when constructing cov matrices autograd_lib.set_default_Acov(Acov) # saves backprop covariances autograd_lib.backward_accum(loss, 1, covG) autograd_lib.backward_accum(output, autograd_lib.xent_bwd, covH) autograd_lib.backward_accum(output, autograd_lib.identity_bwd, covJ)
def test_full_hessian_xent_multibatch(): u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [2, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=True) model.layers[0].weight.data.copy_(torch.eye(2)) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.3]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(float) def save_activations(layer, a, _): activations[layer] = a for i in range(n): with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) hess0 = hess[model.layers[0]] / n u.check_equal(hess_autograd, hess0)
def _test_activations_contextmanager(): d = 5 model = simple_model(d, num_layers=2) autograd_lib.register(model) A1, A2, A3 = {}, {}, {} x = torch.ones(1, d) with autograd_lib.save_activations(A1): y = model(x) with autograd_lib.save_activations(A2): z = model[1](x) context_ids = autograd_lib.global_settings.last_captured_activations_contextid assert context_ids[model[1]] == context_ids[model[0]] + 1
def create_toy_model(): """ Create model from https://www.wolframcloud.com/obj/yaroslavvb/newton/linear-jacobians-and-hessians.nb PyTorch works on transposed representation, hence to obtain Y from notebook, do model(A.T).T """ model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False) autograd_lib.register(model) A = torch.tensor([[-1., 4], [3, 0]]) B = torch.tensor([[-4., 3], [2, 6]]) X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True) model.layers[0].weight.data.copy_(X) model.layers[1].weight.data.copy_(B.t()) return A, model
def test_hooks(): d = 1 model = simple_model(d, num_layers=5) autograd_lib.register(model) A1, A2, A3 = {}, {}, {} x = torch.ones(1, d) with autograd_lib.save_activations(A1): y = model(2 * x) with autograd_lib.save_activations(A2): with autograd_lib.save_activations(A3): y = model(x) B1 = {} B2 = {} with autograd_lib.extend_backprops(B1): y.backward(x, retain_graph=True) model[2].weight.requires_grad = False for layer in model: del layer.weight.grad # model.clear_grads() with autograd_lib.extend_backprops(B2): y.backward(2 * x) print(B2.values()) for layer in model: print(layer.weight.grad) for layer in model: assert A1[layer] == 2 * x assert A2[layer] == x assert A3[layer] == x assert B1[layer] == [x] assert B2[layer] == [2 * x] autograd_lib.unregister()
def _test_explicit_hessian_refactored(): """Check computation of hessian of loss(B'WA) from https://github.com/yaroslavvb/kfac_pytorch/blob/master/derivation.pdf """ torch.set_default_dtype(torch.float64) A = torch.tensor([[-1., 4], [3, 0]]) B = torch.tensor([[-4., 3], [2, 6]]) X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True) Y = B.t() @ X @ A u.check_equal(Y, [[-52, 64], [-81, -108]]) loss = torch.sum(Y * Y) / 2 hess0 = u.hessian(loss, X).reshape([4, 4]) hess1 = u.Kron(A @ A.t(), B @ B.t()) u.check_equal(loss, 12512.5) # Do a test using Linear layers instead of matrix multiplies model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False) model.layers[0].weight.data.copy_(X) # Transpose to match previous results, layers treat dim0 as batch dimension u.check_equal(model.layers[0](A.t()).t(), [[5, -20], [-16, -8]]) # XA = (A'X0)' model.layers[1].weight.data.copy_(B.t()) u.check_equal(model(A.t()).t(), Y) Y = model(A.t()).t() # transpose to data-dimension=columns loss = torch.sum(Y * Y) / 2 loss.backward() u.check_equal(model.layers[0].weight.grad, [[-2285, -105], [-1490, -1770]]) G = B @ Y @ A.t() u.check_equal(model.layers[0].weight.grad, G) autograd_lib.register(model) activations_dict = autograd_lib.ModuleDict() # todo(y): make save_activations ctx manager automatically create A with autograd_lib.save_activations(activations_dict): Y = model(A.t()) Acov = autograd_lib.ModuleDict(autograd_lib.SecondOrderCov) for layer, activations in activations_dict.items(): print(layer, activations) Acov[layer].accumulate(activations, activations) autograd_lib.set_default_activations(activations_dict) autograd_lib.set_default_Acov(Acov) B = autograd_lib.ModuleDict(autograd_lib.SymmetricFourthOrderCov) autograd_lib.backward_accum(Y, "identity", B, retain_graph=False) print(B[model.layers[0]]) autograd_lib.backprop_hess(Y, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron', vecr_order=False, loss_aggregation='sum') param = model.layers[0].weight hess2 = param.hess_kron print(hess2) u.check_equal(hess2, [[425, 170, -75, -30], [170, 680, -30, -120], [-75, -30, 225, 90], [-30, -120, 90, 360]]) # Gradient test model.zero_grad() loss.backward() u.check_close(u.vec(G).flatten(), u.Vec(param.grad)) # Newton step test # Method 0: PyTorch native autograd newton_step0 = param.grad.flatten() @ torch.pinverse(hess0) newton_step0 = newton_step0.reshape(param.shape) u.check_equal(newton_step0, [[-5, 0], [-2, -6]]) # Method 1: colummn major order ihess2 = hess2.pinv() u.check_equal(ihess2.LL, [[1/16, 1/48], [1/48, 17/144]]) u.check_equal(ihess2.RR, [[2/45, -(1/90)], [-(1/90), 1/36]]) u.check_equal(torch.flatten(hess2.pinv() @ u.vec(G)), [-5, -2, 0, -6]) newton_step1 = (ihess2 @ u.Vec(param.grad)).matrix_form() # Method2: row major order ihess2_rowmajor = ihess2.commute() newton_step2 = ihess2_rowmajor @ u.Vecr(param.grad) newton_step2 = newton_step2.matrix_form() u.check_equal(newton_step0, newton_step1) u.check_equal(newton_step0, newton_step2)
def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run') parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument('--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian') parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection') parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--dropout', type=int, default=0) parser.add_argument('--swa', type=int, default=0) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument('--train_batch_size', type=int, default=64) parser.add_argument('--stats_batch_size', type=int, default=10000) parser.add_argument('--stats_num_batches', type=int, default=1) parser.add_argument('--run_name', type=str, default='noname') parser.add_argument('--launch_blocking', type=int, default=0) parser.add_argument('--sampled', type=int, default=0) parser.add_argument('--curv', type=str, default='kfac', help='decomposition to use for curvature estimates: zero_order, kfac, isserlis or full') parser.add_argument('--log_spectra', type=int, default=0) u.seed_random(1) gl.args = parser.parse_args() args = gl.args u.seed_random(1) gl.project_name = 'train_ciresan' u.setup_logdir_and_event_writer(args.run_name) print(f"Logging to {gl.logdir}") d1 = 28 * 28 d = [784, 2500, 2000, 1500, 1000, 500, 10] # number of samples per datapoint. Used to normalize kfac model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout) model = model.to(gl.device) autograd_lib.register(model) assert args.dataset_size >= args.stats_batch_size optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True) train_iter = u.infinite_iter(train_loader) assert not args.full_batch, "fixme: validation still uses stats_iter" if not args.full_batch: stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=True, drop_last=True) stats_iter = u.infinite_iter(stats_loader) else: stats_iter = None test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, original_targets=True, dataset_size=args.dataset_size) test_eval_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) train_eval_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) loss_fn = torch.nn.CrossEntropyLoss() autograd_lib.add_hooks(model) autograd_lib.disable_hooks() gl.token_count = 0 last_outer = 0 for step in range(args.stats_steps): epoch = gl.token_count // 60000 lr = optimizer.param_groups[0]['lr'] print('token_count', gl.token_count) if last_outer: u.log_scalars({"time/outer": 1000 * (time.perf_counter() - last_outer)}) print(f'time: {time.perf_counter() - last_outer:.2f}') last_outer = time.perf_counter() with u.timeit("validate"): val_accuracy, val_loss = validate(model, test_eval_loader, f'test (epoch {epoch})') train_accuracy, train_loss = validate(model, train_eval_loader, f'train (epoch {epoch})') # save log metrics = {'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss, 'train_loss': train_loss, 'train_accuracy': train_accuracy, 'lr': optimizer.param_groups[0]['lr'], 'momentum': optimizer.param_groups[0].get('momentum', 0)} u.log_scalars(metrics) def mom_update(buffer, val): buffer *= 0.9 buffer += val * 0.1 if not args.skip_stats: # number of samples passed through n = args.stats_batch_size * args.stats_num_batches # quanti forward_stats = defaultdict(lambda: AttrDefault(float)) hessians = defaultdict(lambda: AttrDefault(float)) jacobians = defaultdict(lambda: AttrDefault(float)) fishers = defaultdict(lambda: AttrDefault(float)) # empirical fisher/gradient quad_fishers = defaultdict(lambda: AttrDefault(float)) # gradient statistics that depend on fisher (4th order moments) train_regrets = defaultdict(list) test_regrets1 = defaultdict(list) test_regrets2 = defaultdict(list) train_regrets_opt = defaultdict(list) test_regrets_opt = defaultdict(list) cosines = defaultdict(list) dot_products = defaultdict(list) hessians_histograms = defaultdict(lambda: AttrDefault(u.MyList)) jacobians_histograms = defaultdict(lambda: AttrDefault(u.MyList)) fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList)) quad_fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList)) current = None current_histograms = None for i in range(args.stats_num_batches): activations = {} backprops = {} def save_activations(layer, A, _): activations[layer] = A forward_stats[layer].AA += torch.einsum("ni,nj->ij", A, A) print('forward') with u.timeit("stats_forward"): with autograd_lib.module_hook(save_activations): data, targets = next(stats_iter) output = model(data) loss = loss_fn(output, targets) * len(output) def compute_stats(layer, _, B): A = activations[layer] if current == fishers: backprops[layer] = B # about 27ms per layer with u.timeit('compute_stats'): current[layer].BB += torch.einsum("ni,nj->ij", B, B) # TODO(y): index consistency current[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A) current[layer].BA += torch.einsum("ni,nj->ij", B, A) current[layer].a += torch.einsum("ni->i", A) current[layer].b += torch.einsum("nk->k", B) current[layer].norm2 += ((A * A).sum(dim=1) * (B * B).sum(dim=1)).sum() # compute curvatures in direction of all gradiennts if current is fishers: assert args.stats_num_batches == 1, "not tested on more than one stats step, currently reusing aggregated moments" hess = hessians[layer] jac = jacobians[layer] Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1)) current[layer].min_norm2 = min(norms) current[layer].median_norm2 = torch.median(norms) current[layer].max_norm2 = max(norms) norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1)) norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1)) current[layer].norm += norms.sum() current_histograms[layer].norms.extend(torch.sqrt(norms)) current[layer].curv_hess += (skip_nans(norms2_hess / norms)).sum() current_histograms[layer].curv_hess.extend(skip_nans(norms2_hess / norms)) current[layer].curv_hess_max += (skip_nans(norms2_hess / norms)).max() current[layer].curv_hess_median += (skip_nans(norms2_hess / norms)).median() current_histograms[layer].curv_jac.extend(skip_nans(norms2_jac / norms)) current[layer].curv_jac += (skip_nans(norms2_jac / norms)).sum() current[layer].curv_jac_max += (skip_nans(norms2_jac / norms)).max() current[layer].curv_jac_median += (skip_nans(norms2_jac / norms)).median() current[layer].a_sparsity += torch.sum(A <= 0).float() / A.numel() current[layer].b_sparsity += torch.sum(B <= 0).float() / B.numel() current[layer].mean_activation += torch.mean(A) current[layer].mean_activation2 += torch.mean(A*A) current[layer].mean_backprop = torch.mean(B) current[layer].mean_backprop2 = torch.mean(B*B) current[layer].norms_hess += torch.sqrt(norms2_hess).sum() current_histograms[layer].norms_hess.extend(torch.sqrt(norms2_hess)) current[layer].norms_jac += norms2_jac.sum() current_histograms[layer].norms_jac.extend(torch.sqrt(norms2_jac)) normalized_moments = copy.copy(hessians[layer]) normalized_moments.AA = forward_stats[layer].AA normalized_moments = u.divide_attributes(normalized_moments, n) train_regrets_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=0, m=normalized_moments, approx=args.curv) test_regrets1_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=1, m=normalized_moments, approx=args.curv) test_regrets2_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=2, m=normalized_moments, approx=args.curv) test_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=2, m=normalized_moments, approx=args.curv) train_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=0, m=normalized_moments, approx=args.curv) cosines_ = autograd_lib.offset_cosines(A, B) train_regrets[layer].extend(train_regrets_) test_regrets1[layer].extend(test_regrets1_) test_regrets2[layer].extend(test_regrets2_) train_regrets_opt[layer].extend(train_regrets_opt_) test_regrets_opt[layer].extend(test_regrets_opt_) cosines[layer].extend(cosines_) dot_products[layer].extend(autograd_lib.offset_dotprod(A, B)) # statistics of the form g.Sigma.g elif current == quad_fishers: hess = hessians[layer] sigma = fishers[layer] jac = jacobians[layer] Bs, As = B @ sigma.BB / n, A @ forward_stats[layer].AA / n Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1)) norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1)) norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1)) norms_sigma = ((As * A).sum(dim=1) * (Bs * B).sum(dim=1)) current[layer].norm += norms.sum() # TODO(y) remove, redundant with norm2 above current[layer].curv_sigma += (skip_nans(norms_sigma / norms)).sum() current[layer].curv_sigma_max = skip_nans(norms_sigma / norms).max() current[layer].curv_sigma_median = skip_nans(norms_sigma / norms).median() current[layer].curv_hess += skip_nans(norms2_hess / norms).sum() current[layer].curv_hess_max += skip_nans(norms2_hess / norms).max() current[layer].lyap_hess_mean += skip_nans(norms_sigma / norms2_hess).mean() current[layer].lyap_hess_max = max(skip_nans(norms_sigma/norms2_hess)) current[layer].lyap_jac_mean += skip_nans(norms_sigma / norms2_jac).mean() current[layer].lyap_jac_max = max(skip_nans(norms_sigma/norms2_jac)) print('backward') with u.timeit("backprop_H"): with autograd_lib.module_hook(compute_stats): current = hessians current_histograms = hessians_histograms autograd_lib.backward_hessian(output, loss='CrossEntropy', sampled=args.sampled, retain_graph=True) # 600 ms current = jacobians current_histograms = jacobians_histograms autograd_lib.backward_jacobian(output, sampled=args.sampled, retain_graph=True) # 600 ms current = fishers current_histograms = fishers_histograms model.zero_grad() loss.backward(retain_graph=True) # 60 ms current = quad_fishers current_histograms = quad_fishers_histograms model.zero_grad() loss.backward() # 60 ms print('summarize') for (i, layer) in enumerate(model.layers): stats_dict = {'hessian': hessians, 'jacobian': jacobians, 'fisher': fishers} # evaluate stats from # https://app.wandb.ai/yaroslavvb/train_ciresan/runs/425pu650?workspace=user-yaroslavvb for stats_name in stats_dict: s = AttrDict() stats = stats_dict[stats_name][layer] for key in forward_stats[layer]: # print(f'copying {key} in {stats_name}, {layer}') try: assert stats[key] == float() except: f"Trying to overwrite {key} in {stats_name}, {layer}" stats[key] = forward_stats[layer][key] diag: torch.Tensor = stats.diag / n # jacobian: # curv in direction of gradient goes down to roughly 0.3-1 # maximum curvature goes up to 1000-2000 # # Hessian: # max curv goes down to 1, in direction of gradient 0.0001 s.diag_l2 = torch.max(diag) # 40 - 3000 smaller than kfac l2 for jac s.diag_fro = torch.norm( diag) # jacobian grows to 0.5-1.5, rest falls, layer-5 has phase transition, layer-4 also s.diag_trace = diag.sum() # jacobian grows 0-1000 (first), 0-150 (last). Almost same as kfac_trace (771 vs 810 kfac). Jacobian has up/down phase transition s.diag_average = diag.mean() # normalize for mean loss BB = stats.BB / n AA = stats.AA / n # A_evals, _ = torch.symeig(AA) # averaging 120ms per hit, 90 hits # B_evals, _ = torch.symeig(BB) # s.kfac_l2 = torch.max(A_evals) * torch.max(B_evals) # 60x larger than diag_l2. layer0/hess has down/up phase transition. layer5/jacobian has up/down phase transition s.kfac_trace = torch.trace(AA) * torch.trace(BB) # 0/hess down/up tr, 5/jac sharp phase transition s.kfac_fro = torch.norm(stats.AA) * torch.norm( stats.BB) # 0/hess has down/up tr, 5/jac up/down transition # s.kfac_erank = s.kfac_trace / s.kfac_l2 # first layer has 25, rest 15, all layers go down except last, last noisy # s.kfac_erank_fro = s.kfac_trace / s.kfac_fro / max(stats.BA.shape) s.diversity = (stats.norm2 / n) / u.norm_squared( stats.BA / n) # gradient diversity. Goes up 3x. Bottom layer has most diversity. Jacobian diversity much less noisy than everythingelse # discrepancy of KFAC based on exact values of diagonal approximation # average difference normalized by average diagonal magnitude diag_kfac = torch.einsum('ll,ii->li', BB, AA) s.kfac_error = (torch.abs(diag_kfac - diag)).mean() / torch.mean(diag.abs()) u.log_scalars(u.nest_stats(f'layer-{i}/{stats_name}', s)) # openai batch size stat s = AttrDict() hess = hessians[layer] jac = jacobians[layer] fish = fishers[layer] quad_fish = quad_fishers[layer] # the following check passes, but is expensive # if args.stats_num_batches == 1: # u.check_close(fisher[layer].BA, layer.weight.grad) def trsum(A, B): return (A * B).sum() # computes tr(AB') grad = fishers[layer].BA / n s.grad_fro = torch.norm(grad) # get norms s.lyap_hess_max = quad_fish.lyap_hess_max s.lyap_hess_ave = quad_fish.lyap_hess_sum / n s.lyap_jac_max = quad_fish.lyap_jac_max s.lyap_jac_ave = quad_fish.lyap_jac_sum / n s.hess_trace = hess.diag.sum() / n s.jac_trace = jac.diag.sum() / n # Version 1 of Jain stochastic rates, use Hessian for curvature b = args.train_batch_size s.hess_curv = trsum((hess.BB / n) @ grad @ (hess.AA / n), grad) / trsum(grad, grad) s.jac_curv = trsum((jac.BB / n) @ grad @ (jac.AA / n), grad) / trsum(grad, grad) # compute gradient noise statistics # fish.BB has /n factor twice, hence don't need extra /n on fish.AA # after sampling, hess_noise,jac_noise became 100x smaller, but normalized is unaffected s.hess_noise = (trsum(hess.AA / n, fish.AA / n) * trsum(hess.BB / n, fish.BB / n)) s.jac_noise = (trsum(jac.AA / n, fish.AA / n) * trsum(jac.BB / n, fish.BB / n)) s.hess_noise_centered = s.hess_noise - trsum(hess.BB / n @ grad, grad @ hess.AA / n) s.jac_noise_centered = s.jac_noise - trsum(jac.BB / n @ grad, grad @ jac.AA / n) s.openai_gradient_noise = (fish.norms_hess / n) / trsum(hess.BB / n @ grad, grad @ hess.AA / n) s.mean_norm = torch.sqrt(fish.norm2) / n s.min_norm = torch.sqrt(fish.min_norm2) s.median_norm = torch.sqrt(fish.median_norm2) s.max_norm = torch.sqrt(fish.max_norm2) s.enorms = u.norm_squared(grad) s.a_sparsity = fish.a_sparsity s.b_sparsity = fish.b_sparsity s.mean_activation = fish.mean_activation s.msr_activation = torch.sqrt(fish.mean_activation2) s.mean_backprop = fish.mean_backprop s.msr_backprop = torch.sqrt(fish.mean_backprop2) s.norms_centered = fish.norm2 / n - u.norm_squared(grad) s.norms_hess = fish.norms_hess / n s.norms_jac = fish.norms_jac / n s.hess_curv_grad = fish.curv_hess / n # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.hess_curv_grad_max = fish.curv_hess_max # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.hess_curv_grad_median = fish.curv_hess_median # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.sigma_curv_grad = quad_fish.curv_sigma / n s.sigma_curv_grad_max = quad_fish.curv_sigma_max s.sigma_curv_grad_median = quad_fish.curv_sigma_median s.band_bottou = 0.5 * lr * s.sigma_curv_grad / s.hess_curv_grad s.band_bottou_stoch = 0.5 * lr * quad_fish.curv_ratio / n s.band_yaida = 0.25 * lr * s.mean_norm**2 s.band_yaida_centered = 0.25 * lr * s.norms_centered s.jac_curv_grad = fish.curv_jac / n # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. s.jac_curv_grad_max = fish.curv_jac_max # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. s.jac_curv_grad_median = fish.curv_jac_median # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. # OpenAI gradient noise statistics s.hess_noise_normalized = s.hess_noise_centered / (fish.norms_hess / n) s.jac_noise_normalized = s.jac_noise / (fish.norms_jac / n) train_regrets_, test_regrets1_, test_regrets2_, train_regrets_opt_, test_regrets_opt_, cosines_, dot_products_ = (torch.stack(r[layer]) for r in (train_regrets, test_regrets1, test_regrets2, train_regrets_opt, test_regrets_opt, cosines, dot_products)) s.train_regret = train_regrets_.median() # use median because outliers make it hard to see the trend s.test_regret1 = test_regrets1_.median() s.test_regret2 = test_regrets2_.median() s.test_regret_opt = test_regrets_opt_.median() s.train_regret_opt = train_regrets_opt_.median() s.mean_dot_product = torch.mean(dot_products_) s.median_dot_product = torch.median(dot_products_) a = [1, 2, 3] s.median_cosine = cosines_.median() s.mean_cosine = cosines_.mean() # get learning rates L1 = s.hess_curv_grad / n L2 = s.jac_curv_grad / n diversity = (fish.norm2 / n) / u.norm_squared(grad) robust_diversity = (fish.norm2 / n) / fish.median_norm2 dotprod_diversity = fish.median_norm2 / s.median_dot_product s.lr1 = 2 / (L1 * diversity) s.lr2 = 2 / (L2 * diversity) s.lr3 = 2 / (L2 * robust_diversity) s.lr4 = 2 / (L2 * dotprod_diversity) hess_A = u.symeig_pos_evals(hess.AA / n) hess_B = u.symeig_pos_evals(hess.BB / n) fish_A = u.symeig_pos_evals(fish.AA / n) fish_B = u.symeig_pos_evals(fish.BB / n) jac_A = u.symeig_pos_evals(jac.AA / n) jac_B = u.symeig_pos_evals(jac.BB / n) u.log_scalars({f'layer-{i}/hessA_erank': erank(hess_A)}) u.log_scalars({f'layer-{i}/hessB_erank': erank(hess_B)}) u.log_scalars({f'layer-{i}/fishA_erank': erank(fish_A)}) u.log_scalars({f'layer-{i}/fishB_erank': erank(fish_B)}) u.log_scalars({f'layer-{i}/jacA_erank': erank(jac_A)}) u.log_scalars({f'layer-{i}/jacB_erank': erank(jac_B)}) gl.event_writer.add_histogram(f'layer-{i}/hist_hess_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_fish_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_jac_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) s.hess_l2 = max(hess_A) * max(hess_B) s.jac_l2 = max(jac_A) * max(jac_B) s.fish_l2 = max(fish_A) * max(fish_B) s.hess_trace = hess.diag.sum() / n s.jain1_sto = 1/(s.hess_trace + 2 * s.hess_l2) s.jain1_det = 1/s.hess_l2 s.jain1_lr = (1 / b) * (1/s.jain1_sto) + (b - 1) / b * (1/s.jain1_det) s.jain1_lr = 2 / s.jain1_lr s.regret_ratio = ( train_regrets_opt_ / test_regrets_opt_).median() # ratio between train and test regret, large means overfitting u.log_scalars(u.nest_stats(f'layer-{i}', s)) # compute stats that would let you bound rho if i == 0: # only compute this once, for output layer hhh = hessians[model.layers[-1]].BB / n fff = fishers[model.layers[-1]].BB / n d = fff.shape[0] L = u.lyapunov_spectral(hhh, 2 * fff, cond=1e-8) L_evals = u.symeig_pos_evals(L) Lcheap = fff @ u.pinv(hhh, cond=1e-8) Lcheap_evals = u.eig_real(Lcheap) u.log_scalars({f'mismatch/rho': d/erank(L_evals)}) u.log_scalars({f'mismatch/rho_cheap': d/erank(Lcheap_evals)}) u.log_scalars({f'mismatch/diagonalizability': erank(L_evals)/erank(Lcheap_evals)}) # 1 means diagonalizable u.log_spectrum(f'mismatch/sigma', u.symeig_pos_evals(fff), loglog=False) u.log_spectrum(f'mismatch/hess', u.symeig_pos_evals(hhh), loglog=False) u.log_spectrum(f'mismatch/lyapunov', L_evals, loglog=True) u.log_spectrum(f'mismatch/lyapunov_cheap', Lcheap_evals, loglog=True) gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms', u.to_numpy(fishers_histograms[layer].norms.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms_hess', u.to_numpy(fishers_histograms[layer].norms_hess.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_curv_jac', u.to_numpy(fishers_histograms[layer].curv_jac.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_curv_hess', u.to_numpy(fishers_histograms[layer].curv_hess.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_cosines', u.to_numpy(cosines[layer]), gl.get_global_step()) if args.log_spectra: with u.timeit('spectrum'): # 2/alpha # s.jain1_lr = (1 / b) * s.jain1_sto + (b - 1) / b * s.jain1_det # s.jain1_lr = 1 / s.jain1_lr # hess.diag_trace, jac.diag_trace # Version 2 of Jain stochastic rates, use Jacobian squared for curvature s.jain2_sto = s.lyap_jac_max * s.jac_trace / s.lyap_jac_ave s.jain2_det = s.jac_l2 s.jain2_lr = (1 / b) * s.jain2_sto + (b - 1) / b * s.jain2_det s.jain2_lr = 1 / s.jain2_lr u.log_spectrum(f'layer-{i}/hess_A', hess_A) u.log_spectrum(f'layer-{i}/hess_B', hess_B) u.log_spectrum(f'layer-{i}/hess_AB', u.outer(hess_A, hess_B).flatten()) u.log_spectrum(f'layer-{i}/jac_A', jac_A) u.log_spectrum(f'layer-{i}/jac_B', jac_B) u.log_spectrum(f'layer-{i}/fish_A', fish_A) u.log_spectrum(f'layer-{i}/fish_B', fish_B) u.log_scalars({f'layer-{i}/trace_ratio': fish_B.sum()/hess_B.sum()}) L = torch.eig(u.lyapunov_spectral(hess.BB, 2*fish.BB, cond=1e-8))[0] L = L[:, 0] # extract real part L = L.sort()[0] L = torch.flip(L, [0]) L_cheap = torch.eig(fish.BB @ u.pinv(hess.BB, cond=1e-8))[0] L_cheap = L_cheap[:, 0] # extract real part L_cheap = L_cheap.sort()[0] L_cheap = torch.flip(L_cheap, [0]) d = len(hess_B) u.log_spectrum(f'layer-{i}/Lyap', L) u.log_spectrum(f'layer-{i}/Lyap_cheap', L_cheap) u.log_scalars({f'layer-{i}/dims': d}) u.log_scalars({f'layer-{i}/L_erank': erank(L)}) u.log_scalars({f'layer-{i}/L_cheap_erank': erank(L_cheap)}) u.log_scalars({f'layer-{i}/rho': d/erank(L)}) u.log_scalars({f'layer-{i}/rho_cheap': d/erank(L_cheap)}) model.train() with u.timeit('train'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) gl.token_count += data.shape[0] gl.event_writer.close()
def test_grad_norms(): """Test computing gradient norms using various methods.""" u.seed_random(1) # torch.set_default_dtype(torch.float64) data_width = 3 batch_size = 2 d = [data_width**2, 6, 10] o = d[-1] stats_steps = 2 num_samples = batch_size * stats_steps # number of samples used in computation of curvature stats model: u.SimpleModel = u.SimpleMLP(d, nonlin=True, bias=True) loss_fn = torch.nn.CrossEntropyLoss() autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=num_samples, data_width=data_width, original_targets=True) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) stats_iter = iter(stats_loader) moments = defaultdict(lambda: AttrDefault(float)) norms = defaultdict(lambda: AttrDefault(MyList)) data_batches = [] targets_batches = [] for stats_step in range(stats_steps): data, targets = next(stats_iter) data_batches.append(data) targets_batches.append(targets) activations = {} def forward_aggregate(layer, A, _): activations[layer] = A moments[layer].AA += torch.einsum('ni,nj->ij', A, A) moments[layer].a += torch.einsum("ni->i", A) with autograd_lib.module_hook(forward_aggregate): output = model(data) loss_fn(output, targets) def backward_aggregate(layer, _, B): A = activations[layer] moments[layer].b += torch.einsum("nk->k", B) moments[layer].BA += torch.einsum("nl,ni->li", B, A) moments[layer].BB += torch.einsum("nk,nl->kl", B, B) moments[layer].BABA += torch.einsum('nl,ni,nk,nj->likj', B, A, B, A) with autograd_lib.module_hook(backward_aggregate): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) # compare against results using autograd data = torch.cat(data_batches) targets = torch.cat(targets_batches) with autograd_lib.save_activations2() as activations: loss = loss_fn(model(data), targets) def normalize_moments(d, n): result = AttrDict() for val in d: if type(d[val]) == torch.Tensor: result[val] = d[val] / n return result def compute_norms(layer, _, B): A = activations[layer] for kind in ('zero_order', 'kfac', 'isserlis', 'full'): normalized_moments = normalize_moments(moments[layer], num_samples) norms_list = getattr(norms[layer], kind) norms_list.extend( autograd_lib.grad_norms(A, B, normalized_moments, approx=kind)) with autograd_lib.module_hook(compute_norms): model.zero_grad() (len(data) * loss).backward(retain_graph=True) print(norms[model.layers[0]].zero_order.value()) for layer in model.layers: output = model(data) losses = torch.stack([ loss_fn(output[i:i + 1], targets[i:i + 1]) for i in range(len(data)) ]) grads = u.jacobian(losses, layer.weight) grad_norms = torch.einsum('nij,nij->n', grads, grads) u.check_close(grad_norms, norms[layer].zero_order) # test gradient norms with custom metric kfac_norms, isserlis_norms, full_norms = [ u.to_pytorch(getattr(norms[layer], k)) for k in ('kfac', 'isserlis', 'full') ] error_kfac = max(abs(kfac_norms - full_norms)) error_isserlis = max(abs(isserlis_norms - full_norms)) assert error_isserlis < 1e-4 assert error_kfac < 1e-4