def test_kfac_jacobian_mnist(): u.seed_random(1) data_width = 3 d = [data_width**2, 8, 10] model: u.SimpleMLP = u.SimpleMLP(d, nonlin=False) autograd_lib.register(model) batch_size = 4 stats_steps = 2 n = batch_size * stats_steps dataset = u.TinyMNIST(dataset_size=n, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} jacobians = defaultdict(lambda: AttrDefault(float)) total_data = [] # sum up statistics over n examples for train_step in range(stats_steps): data, targets = next(train_iter) total_data.append(data) activations = {} def save_activations(layer, A, _): activations[layer] = A jacobians[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_jacobian(layer, _, B): A = activations[layer] jacobians[layer].BB += torch.einsum("ni,nj->ij", B, B) jacobians[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_jacobian): autograd_lib.backward_jacobian(output) for layer in model.layers: jacobian0 = jacobians[layer] jacobian_full = torch.einsum('kl,ij->kilj', jacobian0.BB / n, jacobian0.AA / n) jacobian_diag = jacobian0.diag / n J = u.jacobian(model(torch.cat(total_data)), layer.weight) J_autograd = torch.einsum('noij,nokl->ijkl', J, J) / n u.check_equal(jacobian_full, J_autograd) u.check_equal(jacobian_diag, torch.einsum('ikik->ik', J_autograd))
def test_diagonal_hessian(): u.seed_random(1) A, model = create_toy_model() activations = {} def save_activations(layer, a, _): if layer != model.layers[0]: return activations[layer] = a with autograd_lib.module_hook(save_activations): Y = model(A.t()) loss = torch.sum(Y * Y) / 2 hess = [0] def compute_hess(layer, _, B): if layer != model.layers[0]: return A = activations[layer] hess[0] += torch.einsum("ni,nj->ij", B * B, A * A).reshape(-1) with autograd_lib.module_hook(compute_hess): autograd_lib.backprop_identity(Y, retain_graph=True) # check against autograd hess0 = u.hessian(loss, model.layers[0].weight).reshape([4, 4]) u.check_equal(hess[0], torch.diag(hess0)) # check against manual solution u.check_equal(hess[0], [425., 225., 680., 360.])
def test_gradient_norms(): """Per-example gradient norms.""" u.seed_random(1) A, model = create_toy_model() activations = {} def save_activations(layer, a, _): if layer != model.layers[0]: return activations[layer] = a with autograd_lib.module_hook(save_activations): Y = model(A.t()) loss = torch.sum(Y * Y) / 2 norms = {} def compute_norms(layer, _, b): if layer != model.layers[0]: return a = activations[layer] del activations[layer] # release memory kept by activations norms[layer] = (a * a).sum(dim=1) * (b * b).sum(dim=1) with autograd_lib.module_hook(compute_norms): loss.backward() u.check_equal(norms[model.layers[0]], [3493250, 9708800])
def test_full_hessian_xent_mnist_multilayer(): """Test regular and diagonal hessian computation.""" u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 6, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() hess = defaultdict(float) hess_diag = defaultdict(float) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) u.check_close(hess[model.layers[0]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[0]] / batch_size) H_autograd = u.hessian(loss, model.layers[1].weight) u.check_close(hess[model.layers[1]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[1]] / batch_size)
def test_kfac_fisher_mnist(): u.seed_random(1) data_width = 3 d = [data_width**2, 8, 10] model: u.SimpleMLP = u.SimpleMLP(d, nonlin=False) autograd_lib.register(model) batch_size = 4 stats_steps = 2 n = batch_size * stats_steps dataset = u.TinyMNIST(dataset_size=n, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} fishers = defaultdict(lambda: AttrDefault(float)) total_data = [] # sum up statistics over n examples for train_step in range(stats_steps): data, targets = next(train_iter) total_data.append(data) activations = {} def save_activations(layer, A, _): activations[layer] = A fishers[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) * len( data) # remove data normalization def compute_fisher(layer, _, B): A = activations[layer] fishers[layer].BB += torch.einsum("ni,nj->ij", B, B) fishers[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_fisher): autograd_lib.backward_jacobian(output) for layer in model.layers: fisher0 = fishers[layer] fisher_full = torch.einsum('kl,ij->kilj', fisher0.BB / n, fisher0.AA / n) fisher_diag = fisher0.diag / n u.check_equal(torch.einsum('ikik->ik', fisher_full), fisher_diag)
def test_kron_mnist(): u.seed_random(1) data_width = 3 batch_size = 3 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 # torch.set_default_dtype(torch.float64) model: u.SimpleModel2 = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.add_hooks(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) # get gradient values u.clear_backprops(model) autograd_lib.enable_hooks() output = model(data) autograd_lib.backprop_hess(output, hess_type='CrossEntropy') i = 0 layer = model.layers[i] autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation H2 = layer.weight.hess_factored H2_bias = layer.bias.hess_factored H2, H2_bias = u.expand_hess(H2, H2_bias) # autograd Hessian computation loss = loss_fn(output, targets) # TODO: change to d[i+1]*d[i] H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1]) H_bias_autograd = u.hessian(loss, layer.bias) # compare direct against autograd u.check_close(H, H_autograd) u.check_close(H_bias, H_bias_autograd) approx_error = u.symsqrt_dist(H, H2) assert approx_error < 1e-2, approx_error
def _test_kfac_hessian_xent_mnist(): u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} hess = defaultdict(lambda: AttrDefault(float)) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] hess[layer].AA += torch.einsum("ni,nj->ij", A, A) hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) rel_error = torch.norm( (hess0 - H_autograd).flatten()) / torch.norm(H_autograd.flatten()) assert rel_error < 0.01 # 0.0057
def test_full_hessian_xent_kfac2(): """Test with uneven layers.""" u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [3, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=False) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.2, 0.1]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(lambda: AttrDefault(float)) for i in range(n): def save_activations(layer, A, _): activations[layer] = A hess[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) o = Y.shape[1] loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # expand hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) u.check_equal(hess_autograd, hess0)
def test_kron_1x2_conv(): """Minimal example of a 1x2 convolution whose Hessian/grad covariance doesn't factor as Kronecker. Two convolutional layers stacked on top of each other, followed by least squares loss. Outputs: 0 tensor([[[[0., 1., 1., 1.]]]]) 1 tensor([[[[2., 3.]]]]) 2 tensor([[[[8.]]]]) Activations/backprops: layerA 0 tensor([[[[0., 1.], [1., 1.]]]]) layerB 0 tensor([[[[1., 2.]]]]) layerA 1 tensor([[[[2.], [3.]]]]) layerB 1 tensor([[[[1.]]]]) layer 0 discrepancy: 0.6597963571548462 layer 1 discrepancy: 0.0 """ u.seed_random(1) n, Xh, Xw = 1, 1, 4 Kh, Kw = 1, 2 dd = [1, 1, 1] o = dd[-1] model: u.SimpleModel = u.StridedConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.tensor([0, 1., 1, 1]).reshape((n, dd[0], Xh, Xw)) model.layers[0].bias.data.zero_() model.layers[0].weight.data.copy_(torch.tensor([1, 2])) model.layers[1].bias.data.zero_() model.layers[1].weight.data.copy_(torch.tensor([1, 2])) sample_output = model(data) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() for i in range(len(model.layers)): layer = model.layers[i] H = layer.weight.hess Hk = layer.weight.hess_kron Hk = Hk.expand() print(u.symsqrt_dist(H, Hk))
def test_kron_nano(): u.seed_random(1) d = [1, 2] n = 1 # torch.set_default_dtype(torch.float32) loss_type = 'CrossEntropy' model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: loss_fn = nn.CrossEntropyLoss() data = torch.randn(n, d[0]) data = torch.ones(n, d[0]) if loss_type.endswith('LeastSquares'): target = torch.randn(n, d[-1]) elif loss_type == 'CrossEntropy': target = torch.LongTensor(n).random_(0, d[-1]) target = torch.tensor([0]) # Hessian computation, saves regular and Kronecker factored versions into .hess and .hess_factored attributes autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() for layer in model.layers: Hk: u.Kron = layer.weight.hess_factored Hk_bias: u.Kron = layer.bias.hess_factored Hk, Hk_bias = u.expand_hess(Hk, Hk_bias) # kronecker multiply the factors # old approach, using direct computation H2, H_bias2 = layer.weight.hess, layer.bias.hess # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, target) H_autograd = u.hessian(loss, layer.weight) H_bias_autograd = u.hessian(loss, layer.bias) # compare autograd with direct approach u.check_close(H2, H_autograd.reshape(Hk.shape)) u.check_close(H_bias2, H_bias_autograd) # compare factored with direct approach assert(u.symsqrt_dist(Hk, H2) < 1e-6)
def test_full_hessian_xent_multibatch(): u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [2, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=True) model.layers[0].weight.data.copy_(torch.eye(2)) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.3]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(float) def save_activations(layer, a, _): activations[layer] = a for i in range(n): with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) hess0 = hess[model.layers[0]] / n u.check_equal(hess_autograd, hess0)
def test_symsqrt(): u.seed_random(1) torch.set_default_dtype(torch.float32) mat = torch.reshape(torch.arange(9) + 1, (3, 3)).float() + torch.eye(3) * 5 mat = mat + mat.t() # make symmetric smat = u.symsqrt(mat) u.check_close(mat, smat @ smat.t()) u.check_close(mat, smat @ smat) def randomly_rotate(X): """Randomly rotate d,n data matrix X""" d, n = X.shape z = torch.randn((d, d), dtype=X.dtype) q, r = torch.qr(z) d = torch.diag(r) ph = d / abs(d) rot_mat = q * ph return rot_mat @ X n = 20 d = 10 X = torch.randn((d, n)) # embed in a larger space X = torch.cat([X, torch.zeros_like(X)]) X = randomly_rotate(X) cov = X @ X.t() / n sqrt, rank = u.symsqrt(cov, return_rank=True) assert rank == d assert torch.allclose(sqrt @ sqrt, cov, atol=1e-5) Y = torch.randn((d, n)) Y = torch.cat([Y, torch.zeros_like(X)]) Y = randomly_rotate(X) cov = u.Kron(X @ X.t(), Y @ Y.t()) sqrt, rank = cov.symsqrt(return_rank=True) assert rank == d * d u.check_close(sqrt @ sqrt, cov, rtol=1e-4) X = torch.tensor([[7., 0, 0, 0, 0]]).t() X = randomly_rotate(X) cov = X @ X.t() u.check_close(u.sym_l2_norm(cov), 7 * 7) Y = torch.tensor([[8., 0, 0, 0, 0]]).t() Y = randomly_rotate(Y) cov = u.Kron(X @ X.t(), Y @ Y.t()) u.check_close(cov.sym_l2_norm(), 7 * 7 * 8 * 8)
def test_factored_vs_regular(): """Take simple network, compute values in two different ways, compare.""" u.seed_random(1) gl.project_name = 'test' gl.logdir_base = '/tmp/runs' u.setup_logdir_and_event_writer(run_name=sys._getframe().f_code.co_name) d = 3 n = 3 model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([d, d], bias=False, nonlin=False) param = model.layers[0].weight # param.data.copy_(torch.eye(d)) #param.data.copy_(torch.arange(9).reshape(3, 3)) # param.data.copy_(torch.zeros(d, d)) # create simple matrix which is not quite symmetric source = 2 * torch.eye(d) source[0, 0] = 3 source[0, 1] = 4 source[1, 0] = -2 data = source.repeat([n, 1]) noise = source.repeat_interleave(n, dim=0) autograd_lib.add_hooks(model) output = model(data) output.backward(retain_graph=True, gradient=noise) loss = u.least_squares(output) autograd_lib.backprop_hess(output, hess_type='LeastSquares', model=model) autograd_lib.compute_grad1(model) autograd_lib.compute_hess(model) autograd_lib.compute_hess(model, method='kron', attr_name='hess2') autograd_lib.compute_stats(model, attr_name='stats_regular', sigma_centering=True) autograd_lib.compute_stats_factored(model, attr_name='stats_factored', sigma_centering=False) stats = param.stats_regular stats_factored = param.stats_factored for name in stats: print(name, stats[name], stats_factored[name])
def test_cross_entropy_hessian_mnist(): u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() loss_hessian = u.HessianExactCrossEntropyLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) # get gradient values u.clear_backprops(model) model.skip_forward_hooks = False model.skip_backward_hooks = False output = model(data) for bval in loss_hessian(output): output.backward(bval, retain_graph=True) i = 0 layer = model.layers[i] H, Hbias = u.hessian_from_backprops(layer.activations, layer.backprops_list, bias=True) model.skip_forward_hooks = True model.skip_backward_hooks = True # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hbias, Hbias_autograd)
def test_cross_entropy_hessian_tiny(): u.seed_random(1) batch_size = 1 d = [2, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True) model.layers[0].weight.data.copy_(torch.eye(2)) loss_fn = torch.nn.CrossEntropyLoss() loss_hessian = u.HessianExactCrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.3]])) targets = torch.tensor([0]) # get gradient values u.clear_backprops(model) model.skip_forward_hooks = False model.skip_backward_hooks = False output = model(data) for bval in loss_hessian(output): output.backward(bval, retain_graph=True) i = 0 layer = model.layers[i] H, Hbias = u.hessian_from_backprops(layer.activations, layer.backprops_list, bias=True) model.skip_forward_hooks = True model.skip_backward_hooks = True # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) u.check_close(H, H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1])) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hbias, Hbias_autograd)
def test_autoencoder_newton(): """Use Newton's method to train autoencoder.""" image_size = 3 batch_size = 64 dataset = u.TinyMNIST(data_width=image_size, targets_width=image_size, dataset_size=batch_size) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) d = image_size ** 2 # hidden layer size u.seed_random(1) model: u.SimpleModel = u.SimpleFullyConnected([d, d]) model.disable_hooks() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == batch_size return torch.sum(err * err) / 2 / len(data) for i in range(10): data, targets = next(iter(trainloader)) optimizer.zero_grad() loss = loss_fn(model(data), targets) if i > 0: assert loss < 1e-9 loss.backward() W = model.layers[0].weight grad = u.tvec(W.grad) loss = loss_fn(model(data), targets) H = u.hessian(loss, W) # for col-major: H = H.transpose(0, 1).transpose(2, 3).reshape(d**2, d**2) H = H.reshape(d ** 2, d ** 2) # For col-major: W1 = u.unvec(u.vec(W) - u.pinv(H) @ grad, d) # W1 = u.untvec(u.tvec(W) - grad @ u.pinv(H), d) W1 = u.untvec(u.tvec(W) - grad @ H.pinverse(), d) W.data.copy_(W1)
def test_full_fisher_multibatch(): torch.set_default_dtype(torch.float64) u.seed_random(1) A, model = create_toy_model() activations = {} def save_activations(layer, a, _): if layer != model.layers[0]: return activations[layer] = a fisher = [0] def compute_fisher(layer, _, B): if layer != model.layers[0]: return A = activations[layer] n = A.shape[0] di = A.shape[1] do = B.shape[1] Jo = torch.einsum("ni,nj->nij", B, A).reshape(n, -1) fisher[0] += torch.einsum('ni,nj->ij', Jo, Jo) for x in A.t(): with autograd_lib.module_hook(save_activations): y = model(x) loss = torch.sum(y * y) / 2 with autograd_lib.module_hook(compute_fisher): loss.backward() # result computed using single step forward prop result0 = torch.tensor( [[5.383625e+06, -3.675000e+03, 4.846250e+06, -6.195000e+04], [-3.675000e+03, 1.102500e+04, -6.195000e+04, 1.858500e+05], [4.846250e+06, -6.195000e+04, 4.674500e+06, -1.044300e+06], [-6.195000e+04, 1.858500e+05, -1.044300e+06, 3.132900e+06]]) u.check_close(fisher[0], result0)
def test_full_fisher(): u.seed_random(1) A, model = create_toy_model() activations = {} def save_activations(layer, a, _): if layer != model.layers[0]: return activations[layer] = a with autograd_lib.module_hook(save_activations): Y = model(A.t()) loss = torch.sum(Y * Y) / 2 fisher = [0] def compute_fisher(layer, _, B): if layer != model.layers[0]: return A = activations[layer] n = A.shape[0] di = A.shape[1] do = B.shape[1] Jo = torch.einsum("ni,nj->nij", B, A).reshape(n, -1) fisher[0] += torch.einsum('ni,nj->ij', Jo, Jo) with autograd_lib.module_hook(compute_fisher): loss.backward() result0 = torch.tensor( [[5.383625e+06, -3.675000e+03, 4.846250e+06, -6.195000e+04], [-3.675000e+03, 1.102500e+04, -6.195000e+04, 1.858500e+05], [4.846250e+06, -6.195000e+04, 4.674500e+06, -1.044300e+06], [-6.195000e+04, 1.858500e+05, -1.044300e+06, 3.132900e+06]]) u.check_close(fisher[0], result0)
def test_full_hessian(): u.seed_random(1) A, model = create_toy_model() data = A.t() # data = data.repeat(3, 1) activations = {} hess = defaultdict(float) def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): Y = model(A.t()) loss = torch.sum(Y * Y) / 2 def compute_hess(layer, _, B): A = activations[layer] n = A.shape[0] di = A.shape[1] do = B.shape[1] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) with autograd_lib.module_hook(compute_hess): autograd_lib.backprop_identity(Y, retain_graph=True) # check against autograd hess_autograd = u.hessian(loss, model.layers[0].weight) hess0 = hess[model.layers[0]] u.check_equal(hess_autograd, hess0) # check against manual solution u.check_equal(hess0.reshape(4, 4), [[425, -75, 170, -30], [-75, 225, -30, 90], [170, -30, 680, -120], [-30, 90, -120, 360]])
def test_autoencoder_minimize(): """Minimize autoencoder for a few steps.""" u.seed_random(1) torch.set_default_dtype(torch.float32) data_width = 4 targets_width = 2 batch_size = 64 dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) d1 = data_width ** 2 d2 = 10 d3 = targets_width ** 2 model: u.SimpleModel = u.SimpleFullyConnected([d1, d2, d3], nonlin=True) model.disable_hooks() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == batch_size return torch.sum(err * err) / 2 / len(data) loss = 0 for i in range(10): data, targets = next(iter(trainloader)) optimizer.zero_grad() loss = loss_fn(model(data), targets) if i == 0: assert loss > 0.054 pass loss.backward() optimizer.step() assert loss < 0.0398
def test_main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=1, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/temp/runs/curv_train_tiny/run') parser.add_argument('--train_batch_size', type=int, default=100) parser.add_argument('--stats_batch_size', type=int, default=60000) parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--method', type=str, choices=['gradient', 'newton'], default='gradient', help="descent method, newton or gradient") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument( '--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian' ) parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=1, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection') parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--weight_decay', type=float, default=1e-4) #args = parser.parse_args() args = AttrDict() args.lmb = 1e-3 args.compute_rho = 1 args.weight_decay = 1e-4 args.method = 'gradient' args.logdir = '/tmp' args.data_width = 2 args.targets_width = 2 args.train_batch_size = 10 args.full_batch = False args.skip_stats = False args.autograd_check = False u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) #gl.event_writer = SummaryWriter(logdir) gl.event_writer = u.NoOp() # print(f"Logging to {run_name}") # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width**2 d2 = 2 d3 = args.targets_width**2 d1 = args.data_width**2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size) + 1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") # optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam( model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) # print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i] > 50 or d[i + 1] > 50: print( f'layer {i} is too big ({d[i], d[i + 1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel( ) # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma) / s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1] * d[i]).to( gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H + lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm()**2)) with u.timeit(f"pinvH-{i}"): pinvH = H.pinverse() with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_newton = u.to_python_scalar( g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_spectral(H, sigma) discrepancy = torch.max(abs(p_sigma - p_sigma.t()) / p_sigma) s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check( u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_( param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close() assert val_losses[0] > 2.4 # 2.4828238487243652 assert val_losses[-1] < 2.25 # 2.20609712600708
def main(): u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") d1 = args.data_width ** 2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width ** 2 d2 = 2 d3 = args.targets_width ** 2 if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size)+1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam(model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i]>50 or d[i+1]>50: print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma)/s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H+lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2)) with u.timeit(f"pinvH-{i}"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm() # frobenius norm of Newton update s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_svd(H, sigma) if u.has_nan(p_sigma) and args.compute_rho: # use expensive method print('using expensive method') import pdb; pdb.set_trace() H0, sigma0 = u.to_numpys(H, sigma) p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0) p_sigma = torch.tensor(p_sigma).to(gl.device) if u.has_nan(p_sigma): # import pdb; pdb.set_trace() s.psigma_erank = H.shape[0] s.rho = 1 else: s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1-args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def main(): attemp_count = 0 while os.path.exists(f"{args.logdir}{attemp_count:02d}"): attemp_count += 1 logdir = f"{args.logdir}{attemp_count:02d}" run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") u.seed_random(1) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method except Exception as e: print(f"wandb crash with {e}") # data_width = 4 # targets_width = 2 d1 = args.data_width**2 d2 = 10 d3 = args.targets_width**2 o = d3 n = args.stats_batch_size d = [d1, d2, d3] model = u.SimpleFullyConnected(d, nonlin=args.nonlin) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) def capture_activations(module, input, _output): if skip_forward_hooks: return assert gl.backward_idx == 0 # no need to forward-prop on Hessian computation assert not hasattr( module, 'activations' ), "Seeing activations from previous forward, call util.zero_grad to clear" assert len(input) == 1, "this works for single input layers only" setattr(module, "activations", input[0].detach()) def capture_backprops(module: nn.Module, _input, output): if skip_backward_hooks: return assert len(output) == 1, "this works for single variable layers only" if gl.backward_idx == 0: assert not hasattr( module, 'backprops' ), "Seeing results of previous autograd, call util.zero_grad to clear" setattr(module, 'backprops', []) assert gl.backward_idx == len(module.backprops) module.backprops.append(output[0]) def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]: """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient if the value has not been set, call util.zero_grad to clear it.""" def save_grad_fn(grad): if not hasattr(param, 'saved_grad'): setattr(param, 'saved_grad', grad) return save_grad_fn for layer in model.layers: layer.register_forward_hook(capture_activations) layer.register_backward_hook(capture_backprops) layer.weight.register_hook(save_grad(layer.weight)) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == len(targets) return torch.sum(err * err) / 2 / len(data) gl.token_count = 0 for step in range(args.stats_steps): data, targets = next(stats_iter) skip_forward_hooks = False skip_backward_hooks = False # get gradient values gl.backward_idx = 0 u.zero_grad(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) print("loss", loss.item()) # get Hessian values skip_forward_hooks = True id_mat = torch.eye(o) u.log_scalars({'loss': loss.item()}) # o = 0 for out_idx in range(o): model.zero_grad() # backprop to get section of batch output jacobian for output at position out_idx output = model( data ) # opt: using autograd.grad means I don't have to zero_grad ei = id_mat[out_idx] bval = torch.stack([ei] * n) gl.backward_idx = out_idx + 1 output.backward(bval) skip_backward_hooks = True # for (i, layer) in enumerate(model.layers): s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops[0] * n assert B_t.shape == (n, d[i + 1]) G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) # empirical Fisher efisher = G.t() @ G / n sigma = efisher - g.t() @ g # u.dump(sigma, f'/tmp/sigmas/{step}-{i}') s.sigma_l2 = u.l2_norm(sigma) ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)] Amat_t = torch.cat([A_t] * o, dim=0) Bmat_t = torch.cat(Bh_t, dim=0) assert Amat_t.shape == (n * o, d[i]) assert Bmat_t.shape == (n * o, d[i + 1]) Jb = u.khatri_rao_t(Bmat_t, Amat_t) # batch Jacobian, in row-vec format H = Jb.t() @ Jb / n pinvH = u.pinv(H) s.hess_l2 = u.l2_norm(H) s.invhess_l2 = u.l2_norm(pinvH) s.hess_fro = H.flatten().norm() s.invhess_fro = pinvH.flatten().norm() s.jacobian_l2 = u.l2_norm(Jb) s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / dd.flatten().norm()**2) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) s.grad_curv = curv_direction(g) ndir = g @ u.pinv(H) # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', u.pinv(H)) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 999 s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_gradient = loss_direction(g, s.step_openai) u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run') parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument( '--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian' ) parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=1, help='skip all stats collection') parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=1000, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=2e-5) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--dropout', type=int, default=0) parser.add_argument('--swa', type=int, default=0) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument('--train_batch_size', type=int, default=64) parser.add_argument('--stats_batch_size', type=int, default=10000) parser.add_argument('--uniform', type=int, default=0, help='use uniform architecture (all layers same size)') parser.add_argument('--run_name', type=str, default='noname') gl.args = parser.parse_args() args = gl.args u.seed_random(1) gl.project_name = 'train_ciresan' u.setup_logdir_and_event_writer(args.run_name) print(f"Logging to {gl.logdir}") d1 = 28 * 28 if args.uniform: d = [784, 784, 784, 784, 784, 784, 10] else: d = [784, 2500, 2000, 1500, 1000, 500, 10] o = 10 n = args.stats_batch_size model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout) model = model.to(gl.device) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True) train_iter = u.infinite_iter(train_loader) assert not args.full_batch, "fixme: validation still uses stats_iter" if not args.full_batch: stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) else: stats_iter = None test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, original_targets=True, dataset_size=args.dataset_size) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) loss_fn = torch.nn.CrossEntropyLoss() autograd_lib.add_hooks(model) autograd_lib.disable_hooks() gl.token_count = 0 last_outer = 0 for step in range(args.stats_steps): epoch = gl.token_count // 60000 print(gl.token_count) if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() # compute validation loss if args.swa: model.eval() with u.timeit('swa'): base_opt = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) opt = torchcontrib.optim.SWA(base_opt, swa_start=0, swa_freq=1, swa_lr=args.lr) for _ in range(100): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() opt.step() opt.swap_swa_sgd() with u.timeit("validate"): val_accuracy, val_loss = validate(model, test_loader, f'test (epoch {epoch})') train_accuracy, train_loss = validate(model, stats_loader, f'train (epoch {epoch})') # save log metrics = { 'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss, 'train_loss': train_loss, 'train_accuracy': train_accuracy, 'lr': optimizer.param_groups[0]['lr'], 'momentum': optimizer.param_groups[0].get('momentum', 0) } u.log_scalars(metrics) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) if not args.skip_stats: autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type='CrossEntropy') autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model, method='kron', attr_name='hess2') autograd_lib.compute_stats_factored(model) for (i, layer) in enumerate(model.layers): param_names = {layer.weight: "weight", layer.bias: "bias"} for param in [layer.weight, layer.bias]: if param is None: continue if not hasattr(param, 'stats'): continue s = param.stats param_name = param_names[param] u.log_scalars(u.nest_stats(f"{param_name}", s)) # gradient steps model.train() last_inner = 0 for i in range(args.train_steps): if last_inner: u.log_scalars( {"time/inner": 1000 * (time.perf_counter() - last_inner)}) last_inner = time.perf_counter() optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) gl.token_count += data.shape[0] gl.event_writer.close()
def main(): u.install_pdb_handler() u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {logdir}") loss_type = 'CrossEntropy' d1 = args.data_width ** 2 args.stats_batch_size = min(args.stats_batch_size, args.dataset_size) args.train_batch_size = min(args.train_batch_size, args.dataset_size) n = args.stats_batch_size o = 10 d = [d1, 60, 60, 60, o] # dataset_size = args.dataset_size model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin, last_layer_linear=True) model = model.to(gl.device) u.mark_expensive(model.layers[0]) # to stop grad1/hess calculations on this layer print(model) try: if args.wandb: wandb.init(project='curv_train_tiny', name=run_name, dir='/tmp/wandb.runs') wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) # optimizer = torch.optim.Adam(model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) test_dataset = u.TinyMNIST(data_width=args.data_width, train=False, dataset_size=args.dataset_size, loss_type=loss_type) test_batch_size = min(args.dataset_size, 1000) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) with u.timeit("validate"): if loss_type == 'CrossEntropy': val_accuracy, val_loss = validate(model, test_loader, f'test (stats_step {step})') # train_accuracy, train_loss = validate(model, train_loader, f'train (stats_step {step})') metrics = {'stats_step': step, 'val_accuracy': val_accuracy, 'val_loss': val_loss} u.log_scalars(metrics) data, targets = stats_data, stats_targets if not args.skip_stats: # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): if hasattr(layer, 'expensive'): continue param_names = {layer.weight: "weight", layer.bias: "bias"} for param in [layer.weight, layer.bias]: # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i]*d[i+1] > 8000: print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats') continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations B_t = layer.backprops_list[0] * n s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher G = param.grad1.reshape((n, -1)) g = G.mean(dim=0, keepdim=True) u.nan_check(G) with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g # sigma_spectrum = s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma)/s.sigma_l2 H = param.hess lambda_regularizer = args.lmb * torch.eye(H.shape[0]).to(gl.device) u.nan_check(H) with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H+lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = param.data.flatten().norm() def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2)) with u.timeit(f"pinvH-{i}"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) # curvature (eigenvalue) in direction g ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 1234567 s.step_div_inf = 2 / s.H_l2 # divegent step size for batch_size=infinity s.step_div_1 = torch.tensor(2) / torch.trace(H) # divergent step for batch_size=1 s.newton_fro = ndir.flatten().norm() # frobenius norm of Newton update s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): s.rho, s.lyap_erank, lyap_evals = u.truncated_lyapunov_rho(H, sigma) s.step_div_1_adjusted = s.step_div_1/s.rho with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n # Gradient diversity / n s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank param_name = f"{layer.name}={param_names[param]}" u.log_scalars(u.nest_stats(f"{param_name}", s)) H_evals = u.symeig_pos_evals(H) sigma_evals = u.symeig_pos_evals(sigma) u.log_spectrum(f'{param_name}/hess', H_evals) u.log_spectrum(f'{param_name}/sigma', sigma_evals) u.log_spectrum(f'{param_name}/lyap', lyap_evals) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1-args.weight_decay) gl.increment_global_step(data.shape[0]) gl.event_writer.close()
def test_kron_conv_exact(): """Test per-example gradient computation for conv layer. Kronecker factoring is exact for 1x1 convolutions and linear activations. """ u.seed_random(1) n, Xh, Xw = 2, 2, 2 Kh, Kw = 1, 1 dd = [2, 2, 2] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) #print(model) #print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_factored Hk_bias = layer.bias.hess_factored Hk = Hk.expand() Hk_bias = Hk_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) u.check_close(H_bias, Hk_bias) u.check_close(H, Hk)
def test_main_autograd(): u.seed_random(1) log_wandb = False autograd_check = True use_double = False logdir = u.get_unique_logdir('/tmp/autoencoder_test/run') run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) batch_size = 5 try: if log_wandb: wandb.init(project='test-autograd_test', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['batch'] = batch_size except Exception as e: print(f"wandb crash with {e}") data_width = 4 targets_width = 2 d1 = data_width ** 2 d2 = 10 d3 = targets_width ** 2 o = d3 n = batch_size d = [d1, d2, d3] model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True) if use_double: model = model.double() train_steps = 3 dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size * train_steps) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == batch_size return torch.sum(err * err) / 2 / len(data) loss_hessian = u.HessianExactSqrLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) if use_double: data, targets = data.double(), targets.double() # get gradient values model.skip_backward_hooks = False model.skip_forward_hooks = False u.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) model.skip_forward_hooks = True output = model(data) for bval in loss_hessian(output): if use_double: bval = bval.double() output.backward(bval, retain_graph=True) model.skip_backward_hooks = True for (i, layer) in enumerate(model.layers): ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) # per example gradients G = u.khatri_rao_t(B_t, A_t) assert G.shape == (n, d[i+1] * d[i]) Gbias = B_t assert Gbias.shape == (n, d[i + 1]) # average gradient g = G.sum(dim=0, keepdim=True) / n gb = Gbias.sum(dim=0, keepdim=True) / n assert g.shape == (1, d[i] * d[i + 1]) assert gb.shape == (1, d[i + 1]) if autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) u.check_close(torch.einsum('nj->j', B_t) / n, layer.bias.saved_grad) u.check_close(torch.mean(B_t, dim=0), layer.bias.saved_grad) u.check_close(torch.einsum('ni,nj->ij', B_t, A_t)/n, layer.weight.saved_grad) # empirical Fisher efisher = G.t() @ G / n _sigma = efisher - g.t() @ g ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [layer.backprops_list[out_idx + 1] for out_idx in range(o)] Amat_t = torch.cat([A_t] * o, dim=0) # todo: can instead replace with a khatri-rao loop Bmat_t = torch.cat(Bh_t, dim=0) Amat_t2 = torch.stack([A_t]*o, dim=0) # o, n, in_dim Bmat_t2 = torch.stack(Bh_t, dim=0) # o, n, out_dim assert Amat_t.shape == (n * o, d[i]) assert Bmat_t.shape == (n * o, d[i + 1]) Jb = u.khatri_rao_t(Bmat_t, Amat_t) # batch output Jacobian H = Jb.t() @ Jb / n Jb2 = torch.einsum('oni,onj->onij', Bmat_t2, Amat_t2) u.check_close(H.reshape(d[i+1], d[i], d[i+1], d[i]), torch.einsum('onij,onkl->ijkl', Jb2, Jb2)/n) Hbias = Bmat_t.t() @ Bmat_t / n u.check_close(Hbias, torch.einsum('ni,nj->ij', Bmat_t, Bmat_t) / n) if autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(H, H_autograd.reshape(d[i+1] * d[i], d[i+1] * d[i])) u.check_close(Hbias, Hbias_autograd)
def test_conv_hessian(): """Test per-example gradient computation for conv layer.""" u.seed_random(1) n, Xc, Xh, Xw = 3, 2, 3, 7 dd = [Xc, 2] Kh, Kw = 2, 3 Oh, Ow = Xh - Kh + 1, Xw - Kw + 1 model: u.SimpleModel = u.ReshapedConvolutional(dd, kernel_size=(Kh, Kw), bias=True) weight_buffer = model.layers[0].weight.data assert (Kh, Kw) == model.layers[0].kernel_size data = torch.randn((n, Xc, Xh, Xw)) # output channels, input channels, height, width assert weight_buffer.shape == (dd[1], dd[0], Kh, Kw) def loss_fn(data): err = data.reshape(len(data), -1) return torch.sum(err * err) / 2 / len(data) loss_hessian = u.HessianExactSqrLoss() # o = Oh * Ow * dd[1] output = model(data) o = output.shape[1] for bval in loss_hessian(output): output.backward(bval, retain_graph=True) assert loss_hessian.num_samples == o i, layer = next(enumerate(model.layers)) At = unfold(layer.activations, (Kh, Kw)) # -> n, Xc * Kh * Kw, Oh * Ow assert At.shape == (n, dd[0] * Kh * Kw, Oh*Ow) # o, n, dd[1], Oh, Ow -> o, n, dd[1], Oh*Ow Bh_t = torch.stack([Bt.reshape(n, dd[1], Oh*Ow) for Bt in layer.backprops_list]) assert Bh_t.shape == (o, n, dd[1], Oh*Ow) Ah_t = torch.stack([At]*o) assert Ah_t.shape == (o, n, dd[0] * Kh * Kw, Oh*Ow) # sum out the output patch dimension Jb = torch.einsum('onij,onkj->onik', Bh_t, Ah_t) # => o, n, dd[1], dd[0] * Kh * Kw Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb) # => n, dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw Jb_bias = torch.einsum('onij->oni', Bh_t) Hb_i = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias) H = Hi.mean(dim=0) Hb = Hb_i.mean(dim=0) model.disable_hooks() loss = loss_fn(model(data)) H_autograd = u.hessian(loss, layer.weight) assert H_autograd.shape == (dd[1], dd[0], Kh, Kw, dd[1], dd[0], Kh, Kw) assert H.shape == (dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw) u.check_close(H, H_autograd.reshape(H.shape), rtol=1e-4, atol=1e-7) Hb_autograd = u.hessian(loss, layer.bias) assert Hb_autograd.shape == (dd[1], dd[1]) u.check_close(Hb, Hb_autograd) assert len(Bh_t) == loss_hessian.num_samples == o for xi in range(n): loss = loss_fn(model(data[xi:xi + 1, ...])) H_autograd = u.hessian(loss, layer.weight) u.check_close(Hi[xi], H_autograd.reshape(H.shape)) Hb_autograd = u.hessian(loss, layer.bias) u.check_close(Hb_i[xi], Hb_autograd) assert Hb_i[xi, 0, 0] == Oh*Ow # each output has curvature 1, bias term adds up Oh*Ow of them
def test_conv_grad(): """Test per-example gradient computation for conv layer. """ u.seed_random(1) N, Xc, Xh, Xw = 3, 2, 3, 7 dd = [Xc, 2] Kh, Kw = 2, 3 Oh, Ow = Xh - Kh + 1, Xw - Kw + 1 model = u.SimpleConvolutional(dd, kernel_size=(Kh, Kw), bias=True).double() weight_buffer = model.layers[0].weight.data # output channels, input channels, height, width assert weight_buffer.shape == (dd[1], dd[0], Kh, Kw) input_dims = N, Xc, Xh, Xw size = int(np.prod(input_dims)) X = torch.arange(0, size).reshape(*input_dims).double() def loss_fn(data): err = data.reshape(len(data), -1) return torch.sum(err * err) / 2 / len(data) layer = model.layers[0] output = model(X) loss = loss_fn(output) loss.backward() u.check_equal(layer.activations, X) assert layer.backprops_list[0].shape == layer.output.shape assert layer.output.shape == (N, dd[1], Oh, Ow) out_unf = layer.weight.view(layer.weight.size(0), -1) @ unfold(layer.activations, (Kh, Kw)) assert out_unf.shape == (N, dd[1], Oh * Ow) reshaped_bias = layer.bias.reshape(1, dd[1], 1) # (Co,) -> (1, Co, 1) out_unf = out_unf + reshaped_bias u.check_equal(fold(out_unf, (Oh, Ow), (1, 1)), output) # two alternative ways of reshaping u.check_equal(out_unf.view(N, dd[1], Oh, Ow), output) # Unfold produces patches with output dimension merged, while in backprop they are not merged # Hence merge the output (width/height) dimension assert unfold(layer.activations, (Kh, Kw)).shape == (N, Xc * Kh * Kw, Oh * Ow) assert layer.backprops_list[0].shape == (N, dd[1], Oh, Ow) grads_bias = layer.backprops_list[0].sum(dim=(2, 3)) * N mean_grad_bias = grads_bias.sum(dim=0) / N u.check_equal(mean_grad_bias, layer.bias.grad) Bt = layer.backprops_list[0] * N # remove factor of N applied during loss batch averaging assert Bt.shape == (N, dd[1], Oh, Ow) Bt = Bt.reshape(N, dd[1], Oh*Ow) At = unfold(layer.activations, (Kh, Kw)) assert At.shape == (N, dd[0] * Kh * Kw, Oh*Ow) grad_unf = torch.einsum('ijk,ilk->ijl', Bt, At) assert grad_unf.shape == (N, dd[1], dd[0] * Kh * Kw) grads = grad_unf.reshape((N, dd[1], dd[0], Kh, Kw)) u.check_equal(grads.mean(dim=0), layer.weight.grad) # compute per-example gradients using autograd, compare against manual computation for i in range(N): u.clear_backprops(model) output = model(X[i:i + 1, ...]) loss = loss_fn(output) loss.backward() u.check_equal(grads[i], layer.weight.grad) u.check_equal(grads_bias[i], layer.bias.grad)
def test_hessian(): """Tests of Hessian computation.""" u.seed_random(1) batch_size = 500 data_width = 4 targets_width = 4 d1 = data_width ** 2 d2 = 10 d3 = targets_width ** 2 o = d3 N = batch_size d = [d1, d2, d3] dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) data, targets = next(train_iter) def loss_fn(data, targets): assert len(data) == len(targets) err = data - targets.view(-1, data.shape[1]) return torch.sum(err * err) / 2 / len(data) u.seed_random(1) model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True) # backprop hessian and compare against autograd hessian_backprop = u.HessianExactSqrLoss() output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) i, layer = next(enumerate(model.layers)) A_t = layer.activations Bh_t = layer.backprops_list H, Hb = u.hessian_from_backprops(A_t, Bh_t, bias=True) model.disable_hooks() H_autograd = u.hessian(loss_fn(model(data), targets), layer.weight) u.check_close(H, H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i]), rtol=1e-4, atol=1e-7) Hb_autograd = u.hessian(loss_fn(model(data), targets), layer.bias) u.check_close(Hb, Hb_autograd, rtol=1e-4, atol=1e-7) # check first few per-example Hessians Hi, Hb_i = u.per_example_hess(A_t, Bh_t, bias=True) u.check_close(H, Hi.mean(dim=0)) u.check_close(Hb, Hb_i.mean(dim=0), atol=2e-6, rtol=1e-5) for xi in range(5): loss = loss_fn(model(data[xi:xi + 1, ...]), targets[xi:xi + 1]) H_autograd = u.hessian(loss, layer.weight) u.check_close(Hi[xi], H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i])) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hb_i[i], Hbias_autograd) # get subsampled Hessian u.seed_random(1) model = u.SimpleFullyConnected(d, nonlin=False) hessian_backprop = u.HessianSampledSqrLoss(num_samples=1) output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) model.disable_hooks() i, layer = next(enumerate(model.layers)) H_approx1 = u.hessian_from_backprops(layer.activations, layer.backprops_list) # get subsampled Hessian with more samples u.seed_random(1) model = u.SimpleFullyConnected(d, nonlin=False) hessian_backprop = u.HessianSampledSqrLoss(num_samples=o) output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) model.disable_hooks() i, layer = next(enumerate(model.layers)) H_approx2 = u.hessian_from_backprops(layer.activations, layer.backprops_list) assert abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) < 0.08, abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) # 0.0612 assert abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) < 0.03, abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) # 0.0239 assert u.kl_div_cov(H_approx1, H) < 0.3, u.kl_div_cov(H_approx1, H) # 0.222 assert u.kl_div_cov(H_approx2, H) < 0.2, u.kl_div_cov(H_approx2, H) # 0.1233