def test_unfold(): """Reproduce convolution as a special case of matrix multiplication with unfolded input tensors""" gl.skip_backward_hooks = False gl.skip_forward_hooks = False gl.backward_idx = 0 N, Xc, Xh, Xw = 1, 2, 3, 3 model: u.SimpleModel = u.SimpleConvolutional([Xc, 2]) weight_buffer = model.layers[0].weight.data weight_buffer.copy_(torch.ones_like(weight_buffer)) dims = N, Xc, Xh, Xw size = np.prod(dims) X = torch.arange(0, size).reshape(*dims) def loss_fn(data): err = data.reshape(len(data), -1) return torch.sum(err * err) / 2 / len(data) layer = model.layers[0] output = model(X) loss = loss_fn(output) loss.backward() u.check_close(layer.activations, X) assert layer.backprops_list[0].shape == layer.output.shape unfold = torch.nn.functional.unfold fold = torch.nn.functional.fold out_unf = layer.weight.view(layer.weight.size(0), -1) @ unfold(layer.activations, (2, 2)) u.check_close(fold(out_unf, layer.output.shape[2:], (1, 1)), output)
def test_to_logits(): torch.set_default_dtype(torch.float32) p = torch.tensor([0.2, 0.5, 0.3]) u.check_close(p, F.softmax(u.to_logits(p), dim=0)) u.check_close(p.unsqueeze(0), F.softmax(u.to_logits(p.unsqueeze(0)), dim=1))
def test_kron_mnist(): u.seed_random(1) data_width = 3 batch_size = 3 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 # torch.set_default_dtype(torch.float64) model: u.SimpleModel2 = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.add_hooks(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) # get gradient values u.clear_backprops(model) autograd_lib.enable_hooks() output = model(data) autograd_lib.backprop_hess(output, hess_type='CrossEntropy') i = 0 layer = model.layers[i] autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation H2 = layer.weight.hess_factored H2_bias = layer.bias.hess_factored H2, H2_bias = u.expand_hess(H2, H2_bias) # autograd Hessian computation loss = loss_fn(output, targets) # TODO: change to d[i+1]*d[i] H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1]) H_bias_autograd = u.hessian(loss, layer.bias) # compare direct against autograd u.check_close(H, H_autograd) u.check_close(H_bias, H_bias_autograd) approx_error = u.symsqrt_dist(H, H2) assert approx_error < 1e-2, approx_error
def test_kron_nano(): u.seed_random(1) d = [1, 2] n = 1 # torch.set_default_dtype(torch.float32) loss_type = 'CrossEntropy' model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: loss_fn = nn.CrossEntropyLoss() data = torch.randn(n, d[0]) data = torch.ones(n, d[0]) if loss_type.endswith('LeastSquares'): target = torch.randn(n, d[-1]) elif loss_type == 'CrossEntropy': target = torch.LongTensor(n).random_(0, d[-1]) target = torch.tensor([0]) # Hessian computation, saves regular and Kronecker factored versions into .hess and .hess_factored attributes autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() for layer in model.layers: Hk: u.Kron = layer.weight.hess_factored Hk_bias: u.Kron = layer.bias.hess_factored Hk, Hk_bias = u.expand_hess(Hk, Hk_bias) # kronecker multiply the factors # old approach, using direct computation H2, H_bias2 = layer.weight.hess, layer.bias.hess # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, target) H_autograd = u.hessian(loss, layer.weight) H_bias_autograd = u.hessian(loss, layer.bias) # compare autograd with direct approach u.check_close(H2, H_autograd.reshape(Hk.shape)) u.check_close(H_bias2, H_bias_autograd) # compare factored with direct approach assert(u.symsqrt_dist(Hk, H2) < 1e-6)
def test_symsqrt(): u.seed_random(1) torch.set_default_dtype(torch.float32) mat = torch.reshape(torch.arange(9) + 1, (3, 3)).float() + torch.eye(3) * 5 mat = mat + mat.t() # make symmetric smat = u.symsqrt(mat) u.check_close(mat, smat @ smat.t()) u.check_close(mat, smat @ smat) def randomly_rotate(X): """Randomly rotate d,n data matrix X""" d, n = X.shape z = torch.randn((d, d), dtype=X.dtype) q, r = torch.qr(z) d = torch.diag(r) ph = d / abs(d) rot_mat = q * ph return rot_mat @ X n = 20 d = 10 X = torch.randn((d, n)) # embed in a larger space X = torch.cat([X, torch.zeros_like(X)]) X = randomly_rotate(X) cov = X @ X.t() / n sqrt, rank = u.symsqrt(cov, return_rank=True) assert rank == d assert torch.allclose(sqrt @ sqrt, cov, atol=1e-5) Y = torch.randn((d, n)) Y = torch.cat([Y, torch.zeros_like(X)]) Y = randomly_rotate(X) cov = u.Kron(X @ X.t(), Y @ Y.t()) sqrt, rank = cov.symsqrt(return_rank=True) assert rank == d * d u.check_close(sqrt @ sqrt, cov, rtol=1e-4) X = torch.tensor([[7., 0, 0, 0, 0]]).t() X = randomly_rotate(X) cov = X @ X.t() u.check_close(u.sym_l2_norm(cov), 7 * 7) Y = torch.tensor([[8., 0, 0, 0, 0]]).t() Y = randomly_rotate(Y) cov = u.Kron(X @ X.t(), Y @ Y.t()) u.check_close(cov.sym_l2_norm(), 7 * 7 * 8 * 8)
def test_truncated_lyapunov(): d = 100 n = 1000 shared_rank = 2 independent_rank = 1 A, C = u.random_cov_pair(shared_rank=shared_rank, independent_rank=independent_rank, strength=0.1, d=d, n=n) X = u.lyapunov_truncated(A, C) # effective rank of X captures dimensionality of shared subspace u.check_close(u.rank(X), shared_rank + independent_rank, rtol=1e-4) u.check_close(u.erank(X), shared_rank, rtol=1e-2)
def test_full_hessian_xent_mnist_multilayer(): """Test regular and diagonal hessian computation.""" u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 6, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() hess = defaultdict(float) hess_diag = defaultdict(float) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) u.check_close(hess[model.layers[0]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[0]] / batch_size) H_autograd = u.hessian(loss, model.layers[1].weight) u.check_close(hess[model.layers[1]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[1]] / batch_size)
def test_cross_entropy_hessian_mnist(): u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() loss_hessian = u.HessianExactCrossEntropyLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) # get gradient values u.clear_backprops(model) model.skip_forward_hooks = False model.skip_backward_hooks = False output = model(data) for bval in loss_hessian(output): output.backward(bval, retain_graph=True) i = 0 layer = model.layers[i] H, Hbias = u.hessian_from_backprops(layer.activations, layer.backprops_list, bias=True) model.skip_forward_hooks = True model.skip_backward_hooks = True # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hbias, Hbias_autograd)
def test_cross_entropy_hessian_tiny(): u.seed_random(1) batch_size = 1 d = [2, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True) model.layers[0].weight.data.copy_(torch.eye(2)) loss_fn = torch.nn.CrossEntropyLoss() loss_hessian = u.HessianExactCrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.3]])) targets = torch.tensor([0]) # get gradient values u.clear_backprops(model) model.skip_forward_hooks = False model.skip_backward_hooks = False output = model(data) for bval in loss_hessian(output): output.backward(bval, retain_graph=True) i = 0 layer = model.layers[i] H, Hbias = u.hessian_from_backprops(layer.activations, layer.backprops_list, bias=True) model.skip_forward_hooks = True model.skip_backward_hooks = True # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) u.check_close(H, H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1])) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hbias, Hbias_autograd)
def test_full_fisher_multibatch(): torch.set_default_dtype(torch.float64) u.seed_random(1) A, model = create_toy_model() activations = {} def save_activations(layer, a, _): if layer != model.layers[0]: return activations[layer] = a fisher = [0] def compute_fisher(layer, _, B): if layer != model.layers[0]: return A = activations[layer] n = A.shape[0] di = A.shape[1] do = B.shape[1] Jo = torch.einsum("ni,nj->nij", B, A).reshape(n, -1) fisher[0] += torch.einsum('ni,nj->ij', Jo, Jo) for x in A.t(): with autograd_lib.module_hook(save_activations): y = model(x) loss = torch.sum(y * y) / 2 with autograd_lib.module_hook(compute_fisher): loss.backward() # result computed using single step forward prop result0 = torch.tensor( [[5.383625e+06, -3.675000e+03, 4.846250e+06, -6.195000e+04], [-3.675000e+03, 1.102500e+04, -6.195000e+04, 1.858500e+05], [4.846250e+06, -6.195000e+04, 4.674500e+06, -1.044300e+06], [-6.195000e+04, 1.858500e+05, -1.044300e+06, 3.132900e+06]]) u.check_close(fisher[0], result0)
def test_full_fisher(): u.seed_random(1) A, model = create_toy_model() activations = {} def save_activations(layer, a, _): if layer != model.layers[0]: return activations[layer] = a with autograd_lib.module_hook(save_activations): Y = model(A.t()) loss = torch.sum(Y * Y) / 2 fisher = [0] def compute_fisher(layer, _, B): if layer != model.layers[0]: return A = activations[layer] n = A.shape[0] di = A.shape[1] do = B.shape[1] Jo = torch.einsum("ni,nj->nij", B, A).reshape(n, -1) fisher[0] += torch.einsum('ni,nj->ij', Jo, Jo) with autograd_lib.module_hook(compute_fisher): loss.backward() result0 = torch.tensor( [[5.383625e+06, -3.675000e+03, 4.846250e+06, -6.195000e+04], [-3.675000e+03, 1.102500e+04, -6.195000e+04, 1.858500e+05], [4.846250e+06, -6.195000e+04, 4.674500e+06, -1.044300e+06], [-6.195000e+04, 1.858500e+05, -1.044300e+06, 3.132900e+06]]) u.check_close(fisher[0], result0)
def test_cross_entropy_soft(): torch.set_default_dtype(torch.float32) q = torch.tensor([0.4, 0.6]).unsqueeze(0).float() p = torch.tensor([0.7, 0.3]).unsqueeze(0).float() observed_logit = u.to_logits(p) # Compare against other loss functions # https://www.wolframcloud.com/obj/user-eac9ee2d-7714-42da-8f84-bec1603944d5/newton/logistic-hessian.nb loss1 = F.binary_cross_entropy(p[0], q[0]) u.check_close(loss1, 0.865054) loss_fn = u.CrossEntropySoft() loss2 = loss_fn(observed_logit, q) u.check_close(loss2, loss1) loss3 = F.cross_entropy(observed_logit, torch.tensor([0])) u.check_close(loss3, loss_fn(observed_logit, torch.tensor([[1, 0.]]))) # check gradient observed_logit.requires_grad = True grad = torch.autograd.grad(loss_fn(observed_logit, target=q), observed_logit) u.check_close(p - q, grad[0]) # check Hessian observed_logit = u.to_logits(p) observed_logit.zero_() observed_logit.requires_grad = True hessian_autograd = u.hessian(loss_fn(observed_logit, target=q), observed_logit) hessian_autograd = hessian_autograd.reshape((p.numel(), p.numel())) p = F.softmax(observed_logit, dim=1) hessian_manual = torch.diag(p[0]) - p.t() @ p u.check_close(hessian_autograd, hessian_manual)
def test_hessian_multibatch(): """Test that Kronecker-factored computations still work when splitting work over batches.""" u.seed_random(1) # torch.set_default_dtype(torch.float64) gl.project_name = 'test' gl.logdir_base = '/tmp/runs' run_name = 'test_hessian_multibatch' u.setup_logdir_and_event_writer(run_name=run_name) loss_type = 'CrossEntropy' data_width = 2 n = 4 d1 = data_width ** 2 o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=False) model = model.to(gl.device) dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False) stats_iter = u.infinite_iter(stats_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) data, targets = stats_data, stats_targets # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) layer = model.layers[0] autograd_lib.clear_hess_backprops(model) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # compute Hessian using direct method, compare against PyTorch autograd hess0 = u.hessian(loss, layer.weight) autograd_lib.compute_hess(model) hess1 = layer.weight.hess u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-8, rtol=1e-6) # compute Hessian using factored method. Because Hessian depends on examples for cross entropy, factoring is not exact, raise tolerance autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True) hess2 = layer.weight.hess2 u.check_close(hess1, hess2, atol=1e-3, rtol=1e-1) # compute Hessian using multibatch # restart iterators dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) assert n % 2 == 0 stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n//2, shuffle=False) stats_iter = u.infinite_iter(stats_loader) autograd_lib.compute_cov(model, loss_fn, stats_iter, batch_size=n//2, steps=2) cov: autograd_lib.LayerCov = layer.cov hess2: u.Kron = hess2.commute() # get back into AA x BB order u.check_close(cov.H.value(), hess2)
def test_conv_hessian(): """Test per-example gradient computation for conv layer.""" u.seed_random(1) n, Xc, Xh, Xw = 3, 2, 3, 7 dd = [Xc, 2] Kh, Kw = 2, 3 Oh, Ow = Xh - Kh + 1, Xw - Kw + 1 model: u.SimpleModel = u.ReshapedConvolutional(dd, kernel_size=(Kh, Kw), bias=True) weight_buffer = model.layers[0].weight.data assert (Kh, Kw) == model.layers[0].kernel_size data = torch.randn((n, Xc, Xh, Xw)) # output channels, input channels, height, width assert weight_buffer.shape == (dd[1], dd[0], Kh, Kw) def loss_fn(data): err = data.reshape(len(data), -1) return torch.sum(err * err) / 2 / len(data) loss_hessian = u.HessianExactSqrLoss() # o = Oh * Ow * dd[1] output = model(data) o = output.shape[1] for bval in loss_hessian(output): output.backward(bval, retain_graph=True) assert loss_hessian.num_samples == o i, layer = next(enumerate(model.layers)) At = unfold(layer.activations, (Kh, Kw)) # -> n, Xc * Kh * Kw, Oh * Ow assert At.shape == (n, dd[0] * Kh * Kw, Oh*Ow) # o, n, dd[1], Oh, Ow -> o, n, dd[1], Oh*Ow Bh_t = torch.stack([Bt.reshape(n, dd[1], Oh*Ow) for Bt in layer.backprops_list]) assert Bh_t.shape == (o, n, dd[1], Oh*Ow) Ah_t = torch.stack([At]*o) assert Ah_t.shape == (o, n, dd[0] * Kh * Kw, Oh*Ow) # sum out the output patch dimension Jb = torch.einsum('onij,onkj->onik', Bh_t, Ah_t) # => o, n, dd[1], dd[0] * Kh * Kw Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb) # => n, dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw Jb_bias = torch.einsum('onij->oni', Bh_t) Hb_i = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias) H = Hi.mean(dim=0) Hb = Hb_i.mean(dim=0) model.disable_hooks() loss = loss_fn(model(data)) H_autograd = u.hessian(loss, layer.weight) assert H_autograd.shape == (dd[1], dd[0], Kh, Kw, dd[1], dd[0], Kh, Kw) assert H.shape == (dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw) u.check_close(H, H_autograd.reshape(H.shape), rtol=1e-4, atol=1e-7) Hb_autograd = u.hessian(loss, layer.bias) assert Hb_autograd.shape == (dd[1], dd[1]) u.check_close(Hb, Hb_autograd) assert len(Bh_t) == loss_hessian.num_samples == o for xi in range(n): loss = loss_fn(model(data[xi:xi + 1, ...])) H_autograd = u.hessian(loss, layer.weight) u.check_close(Hi[xi], H_autograd.reshape(H.shape)) Hb_autograd = u.hessian(loss, layer.bias) u.check_close(Hb_i[xi], Hb_autograd) assert Hb_i[xi, 0, 0] == Oh*Ow # each output has curvature 1, bias term adds up Oh*Ow of them
def test_hessian(): """Tests of Hessian computation.""" u.seed_random(1) batch_size = 500 data_width = 4 targets_width = 4 d1 = data_width ** 2 d2 = 10 d3 = targets_width ** 2 o = d3 N = batch_size d = [d1, d2, d3] dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) data, targets = next(train_iter) def loss_fn(data, targets): assert len(data) == len(targets) err = data - targets.view(-1, data.shape[1]) return torch.sum(err * err) / 2 / len(data) u.seed_random(1) model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True) # backprop hessian and compare against autograd hessian_backprop = u.HessianExactSqrLoss() output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) i, layer = next(enumerate(model.layers)) A_t = layer.activations Bh_t = layer.backprops_list H, Hb = u.hessian_from_backprops(A_t, Bh_t, bias=True) model.disable_hooks() H_autograd = u.hessian(loss_fn(model(data), targets), layer.weight) u.check_close(H, H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i]), rtol=1e-4, atol=1e-7) Hb_autograd = u.hessian(loss_fn(model(data), targets), layer.bias) u.check_close(Hb, Hb_autograd, rtol=1e-4, atol=1e-7) # check first few per-example Hessians Hi, Hb_i = u.per_example_hess(A_t, Bh_t, bias=True) u.check_close(H, Hi.mean(dim=0)) u.check_close(Hb, Hb_i.mean(dim=0), atol=2e-6, rtol=1e-5) for xi in range(5): loss = loss_fn(model(data[xi:xi + 1, ...]), targets[xi:xi + 1]) H_autograd = u.hessian(loss, layer.weight) u.check_close(Hi[xi], H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i])) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hb_i[i], Hbias_autograd) # get subsampled Hessian u.seed_random(1) model = u.SimpleFullyConnected(d, nonlin=False) hessian_backprop = u.HessianSampledSqrLoss(num_samples=1) output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) model.disable_hooks() i, layer = next(enumerate(model.layers)) H_approx1 = u.hessian_from_backprops(layer.activations, layer.backprops_list) # get subsampled Hessian with more samples u.seed_random(1) model = u.SimpleFullyConnected(d, nonlin=False) hessian_backprop = u.HessianSampledSqrLoss(num_samples=o) output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) model.disable_hooks() i, layer = next(enumerate(model.layers)) H_approx2 = u.hessian_from_backprops(layer.activations, layer.backprops_list) assert abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) < 0.08, abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) # 0.0612 assert abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) < 0.03, abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) # 0.0239 assert u.kl_div_cov(H_approx1, H) < 0.3, u.kl_div_cov(H_approx1, H) # 0.222 assert u.kl_div_cov(H_approx2, H) < 0.2, u.kl_div_cov(H_approx2, H) # 0.1233
def test_symsqrt_neg(): """Test robustness to small negative eigenvalues.""" u.seed_random(1) torch.set_default_dtype(torch.float32) mat = torch.tensor( [[ 1.704692840576171875e-05, -9.693153669044357601e-15, -4.637238930627063382e-07, -5.784777457051859528e-08, -7.958237541183521557e-11, -9.898678399622440338e-06, -2.152719247305867611e-07, -1.635662982835128787e-08, -6.400216989277396351e-06, -1.906904145698717912e-08 ], [ -9.693153669044357601e-15, 9.693318840469072190e-15, -4.495100185314538545e-21, -5.607465147510056466e-22, -7.714305304820877864e-25, -9.595268609986125189e-20, -2.086734953246380139e-21, -1.585527368218132061e-22, -6.204040063550678436e-20, -1.848454455492805625e-22 ], [ -4.637238930627063382e-07, -4.495100185314538545e-21, 4.637315669242525473e-07, -2.682631101578927119e-14, -3.690551006200058401e-17, -4.590410516980281130e-12, -9.983013764605641605e-14, -7.585219006177833234e-15, -2.968034672201635971e-12, -8.843071403179941781e-15 ], [ -5.784777457051859528e-08, -5.607465147510056466e-22, -2.682631101578927119e-14, 5.784875867220762302e-08, -4.603820523722603687e-18, -5.726360683376563454e-13, -1.245342652107404510e-14, -9.462269640040017228e-16, -3.702509490405986314e-13, -1.103139288087268827e-15 ], [ -7.958237541183521557e-11, -7.714305304820877864e-25, -3.690551006200058401e-17, -4.603820523722603687e-18, 7.958373543504038139e-11, -7.877872806981575052e-16, -1.713243580702275093e-17, -1.301743849346897150e-18, -5.093618864028342207e-16, -1.517611416046059107e-18 ], [ -9.898678399622440338e-06, -9.595268609986125189e-20, -4.590410516980281130e-12, -5.726360683376563454e-13, -7.877872806981575052e-16, 9.898749340209178627e-06, -2.130980288062023220e-12, -1.619145491510778911e-13, -6.335585528427500890e-11, -1.887647481900456281e-13 ], [ -2.152719247305867611e-07, -2.086734953246380139e-21, -9.983013764605641605e-14, -1.245342652107404510e-14, -1.713243580702275093e-17, -2.130980288062023220e-12, 2.152755484985391377e-07, -3.521243228436451468e-15, -1.377834044774539635e-12, -4.105169319306963341e-15 ], [ -1.635662982835128787e-08, -1.585527368218132061e-22, -7.585219006177833234e-15, -9.462269640040017228e-16, -1.301743849346897150e-18, -1.619145491510778911e-13, -3.521243228436451468e-15, 1.635690871637507371e-08, -1.046895521119271810e-13, -3.119158858896762436e-16 ], [ -6.400216989277396351e-06, -6.204040063550678436e-20, -2.968034672201635971e-12, -3.702509490405986314e-13, -5.093618864028342207e-16, -6.335585528427500890e-11, -1.377834044774539635e-12, -1.046895521119271810e-13, 6.400285201380029321e-06, -1.220501699922618699e-13 ], [ -1.906904145698717912e-08, -1.848454455492805625e-22, -8.843071403179941781e-15, -1.103139288087268827e-15, -1.517611416046059107e-18, -1.887647481900456281e-13, -4.105169319306963341e-15, -3.119158858896762436e-16, -1.220501699922618699e-13, 1.906936653028878936e-08 ]]) evals = torch.eig(mat).eigenvalues assert torch.min(evals) < 0 smat = u.symsqrt(mat) u.check_close(mat, smat @ smat.t()) u.check_close(mat, smat @ smat)
def test_robust_svd(): mat = np.genfromtxt('test/gesvd_crash.txt', delimiter=",") mat = torch.tensor(mat).type(torch.get_default_dtype()) U, S, V = u.robust_svd(mat) mat2 = U @ torch.diag(S) @ V.T u.check_close(mat, mat2)
def test_kron(): """Test kron, vec and vecr identities""" torch.set_default_dtype(torch.float64) a = torch.tensor([1, 2, 3, 4]).reshape(2, 2) b = torch.tensor([5, 6, 7, 8]).reshape(2, 2) u.check_close(u.Kron(a, b).trace(), 65) a = torch.tensor([[2., 7, 9], [1, 9, 8], [2, 7, 5]]) b = torch.tensor([[6., 6, 1], [10, 7, 7], [7, 10, 10]]) Ck = u.Kron(a, b) u.check_close(a.flatten().norm() * b.flatten().norm(), Ck.frobenius_norm()) u.check_close(Ck.frobenius_norm(), 4 * math.sqrt(11635.)) Ci = [[ 0, 5 / 102, -(7 / 204), 0, -(70 / 561), 49 / 561, 0, 125 / 1122, -(175 / 2244) ], [ 1 / 20, -(53 / 1020), 8 / 255, -(7 / 55), 371 / 2805, -(224 / 2805), 5 / 44, -(265 / 2244), 40 / 561 ], [ -(1 / 20), 3 / 170, 3 / 170, 7 / 55, -(42 / 935), -(42 / 935), -(5 / 44), 15 / 374, 15 / 374 ], [ 0, -(5 / 102), 7 / 204, 0, 20 / 561, -(14 / 561), 0, 35 / 1122, -(49 / 2244) ], [ -(1 / 20), 53 / 1020, -(8 / 255), 2 / 55, -(106 / 2805), 64 / 2805, 7 / 220, -(371 / 11220), 56 / 2805 ], [ 1 / 20, -(3 / 170), -(3 / 170), -(2 / 55), 12 / 935, 12 / 935, -(7 / 220), 21 / 1870, 21 / 1870 ], [0, 5 / 102, -(7 / 204), 0, 0, 0, 0, -(5 / 102), 7 / 204], [ 1 / 20, -(53 / 1020), 8 / 255, 0, 0, 0, -(1 / 20), 53 / 1020, -(8 / 255) ], [ -(1 / 20), 3 / 170, 3 / 170, 0, 0, 0, 1 / 20, -(3 / 170), -(3 / 170) ]] C = Ck.expand() C0 = u.to_numpy(C) Ci = torch.tensor(Ci) u.check_close(C @ Ci @ C, C) u.check_close(Ck.inv().expand(), torch.inverse(Ck.expand())) u.check_close(Ck.inv().expand_vec(), torch.inverse(Ck.expand_vec())) u.check_close(Ck.pinv().expand(), torch.pinverse(Ck.expand())) u.check_close(linalg.pinv(C0), Ci, rtol=1e-5, atol=1e-6) u.check_close(torch.pinverse(C), Ci, rtol=1e-5, atol=1e-6) u.check_close(Ck.inv().expand(), Ci, rtol=1e-5, atol=1e-6) u.check_close(Ck.pinv().expand(), Ci, rtol=1e-5, atol=1e-6) Ck2 = u.Kron(b, 2 * a) u.check_close((Ck @ Ck2).expand(), Ck.expand() @ Ck2.expand()) u.check_close((Ck @ Ck2).expand_vec(), Ck.expand_vec() @ Ck2.expand_vec()) d2 = 3 d1 = 2 G = torch.randn(d2, d1) g = u.vec(G) H = u.Kron(u.random_cov(d1), u.random_cov(d2)) Gt = G.t() gt = g.reshape(1, -1) vecX = u.Vec([1, 2, 3, 4], shape=(2, 2)) K = u.Kron([[5, 6], [7, 8]], [[9, 10], [11, 12]]) u.check_equal(vecX @ K, [644, 706, 748, 820]) u.check_equal(K @ vecX, [543, 655, 737, 889]) u.check_equal(u.matmul(vecX @ K, vecX), 7538) u.check_equal(vecX @ (vecX @ K), 7538) u.check_equal(vecX @ vecX, 30) vecX = u.Vec([1, 2], shape=(1, 2)) K = u.Kron([[5]], [[9, 10], [11, 12]]) u.check_equal(vecX.norm()**2, 5) # check kronecker rules X = torch.tensor([[1., 2], [3, 4]]) A = torch.tensor([[5., 6], [7, 8]]) B = torch.tensor([[9., 10], [11, 12]]) x = u.Vec(X) # kron/vec/vecr identities u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t())) u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A) @ x) u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()) @ u.Vecr(X)) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B)) def extra_checks(A, X, B): x = u.Vec(X) u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t())) u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A) @ x) u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()) @ u.Vecr(X)) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B)) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B).normal_form()) u.check_equal(u.Vecr(A @ X @ B), u.matmul(u.Kron(A, B.t()).normal_form(), u.Vecr(X))) u.check_equal(u.Vec(A @ X @ B), u.matmul(u.Kron(B.t(), A).normal_form(), x)) u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t()).normal_form()) u.check_equal(u.Vec(A @ X @ B), x.normal_form() @ u.Kron(B, A.t()).normal_form()) u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A).normal_form() @ x.normal_form()) u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()).normal_form() @ u.Vecr(X).normal_form()) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X).normal_form() @ u.Kron(A.t(), B).normal_form()) # shape checks d1, d2 = 3, 4 extra_checks(torch.ones((d1, d1)), torch.ones((d1, d2)), torch.ones((d2, d2))) A = torch.rand(d1, d1) B = torch.rand(d2, d2) #x = torch.rand((d1*d2)) #X = x.t().reshape(d1, d2) # X = torch.rand((d1, d2)) # x = u.vec(X) x = torch.rand((d1 * d2))
def test_main_autograd(): u.seed_random(1) log_wandb = False autograd_check = True use_double = False logdir = u.get_unique_logdir('/tmp/autoencoder_test/run') run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) batch_size = 5 try: if log_wandb: wandb.init(project='test-autograd_test', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['batch'] = batch_size except Exception as e: print(f"wandb crash with {e}") data_width = 4 targets_width = 2 d1 = data_width ** 2 d2 = 10 d3 = targets_width ** 2 o = d3 n = batch_size d = [d1, d2, d3] model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True) if use_double: model = model.double() train_steps = 3 dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size * train_steps) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == batch_size return torch.sum(err * err) / 2 / len(data) loss_hessian = u.HessianExactSqrLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) if use_double: data, targets = data.double(), targets.double() # get gradient values model.skip_backward_hooks = False model.skip_forward_hooks = False u.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) model.skip_forward_hooks = True output = model(data) for bval in loss_hessian(output): if use_double: bval = bval.double() output.backward(bval, retain_graph=True) model.skip_backward_hooks = True for (i, layer) in enumerate(model.layers): ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) # per example gradients G = u.khatri_rao_t(B_t, A_t) assert G.shape == (n, d[i+1] * d[i]) Gbias = B_t assert Gbias.shape == (n, d[i + 1]) # average gradient g = G.sum(dim=0, keepdim=True) / n gb = Gbias.sum(dim=0, keepdim=True) / n assert g.shape == (1, d[i] * d[i + 1]) assert gb.shape == (1, d[i + 1]) if autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) u.check_close(torch.einsum('nj->j', B_t) / n, layer.bias.saved_grad) u.check_close(torch.mean(B_t, dim=0), layer.bias.saved_grad) u.check_close(torch.einsum('ni,nj->ij', B_t, A_t)/n, layer.weight.saved_grad) # empirical Fisher efisher = G.t() @ G / n _sigma = efisher - g.t() @ g ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [layer.backprops_list[out_idx + 1] for out_idx in range(o)] Amat_t = torch.cat([A_t] * o, dim=0) # todo: can instead replace with a khatri-rao loop Bmat_t = torch.cat(Bh_t, dim=0) Amat_t2 = torch.stack([A_t]*o, dim=0) # o, n, in_dim Bmat_t2 = torch.stack(Bh_t, dim=0) # o, n, out_dim assert Amat_t.shape == (n * o, d[i]) assert Bmat_t.shape == (n * o, d[i + 1]) Jb = u.khatri_rao_t(Bmat_t, Amat_t) # batch output Jacobian H = Jb.t() @ Jb / n Jb2 = torch.einsum('oni,onj->onij', Bmat_t2, Amat_t2) u.check_close(H.reshape(d[i+1], d[i], d[i+1], d[i]), torch.einsum('onij,onkl->ijkl', Jb2, Jb2)/n) Hbias = Bmat_t.t() @ Bmat_t / n u.check_close(Hbias, torch.einsum('ni,nj->ij', Bmat_t, Bmat_t) / n) if autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(H, H_autograd.reshape(d[i+1] * d[i], d[i+1] * d[i])) u.check_close(Hbias, Hbias_autograd)
def _test_kron_conv_golden(): """Hardcoded error values to detect unexpected numeric changes.""" u.seed_random(1) n, Xh, Xw = 2, 8, 8 Kh, Kw = 2, 2 dd = [3, 3, 3, 3] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) # print(model) # print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron') autograd_lib.compute_hess(model, method='mean_kron', attr_name='hess_mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() errors1 = [] errors2 = [] for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_kron Hk_bias = layer.bias.hess_kron Hk = Hk.expand() Hk_bias = Hk_bias.expand() Hk2 = layer.weight.hess_mean_kron Hk2_bias = layer.bias.hess_mean_kron Hk2 = Hk2.expand() Hk2_bias = Hk2_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) errors1.extend([u.symsqrt_dist(H, Hk), u.symsqrt_dist(H_bias, Hk_bias)]) errors2.extend([u.symsqrt_dist(H, Hk2), u.symsqrt_dist(H_bias, Hk2_bias)]) errors1 = torch.tensor(errors1) errors2 = torch.tensor(errors2) golden_errors1 = torch.tensor([0.09458080679178238, 0.0, 0.13416489958763123, 0.0, 0.0003909761435352266, 0.0]) golden_errors2 = torch.tensor([0.0945773795247078, 0.0, 0.13418318331241608, 0.0, 4.478318658129865e-07, 0.0]) u.check_close(golden_errors1, errors1) u.check_close(golden_errors2, errors2)
def main(): u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") d1 = args.data_width ** 2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width ** 2 d2 = 2 d3 = args.targets_width ** 2 if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size)+1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam(model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i]>50 or d[i+1]>50: print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma)/s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H+lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2)) with u.timeit(f"pinvH-{i}"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm() # frobenius norm of Newton update s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_svd(H, sigma) if u.has_nan(p_sigma) and args.compute_rho: # use expensive method print('using expensive method') import pdb; pdb.set_trace() H0, sigma0 = u.to_numpys(H, sigma) p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0) p_sigma = torch.tensor(p_sigma).to(gl.device) if u.has_nan(p_sigma): # import pdb; pdb.set_trace() s.psigma_erank = H.shape[0] s.rho = 1 else: s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1-args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def test_lyapunov(): """Test that scipy lyapunov solver works correctly.""" d = 2 n = 3 torch.set_default_dtype(torch.float32) model = Net(d) w0 = torch.tensor([[1, 2]]).float() assert w0.shape[1] == d model.w.weight.data.copy_(w0) X = torch.tensor([[-2, 0, 2], [-1, 1, 3]]).float() assert X.shape[0] == d assert X.shape[1] == n Y = torch.tensor([[0, 1, 2]]).float() assert Y.shape[1] == X.shape[1] data = X.t() # PyTorch expects batch dimension first target = Y.t() assert data.shape[0] == n output = model(data) # residuals, aka e residuals = output - Y.t() def compute_loss(residuals_): return torch.sum(residuals_ * residuals_) / (2 * n) loss = compute_loss(residuals) assert loss - 8.83333 < 1e-5, torch.norm(loss) - 8.83333 # use learning rate 0 to avoid changing parameter vector optim_kwargs = dict( lr=0, momentum=0, weight_decay=0, l2_reg=0, bias_correction=False, acc_steps=1, curv_type="Cov", curv_shapes={"Linear": "Kron"}, momentum_type="preconditioned", ) curv_args = dict(damping=1, ema_decay=1) # todo: damping optimizer = SecondOrderOptimizer(model, **optim_kwargs, curv_kwargs=curv_args) def backward(last_layer: str) -> Callable: """Creates closure that backpropagates either from output layer or from loss layer""" def closure() -> Tuple[Optional[torch.Tensor], torch.Tensor]: optimizer.zero_grad() output = model(data) if last_layer == "output": output.backward(torch.ones_like(target)) return None, output elif last_layer == 'loss': loss = compute_loss(output - target) loss.backward() return loss, output else: assert False, 'last layer must be "output" or "loss"' return closure # loss = compute_loss(output - Y.t()) # loss.backward() loss, output = optimizer.step(closure=backward('loss')) J = X.t() A = model.w.data_input B = model.w.grad_output * n G = residuals.repeat(1, d) * J losses = torch.stack([compute_loss(r) for r in residuals]) g = G.sum(dim=0) / n efisher = G.t() @ G / n sigma = efisher - u.outer(g, g) loss2 = (residuals * residuals).sum() / (2 * n) H = J.t() @ J / n noise_variance = torch.trace(H.inverse() @ sigma) # H is not quite symmetric, make it so H = H + H.t() # Slow way p_sigma = u.lyapunov_lstsq(H, sigma) sigma0 = u.to_numpy(sigma) H0 = u.to_numpy(H) # Alternative faster way p_sigma2 = scipy.linalg.solve_lyapunov(H0, sigma0) print(f"Error 1: {np.max(abs(H0 @ p_sigma2 + p_sigma2 @ H0 - sigma0))}") u.check_close(p_sigma, p_sigma2) # alternative through SVD p_sigma3 = lyapunov_svd(torch.tensor(H0), torch.tensor(sigma0)) u.check_close(p_sigma2, p_sigma3) # alternative through evals p_sigma4 = u.lyapunov_spectral(torch.tensor(H0), torch.tensor(sigma0)) u.check_close(p_sigma2, p_sigma4)
def main(): attemp_count = 0 while os.path.exists(f"{args.logdir}{attemp_count:02d}"): attemp_count += 1 logdir = f"{args.logdir}{attemp_count:02d}" run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") u.seed_random(1) d1 = args.data_width**2 d2 = 10 d3 = args.targets_width**2 o = d3 n = args.stats_batch_size d = [d1, d2, d3] model = u.SimpleFullyConnected(d, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['d1'] = d1 wandb.config['d2'] = d2 wandb.config['d3'] = d3 wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size, train=False) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=True, drop_last=True) test_iter = u.infinite_iter(test_loader) skip_forward_hooks = False skip_backward_hooks = False def capture_activations(module: nn.Module, input: List[torch.Tensor], output: torch.Tensor): if skip_forward_hooks: return assert not hasattr( module, 'activations' ), "Seeing results of previous autograd, call util.zero_grad to clear" assert len(input) == 1, "this was tested for single input layers only" setattr(module, "activations", input[0].detach()) setattr(module, "output", output.detach()) def capture_backprops(module: nn.Module, _input, output): if skip_backward_hooks: return assert len(output) == 1, "this works for single variable layers only" if gl.backward_idx == 0: assert not hasattr( module, 'backprops' ), "Seeing results of previous autograd, call util.zero_grad to clear" setattr(module, 'backprops', []) assert gl.backward_idx == len(module.backprops) module.backprops.append(output[0]) def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]: """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient if the value has not been set, call util.zero_grad to clear it.""" def save_grad_fn(grad): if not hasattr(param, 'saved_grad'): setattr(param, 'saved_grad', grad) return save_grad_fn for layer in model.layers: layer.register_forward_hook(capture_activations) layer.register_backward_hook(capture_backprops) layer.weight.register_hook(save_grad(layer.weight)) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == len(targets) return torch.sum(err * err) / 2 / len(data) gl.token_count = 0 last_outer = 0 for step in range(args.stats_steps): if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() # compute validation loss skip_forward_hooks = True skip_backward_hooks = True with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats data, targets = next(stats_iter) skip_forward_hooks = False skip_backward_hooks = False # get gradient values with u.timeit("backprop_g"): gl.backward_idx = 0 u.zero_grad(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) # get Hessian values skip_forward_hooks = True id_mat = torch.eye(o).to(gl.device) u.log_scalar(loss=loss.item()) with u.timeit("backprop_H"): # optionally use randomized low-rank approximation of Hessian hess_rank = args.hess_samples if args.hess_samples else o for out_idx in range(hess_rank): model.zero_grad() # backprop to get section of batch output jacobian for output at position out_idx output = model( data ) # opt: using autograd.grad means I don't have to zero_grad if args.hess_samples: bval = torch.LongTensor(n, o).to(gl.device).random_( 0, 2) * 2 - 1 bval = bval.float() else: ei = id_mat[out_idx] bval = torch.stack([ei] * n) gl.backward_idx = out_idx + 1 output.backward(bval) skip_backward_hooks = True # for (i, layer) in enumerate(model.layers): s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma) / s.sigma_l2 ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [ layer.backprops[out_idx + 1] for out_idx in range(hess_rank) ] Amat_t = torch.cat([A_t] * hess_rank, dim=0) Bmat_t = torch.cat(Bh_t, dim=0) assert Amat_t.shape == (n * hess_rank, d[i]) assert Bmat_t.shape == (n * hess_rank, d[i + 1]) lambda_regularizer = args.lmb * torch.eye(d[i] * d[i + 1]).to( gl.device) with u.timeit(f"khatri_H-{i}"): Jb = u.khatri_rao_t( Bmat_t, Amat_t) # batch Jacobian, in row-vec format with u.timeit(f"H-{i}"): H = Jb.t() @ Jb / n with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H + lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.jacobian_fro = Jb.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm()**2)) with u.timeit("pinvH"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / u.sym_l2_norm(H) s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_svd(H, sigma) if u.has_nan( p_sigma) and args.compute_rho: # use expensive method H0 = H.cpu().detach().numpy() sigma0 = sigma.cpu().detach().numpy() p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0) p_sigma = torch.tensor(p_sigma).to(gl.device) if u.has_nan(p_sigma): s.psigma_erank = H.shape[0] s.rho = 1 else: s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) print('openai batch', s.batch_openai) s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps last_inner = 0 for i in range(args.train_steps): if last_inner: u.log_scalars( {"time/inner": 1000 * (time.perf_counter() - last_inner)}) last_inner = time.perf_counter() optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def test_factored_hessian(): """"Simple test to ensure Hessian computation is working. In a linear neural network with squared loss, Newton step will converge in one step. Compute stats after minimizing, pass sanity checks. """ u.seed_random(1) loss_type = 'LeastSquares' data_width = 2 n = 5 d1 = data_width ** 2 o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=False) model = model.to(gl.device) print(model) dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False) stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 data, targets = stats_data, stats_targets # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) print(loss) loss.backward(retain_graph=True) layer = model.layers[0] autograd_lib.clear_hess_backprops(model) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # compute Hessian using direct method, compare against PyTorch autograd hess0 = u.hessian(loss, layer.weight) autograd_lib.compute_hess(model) hess1 = layer.weight.hess print(hess1) u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-9, rtol=1e-6) # compute Hessian using factored method autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True) # s.regret_newton = vecG.t() @ pinvH.commute() @ vecG.t() / 2 # TODO(y): figure out why needed transposes hess2 = layer.weight.hess2 u.check_close(hess1, hess2, atol=1e-9, rtol=1e-6) # Newton step in regular notation g1 = layer.weight.grad.flatten() newton1 = hess1 @ g1 g2 = u.Vecr(layer.weight.grad) newton2 = g2 @ hess2 u.check_close(newton1, newton2, atol=1e-9, rtol=1e-6) # compute regret in factored notation, compare against actual drop in loss regret1 = g1 @ hess1.pinverse() @ g1 / 2 regret2 = g2 @ hess2.pinv() @ g2 / 2 u.check_close(regret1, regret2) current_weight = layer.weight.detach().clone() param: torch.nn.Parameter = layer.weight # param.data.sub_((hess1.pinverse() @ g1).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 1", loss) # param.data.sub_((hess1.pinverse() @ u.vec(layer.weight.grad)).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 2", loss) # param.data.sub_((u.vec(layer.weight.grad).t() @ hess1.pinverse()).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 3", loss) # del layer.weight.grad output = model(data) loss = loss_fn(output, targets) loss.backward() param.data.sub_(u.unvec(hess1.pinverse() @ u.vec(layer.weight.grad), layer.weight.shape[0])) output = model(data) loss = loss_fn(output, targets) print("result 4", loss) # param.data.sub_((g1 @ hess1.pinverse() @ g1).reshape(param.shape)) print(loss)
def test_kron_conv_exact(): """Test per-example gradient computation for conv layer. Kronecker factoring is exact for 1x1 convolutions and linear activations. """ u.seed_random(1) n, Xh, Xw = 2, 2, 2 Kh, Kw = 1, 1 dd = [2, 2, 2] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) #print(model) #print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_factored Hk_bias = layer.bias.hess_factored Hk = Hk.expand() Hk_bias = Hk_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) u.check_close(H_bias, Hk_bias) u.check_close(H, Hk)
def _test_explicit_hessian_refactored(): """Check computation of hessian of loss(B'WA) from https://github.com/yaroslavvb/kfac_pytorch/blob/master/derivation.pdf """ torch.set_default_dtype(torch.float64) A = torch.tensor([[-1., 4], [3, 0]]) B = torch.tensor([[-4., 3], [2, 6]]) X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True) Y = B.t() @ X @ A u.check_equal(Y, [[-52, 64], [-81, -108]]) loss = torch.sum(Y * Y) / 2 hess0 = u.hessian(loss, X).reshape([4, 4]) hess1 = u.Kron(A @ A.t(), B @ B.t()) u.check_equal(loss, 12512.5) # Do a test using Linear layers instead of matrix multiplies model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False) model.layers[0].weight.data.copy_(X) # Transpose to match previous results, layers treat dim0 as batch dimension u.check_equal(model.layers[0](A.t()).t(), [[5, -20], [-16, -8]]) # XA = (A'X0)' model.layers[1].weight.data.copy_(B.t()) u.check_equal(model(A.t()).t(), Y) Y = model(A.t()).t() # transpose to data-dimension=columns loss = torch.sum(Y * Y) / 2 loss.backward() u.check_equal(model.layers[0].weight.grad, [[-2285, -105], [-1490, -1770]]) G = B @ Y @ A.t() u.check_equal(model.layers[0].weight.grad, G) autograd_lib.register(model) activations_dict = autograd_lib.ModuleDict() # todo(y): make save_activations ctx manager automatically create A with autograd_lib.save_activations(activations_dict): Y = model(A.t()) Acov = autograd_lib.ModuleDict(autograd_lib.SecondOrderCov) for layer, activations in activations_dict.items(): print(layer, activations) Acov[layer].accumulate(activations, activations) autograd_lib.set_default_activations(activations_dict) autograd_lib.set_default_Acov(Acov) B = autograd_lib.ModuleDict(autograd_lib.SymmetricFourthOrderCov) autograd_lib.backward_accum(Y, "identity", B, retain_graph=False) print(B[model.layers[0]]) autograd_lib.backprop_hess(Y, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron', vecr_order=False, loss_aggregation='sum') param = model.layers[0].weight hess2 = param.hess_kron print(hess2) u.check_equal(hess2, [[425, 170, -75, -30], [170, 680, -30, -120], [-75, -30, 225, 90], [-30, -120, 90, 360]]) # Gradient test model.zero_grad() loss.backward() u.check_close(u.vec(G).flatten(), u.Vec(param.grad)) # Newton step test # Method 0: PyTorch native autograd newton_step0 = param.grad.flatten() @ torch.pinverse(hess0) newton_step0 = newton_step0.reshape(param.shape) u.check_equal(newton_step0, [[-5, 0], [-2, -6]]) # Method 1: colummn major order ihess2 = hess2.pinv() u.check_equal(ihess2.LL, [[1/16, 1/48], [1/48, 17/144]]) u.check_equal(ihess2.RR, [[2/45, -(1/90)], [-(1/90), 1/36]]) u.check_equal(torch.flatten(hess2.pinv() @ u.vec(G)), [-5, -2, 0, -6]) newton_step1 = (ihess2 @ u.Vec(param.grad)).matrix_form() # Method2: row major order ihess2_rowmajor = ihess2.commute() newton_step2 = ihess2_rowmajor @ u.Vecr(param.grad) newton_step2 = newton_step2.matrix_form() u.check_equal(newton_step0, newton_step1) u.check_equal(newton_step0, newton_step2)
def main(): attemp_count = 0 while os.path.exists(f"{args.logdir}{attemp_count:02d}"): attemp_count += 1 logdir = f"{args.logdir}{attemp_count:02d}" run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") u.seed_random(1) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method except Exception as e: print(f"wandb crash with {e}") # data_width = 4 # targets_width = 2 d1 = args.data_width**2 d2 = 10 d3 = args.targets_width**2 o = d3 n = args.stats_batch_size d = [d1, d2, d3] model = u.SimpleFullyConnected(d, nonlin=args.nonlin) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) def capture_activations(module, input, _output): if skip_forward_hooks: return assert gl.backward_idx == 0 # no need to forward-prop on Hessian computation assert not hasattr( module, 'activations' ), "Seeing activations from previous forward, call util.zero_grad to clear" assert len(input) == 1, "this works for single input layers only" setattr(module, "activations", input[0].detach()) def capture_backprops(module: nn.Module, _input, output): if skip_backward_hooks: return assert len(output) == 1, "this works for single variable layers only" if gl.backward_idx == 0: assert not hasattr( module, 'backprops' ), "Seeing results of previous autograd, call util.zero_grad to clear" setattr(module, 'backprops', []) assert gl.backward_idx == len(module.backprops) module.backprops.append(output[0]) def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]: """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient if the value has not been set, call util.zero_grad to clear it.""" def save_grad_fn(grad): if not hasattr(param, 'saved_grad'): setattr(param, 'saved_grad', grad) return save_grad_fn for layer in model.layers: layer.register_forward_hook(capture_activations) layer.register_backward_hook(capture_backprops) layer.weight.register_hook(save_grad(layer.weight)) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == len(targets) return torch.sum(err * err) / 2 / len(data) gl.token_count = 0 for step in range(args.stats_steps): data, targets = next(stats_iter) skip_forward_hooks = False skip_backward_hooks = False # get gradient values gl.backward_idx = 0 u.zero_grad(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) print("loss", loss.item()) # get Hessian values skip_forward_hooks = True id_mat = torch.eye(o) u.log_scalars({'loss': loss.item()}) # o = 0 for out_idx in range(o): model.zero_grad() # backprop to get section of batch output jacobian for output at position out_idx output = model( data ) # opt: using autograd.grad means I don't have to zero_grad ei = id_mat[out_idx] bval = torch.stack([ei] * n) gl.backward_idx = out_idx + 1 output.backward(bval) skip_backward_hooks = True # for (i, layer) in enumerate(model.layers): s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops[0] * n assert B_t.shape == (n, d[i + 1]) G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) # empirical Fisher efisher = G.t() @ G / n sigma = efisher - g.t() @ g # u.dump(sigma, f'/tmp/sigmas/{step}-{i}') s.sigma_l2 = u.l2_norm(sigma) ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)] Amat_t = torch.cat([A_t] * o, dim=0) Bmat_t = torch.cat(Bh_t, dim=0) assert Amat_t.shape == (n * o, d[i]) assert Bmat_t.shape == (n * o, d[i + 1]) Jb = u.khatri_rao_t(Bmat_t, Amat_t) # batch Jacobian, in row-vec format H = Jb.t() @ Jb / n pinvH = u.pinv(H) s.hess_l2 = u.l2_norm(H) s.invhess_l2 = u.l2_norm(pinvH) s.hess_fro = H.flatten().norm() s.invhess_fro = pinvH.flatten().norm() s.jacobian_l2 = u.l2_norm(Jb) s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / dd.flatten().norm()**2) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) s.grad_curv = curv_direction(g) ndir = g @ u.pinv(H) # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', u.pinv(H)) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 999 s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_gradient = loss_direction(g, s.step_openai) u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def test_explicit_hessian(): """Check computation of hessian of loss(B'WA) from https://github.com/yaroslavvb/kfac_pytorch/blob/master/derivation.pdf """ torch.set_default_dtype(torch.float64) A = torch.tensor([[-1., 4], [3, 0]]) B = torch.tensor([[-4., 3], [2, 6]]) X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True) Y = B.t() @ X @ A u.check_equal(Y, [[-52, 64], [-81, -108]]) loss = torch.sum(Y * Y) / 2 hess0 = u.hessian(loss, X).reshape([4, 4]) hess1 = u.Kron(A @ A.t(), B @ B.t()) u.check_equal(loss, 12512.5) # PyTorch autograd computes Hessian with respect to row-vectorized parameters, whereas # autograd_lib uses math convention and does column-vectorized. # Commuting order of Kronecker product switches between two representations u.check_equal(hess1.commute(), hess0) # Do a test using Linear layers instead of matrix multiplies model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False) model.layers[0].weight.data.copy_(X) # Transpose to match previous results, layers treat dim0 as batch dimension u.check_equal(model.layers[0](A.t()).t(), [[5, -20], [-16, -8]]) # XA = (A'X0)' model.layers[1].weight.data.copy_(B.t()) u.check_equal(model(A.t()).t(), Y) Y = model(A.t()).t() # transpose to data-dimension=columns loss = torch.sum(Y * Y) / 2 loss.backward() u.check_equal(model.layers[0].weight.grad, [[-2285, -105], [-1490, -1770]]) G = B @ Y @ A.t() u.check_equal(model.layers[0].weight.grad, G) u.check_equal(hess0, u.Kron(B @ B.t(), A @ A.t())) # compute newton step u.check_equal(u.Kron([email protected](), [email protected]()).pinv() @ u.vec(G), u.v2c([-5, -2, 0, -6])) # compute Newton step using factored representation autograd_lib.add_hooks(model) Y = model(A.t()) n = 2 loss = torch.sum(Y * Y) / 2 autograd_lib.backprop_hess(Y, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron', vecr_order=False, loss_aggregation='sum') param = model.layers[0].weight hess2 = param.hess_kron print(hess2) u.check_equal(hess2, [[425, 170, -75, -30], [170, 680, -30, -120], [-75, -30, 225, 90], [-30, -120, 90, 360]]) # Gradient test model.zero_grad() loss.backward() u.check_close(u.vec(G).flatten(), u.Vec(param.grad)) # Newton step test # Method 0: PyTorch native autograd newton_step0 = param.grad.flatten() @ torch.pinverse(hess0) newton_step0 = newton_step0.reshape(param.shape) u.check_equal(newton_step0, [[-5, 0], [-2, -6]]) # Method 1: colummn major order ihess2 = hess2.pinv() u.check_equal(ihess2.LL, [[1/16, 1/48], [1/48, 17/144]]) u.check_equal(ihess2.RR, [[2/45, -(1/90)], [-(1/90), 1/36]]) u.check_equal(torch.flatten(hess2.pinv() @ u.vec(G)), [-5, -2, 0, -6]) newton_step1 = (ihess2 @ u.Vec(param.grad)).matrix_form() # Method2: row major order ihess2_rowmajor = ihess2.commute() newton_step2 = ihess2_rowmajor @ u.Vecr(param.grad) newton_step2 = newton_step2.matrix_form() u.check_equal(newton_step0, newton_step1) u.check_equal(newton_step0, newton_step2)
def test_main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=1, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/temp/runs/curv_train_tiny/run') parser.add_argument('--train_batch_size', type=int, default=100) parser.add_argument('--stats_batch_size', type=int, default=60000) parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--method', type=str, choices=['gradient', 'newton'], default='gradient', help="descent method, newton or gradient") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument( '--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian' ) parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=1, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection') parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--weight_decay', type=float, default=1e-4) #args = parser.parse_args() args = AttrDict() args.lmb = 1e-3 args.compute_rho = 1 args.weight_decay = 1e-4 args.method = 'gradient' args.logdir = '/tmp' args.data_width = 2 args.targets_width = 2 args.train_batch_size = 10 args.full_batch = False args.skip_stats = False args.autograd_check = False u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) #gl.event_writer = SummaryWriter(logdir) gl.event_writer = u.NoOp() # print(f"Logging to {run_name}") # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width**2 d2 = 2 d3 = args.targets_width**2 d1 = args.data_width**2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size) + 1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") # optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam( model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) # print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i] > 50 or d[i + 1] > 50: print( f'layer {i} is too big ({d[i], d[i + 1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel( ) # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma) / s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1] * d[i]).to( gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H + lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm()**2)) with u.timeit(f"pinvH-{i}"): pinvH = H.pinverse() with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_newton = u.to_python_scalar( g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_spectral(H, sigma) discrepancy = torch.max(abs(p_sigma - p_sigma.t()) / p_sigma) s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check( u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_( param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close() assert val_losses[0] > 2.4 # 2.4828238487243652 assert val_losses[-1] < 2.25 # 2.20609712600708
def compute_layer_stats(layer): refreeze = False if hasattr(layer, 'frozen') and layer.frozen: u.unfreeze(layer) refreeze = True s = AttrDefault(str, {}) n = args.stats_batch_size param = u.get_param(layer) _d = len(param.flatten()) # dimensionality of parameters layer_idx = model.layers.index(layer) # TODO: get layer type, include it in name assert layer_idx >= 0 assert stats_data.shape[0] == n def backprop_loss(): model.zero_grad() output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) loss.backward() return loss, output def backprop_output(): model.zero_grad() output = model(stats_data) output.backward(gradient=torch.ones_like(output)) return output # per-example gradients, n, d _loss, _output = backprop_loss() At = layer.data_input Bt = layer.grad_output * n G = u.khatri_rao_t(At, Bt) g = G.sum(dim=0, keepdim=True) / n u.check_close(g, u.vec(param.grad).t()) s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2 s.grad_fro = g.flatten().norm() s.param_fro = param.data.flatten().norm() pos_activations = torch.sum(layer.data_output > 0) neg_activations = torch.sum(layer.data_output <= 0) s.a_sparsity = neg_activations.float() / ( pos_activations + neg_activations) # 1 sparsity means all 0's activation_size = len(layer.data_output.flatten()) s.a_magnitude = torch.sum(layer.data_output) / activation_size _output = backprop_output() B2t = layer.grad_output J = u.khatri_rao_t(At, B2t) # batch output Jacobian H = J.t() @ J / n s.hessian_l2 = u.l2_norm(H) s.jacobian_l2 = u.l2_norm(J) J1 = J.sum(dim=0) / n # single output Jacobian s.J1_l2 = J1.norm() # newton decrement def loss_direction(direction, eps): """loss improvement if we take step eps in direction dir""" return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 * eps**2 * direction @ H @ direction.t()) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) # TODO: gradient diversity is stuck at 1 # TODO: newton/gradient angle # TODO: newton step magnitude s.grad_curvature = u.to_python_scalar( g @ H @ g.t()) # curvature in direction of g s.step_openai = u.to_python_scalar( s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999 s.regret_gradient = loss_direction(g, s.step_openai) if refreeze: u.freeze(layer) return s