def test_full_hessian_xent_mnist_multilayer(): """Test regular and diagonal hessian computation.""" u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 6, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() hess = defaultdict(float) hess_diag = defaultdict(float) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) u.check_close(hess[model.layers[0]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[0]] / batch_size) H_autograd = u.hessian(loss, model.layers[1].weight) u.check_close(hess[model.layers[1]] / batch_size, H_autograd) diag_autograd = torch.einsum('lili->li', H_autograd) u.check_close(diag_autograd, hess_diag[model.layers[1]] / batch_size)
def test_kron_mnist(): u.seed_random(1) data_width = 3 batch_size = 3 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 # torch.set_default_dtype(torch.float64) model: u.SimpleModel2 = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.add_hooks(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) # get gradient values u.clear_backprops(model) autograd_lib.enable_hooks() output = model(data) autograd_lib.backprop_hess(output, hess_type='CrossEntropy') i = 0 layer = model.layers[i] autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation H2 = layer.weight.hess_factored H2_bias = layer.bias.hess_factored H2, H2_bias = u.expand_hess(H2, H2_bias) # autograd Hessian computation loss = loss_fn(output, targets) # TODO: change to d[i+1]*d[i] H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1]) H_bias_autograd = u.hessian(loss, layer.bias) # compare direct against autograd u.check_close(H, H_autograd) u.check_close(H_bias, H_bias_autograd) approx_error = u.symsqrt_dist(H, H2) assert approx_error < 1e-2, approx_error
def test_kron_nano(): u.seed_random(1) d = [1, 2] n = 1 # torch.set_default_dtype(torch.float32) loss_type = 'CrossEntropy' model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: loss_fn = nn.CrossEntropyLoss() data = torch.randn(n, d[0]) data = torch.ones(n, d[0]) if loss_type.endswith('LeastSquares'): target = torch.randn(n, d[-1]) elif loss_type == 'CrossEntropy': target = torch.LongTensor(n).random_(0, d[-1]) target = torch.tensor([0]) # Hessian computation, saves regular and Kronecker factored versions into .hess and .hess_factored attributes autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() for layer in model.layers: Hk: u.Kron = layer.weight.hess_factored Hk_bias: u.Kron = layer.bias.hess_factored Hk, Hk_bias = u.expand_hess(Hk, Hk_bias) # kronecker multiply the factors # old approach, using direct computation H2, H_bias2 = layer.weight.hess, layer.bias.hess # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, target) H_autograd = u.hessian(loss, layer.weight) H_bias_autograd = u.hessian(loss, layer.bias) # compare autograd with direct approach u.check_close(H2, H_autograd.reshape(Hk.shape)) u.check_close(H_bias2, H_bias_autograd) # compare factored with direct approach assert(u.symsqrt_dist(Hk, H2) < 1e-6)
def test_diagonal_hessian(): u.seed_random(1) A, model = create_toy_model() activations = {} def save_activations(layer, a, _): if layer != model.layers[0]: return activations[layer] = a with autograd_lib.module_hook(save_activations): Y = model(A.t()) loss = torch.sum(Y * Y) / 2 hess = [0] def compute_hess(layer, _, B): if layer != model.layers[0]: return A = activations[layer] hess[0] += torch.einsum("ni,nj->ij", B * B, A * A).reshape(-1) with autograd_lib.module_hook(compute_hess): autograd_lib.backprop_identity(Y, retain_graph=True) # check against autograd hess0 = u.hessian(loss, model.layers[0].weight).reshape([4, 4]) u.check_equal(hess[0], torch.diag(hess0)) # check against manual solution u.check_equal(hess[0], [425., 225., 680., 360.])
def test_full_hessian_multibatch(): A, model = create_toy_model() data = A.t() data = data.repeat(3, 1) n = float(len(data)) activations = {} hess = defaultdict(float) def save_activations(layer, a, _): activations[layer] = a def compute_hessian(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) for x in data: with autograd_lib.module_hook(save_activations): y = model(x) loss = torch.sum(y * y) / 2 with autograd_lib.module_hook(compute_hessian): autograd_lib.backprop_identity(y) result = hess[model.layers[0]] # check result against autograd loss = u.least_squares(model(data), aggregation='sum') hess0 = u.hessian(loss, model.layers[0].weight) u.check_equal(hess0, result)
def test_kfac_hessian(): A, model = create_toy_model() data = A.t() data = data.repeat(7, 1) n = float(len(data)) activations = {} hess = defaultdict(lambda: AttrDefault(float)) def save_activations(layer, a, _): activations[layer] = a def compute_hessian(layer, _, B): A = activations[layer] hess[layer].AA += torch.einsum("ni,nj->ij", A, A) hess[layer].BB += torch.einsum("ni,nj->ij", B, B) for x in data: with autograd_lib.module_hook(save_activations): y = model(x) o = y.shape[1] loss = torch.sum(y * y) / 2 with autograd_lib.module_hook(compute_hessian): autograd_lib.backprop_identity(y) hess0 = hess[model.layers[0]] result = u.kron(hess0.BB / n, hess0.AA / o) # check result against autograd loss = u.least_squares(model(data), aggregation='sum') hess0 = u.hessian(loss, model.layers[0].weight).reshape(4, 4) u.check_equal(hess0, result)
def test_cross_entropy_hessian_mnist(): u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() loss_hessian = u.HessianExactCrossEntropyLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) # get gradient values u.clear_backprops(model) model.skip_forward_hooks = False model.skip_backward_hooks = False output = model(data) for bval in loss_hessian(output): output.backward(bval, retain_graph=True) i = 0 layer = model.layers[i] H, Hbias = u.hessian_from_backprops(layer.activations, layer.backprops_list, bias=True) model.skip_forward_hooks = True model.skip_backward_hooks = True # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hbias, Hbias_autograd)
def _test_kfac_hessian_xent_mnist(): u.seed_random(1) data_width = 3 batch_size = 2 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.register(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() activations = {} hess = defaultdict(lambda: AttrDefault(float)) for train_step in range(train_steps): data, targets = next(train_iter) activations = {} def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): output = model(data) loss = loss_fn(output, targets) def compute_hess(layer, _, B): A = activations[layer] hess[layer].AA += torch.einsum("ni,nj->ij", A, A) hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(output, loss='CrossEntropy', retain_graph=True) hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # compute Hessian through autograd H_autograd = u.hessian(loss, model.layers[0].weight) rel_error = torch.norm( (hess0 - H_autograd).flatten()) / torch.norm(H_autograd.flatten()) assert rel_error < 0.01 # 0.0057
def test_full_hessian_xent_kfac2(): """Test with uneven layers.""" u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [3, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=False) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.2, 0.1]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(lambda: AttrDefault(float)) for i in range(n): def save_activations(layer, A, _): activations[layer] = A hess[layer].AA += torch.einsum("ni,nj->ij", A, A) with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) o = Y.shape[1] loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): hess[layer].BB += torch.einsum("ni,nj->ij", B, B) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # expand hess_factored = hess[model.layers[0]] hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n, hess_factored.AA / o) # hess for sum loss hess0 /= n # hess for mean loss # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) u.check_equal(hess_autograd, hess0)
def test_cross_entropy_hessian_tiny(): u.seed_random(1) batch_size = 1 d = [2, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True) model.layers[0].weight.data.copy_(torch.eye(2)) loss_fn = torch.nn.CrossEntropyLoss() loss_hessian = u.HessianExactCrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.3]])) targets = torch.tensor([0]) # get gradient values u.clear_backprops(model) model.skip_forward_hooks = False model.skip_backward_hooks = False output = model(data) for bval in loss_hessian(output): output.backward(bval, retain_graph=True) i = 0 layer = model.layers[i] H, Hbias = u.hessian_from_backprops(layer.activations, layer.backprops_list, bias=True) model.skip_forward_hooks = True model.skip_backward_hooks = True # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) u.check_close(H, H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1])) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hbias, Hbias_autograd)
def test_full_hessian_xent_multibatch(): u.seed_random(1) torch.set_default_dtype(torch.float64) batch_size = 1 d = [2, 2] o = d[-1] n = batch_size train_steps = 1 model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=True) model.layers[0].weight.data.copy_(torch.eye(2)) autograd_lib.register(model) loss_fn = torch.nn.CrossEntropyLoss() data = u.to_logits(torch.tensor([[0.7, 0.3]])) targets = torch.tensor([0]) data = data.repeat([3, 1]) targets = targets.repeat([3]) n = len(data) activations = {} hess = defaultdict(float) def save_activations(layer, a, _): activations[layer] = a for i in range(n): with autograd_lib.module_hook(save_activations): data_batch = data[i:i + 1] targets_batch = targets[i:i + 1] Y = model(data_batch) loss = loss_fn(Y, targets_batch) def compute_hess(layer, _, B): A = activations[layer] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) with autograd_lib.module_hook(compute_hess): autograd_lib.backward_hessian(Y, loss='CrossEntropy') # check against autograd # 0.1459 Y = model(data) loss = loss_fn(Y, targets) hess_autograd = u.hessian(loss, model.layers[0].weight) hess0 = hess[model.layers[0]] / n u.check_equal(hess_autograd, hess0)
def test_autoencoder_newton(): """Use Newton's method to train autoencoder.""" image_size = 3 batch_size = 64 dataset = u.TinyMNIST(data_width=image_size, targets_width=image_size, dataset_size=batch_size) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) d = image_size ** 2 # hidden layer size u.seed_random(1) model: u.SimpleModel = u.SimpleFullyConnected([d, d]) model.disable_hooks() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == batch_size return torch.sum(err * err) / 2 / len(data) for i in range(10): data, targets = next(iter(trainloader)) optimizer.zero_grad() loss = loss_fn(model(data), targets) if i > 0: assert loss < 1e-9 loss.backward() W = model.layers[0].weight grad = u.tvec(W.grad) loss = loss_fn(model(data), targets) H = u.hessian(loss, W) # for col-major: H = H.transpose(0, 1).transpose(2, 3).reshape(d**2, d**2) H = H.reshape(d ** 2, d ** 2) # For col-major: W1 = u.unvec(u.vec(W) - u.pinv(H) @ grad, d) # W1 = u.untvec(u.tvec(W) - grad @ u.pinv(H), d) W1 = u.untvec(u.tvec(W) - grad @ H.pinverse(), d) W.data.copy_(W1)
def fit(self, x, y): """Run Newton's Method to minimize J(theta) for logistic regression. Args: x: Training example inputs. Shape (m, n). y: Training example labels. Shape (m,). """ # *** START CODE HERE *** if self.theta is None: self.theta = np.zeros((x.shape[1], 1)) y = y.reshape((y.shape[0], 1)) error = 1e9 numIters = 0 while error > self.eps and numIters < self.max_iter: hess = util.hessian(x, self.theta) Jprime = util.gradient(x, self.theta, y) hessInv = np.linalg.inv(hess) theta_new = self.theta - hessInv.dot(Jprime) error = np.sum(np.abs(self.theta - theta_new)) self.theta = theta_new.copy() numIters += 1
def subtest_hess_type(hess_type): torch.manual_seed(1) model = TinyNet() def least_squares_loss(data_, targets_): assert len(data_) == len(targets_) err = data_ - targets_ return torch.sum(err * err) / 2 / len(data_) n = 3 data = torch.rand(n, 1, 28, 28) autograd_lib.add_hooks(model) output = model(data) if hess_type == 'LeastSquares': targets = torch.rand(output.shape) loss_fn = least_squares_loss else: # hess_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, 10) loss_fn = nn.CrossEntropyLoss() # Dummy backprop to make sure multiple backprops don't invalidate each other autograd_lib.backprop_hess(output, hess_type=hess_type) autograd_lib.clear_hess_backprops(model) autograd_lib.backprop_hess(output, hess_type=hess_type) autograd_lib.compute_hess(model) autograd_lib.disable_hooks() for layer in model.modules(): if not autograd_lib.is_supported(layer): continue for param in layer.parameters(): loss = loss_fn(output, targets) hess_autograd = u.hessian(loss, param) hess = param.hess assert torch.allclose(hess, hess_autograd.reshape(hess.shape))
def test_full_hessian(): u.seed_random(1) A, model = create_toy_model() data = A.t() # data = data.repeat(3, 1) activations = {} hess = defaultdict(float) def save_activations(layer, a, _): activations[layer] = a with autograd_lib.module_hook(save_activations): Y = model(A.t()) loss = torch.sum(Y * Y) / 2 def compute_hess(layer, _, B): A = activations[layer] n = A.shape[0] di = A.shape[1] do = B.shape[1] BA = torch.einsum("nl,ni->nli", B, A) hess[layer] += torch.einsum('nli,nkj->likj', BA, BA) with autograd_lib.module_hook(compute_hess): autograd_lib.backprop_identity(Y, retain_graph=True) # check against autograd hess_autograd = u.hessian(loss, model.layers[0].weight) hess0 = hess[model.layers[0]] u.check_equal(hess_autograd, hess0) # check against manual solution u.check_equal(hess0.reshape(4, 4), [[425, -75, 170, -30], [-75, 225, -30, 90], [170, -30, 680, -120], [-30, 90, -120, 360]])
def test_cross_entropy_soft(): torch.set_default_dtype(torch.float32) q = torch.tensor([0.4, 0.6]).unsqueeze(0).float() p = torch.tensor([0.7, 0.3]).unsqueeze(0).float() observed_logit = u.to_logits(p) # Compare against other loss functions # https://www.wolframcloud.com/obj/user-eac9ee2d-7714-42da-8f84-bec1603944d5/newton/logistic-hessian.nb loss1 = F.binary_cross_entropy(p[0], q[0]) u.check_close(loss1, 0.865054) loss_fn = u.CrossEntropySoft() loss2 = loss_fn(observed_logit, q) u.check_close(loss2, loss1) loss3 = F.cross_entropy(observed_logit, torch.tensor([0])) u.check_close(loss3, loss_fn(observed_logit, torch.tensor([[1, 0.]]))) # check gradient observed_logit.requires_grad = True grad = torch.autograd.grad(loss_fn(observed_logit, target=q), observed_logit) u.check_close(p - q, grad[0]) # check Hessian observed_logit = u.to_logits(p) observed_logit.zero_() observed_logit.requires_grad = True hessian_autograd = u.hessian(loss_fn(observed_logit, target=q), observed_logit) hessian_autograd = hessian_autograd.reshape((p.numel(), p.numel())) p = F.softmax(observed_logit, dim=1) hessian_manual = torch.diag(p[0]) - p.t() @ p u.check_close(hessian_autograd, hessian_manual)
def test_explicit_hessian(): """Check computation of hessian of loss(B'WA) from https://github.com/yaroslavvb/kfac_pytorch/blob/master/derivation.pdf """ torch.set_default_dtype(torch.float64) A = torch.tensor([[-1., 4], [3, 0]]) B = torch.tensor([[-4., 3], [2, 6]]) X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True) Y = B.t() @ X @ A u.check_equal(Y, [[-52, 64], [-81, -108]]) loss = torch.sum(Y * Y) / 2 hess0 = u.hessian(loss, X).reshape([4, 4]) hess1 = u.Kron(A @ A.t(), B @ B.t()) u.check_equal(loss, 12512.5) # PyTorch autograd computes Hessian with respect to row-vectorized parameters, whereas # autograd_lib uses math convention and does column-vectorized. # Commuting order of Kronecker product switches between two representations u.check_equal(hess1.commute(), hess0) # Do a test using Linear layers instead of matrix multiplies model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False) model.layers[0].weight.data.copy_(X) # Transpose to match previous results, layers treat dim0 as batch dimension u.check_equal(model.layers[0](A.t()).t(), [[5, -20], [-16, -8]]) # XA = (A'X0)' model.layers[1].weight.data.copy_(B.t()) u.check_equal(model(A.t()).t(), Y) Y = model(A.t()).t() # transpose to data-dimension=columns loss = torch.sum(Y * Y) / 2 loss.backward() u.check_equal(model.layers[0].weight.grad, [[-2285, -105], [-1490, -1770]]) G = B @ Y @ A.t() u.check_equal(model.layers[0].weight.grad, G) u.check_equal(hess0, u.Kron(B @ B.t(), A @ A.t())) # compute newton step u.check_equal(u.Kron([email protected](), [email protected]()).pinv() @ u.vec(G), u.v2c([-5, -2, 0, -6])) # compute Newton step using factored representation autograd_lib.add_hooks(model) Y = model(A.t()) n = 2 loss = torch.sum(Y * Y) / 2 autograd_lib.backprop_hess(Y, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron', vecr_order=False, loss_aggregation='sum') param = model.layers[0].weight hess2 = param.hess_kron print(hess2) u.check_equal(hess2, [[425, 170, -75, -30], [170, 680, -30, -120], [-75, -30, 225, 90], [-30, -120, 90, 360]]) # Gradient test model.zero_grad() loss.backward() u.check_close(u.vec(G).flatten(), u.Vec(param.grad)) # Newton step test # Method 0: PyTorch native autograd newton_step0 = param.grad.flatten() @ torch.pinverse(hess0) newton_step0 = newton_step0.reshape(param.shape) u.check_equal(newton_step0, [[-5, 0], [-2, -6]]) # Method 1: colummn major order ihess2 = hess2.pinv() u.check_equal(ihess2.LL, [[1/16, 1/48], [1/48, 17/144]]) u.check_equal(ihess2.RR, [[2/45, -(1/90)], [-(1/90), 1/36]]) u.check_equal(torch.flatten(hess2.pinv() @ u.vec(G)), [-5, -2, 0, -6]) newton_step1 = (ihess2 @ u.Vec(param.grad)).matrix_form() # Method2: row major order ihess2_rowmajor = ihess2.commute() newton_step2 = ihess2_rowmajor @ u.Vecr(param.grad) newton_step2 = newton_step2.matrix_form() u.check_equal(newton_step0, newton_step1) u.check_equal(newton_step0, newton_step2)
def main(): attemp_count = 0 while os.path.exists(f"{args.logdir}{attemp_count:02d}"): attemp_count += 1 logdir = f"{args.logdir}{attemp_count:02d}" run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") u.seed_random(1) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method except Exception as e: print(f"wandb crash with {e}") # data_width = 4 # targets_width = 2 d1 = args.data_width**2 d2 = 10 d3 = args.targets_width**2 o = d3 n = args.stats_batch_size d = [d1, d2, d3] model = u.SimpleFullyConnected(d, nonlin=args.nonlin) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) def capture_activations(module, input, _output): if skip_forward_hooks: return assert gl.backward_idx == 0 # no need to forward-prop on Hessian computation assert not hasattr( module, 'activations' ), "Seeing activations from previous forward, call util.zero_grad to clear" assert len(input) == 1, "this works for single input layers only" setattr(module, "activations", input[0].detach()) def capture_backprops(module: nn.Module, _input, output): if skip_backward_hooks: return assert len(output) == 1, "this works for single variable layers only" if gl.backward_idx == 0: assert not hasattr( module, 'backprops' ), "Seeing results of previous autograd, call util.zero_grad to clear" setattr(module, 'backprops', []) assert gl.backward_idx == len(module.backprops) module.backprops.append(output[0]) def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]: """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient if the value has not been set, call util.zero_grad to clear it.""" def save_grad_fn(grad): if not hasattr(param, 'saved_grad'): setattr(param, 'saved_grad', grad) return save_grad_fn for layer in model.layers: layer.register_forward_hook(capture_activations) layer.register_backward_hook(capture_backprops) layer.weight.register_hook(save_grad(layer.weight)) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == len(targets) return torch.sum(err * err) / 2 / len(data) gl.token_count = 0 for step in range(args.stats_steps): data, targets = next(stats_iter) skip_forward_hooks = False skip_backward_hooks = False # get gradient values gl.backward_idx = 0 u.zero_grad(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) print("loss", loss.item()) # get Hessian values skip_forward_hooks = True id_mat = torch.eye(o) u.log_scalars({'loss': loss.item()}) # o = 0 for out_idx in range(o): model.zero_grad() # backprop to get section of batch output jacobian for output at position out_idx output = model( data ) # opt: using autograd.grad means I don't have to zero_grad ei = id_mat[out_idx] bval = torch.stack([ei] * n) gl.backward_idx = out_idx + 1 output.backward(bval) skip_backward_hooks = True # for (i, layer) in enumerate(model.layers): s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops[0] * n assert B_t.shape == (n, d[i + 1]) G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) # empirical Fisher efisher = G.t() @ G / n sigma = efisher - g.t() @ g # u.dump(sigma, f'/tmp/sigmas/{step}-{i}') s.sigma_l2 = u.l2_norm(sigma) ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)] Amat_t = torch.cat([A_t] * o, dim=0) Bmat_t = torch.cat(Bh_t, dim=0) assert Amat_t.shape == (n * o, d[i]) assert Bmat_t.shape == (n * o, d[i + 1]) Jb = u.khatri_rao_t(Bmat_t, Amat_t) # batch Jacobian, in row-vec format H = Jb.t() @ Jb / n pinvH = u.pinv(H) s.hess_l2 = u.l2_norm(H) s.invhess_l2 = u.l2_norm(pinvH) s.hess_fro = H.flatten().norm() s.invhess_fro = pinvH.flatten().norm() s.jacobian_l2 = u.l2_norm(Jb) s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / dd.flatten().norm()**2) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) s.grad_curv = curv_direction(g) ndir = g @ u.pinv(H) # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', u.pinv(H)) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 999 s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_gradient = loss_direction(g, s.step_openai) u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def test_kron_conv_exact(): """Test per-example gradient computation for conv layer. Kronecker factoring is exact for 1x1 convolutions and linear activations. """ u.seed_random(1) n, Xh, Xw = 2, 2, 2 Kh, Kw = 1, 1 dd = [2, 2, 2] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) #print(model) #print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_factored Hk_bias = layer.bias.hess_factored Hk = Hk.expand() Hk_bias = Hk_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) u.check_close(H_bias, Hk_bias) u.check_close(H, Hk)
def compute_layer_stats(layer): stats = AttrDefault(str, {}) n = stats_batch_size param = u.get_param(layer) d = len(param.flatten()) layer_idx = model.layers.index(layer) assert layer_idx >= 0 assert stats_data.shape[0] == n def backprop_loss(): model.zero_grad() output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) loss.backward() return loss, output def backprop_output(): model.zero_grad() output = model(stats_data) output.backward(gradient=torch.ones_like(output)) return output # per-example gradients, n, d loss, output = backprop_loss() At = layer.data_input Bt = layer.grad_output * n G = u.khatri_rao_t(At, Bt) g = G.sum(dim=0, keepdim=True) / n u.check_close(g, u.vec(param.grad).t()) stats.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2 stats.gradient_norm = g.flatten().norm() stats.parameter_norm = param.data.flatten().norm() pos_activations = torch.sum(layer.data_output > 0) neg_activations = torch.sum(layer.data_output <= 0) stats.sparsity = pos_activations.float() / (pos_activations + neg_activations) output = backprop_output() At2 = layer.data_input u.check_close(At, At2) B2t = layer.grad_output J = u.khatri_rao_t(At, B2t) H = J.t() @ J / n model.zero_grad() output = model(stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) hess = u.hessian(loss, param) hess = hess.transpose(2, 3).transpose(0, 1).reshape(d, d) u.check_close(hess, H) u.check_close(hess, H) stats.hessian_norm = u.l2_norm(H) stats.jacobian_norm = u.l2_norm(J) Joutput = J.sum(dim=0) / n stats.jacobian_sensitivity = Joutput.norm() # newton decrement stats.loss_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) u.check_close(stats.loss_newton, loss) # do line-search to find optimal step def line_search(directionv, start, end, steps=10): """Takes steps between start and end, returns steps+1 loss entries""" param0 = param.data.clone() param0v = u.vec(param0).t() losses = [] for i in range(steps + 1): output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) losses.append(loss) offset = start + i * ((end - start) / steps) param1v = param0v + offset * directionv param1 = u.unvec(param1v.t(), param.data.shape[0]) param.data.copy_(param1) output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) losses.append(loss) param.data.copy_(param0) return losses # try to take a newton step gradv = g line_losses = line_search(-gradv @ u.pinv(H), 0, 2, steps=10) u.check_equal(line_losses[0], loss) u.check_equal(line_losses[6], 0) assert line_losses[5] > line_losses[6] assert line_losses[7] > line_losses[6] return stats
def main(): attemp_count = 0 while os.path.exists(f"{args.logdir}{attemp_count:02d}"): attemp_count += 1 logdir = f"{args.logdir}{attemp_count:02d}" run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") u.seed_random(1) d1 = args.data_width**2 d2 = 10 d3 = args.targets_width**2 o = d3 n = args.stats_batch_size d = [d1, d2, d3] model = u.SimpleFullyConnected(d, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['d1'] = d1 wandb.config['d2'] = d2 wandb.config['d3'] = d3 wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=args.dataset_size, train=False) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=True, drop_last=True) test_iter = u.infinite_iter(test_loader) skip_forward_hooks = False skip_backward_hooks = False def capture_activations(module: nn.Module, input: List[torch.Tensor], output: torch.Tensor): if skip_forward_hooks: return assert not hasattr( module, 'activations' ), "Seeing results of previous autograd, call util.zero_grad to clear" assert len(input) == 1, "this was tested for single input layers only" setattr(module, "activations", input[0].detach()) setattr(module, "output", output.detach()) def capture_backprops(module: nn.Module, _input, output): if skip_backward_hooks: return assert len(output) == 1, "this works for single variable layers only" if gl.backward_idx == 0: assert not hasattr( module, 'backprops' ), "Seeing results of previous autograd, call util.zero_grad to clear" setattr(module, 'backprops', []) assert gl.backward_idx == len(module.backprops) module.backprops.append(output[0]) def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]: """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient if the value has not been set, call util.zero_grad to clear it.""" def save_grad_fn(grad): if not hasattr(param, 'saved_grad'): setattr(param, 'saved_grad', grad) return save_grad_fn for layer in model.layers: layer.register_forward_hook(capture_activations) layer.register_backward_hook(capture_backprops) layer.weight.register_hook(save_grad(layer.weight)) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == len(targets) return torch.sum(err * err) / 2 / len(data) gl.token_count = 0 last_outer = 0 for step in range(args.stats_steps): if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() # compute validation loss skip_forward_hooks = True skip_backward_hooks = True with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats data, targets = next(stats_iter) skip_forward_hooks = False skip_backward_hooks = False # get gradient values with u.timeit("backprop_g"): gl.backward_idx = 0 u.zero_grad(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) # get Hessian values skip_forward_hooks = True id_mat = torch.eye(o).to(gl.device) u.log_scalar(loss=loss.item()) with u.timeit("backprop_H"): # optionally use randomized low-rank approximation of Hessian hess_rank = args.hess_samples if args.hess_samples else o for out_idx in range(hess_rank): model.zero_grad() # backprop to get section of batch output jacobian for output at position out_idx output = model( data ) # opt: using autograd.grad means I don't have to zero_grad if args.hess_samples: bval = torch.LongTensor(n, o).to(gl.device).random_( 0, 2) * 2 - 1 bval = bval.float() else: ei = id_mat[out_idx] bval = torch.stack([ei] * n) gl.backward_idx = out_idx + 1 output.backward(bval) skip_backward_hooks = True # for (i, layer) in enumerate(model.layers): s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma) / s.sigma_l2 ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [ layer.backprops[out_idx + 1] for out_idx in range(hess_rank) ] Amat_t = torch.cat([A_t] * hess_rank, dim=0) Bmat_t = torch.cat(Bh_t, dim=0) assert Amat_t.shape == (n * hess_rank, d[i]) assert Bmat_t.shape == (n * hess_rank, d[i + 1]) lambda_regularizer = args.lmb * torch.eye(d[i] * d[i + 1]).to( gl.device) with u.timeit(f"khatri_H-{i}"): Jb = u.khatri_rao_t( Bmat_t, Amat_t) # batch Jacobian, in row-vec format with u.timeit(f"H-{i}"): H = Jb.t() @ Jb / n with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H + lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.jacobian_fro = Jb.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm()**2)) with u.timeit("pinvH"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = 1 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / u.sym_l2_norm(H) s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_svd(H, sigma) if u.has_nan( p_sigma) and args.compute_rho: # use expensive method H0 = H.cpu().detach().numpy() sigma0 = sigma.cpu().detach().numpy() p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0) p_sigma = torch.tensor(p_sigma).to(gl.device) if u.has_nan(p_sigma): s.psigma_erank = H.shape[0] s.rho = 1 else: s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) print('openai batch', s.batch_openai) s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps last_inner = 0 for i in range(args.train_steps): if last_inner: u.log_scalars( {"time/inner": 1000 * (time.perf_counter() - last_inner)}) last_inner = time.perf_counter() optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def test_conv_hessian(): """Test per-example gradient computation for conv layer.""" u.seed_random(1) n, Xc, Xh, Xw = 3, 2, 3, 7 dd = [Xc, 2] Kh, Kw = 2, 3 Oh, Ow = Xh - Kh + 1, Xw - Kw + 1 model: u.SimpleModel = u.ReshapedConvolutional(dd, kernel_size=(Kh, Kw), bias=True) weight_buffer = model.layers[0].weight.data assert (Kh, Kw) == model.layers[0].kernel_size data = torch.randn((n, Xc, Xh, Xw)) # output channels, input channels, height, width assert weight_buffer.shape == (dd[1], dd[0], Kh, Kw) def loss_fn(data): err = data.reshape(len(data), -1) return torch.sum(err * err) / 2 / len(data) loss_hessian = u.HessianExactSqrLoss() # o = Oh * Ow * dd[1] output = model(data) o = output.shape[1] for bval in loss_hessian(output): output.backward(bval, retain_graph=True) assert loss_hessian.num_samples == o i, layer = next(enumerate(model.layers)) At = unfold(layer.activations, (Kh, Kw)) # -> n, Xc * Kh * Kw, Oh * Ow assert At.shape == (n, dd[0] * Kh * Kw, Oh*Ow) # o, n, dd[1], Oh, Ow -> o, n, dd[1], Oh*Ow Bh_t = torch.stack([Bt.reshape(n, dd[1], Oh*Ow) for Bt in layer.backprops_list]) assert Bh_t.shape == (o, n, dd[1], Oh*Ow) Ah_t = torch.stack([At]*o) assert Ah_t.shape == (o, n, dd[0] * Kh * Kw, Oh*Ow) # sum out the output patch dimension Jb = torch.einsum('onij,onkj->onik', Bh_t, Ah_t) # => o, n, dd[1], dd[0] * Kh * Kw Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb) # => n, dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw Jb_bias = torch.einsum('onij->oni', Bh_t) Hb_i = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias) H = Hi.mean(dim=0) Hb = Hb_i.mean(dim=0) model.disable_hooks() loss = loss_fn(model(data)) H_autograd = u.hessian(loss, layer.weight) assert H_autograd.shape == (dd[1], dd[0], Kh, Kw, dd[1], dd[0], Kh, Kw) assert H.shape == (dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw) u.check_close(H, H_autograd.reshape(H.shape), rtol=1e-4, atol=1e-7) Hb_autograd = u.hessian(loss, layer.bias) assert Hb_autograd.shape == (dd[1], dd[1]) u.check_close(Hb, Hb_autograd) assert len(Bh_t) == loss_hessian.num_samples == o for xi in range(n): loss = loss_fn(model(data[xi:xi + 1, ...])) H_autograd = u.hessian(loss, layer.weight) u.check_close(Hi[xi], H_autograd.reshape(H.shape)) Hb_autograd = u.hessian(loss, layer.bias) u.check_close(Hb_i[xi], Hb_autograd) assert Hb_i[xi, 0, 0] == Oh*Ow # each output has curvature 1, bias term adds up Oh*Ow of them
def test_factored_hessian(): """"Simple test to ensure Hessian computation is working. In a linear neural network with squared loss, Newton step will converge in one step. Compute stats after minimizing, pass sanity checks. """ u.seed_random(1) loss_type = 'LeastSquares' data_width = 2 n = 5 d1 = data_width ** 2 o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=False) model = model.to(gl.device) print(model) dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False) stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 data, targets = stats_data, stats_targets # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) print(loss) loss.backward(retain_graph=True) layer = model.layers[0] autograd_lib.clear_hess_backprops(model) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # compute Hessian using direct method, compare against PyTorch autograd hess0 = u.hessian(loss, layer.weight) autograd_lib.compute_hess(model) hess1 = layer.weight.hess print(hess1) u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-9, rtol=1e-6) # compute Hessian using factored method autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True) # s.regret_newton = vecG.t() @ pinvH.commute() @ vecG.t() / 2 # TODO(y): figure out why needed transposes hess2 = layer.weight.hess2 u.check_close(hess1, hess2, atol=1e-9, rtol=1e-6) # Newton step in regular notation g1 = layer.weight.grad.flatten() newton1 = hess1 @ g1 g2 = u.Vecr(layer.weight.grad) newton2 = g2 @ hess2 u.check_close(newton1, newton2, atol=1e-9, rtol=1e-6) # compute regret in factored notation, compare against actual drop in loss regret1 = g1 @ hess1.pinverse() @ g1 / 2 regret2 = g2 @ hess2.pinv() @ g2 / 2 u.check_close(regret1, regret2) current_weight = layer.weight.detach().clone() param: torch.nn.Parameter = layer.weight # param.data.sub_((hess1.pinverse() @ g1).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 1", loss) # param.data.sub_((hess1.pinverse() @ u.vec(layer.weight.grad)).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 2", loss) # param.data.sub_((u.vec(layer.weight.grad).t() @ hess1.pinverse()).reshape(param.shape)) # output = model(data) # loss = loss_fn(output, targets) # print("result 3", loss) # del layer.weight.grad output = model(data) loss = loss_fn(output, targets) loss.backward() param.data.sub_(u.unvec(hess1.pinverse() @ u.vec(layer.weight.grad), layer.weight.shape[0])) output = model(data) loss = loss_fn(output, targets) print("result 4", loss) # param.data.sub_((g1 @ hess1.pinverse() @ g1).reshape(param.shape)) print(loss)
def test_hessian(): """Tests of Hessian computation.""" u.seed_random(1) batch_size = 500 data_width = 4 targets_width = 4 d1 = data_width ** 2 d2 = 10 d3 = targets_width ** 2 o = d3 N = batch_size d = [d1, d2, d3] dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) data, targets = next(train_iter) def loss_fn(data, targets): assert len(data) == len(targets) err = data - targets.view(-1, data.shape[1]) return torch.sum(err * err) / 2 / len(data) u.seed_random(1) model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True) # backprop hessian and compare against autograd hessian_backprop = u.HessianExactSqrLoss() output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) i, layer = next(enumerate(model.layers)) A_t = layer.activations Bh_t = layer.backprops_list H, Hb = u.hessian_from_backprops(A_t, Bh_t, bias=True) model.disable_hooks() H_autograd = u.hessian(loss_fn(model(data), targets), layer.weight) u.check_close(H, H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i]), rtol=1e-4, atol=1e-7) Hb_autograd = u.hessian(loss_fn(model(data), targets), layer.bias) u.check_close(Hb, Hb_autograd, rtol=1e-4, atol=1e-7) # check first few per-example Hessians Hi, Hb_i = u.per_example_hess(A_t, Bh_t, bias=True) u.check_close(H, Hi.mean(dim=0)) u.check_close(Hb, Hb_i.mean(dim=0), atol=2e-6, rtol=1e-5) for xi in range(5): loss = loss_fn(model(data[xi:xi + 1, ...]), targets[xi:xi + 1]) H_autograd = u.hessian(loss, layer.weight) u.check_close(Hi[xi], H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i])) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(Hb_i[i], Hbias_autograd) # get subsampled Hessian u.seed_random(1) model = u.SimpleFullyConnected(d, nonlin=False) hessian_backprop = u.HessianSampledSqrLoss(num_samples=1) output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) model.disable_hooks() i, layer = next(enumerate(model.layers)) H_approx1 = u.hessian_from_backprops(layer.activations, layer.backprops_list) # get subsampled Hessian with more samples u.seed_random(1) model = u.SimpleFullyConnected(d, nonlin=False) hessian_backprop = u.HessianSampledSqrLoss(num_samples=o) output = model(data) for bval in hessian_backprop(output): output.backward(bval, retain_graph=True) model.disable_hooks() i, layer = next(enumerate(model.layers)) H_approx2 = u.hessian_from_backprops(layer.activations, layer.backprops_list) assert abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) < 0.08, abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) # 0.0612 assert abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) < 0.03, abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) # 0.0239 assert u.kl_div_cov(H_approx1, H) < 0.3, u.kl_div_cov(H_approx1, H) # 0.222 assert u.kl_div_cov(H_approx2, H) < 0.2, u.kl_div_cov(H_approx2, H) # 0.1233
def test_hessian_multibatch(): """Test that Kronecker-factored computations still work when splitting work over batches.""" u.seed_random(1) # torch.set_default_dtype(torch.float64) gl.project_name = 'test' gl.logdir_base = '/tmp/runs' run_name = 'test_hessian_multibatch' u.setup_logdir_and_event_writer(run_name=run_name) loss_type = 'CrossEntropy' data_width = 2 n = 4 d1 = data_width ** 2 o = 10 d = [d1, o] model = u.SimpleFullyConnected2(d, bias=False, nonlin=False) model = model.to(gl.device) dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False) stats_iter = u.infinite_iter(stats_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares else: # loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.reset_global_step() last_outer = 0 stats_iter = u.infinite_iter(stats_loader) stats_data, stats_targets = next(stats_iter) data, targets = stats_data, stats_targets # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) layer = model.layers[0] autograd_lib.clear_hess_backprops(model) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # compute Hessian using direct method, compare against PyTorch autograd hess0 = u.hessian(loss, layer.weight) autograd_lib.compute_hess(model) hess1 = layer.weight.hess u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-8, rtol=1e-6) # compute Hessian using factored method. Because Hessian depends on examples for cross entropy, factoring is not exact, raise tolerance autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True) hess2 = layer.weight.hess2 u.check_close(hess1, hess2, atol=1e-3, rtol=1e-1) # compute Hessian using multibatch # restart iterators dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type) assert n % 2 == 0 stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n//2, shuffle=False) stats_iter = u.infinite_iter(stats_loader) autograd_lib.compute_cov(model, loss_fn, stats_iter, batch_size=n//2, steps=2) cov: autograd_lib.LayerCov = layer.cov hess2: u.Kron = hess2.commute() # get back into AA x BB order u.check_close(cov.H.value(), hess2)
def _test_explicit_hessian_refactored(): """Check computation of hessian of loss(B'WA) from https://github.com/yaroslavvb/kfac_pytorch/blob/master/derivation.pdf """ torch.set_default_dtype(torch.float64) A = torch.tensor([[-1., 4], [3, 0]]) B = torch.tensor([[-4., 3], [2, 6]]) X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True) Y = B.t() @ X @ A u.check_equal(Y, [[-52, 64], [-81, -108]]) loss = torch.sum(Y * Y) / 2 hess0 = u.hessian(loss, X).reshape([4, 4]) hess1 = u.Kron(A @ A.t(), B @ B.t()) u.check_equal(loss, 12512.5) # Do a test using Linear layers instead of matrix multiplies model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False) model.layers[0].weight.data.copy_(X) # Transpose to match previous results, layers treat dim0 as batch dimension u.check_equal(model.layers[0](A.t()).t(), [[5, -20], [-16, -8]]) # XA = (A'X0)' model.layers[1].weight.data.copy_(B.t()) u.check_equal(model(A.t()).t(), Y) Y = model(A.t()).t() # transpose to data-dimension=columns loss = torch.sum(Y * Y) / 2 loss.backward() u.check_equal(model.layers[0].weight.grad, [[-2285, -105], [-1490, -1770]]) G = B @ Y @ A.t() u.check_equal(model.layers[0].weight.grad, G) autograd_lib.register(model) activations_dict = autograd_lib.ModuleDict() # todo(y): make save_activations ctx manager automatically create A with autograd_lib.save_activations(activations_dict): Y = model(A.t()) Acov = autograd_lib.ModuleDict(autograd_lib.SecondOrderCov) for layer, activations in activations_dict.items(): print(layer, activations) Acov[layer].accumulate(activations, activations) autograd_lib.set_default_activations(activations_dict) autograd_lib.set_default_Acov(Acov) B = autograd_lib.ModuleDict(autograd_lib.SymmetricFourthOrderCov) autograd_lib.backward_accum(Y, "identity", B, retain_graph=False) print(B[model.layers[0]]) autograd_lib.backprop_hess(Y, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron', vecr_order=False, loss_aggregation='sum') param = model.layers[0].weight hess2 = param.hess_kron print(hess2) u.check_equal(hess2, [[425, 170, -75, -30], [170, 680, -30, -120], [-75, -30, 225, 90], [-30, -120, 90, 360]]) # Gradient test model.zero_grad() loss.backward() u.check_close(u.vec(G).flatten(), u.Vec(param.grad)) # Newton step test # Method 0: PyTorch native autograd newton_step0 = param.grad.flatten() @ torch.pinverse(hess0) newton_step0 = newton_step0.reshape(param.shape) u.check_equal(newton_step0, [[-5, 0], [-2, -6]]) # Method 1: colummn major order ihess2 = hess2.pinv() u.check_equal(ihess2.LL, [[1/16, 1/48], [1/48, 17/144]]) u.check_equal(ihess2.RR, [[2/45, -(1/90)], [-(1/90), 1/36]]) u.check_equal(torch.flatten(hess2.pinv() @ u.vec(G)), [-5, -2, 0, -6]) newton_step1 = (ihess2 @ u.Vec(param.grad)).matrix_form() # Method2: row major order ihess2_rowmajor = ihess2.commute() newton_step2 = ihess2_rowmajor @ u.Vecr(param.grad) newton_step2 = newton_step2.matrix_form() u.check_equal(newton_step0, newton_step1) u.check_equal(newton_step0, newton_step2)
def test_main_autograd(): u.seed_random(1) log_wandb = False autograd_check = True use_double = False logdir = u.get_unique_logdir('/tmp/autoencoder_test/run') run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) batch_size = 5 try: if log_wandb: wandb.init(project='test-autograd_test', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['batch'] = batch_size except Exception as e: print(f"wandb crash with {e}") data_width = 4 targets_width = 2 d1 = data_width ** 2 d2 = 10 d3 = targets_width ** 2 o = d3 n = batch_size d = [d1, d2, d3] model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True) if use_double: model = model.double() train_steps = 3 dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size * train_steps) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) def loss_fn(data, targets): err = data - targets.view(-1, data.shape[1]) assert len(data) == batch_size return torch.sum(err * err) / 2 / len(data) loss_hessian = u.HessianExactSqrLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) if use_double: data, targets = data.double(), targets.double() # get gradient values model.skip_backward_hooks = False model.skip_forward_hooks = False u.clear_backprops(model) output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) model.skip_forward_hooks = True output = model(data) for bval in loss_hessian(output): if use_double: bval = bval.double() output.backward(bval, retain_graph=True) model.skip_backward_hooks = True for (i, layer) in enumerate(model.layers): ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) # per example gradients G = u.khatri_rao_t(B_t, A_t) assert G.shape == (n, d[i+1] * d[i]) Gbias = B_t assert Gbias.shape == (n, d[i + 1]) # average gradient g = G.sum(dim=0, keepdim=True) / n gb = Gbias.sum(dim=0, keepdim=True) / n assert g.shape == (1, d[i] * d[i + 1]) assert gb.shape == (1, d[i + 1]) if autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) u.check_close(torch.einsum('nj->j', B_t) / n, layer.bias.saved_grad) u.check_close(torch.mean(B_t, dim=0), layer.bias.saved_grad) u.check_close(torch.einsum('ni,nj->ij', B_t, A_t)/n, layer.weight.saved_grad) # empirical Fisher efisher = G.t() @ G / n _sigma = efisher - g.t() @ g ############################# # Hessian stats ############################# A_t = layer.activations Bh_t = [layer.backprops_list[out_idx + 1] for out_idx in range(o)] Amat_t = torch.cat([A_t] * o, dim=0) # todo: can instead replace with a khatri-rao loop Bmat_t = torch.cat(Bh_t, dim=0) Amat_t2 = torch.stack([A_t]*o, dim=0) # o, n, in_dim Bmat_t2 = torch.stack(Bh_t, dim=0) # o, n, out_dim assert Amat_t.shape == (n * o, d[i]) assert Bmat_t.shape == (n * o, d[i + 1]) Jb = u.khatri_rao_t(Bmat_t, Amat_t) # batch output Jacobian H = Jb.t() @ Jb / n Jb2 = torch.einsum('oni,onj->onij', Bmat_t2, Amat_t2) u.check_close(H.reshape(d[i+1], d[i], d[i+1], d[i]), torch.einsum('onij,onkl->ijkl', Jb2, Jb2)/n) Hbias = Bmat_t.t() @ Bmat_t / n u.check_close(Hbias, torch.einsum('ni,nj->ij', Bmat_t, Bmat_t) / n) if autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) Hbias_autograd = u.hessian(loss, layer.bias) u.check_close(H, H_autograd.reshape(d[i+1] * d[i], d[i+1] * d[i])) u.check_close(Hbias, Hbias_autograd)
def main(): u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) gl.event_writer = SummaryWriter(logdir) print(f"Logging to {run_name}") d1 = args.data_width ** 2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width ** 2 d2 = 2 d3 = args.targets_width ** 2 if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size)+1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam(model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i]>50 or d[i+1]>50: print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel() # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma)/s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H+lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2)) with u.timeit(f"pinvH-{i}"): pinvH = u.pinv(H) with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm() # frobenius norm of Newton update s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_svd(H, sigma) if u.has_nan(p_sigma) and args.compute_rho: # use expensive method print('using expensive method') import pdb; pdb.set_trace() H0, sigma0 = u.to_numpys(H, sigma) p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0) p_sigma = torch.tensor(p_sigma).to(gl.device) if u.has_nan(p_sigma): # import pdb; pdb.set_trace() s.psigma_erank = H.shape[0] s.rho = 1 else: s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1-args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_(param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close()
def test_main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument( '--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=1, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/temp/runs/curv_train_tiny/run') parser.add_argument('--train_batch_size', type=int, default=100) parser.add_argument('--stats_batch_size', type=int, default=60000) parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--method', type=str, choices=['gradient', 'newton'], default='gradient', help="descent method, newton or gradient") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument( '--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian' ) parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=1, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection') parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--weight_decay', type=float, default=1e-4) #args = parser.parse_args() args = AttrDict() args.lmb = 1e-3 args.compute_rho = 1 args.weight_decay = 1e-4 args.method = 'gradient' args.logdir = '/tmp' args.data_width = 2 args.targets_width = 2 args.train_batch_size = 10 args.full_batch = False args.skip_stats = False args.autograd_check = False u.seed_random(1) logdir = u.create_local_logdir(args.logdir) run_name = os.path.basename(logdir) #gl.event_writer = SummaryWriter(logdir) gl.event_writer = u.NoOp() # print(f"Logging to {run_name}") # small values for debugging # loss_type = 'LeastSquares' loss_type = 'CrossEntropy' args.wandb = 0 args.stats_steps = 10 args.train_steps = 10 args.stats_batch_size = 10 args.data_width = 2 args.targets_width = 2 args.nonlin = False d1 = args.data_width**2 d2 = 2 d3 = args.targets_width**2 d1 = args.data_width**2 assert args.data_width == args.targets_width o = d1 n = args.stats_batch_size d = [d1, 30, 30, 30, 20, 30, 30, 30, d1] if loss_type == 'CrossEntropy': d3 = 10 o = d3 n = args.stats_batch_size d = [d1, d2, d3] dsize = max(args.train_batch_size, args.stats_batch_size) + 1 model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin) model = model.to(gl.device) try: # os.environ['WANDB_SILENT'] = 'true' if args.wandb: wandb.init(project='curv_train_tiny', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['train_batch'] = args.train_batch_size wandb.config['stats_batch'] = args.stats_batch_size wandb.config['method'] = args.method wandb.config['n'] = n except Exception as e: print(f"wandb crash with {e}") # optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9) optimizer = torch.optim.Adam( model.parameters(), lr=0.03) # make 10x smaller for least-squares loss dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True) train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) train_iter = u.infinite_iter(train_loader) stats_iter = None if not args.full_batch: stats_loader = torch.utils.data.DataLoader( dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True) stats_iter = u.infinite_iter(stats_loader) test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True) test_iter = u.infinite_iter(test_loader) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'CrossEntropy': loss_fn = nn.CrossEntropyLoss() autograd_lib.add_hooks(model) gl.token_count = 0 last_outer = 0 val_losses = [] for step in range(args.stats_steps): if last_outer: u.log_scalars( {"time/outer": 1000 * (time.perf_counter() - last_outer)}) last_outer = time.perf_counter() with u.timeit("val_loss"): test_data, test_targets = next(test_iter) test_output = model(test_data) val_loss = loss_fn(test_output, test_targets) # print("val_loss", val_loss.item()) val_losses.append(val_loss.item()) u.log_scalar(val_loss=val_loss.item()) # compute stats if args.full_batch: data, targets = dataset.data, dataset.targets else: data, targets = next(stats_iter) # Capture Hessian and gradient stats autograd_lib.enable_hooks() autograd_lib.clear_backprops(model) autograd_lib.clear_hess_backprops(model) with u.timeit("backprop_g"): output = model(data) loss = loss_fn(output, targets) loss.backward(retain_graph=True) with u.timeit("backprop_H"): autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.disable_hooks() # TODO(y): use remove_hooks with u.timeit("compute_grad1"): autograd_lib.compute_grad1(model) with u.timeit("compute_hess"): autograd_lib.compute_hess(model) for (i, layer) in enumerate(model.layers): # input/output layers are unreasonably expensive if not using Kronecker factoring if d[i] > 50 or d[i + 1] > 50: print( f'layer {i} is too big ({d[i], d[i + 1]}), skipping stats') continue if args.skip_stats: continue s = AttrDefault(str, {}) # dictionary-like object for layer stats ############################# # Gradient stats ############################# A_t = layer.activations assert A_t.shape == (n, d[i]) # add factor of n because backprop takes loss averaged over batch, while we need per-example loss B_t = layer.backprops_list[0] * n assert B_t.shape == (n, d[i + 1]) with u.timeit(f"khatri_g-{i}"): G = u.khatri_rao_t(B_t, A_t) # batch loss Jacobian assert G.shape == (n, d[i] * d[i + 1]) g = G.sum(dim=0, keepdim=True) / n # average gradient assert g.shape == (1, d[i] * d[i + 1]) u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1) if args.autograd_check: u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad) u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad) s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel( ) # proportion of activations that are zero s.mean_activation = torch.mean(A_t) s.mean_backprop = torch.mean(B_t) # empirical Fisher with u.timeit(f'sigma-{i}'): efisher = G.t() @ G / n sigma = efisher - g.t() @ g s.sigma_l2 = u.sym_l2_norm(sigma) s.sigma_erank = torch.trace(sigma) / s.sigma_l2 lambda_regularizer = args.lmb * torch.eye(d[i + 1] * d[i]).to( gl.device) H = layer.weight.hess with u.timeit(f"invH-{i}"): invH = torch.cholesky_inverse(H + lambda_regularizer) with u.timeit(f"H_l2-{i}"): s.H_l2 = u.sym_l2_norm(H) s.iH_l2 = u.sym_l2_norm(invH) with u.timeit(f"norms-{i}"): s.H_fro = H.flatten().norm() s.iH_fro = invH.flatten().norm() s.grad_fro = g.flatten().norm() s.param_fro = layer.weight.data.flatten().norm() u.nan_check(H) if args.autograd_check: model.zero_grad() output = model(data) loss = loss_fn(output, targets) H_autograd = u.hessian(loss, layer.weight) H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]) u.check_close(H, H_autograd) # u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}') def loss_direction(dd: torch.Tensor, eps): """loss improvement if we take step eps in direction dd""" return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps**2 * dd @ H @ dd.t()) def curv_direction(dd: torch.Tensor): """Curvature in direction dd""" return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm()**2)) with u.timeit(f"pinvH-{i}"): pinvH = H.pinverse() with u.timeit(f'curv-{i}'): s.grad_curv = curv_direction(g) ndir = g @ pinvH # newton direction s.newton_curv = curv_direction(ndir) setattr(layer.weight, 'pre', pinvH) # save Newton preconditioner s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999 s.step_max = 2 / s.H_l2 s.step_min = torch.tensor(2) / torch.trace(H) s.newton_fro = ndir.flatten().norm( ) # frobenius norm of Newton update s.regret_newton = u.to_python_scalar( g @ pinvH @ g.t() / 2) # replace with "quadratic_form" s.regret_gradient = loss_direction(g, s.step_openai) with u.timeit(f'rho-{i}'): p_sigma = u.lyapunov_spectral(H, sigma) discrepancy = torch.max(abs(p_sigma - p_sigma.t()) / p_sigma) s.psigma_erank = u.sym_erank(p_sigma) s.rho = H.shape[0] / s.psigma_erank with u.timeit(f"batch-{i}"): s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t()) s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 / n # Faster approaches for noise variance computation # s.noise_variance = torch.trace(H.inverse() @ sigma) # try: # # this fails with singular sigma # s.noise_variance = torch.trace(torch.solve(sigma, H)[0]) # # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0]) # pass # except RuntimeError as _: s.noise_variance_pinv = torch.trace(pinvH @ sigma) s.H_erank = torch.trace(H) / s.H_l2 s.batch_jain_simple = 1 + s.H_erank s.batch_jain_full = 1 + s.rho * s.H_erank u.log_scalars(u.nest_stats(layer.name, s)) # gradient steps with u.timeit('inner'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() # u.log_scalar(train_loss=loss.item()) if args.method != 'newton': optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) else: for (layer_idx, layer) in enumerate(model.layers): param: torch.nn.Parameter = layer.weight param_data: torch.Tensor = param.data param_data.copy_(param_data - 0.1 * param.grad) if layer_idx != 1: # only update 1 layer with Newton, unstable otherwise continue u.nan_check(layer.weight.pre) u.nan_check(param.grad.flatten()) u.nan_check( u.v2r(param.grad.flatten()) @ layer.weight.pre) param_new_flat = u.v2r(param_data.flatten()) - u.v2r( param.grad.flatten()) @ layer.weight.pre u.nan_check(param_new_flat) param_data.copy_( param_new_flat.reshape(param_data.shape)) gl.token_count += data.shape[0] gl.event_writer.close() assert val_losses[0] > 2.4 # 2.4828238487243652 assert val_losses[-1] < 2.25 # 2.20609712600708
def _test_kron_conv_golden(): """Hardcoded error values to detect unexpected numeric changes.""" u.seed_random(1) n, Xh, Xw = 2, 8, 8 Kh, Kw = 2, 2 dd = [3, 3, 3, 3] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) # print(model) # print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron') autograd_lib.compute_hess(model, method='mean_kron', attr_name='hess_mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() errors1 = [] errors2 = [] for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_kron Hk_bias = layer.bias.hess_kron Hk = Hk.expand() Hk_bias = Hk_bias.expand() Hk2 = layer.weight.hess_mean_kron Hk2_bias = layer.bias.hess_mean_kron Hk2 = Hk2.expand() Hk2_bias = Hk2_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) errors1.extend([u.symsqrt_dist(H, Hk), u.symsqrt_dist(H_bias, Hk_bias)]) errors2.extend([u.symsqrt_dist(H, Hk2), u.symsqrt_dist(H_bias, Hk2_bias)]) errors1 = torch.tensor(errors1) errors2 = torch.tensor(errors2) golden_errors1 = torch.tensor([0.09458080679178238, 0.0, 0.13416489958763123, 0.0, 0.0003909761435352266, 0.0]) golden_errors2 = torch.tensor([0.0945773795247078, 0.0, 0.13418318331241608, 0.0, 4.478318658129865e-07, 0.0]) u.check_close(golden_errors1, errors1) u.check_close(golden_errors2, errors2)