def main(): # for kernel_size=1, mean kron factoring works for any image size main_vals = AttrDict(n=2, kernel_size=1, image_size=625, num_channels=5, num_layers=4, loss='CrossEntropy', nonlin=False) hess_list1 = compute_hess(method='exact', **main_vals) hess_list2 = compute_hess(method='kron', **main_vals) value_error = max( [u.symsqrt_dist(h1, h2) for h1, h2 in zip(hess_list1, hess_list2)]) magnitude_error = max([ u.l2_norm(h2) / u.l2_norm(h1) for h1, h2 in zip(hess_list1, hess_list2) ]) print(value_error) print(magnitude_error) dimension_vals = dict(image_size=[2, 3, 4, 5, 6], num_channels=range(2, 12), kernel_size=[1, 2, 3, 4, 5]) for method in ['mean_kron', 'kron']: # , 'experimental_kfac']: print() print('=' * 10, method, '=' * 40) for dimension in ['image_size', 'num_channels', 'kernel_size']: value_errors = [] magnitude_errors = [] for val in dimension_vals[dimension]: vals = AttrDict(main_vals.copy()) vals.method = method vals[dimension] = val vals.image_size = max(vals.image_size, vals.kernel_size**vals.num_layers) # print(vals) vals_exact = AttrDict(vals.copy()) vals_exact.method = 'exact' hess_list1 = compute_hess(**vals_exact) hess_list2 = compute_hess(**vals) magnitude_error = max([ u.l2_norm(h2) / u.l2_norm(h1) for h1, h2 in zip(hess_list1, hess_list2) ]) hess_list1 = [h / u.l2_norm(h) for h in hess_list1] hess_list2 = [h / u.l2_norm(h) for h in hess_list2] value_error = max([ u.symsqrt_dist(h1, h2) for h1, h2 in zip(hess_list1, hess_list2) ]) value_errors.append(value_error) magnitude_errors.append(magnitude_error.item()) print(dimension) print(' value: :', value_errors) print(' magnitude :', u.format_list(magnitude_errors))
def test_kron_mnist(): u.seed_random(1) data_width = 3 batch_size = 3 d = [data_width**2, 10] o = d[-1] n = batch_size train_steps = 1 # torch.set_default_dtype(torch.float64) model: u.SimpleModel2 = u.SimpleFullyConnected2(d, nonlin=False, bias=True) autograd_lib.add_hooks(model) dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True) trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) train_iter = iter(trainloader) loss_fn = torch.nn.CrossEntropyLoss() gl.token_count = 0 for train_step in range(train_steps): data, targets = next(train_iter) # get gradient values u.clear_backprops(model) autograd_lib.enable_hooks() output = model(data) autograd_lib.backprop_hess(output, hess_type='CrossEntropy') i = 0 layer = model.layers[i] autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation H2 = layer.weight.hess_factored H2_bias = layer.bias.hess_factored H2, H2_bias = u.expand_hess(H2, H2_bias) # autograd Hessian computation loss = loss_fn(output, targets) # TODO: change to d[i+1]*d[i] H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1]) H_bias_autograd = u.hessian(loss, layer.bias) # compare direct against autograd u.check_close(H, H_autograd) u.check_close(H_bias, H_bias_autograd) approx_error = u.symsqrt_dist(H, H2) assert approx_error < 1e-2, approx_error
def test_kron_1x2_conv(): """Minimal example of a 1x2 convolution whose Hessian/grad covariance doesn't factor as Kronecker. Two convolutional layers stacked on top of each other, followed by least squares loss. Outputs: 0 tensor([[[[0., 1., 1., 1.]]]]) 1 tensor([[[[2., 3.]]]]) 2 tensor([[[[8.]]]]) Activations/backprops: layerA 0 tensor([[[[0., 1.], [1., 1.]]]]) layerB 0 tensor([[[[1., 2.]]]]) layerA 1 tensor([[[[2.], [3.]]]]) layerB 1 tensor([[[[1.]]]]) layer 0 discrepancy: 0.6597963571548462 layer 1 discrepancy: 0.0 """ u.seed_random(1) n, Xh, Xw = 1, 1, 4 Kh, Kw = 1, 2 dd = [1, 1, 1] o = dd[-1] model: u.SimpleModel = u.StridedConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.tensor([0, 1., 1, 1]).reshape((n, dd[0], Xh, Xw)) model.layers[0].bias.data.zero_() model.layers[0].weight.data.copy_(torch.tensor([1, 2])) model.layers[1].bias.data.zero_() model.layers[1].weight.data.copy_(torch.tensor([1, 2])) sample_output = model(data) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type='LeastSquares') autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() for i in range(len(model.layers)): layer = model.layers[i] H = layer.weight.hess Hk = layer.weight.hess_kron Hk = Hk.expand() print(u.symsqrt_dist(H, Hk))
def test_kron_nano(): u.seed_random(1) d = [1, 2] n = 1 # torch.set_default_dtype(torch.float32) loss_type = 'CrossEntropy' model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True) if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: loss_fn = nn.CrossEntropyLoss() data = torch.randn(n, d[0]) data = torch.ones(n, d[0]) if loss_type.endswith('LeastSquares'): target = torch.randn(n, d[-1]) elif loss_type == 'CrossEntropy': target = torch.LongTensor(n).random_(0, d[-1]) target = torch.tensor([0]) # Hessian computation, saves regular and Kronecker factored versions into .hess and .hess_factored attributes autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron') autograd_lib.compute_hess(model) autograd_lib.disable_hooks() for layer in model.layers: Hk: u.Kron = layer.weight.hess_factored Hk_bias: u.Kron = layer.bias.hess_factored Hk, Hk_bias = u.expand_hess(Hk, Hk_bias) # kronecker multiply the factors # old approach, using direct computation H2, H_bias2 = layer.weight.hess, layer.bias.hess # compute Hessian through autograd model.zero_grad() output = model(data) loss = loss_fn(output, target) H_autograd = u.hessian(loss, layer.weight) H_bias_autograd = u.hessian(loss, layer.bias) # compare autograd with direct approach u.check_close(H2, H_autograd.reshape(Hk.shape)) u.check_close(H_bias2, H_bias_autograd) # compare factored with direct approach assert(u.symsqrt_dist(Hk, H2) < 1e-6)
def _test_kron_conv_golden(): """Hardcoded error values to detect unexpected numeric changes.""" u.seed_random(1) n, Xh, Xw = 2, 8, 8 Kh, Kw = 2, 2 dd = [3, 3, 3, 3] o = dd[-1] model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True) data = torch.randn((n, dd[0], Xh, Xw)) # print(model) # print(data) loss_type = 'CrossEntropy' # loss_type = 'LeastSquares' if loss_type == 'LeastSquares': loss_fn = u.least_squares elif loss_type == 'DebugLeastSquares': loss_fn = u.debug_least_squares else: # CrossEntropy loss_fn = nn.CrossEntropyLoss() sample_output = model(data) if loss_type.endswith('LeastSquares'): targets = torch.randn(sample_output.shape) elif loss_type == 'CrossEntropy': targets = torch.LongTensor(n).random_(0, o) autograd_lib.clear_backprops(model) autograd_lib.add_hooks(model) output = model(data) autograd_lib.backprop_hess(output, hess_type=loss_type) autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron') autograd_lib.compute_hess(model, method='mean_kron', attr_name='hess_mean_kron') autograd_lib.compute_hess(model, method='exact') autograd_lib.disable_hooks() errors1 = [] errors2 = [] for i in range(len(model.layers)): layer = model.layers[i] # direct Hessian computation H = layer.weight.hess H_bias = layer.bias.hess # factored Hessian computation Hk = layer.weight.hess_kron Hk_bias = layer.bias.hess_kron Hk = Hk.expand() Hk_bias = Hk_bias.expand() Hk2 = layer.weight.hess_mean_kron Hk2_bias = layer.bias.hess_mean_kron Hk2 = Hk2.expand() Hk2_bias = Hk2_bias.expand() # autograd Hessian computation loss = loss_fn(output, targets) Ha = u.hessian(loss, layer.weight).reshape(H.shape) Ha_bias = u.hessian(loss, layer.bias) # compare direct against autograd Ha = Ha.reshape(H.shape) # rel_error = torch.max((H-Ha)/Ha) u.check_close(H, Ha, rtol=1e-5, atol=1e-7) u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7) errors1.extend([u.symsqrt_dist(H, Hk), u.symsqrt_dist(H_bias, Hk_bias)]) errors2.extend([u.symsqrt_dist(H, Hk2), u.symsqrt_dist(H_bias, Hk2_bias)]) errors1 = torch.tensor(errors1) errors2 = torch.tensor(errors2) golden_errors1 = torch.tensor([0.09458080679178238, 0.0, 0.13416489958763123, 0.0, 0.0003909761435352266, 0.0]) golden_errors2 = torch.tensor([0.0945773795247078, 0.0, 0.13418318331241608, 0.0, 4.478318658129865e-07, 0.0]) u.check_close(golden_errors1, errors1) u.check_close(golden_errors2, errors2)