def test_full_hessian_xent_mnist_multilayer(): """Test regular and diagonal hessian computation.""" u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 6, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() hess = defaultdict(float) hess_diag = defaultdict(float) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) u.check_close(hess[model.layers[0]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[0]] / batch_size) H_autograd = u.hessian(loss, model.layers[1].weight) u.check_close(hess[model.layers[1]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[1]] / batch_size)
def _test_kfac_hessian_xent_mnist(): u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} hess = defaultdict(lambda: AttrDefault(float)) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] hess[layer].AA += torch.einsum("ni,nj->ij", A, A) hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) rel_error = torch.norm( (hess0 - H_autograd).flatten()) / torch.norm(H_autograd.flatten()) assert rel_error < 0.01 # 0.0057
def test_full_hessian_xent_kfac2(): """Test with uneven layers.""" u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [3, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=False) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.2, 0.1]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(lambda: AttrDefault(float)) for i in range(n): def save_activations(layer, A, _): activations[layer] = A hess[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) o = Y.shape[1] loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # expand hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) u.check_equal(hess_autograd, hess0)
def test_full_hessian_xent_multibatch(): u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [2, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=True) model.layers[0].weight.data.copy_(torch.eye(2)) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.3]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(float) def save_activations(layer, a, _): activations[layer] = a for i in range(n): with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) hess0 = hess[model.layers[0]] / n u.check_equal(hess_autograd, hess0)
def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run') parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument('--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian') parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection') parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--dropout', type=int, default=0) parser.add_argument('--swa', type=int, default=0) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument('--train_batch_size', type=int, default=64) parser.add_argument('--stats_batch_size', type=int, default=10000) parser.add_argument('--stats_num_batches', type=int, default=1) parser.add_argument('--run_name', type=str, default='noname') parser.add_argument('--launch_blocking', type=int, default=0) parser.add_argument('--sampled', type=int, default=0) parser.add_argument('--curv', type=str, default='kfac', help='decomposition to use for curvature estimates: zero_order, kfac, isserlis or full') parser.add_argument('--log_spectra', type=int, default=0) u.seed_random(1) gl.args = parser.parse_args() args = gl.args u.seed_random(1) gl.project_name = 'train_ciresan' u.setup_logdir_and_event_writer(args.run_name) print(f"Logging to {gl.logdir}") d1 = 28 * 28 d = [784, 2500, 2000, 1500, 1000, 500, 10] # number of samples per datapoint. Used to normalize kfac model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout) model = model.to(gl.device) autograd_lib.register(model) assert args.dataset_size >= args.stats_batch_size optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True) train_iter = u.infinite_iter(train_loader) assert not args.full_batch, "fixme: validation still uses stats_iter" if not args.full_batch: stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=True, drop_last=True) stats_iter = u.infinite_iter(stats_loader) else: stats_iter = None test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, original_targets=True, dataset_size=args.dataset_size) test_eval_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) train_eval_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) loss_fn = torch.nn.CrossEntropyLoss() autograd_lib.add_hooks(model) autograd_lib.disable_hooks() gl.token_count = 0 last_outer = 0 for step in range(args.stats_steps): epoch = gl.token_count // 60000 lr = optimizer.param_groups[0]['lr'] print('token_count', gl.token_count) if last_outer: u.log_scalars({"time/outer": 1000 * (time.perf_counter() - last_outer)}) print(f'time: {time.perf_counter() - last_outer:.2f}') last_outer = time.perf_counter() with u.timeit("validate"): val_accuracy, val_loss = validate(model, test_eval_loader, f'test (epoch {epoch})') train_accuracy, train_loss = validate(model, train_eval_loader, f'train (epoch {epoch})') # save log metrics = {'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss, 'train_loss': train_loss, 'train_accuracy': train_accuracy, 'lr': optimizer.param_groups[0]['lr'], 'momentum': optimizer.param_groups[0].get('momentum', 0)} u.log_scalars(metrics) def mom_update(buffer, val): buffer *= 0.9 buffer += val * 0.1 if not args.skip_stats: # number of samples passed through n = args.stats_batch_size * args.stats_num_batches # quanti forward_stats = defaultdict(lambda: AttrDefault(float)) hessians = defaultdict(lambda: AttrDefault(float)) jacobians = defaultdict(lambda: AttrDefault(float)) fishers = defaultdict(lambda: AttrDefault(float)) # empirical fisher/gradient quad_fishers = defaultdict(lambda: AttrDefault(float)) # gradient statistics that depend on fisher (4th order moments) train_regrets = defaultdict(list) test_regrets1 = defaultdict(list) test_regrets2 = defaultdict(list) train_regrets_opt = defaultdict(list) test_regrets_opt = defaultdict(list) cosines = defaultdict(list) dot_products = defaultdict(list) hessians_histograms = defaultdict(lambda: AttrDefault(u.MyList)) jacobians_histograms = defaultdict(lambda: AttrDefault(u.MyList)) fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList)) quad_fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList)) current = None current_histograms = None for i in range(args.stats_num_batches): activations = {} backprops = {} def save_activations(layer, A, _): activations[layer] = A forward_stats[layer].AA += torch.einsum("ni,nj->ij", A, A) print('forward') with u.timeit("stats_forward"): with autograd_lib.module_hook(save_activations): data, targets = next(stats_iter) output = model(data) loss = loss_fn(output, targets) * len(output) def compute_stats(layer, _, B): A = activations[layer] if current == fishers: backprops[layer] = B # about 27ms per layer with u.timeit('compute_stats'): current[layer].BB += torch.einsum("ni,nj->ij", B, B) # TODO(y): index consistency current[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A) current[layer].BA += torch.einsum("ni,nj->ij", B, A) current[layer].a += torch.einsum("ni->i", A) current[layer].b += torch.einsum("nk->k", B) current[layer].norm2 += ((A * A).sum(dim=1) * (B * B).sum(dim=1)).sum() # compute curvatures in direction of all gradiennts if current is fishers: assert args.stats_num_batches == 1, "not tested on more than one stats step, currently reusing aggregated moments" hess = hessians[layer] jac = jacobians[layer] Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1)) current[layer].min_norm2 = min(norms) current[layer].median_norm2 = torch.median(norms) current[layer].max_norm2 = max(norms) norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1)) norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1)) current[layer].norm += norms.sum() current_histograms[layer].norms.extend(torch.sqrt(norms)) current[layer].curv_hess += (skip_nans(norms2_hess / norms)).sum() current_histograms[layer].curv_hess.extend(skip_nans(norms2_hess / norms)) current[layer].curv_hess_max += (skip_nans(norms2_hess / norms)).max() current[layer].curv_hess_median += (skip_nans(norms2_hess / norms)).median() current_histograms[layer].curv_jac.extend(skip_nans(norms2_jac / norms)) current[layer].curv_jac += (skip_nans(norms2_jac / norms)).sum() current[layer].curv_jac_max += (skip_nans(norms2_jac / norms)).max() current[layer].curv_jac_median += (skip_nans(norms2_jac / norms)).median() current[layer].a_sparsity += torch.sum(A <= 0).float() / A.numel() current[layer].b_sparsity += torch.sum(B <= 0).float() / B.numel() current[layer].mean_activation += torch.mean(A) current[layer].mean_activation2 += torch.mean(A*A) current[layer].mean_backprop = torch.mean(B) current[layer].mean_backprop2 = torch.mean(B*B) current[layer].norms_hess += torch.sqrt(norms2_hess).sum() current_histograms[layer].norms_hess.extend(torch.sqrt(norms2_hess)) current[layer].norms_jac += norms2_jac.sum() current_histograms[layer].norms_jac.extend(torch.sqrt(norms2_jac)) normalized_moments = copy.copy(hessians[layer]) normalized_moments.AA = forward_stats[layer].AA normalized_moments = u.divide_attributes(normalized_moments, n) train_regrets_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=0, m=normalized_moments, approx=args.curv) test_regrets1_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=1, m=normalized_moments, approx=args.curv) test_regrets2_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=2, m=normalized_moments, approx=args.curv) test_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=2, m=normalized_moments, approx=args.curv) train_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=0, m=normalized_moments, approx=args.curv) cosines_ = autograd_lib.offset_cosines(A, B) train_regrets[layer].extend(train_regrets_) test_regrets1[layer].extend(test_regrets1_) test_regrets2[layer].extend(test_regrets2_) train_regrets_opt[layer].extend(train_regrets_opt_) test_regrets_opt[layer].extend(test_regrets_opt_) cosines[layer].extend(cosines_) dot_products[layer].extend(autograd_lib.offset_dotprod(A, B)) # statistics of the form g.Sigma.g elif current == quad_fishers: hess = hessians[layer] sigma = fishers[layer] jac = jacobians[layer] Bs, As = B @ sigma.BB / n, A @ forward_stats[layer].AA / n Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1)) norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1)) norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1)) norms_sigma = ((As * A).sum(dim=1) * (Bs * B).sum(dim=1)) current[layer].norm += norms.sum() # TODO(y) remove, redundant with norm2 above current[layer].curv_sigma += (skip_nans(norms_sigma / norms)).sum() current[layer].curv_sigma_max = skip_nans(norms_sigma / norms).max() current[layer].curv_sigma_median = skip_nans(norms_sigma / norms).median() current[layer].curv_hess += skip_nans(norms2_hess / norms).sum() current[layer].curv_hess_max += skip_nans(norms2_hess / norms).max() current[layer].lyap_hess_mean += skip_nans(norms_sigma / norms2_hess).mean() current[layer].lyap_hess_max = max(skip_nans(norms_sigma/norms2_hess)) current[layer].lyap_jac_mean += skip_nans(norms_sigma / norms2_jac).mean() current[layer].lyap_jac_max = max(skip_nans(norms_sigma/norms2_jac)) print('backward') with u.timeit("backprop_H"): with autograd_lib.module_hook(compute_stats): current = hessians current_histograms = hessians_histograms autograd_lib.backward_hessian(output, loss='CrossEntropy', sampled=args.sampled, retain_graph=True) # 600 ms current = jacobians current_histograms = jacobians_histograms autograd_lib.backward_jacobian(output, sampled=args.sampled, retain_graph=True) # 600 ms current = fishers current_histograms = fishers_histograms model.zero_grad() loss.backward(retain_graph=True) # 60 ms current = quad_fishers current_histograms = quad_fishers_histograms model.zero_grad() loss.backward() # 60 ms print('summarize') for (i, layer) in enumerate(model.layers): stats_dict = {'hessian': hessians, 'jacobian': jacobians, 'fisher': fishers} # evaluate stats from # https://app.wandb.ai/yaroslavvb/train_ciresan/runs/425pu650?workspace=user-yaroslavvb for stats_name in stats_dict: s = AttrDict() stats = stats_dict[stats_name][layer] for key in forward_stats[layer]: # print(f'copying {key} in {stats_name}, {layer}') try: assert stats[key] == float() except: f"Trying to overwrite {key} in {stats_name}, {layer}" stats[key] = forward_stats[layer][key] diag: torch.Tensor = stats.diag / n # jacobian: # curv in direction of gradient goes down to roughly 0.3-1 # maximum curvature goes up to 1000-2000 # # Hessian: # max curv goes down to 1, in direction of gradient 0.0001 s.diag_l2 = torch.max(diag) # 40 - 3000 smaller than kfac l2 for jac s.diag_fro = torch.norm( diag) # jacobian grows to 0.5-1.5, rest falls, layer-5 has phase transition, layer-4 also s.diag_trace = diag.sum() # jacobian grows 0-1000 (first), 0-150 (last). Almost same as kfac_trace (771 vs 810 kfac). Jacobian has up/down phase transition s.diag_average = diag.mean() # normalize for mean loss BB = stats.BB / n AA = stats.AA / n # A_evals, _ = torch.symeig(AA) # averaging 120ms per hit, 90 hits # B_evals, _ = torch.symeig(BB) # s.kfac_l2 = torch.max(A_evals) * torch.max(B_evals) # 60x larger than diag_l2. layer0/hess has down/up phase transition. layer5/jacobian has up/down phase transition s.kfac_trace = torch.trace(AA) * torch.trace(BB) # 0/hess down/up tr, 5/jac sharp phase transition s.kfac_fro = torch.norm(stats.AA) * torch.norm( stats.BB) # 0/hess has down/up tr, 5/jac up/down transition # s.kfac_erank = s.kfac_trace / s.kfac_l2 # first layer has 25, rest 15, all layers go down except last, last noisy # s.kfac_erank_fro = s.kfac_trace / s.kfac_fro / max(stats.BA.shape) s.diversity = (stats.norm2 / n) / u.norm_squared( stats.BA / n) # gradient diversity. Goes up 3x. Bottom layer has most diversity. Jacobian diversity much less noisy than everythingelse # discrepancy of KFAC based on exact values of diagonal approximation # average difference normalized by average diagonal magnitude diag_kfac = torch.einsum('ll,ii->li', BB, AA) s.kfac_error = (torch.abs(diag_kfac - diag)).mean() / torch.mean(diag.abs()) u.log_scalars(u.nest_stats(f'layer-{i}/{stats_name}', s)) # openai batch size stat s = AttrDict() hess = hessians[layer] jac = jacobians[layer] fish = fishers[layer] quad_fish = quad_fishers[layer] # the following check passes, but is expensive # if args.stats_num_batches == 1: # u.check_close(fisher[layer].BA, layer.weight.grad) def trsum(A, B): return (A * B).sum() # computes tr(AB') grad = fishers[layer].BA / n s.grad_fro = torch.norm(grad) # get norms s.lyap_hess_max = quad_fish.lyap_hess_max s.lyap_hess_ave = quad_fish.lyap_hess_sum / n s.lyap_jac_max = quad_fish.lyap_jac_max s.lyap_jac_ave = quad_fish.lyap_jac_sum / n s.hess_trace = hess.diag.sum() / n s.jac_trace = jac.diag.sum() / n # Version 1 of Jain stochastic rates, use Hessian for curvature b = args.train_batch_size s.hess_curv = trsum((hess.BB / n) @ grad @ (hess.AA / n), grad) / trsum(grad, grad) s.jac_curv = trsum((jac.BB / n) @ grad @ (jac.AA / n), grad) / trsum(grad, grad) # compute gradient noise statistics # fish.BB has /n factor twice, hence don't need extra /n on fish.AA # after sampling, hess_noise,jac_noise became 100x smaller, but normalized is unaffected s.hess_noise = (trsum(hess.AA / n, fish.AA / n) * trsum(hess.BB / n, fish.BB / n)) s.jac_noise = (trsum(jac.AA / n, fish.AA / n) * trsum(jac.BB / n, fish.BB / n)) s.hess_noise_centered = s.hess_noise - trsum(hess.BB / n @ grad, grad @ hess.AA / n) s.jac_noise_centered = s.jac_noise - trsum(jac.BB / n @ grad, grad @ jac.AA / n) s.openai_gradient_noise = (fish.norms_hess / n) / trsum(hess.BB / n @ grad, grad @ hess.AA / n) s.mean_norm = torch.sqrt(fish.norm2) / n s.min_norm = torch.sqrt(fish.min_norm2) s.median_norm = torch.sqrt(fish.median_norm2) s.max_norm = torch.sqrt(fish.max_norm2) s.enorms = u.norm_squared(grad) s.a_sparsity = fish.a_sparsity s.b_sparsity = fish.b_sparsity s.mean_activation = fish.mean_activation s.msr_activation = torch.sqrt(fish.mean_activation2) s.mean_backprop = fish.mean_backprop s.msr_backprop = torch.sqrt(fish.mean_backprop2) s.norms_centered = fish.norm2 / n - u.norm_squared(grad) s.norms_hess = fish.norms_hess / n s.norms_jac = fish.norms_jac / n s.hess_curv_grad = fish.curv_hess / n # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.hess_curv_grad_max = fish.curv_hess_max # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.hess_curv_grad_median = fish.curv_hess_median # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.sigma_curv_grad = quad_fish.curv_sigma / n s.sigma_curv_grad_max = quad_fish.curv_sigma_max s.sigma_curv_grad_median = quad_fish.curv_sigma_median s.band_bottou = 0.5 * lr * s.sigma_curv_grad / s.hess_curv_grad s.band_bottou_stoch = 0.5 * lr * quad_fish.curv_ratio / n s.band_yaida = 0.25 * lr * s.mean_norm**2 s.band_yaida_centered = 0.25 * lr * s.norms_centered s.jac_curv_grad = fish.curv_jac / n # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. s.jac_curv_grad_max = fish.curv_jac_max # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. s.jac_curv_grad_median = fish.curv_jac_median # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. # OpenAI gradient noise statistics s.hess_noise_normalized = s.hess_noise_centered / (fish.norms_hess / n) s.jac_noise_normalized = s.jac_noise / (fish.norms_jac / n) train_regrets_, test_regrets1_, test_regrets2_, train_regrets_opt_, test_regrets_opt_, cosines_, dot_products_ = (torch.stack(r[layer]) for r in (train_regrets, test_regrets1, test_regrets2, train_regrets_opt, test_regrets_opt, cosines, dot_products)) s.train_regret = train_regrets_.median() # use median because outliers make it hard to see the trend s.test_regret1 = test_regrets1_.median() s.test_regret2 = test_regrets2_.median() s.test_regret_opt = test_regrets_opt_.median() s.train_regret_opt = train_regrets_opt_.median() s.mean_dot_product = torch.mean(dot_products_) s.median_dot_product = torch.median(dot_products_) a = [1, 2, 3] s.median_cosine = cosines_.median() s.mean_cosine = cosines_.mean() # get learning rates L1 = s.hess_curv_grad / n L2 = s.jac_curv_grad / n diversity = (fish.norm2 / n) / u.norm_squared(grad) robust_diversity = (fish.norm2 / n) / fish.median_norm2 dotprod_diversity = fish.median_norm2 / s.median_dot_product s.lr1 = 2 / (L1 * diversity) s.lr2 = 2 / (L2 * diversity) s.lr3 = 2 / (L2 * robust_diversity) s.lr4 = 2 / (L2 * dotprod_diversity) hess_A = u.symeig_pos_evals(hess.AA / n) hess_B = u.symeig_pos_evals(hess.BB / n) fish_A = u.symeig_pos_evals(fish.AA / n) fish_B = u.symeig_pos_evals(fish.BB / n) jac_A = u.symeig_pos_evals(jac.AA / n) jac_B = u.symeig_pos_evals(jac.BB / n) u.log_scalars({f'layer-{i}/hessA_erank': erank(hess_A)}) u.log_scalars({f'layer-{i}/hessB_erank': erank(hess_B)}) u.log_scalars({f'layer-{i}/fishA_erank': erank(fish_A)}) u.log_scalars({f'layer-{i}/fishB_erank': erank(fish_B)}) u.log_scalars({f'layer-{i}/jacA_erank': erank(jac_A)}) u.log_scalars({f'layer-{i}/jacB_erank': erank(jac_B)}) gl.event_writer.add_histogram(f'layer-{i}/hist_hess_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_fish_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_jac_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) s.hess_l2 = max(hess_A) * max(hess_B) s.jac_l2 = max(jac_A) * max(jac_B) s.fish_l2 = max(fish_A) * max(fish_B) s.hess_trace = hess.diag.sum() / n s.jain1_sto = 1/(s.hess_trace + 2 * s.hess_l2) s.jain1_det = 1/s.hess_l2 s.jain1_lr = (1 / b) * (1/s.jain1_sto) + (b - 1) / b * (1/s.jain1_det) s.jain1_lr = 2 / s.jain1_lr s.regret_ratio = ( train_regrets_opt_ / test_regrets_opt_).median() # ratio between train and test regret, large means overfitting u.log_scalars(u.nest_stats(f'layer-{i}', s)) # compute stats that would let you bound rho if i == 0: # only compute this once, for output layer hhh = hessians[model.layers[-1]].BB / n fff = fishers[model.layers[-1]].BB / n d = fff.shape[0] L = u.lyapunov_spectral(hhh, 2 * fff, cond=1e-8) L_evals = u.symeig_pos_evals(L) Lcheap = fff @ u.pinv(hhh, cond=1e-8) Lcheap_evals = u.eig_real(Lcheap) u.log_scalars({f'mismatch/rho': d/erank(L_evals)}) u.log_scalars({f'mismatch/rho_cheap': d/erank(Lcheap_evals)}) u.log_scalars({f'mismatch/diagonalizability': erank(L_evals)/erank(Lcheap_evals)}) # 1 means diagonalizable u.log_spectrum(f'mismatch/sigma', u.symeig_pos_evals(fff), loglog=False) u.log_spectrum(f'mismatch/hess', u.symeig_pos_evals(hhh), loglog=False) u.log_spectrum(f'mismatch/lyapunov', L_evals, loglog=True) u.log_spectrum(f'mismatch/lyapunov_cheap', Lcheap_evals, loglog=True) gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms', u.to_numpy(fishers_histograms[layer].norms.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms_hess', u.to_numpy(fishers_histograms[layer].norms_hess.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_curv_jac', u.to_numpy(fishers_histograms[layer].curv_jac.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_curv_hess', u.to_numpy(fishers_histograms[layer].curv_hess.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_cosines', u.to_numpy(cosines[layer]), gl.get_global_step()) if args.log_spectra: with u.timeit('spectrum'): # 2/alpha # s.jain1_lr = (1 / b) * s.jain1_sto + (b - 1) / b * s.jain1_det # s.jain1_lr = 1 / s.jain1_lr # hess.diag_trace, jac.diag_trace # Version 2 of Jain stochastic rates, use Jacobian squared for curvature s.jain2_sto = s.lyap_jac_max * s.jac_trace / s.lyap_jac_ave s.jain2_det = s.jac_l2 s.jain2_lr = (1 / b) * s.jain2_sto + (b - 1) / b * s.jain2_det s.jain2_lr = 1 / s.jain2_lr u.log_spectrum(f'layer-{i}/hess_A', hess_A) u.log_spectrum(f'layer-{i}/hess_B', hess_B) u.log_spectrum(f'layer-{i}/hess_AB', u.outer(hess_A, hess_B).flatten()) u.log_spectrum(f'layer-{i}/jac_A', jac_A) u.log_spectrum(f'layer-{i}/jac_B', jac_B) u.log_spectrum(f'layer-{i}/fish_A', fish_A) u.log_spectrum(f'layer-{i}/fish_B', fish_B) u.log_scalars({f'layer-{i}/trace_ratio': fish_B.sum()/hess_B.sum()}) L = torch.eig(u.lyapunov_spectral(hess.BB, 2*fish.BB, cond=1e-8))[0] L = L[:, 0] # extract real part L = L.sort()[0] L = torch.flip(L, [0]) L_cheap = torch.eig(fish.BB @ u.pinv(hess.BB, cond=1e-8))[0] L_cheap = L_cheap[:, 0] # extract real part L_cheap = L_cheap.sort()[0] L_cheap = torch.flip(L_cheap, [0]) d = len(hess_B) u.log_spectrum(f'layer-{i}/Lyap', L) u.log_spectrum(f'layer-{i}/Lyap_cheap', L_cheap) u.log_scalars({f'layer-{i}/dims': d}) u.log_scalars({f'layer-{i}/L_erank': erank(L)}) u.log_scalars({f'layer-{i}/L_cheap_erank': erank(L_cheap)}) u.log_scalars({f'layer-{i}/rho': d/erank(L)}) u.log_scalars({f'layer-{i}/rho_cheap': d/erank(L_cheap)}) model.train() with u.timeit('train'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) gl.token_count += data.shape[0] gl.event_writer.close()
def test_grad_norms(): """Test computing gradient norms using various methods.""" u.seed_random(1) # torch.set_default_dtype(torch.float64) data_width = 3 batch_size = 2 d = [data_width**2, 6, 10] o = d[-1] stats_steps = 2 num_samples = batch_size * stats_steps # number of samples used in computation of curvature stats model: u.SimpleModel = u.SimpleMLP(d, nonlin=True, bias=True) loss_fn = torch.nn.CrossEntropyLoss() autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=num_samples, data_width=data_width, original_targets=True) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) stats_iter = iter(stats_loader) moments = defaultdict(lambda: AttrDefault(float)) norms = defaultdict(lambda: AttrDefault(MyList)) data_batches = [] targets_batches = [] for stats_step in range(stats_steps): data, targets = next(stats_iter) data_batches.append(data) targets_batches.append(targets) activations = {} def forward_aggregate(layer, A, _): activations[layer] = A moments[layer].AA += torch.einsum('ni,nj->ij', A, A) moments[layer].a += torch.einsum("ni->i", A) with autograd_lib.module_hook(forward_aggregate): output = model(data) loss_fn(output, targets) def backward_aggregate(layer, _, B): A = activations[layer] moments[layer].b += torch.einsum("nk->k", B) moments[layer].BA += torch.einsum("nl,ni->li", B, A) moments[layer].BB += torch.einsum("nk,nl->kl", B, B) moments[layer].BABA += torch.einsum('nl,ni,nk,nj->likj', B, A, B, A) with autograd_lib.module_hook(backward_aggregate): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) # compare against results using autograd data = torch.cat(data_batches) targets = torch.cat(targets_batches) with autograd_lib.save_activations2() as activations: loss = loss_fn(model(data), targets) def normalize_moments(d, n): result = AttrDict() for val in d: if type(d[val]) == torch.Tensor: result[val] = d[val] / n return result def compute_norms(layer, _, B): A = activations[layer] for kind in ('zero_order', 'kfac', 'isserlis', 'full'): normalized_moments = normalize_moments(moments[layer], num_samples) norms_list = getattr(norms[layer], kind) norms_list.extend( autograd_lib.grad_norms(A, B, normalized_moments, approx=kind)) with autograd_lib.module_hook(compute_norms): model.zero_grad() (len(data) * loss).backward(retain_graph=True) print(norms[model.layers[0]].zero_order.value()) for layer in model.layers: output = model(data) losses = torch.stack([ loss_fn(output[i:i + 1], targets[i:i + 1]) for i in range(len(data)) ]) grads = u.jacobian(losses, layer.weight) grad_norms = torch.einsum('nij,nij->n', grads, grads) u.check_close(grad_norms, norms[layer].zero_order) # test gradient norms with custom metric kfac_norms, isserlis_norms, full_norms = [ u.to_pytorch(getattr(norms[layer], k)) for k in ('kfac', 'isserlis', 'full') ] error_kfac = max(abs(kfac_norms - full_norms)) error_isserlis = max(abs(isserlis_norms - full_norms)) assert error_isserlis < 1e-4 assert error_kfac < 1e-4