Пример #1
0
def main():
    # for kernel_size=1, mean kron factoring works for any image size
    main_vals = AttrDict(n=2,
                         kernel_size=1,
                         image_size=625,
                         num_channels=5,
                         num_layers=4,
                         loss='CrossEntropy',
                         nonlin=False)

    hess_list1 = compute_hess(method='exact', **main_vals)
    hess_list2 = compute_hess(method='kron', **main_vals)
    value_error = max(
        [u.symsqrt_dist(h1, h2) for h1, h2 in zip(hess_list1, hess_list2)])
    magnitude_error = max([
        u.l2_norm(h2) / u.l2_norm(h1)
        for h1, h2 in zip(hess_list1, hess_list2)
    ])
    print(value_error)
    print(magnitude_error)

    dimension_vals = dict(image_size=[2, 3, 4, 5, 6],
                          num_channels=range(2, 12),
                          kernel_size=[1, 2, 3, 4, 5])
    for method in ['mean_kron', 'kron']:  # , 'experimental_kfac']:
        print()
        print('=' * 10, method, '=' * 40)
        for dimension in ['image_size', 'num_channels', 'kernel_size']:
            value_errors = []
            magnitude_errors = []
            for val in dimension_vals[dimension]:
                vals = AttrDict(main_vals.copy())
                vals.method = method
                vals[dimension] = val
                vals.image_size = max(vals.image_size,
                                      vals.kernel_size**vals.num_layers)
                # print(vals)
                vals_exact = AttrDict(vals.copy())
                vals_exact.method = 'exact'
                hess_list1 = compute_hess(**vals_exact)
                hess_list2 = compute_hess(**vals)
                magnitude_error = max([
                    u.l2_norm(h2) / u.l2_norm(h1)
                    for h1, h2 in zip(hess_list1, hess_list2)
                ])
                hess_list1 = [h / u.l2_norm(h) for h in hess_list1]
                hess_list2 = [h / u.l2_norm(h) for h in hess_list2]

                value_error = max([
                    u.symsqrt_dist(h1, h2)
                    for h1, h2 in zip(hess_list1, hess_list2)
                ])
                value_errors.append(value_error)
                magnitude_errors.append(magnitude_error.item())
            print(dimension)
            print('   value: :', value_errors)
            print('   magnitude :', u.format_list(magnitude_errors))
Пример #2
0
def test_kron_mnist():
    u.seed_random(1)

    data_width = 3
    batch_size = 3
    d = [data_width**2, 10]
    o = d[-1]
    n = batch_size
    train_steps = 1

    # torch.set_default_dtype(torch.float64)

    model: u.SimpleModel2 = u.SimpleFullyConnected2(d, nonlin=False, bias=True)
    autograd_lib.add_hooks(model)

    dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()

    gl.token_count = 0
    for train_step in range(train_steps):
        data, targets = next(train_iter)

        # get gradient values
        u.clear_backprops(model)
        autograd_lib.enable_hooks()
        output = model(data)
        autograd_lib.backprop_hess(output, hess_type='CrossEntropy')

        i = 0
        layer = model.layers[i]
        autograd_lib.compute_hess(model, method='kron')
        autograd_lib.compute_hess(model)
        autograd_lib.disable_hooks()

        # direct Hessian computation
        H = layer.weight.hess
        H_bias = layer.bias.hess

        # factored Hessian computation
        H2 = layer.weight.hess_factored
        H2_bias = layer.bias.hess_factored
        H2, H2_bias = u.expand_hess(H2, H2_bias)

        # autograd Hessian computation
        loss = loss_fn(output, targets) # TODO: change to d[i+1]*d[i]
        H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1])
        H_bias_autograd = u.hessian(loss, layer.bias)

        # compare direct against autograd
        u.check_close(H, H_autograd)
        u.check_close(H_bias, H_bias_autograd)

        approx_error = u.symsqrt_dist(H, H2)
        assert approx_error < 1e-2, approx_error
Пример #3
0
def test_kron_1x2_conv():
    """Minimal example of a 1x2 convolution whose Hessian/grad covariance doesn't factor as Kronecker.

    Two convolutional layers stacked on top of each other, followed by least squares loss.

        Outputs:
        0 tensor([[[[0., 1., 1., 1.]]]])
        1 tensor([[[[2., 3.]]]])
        2 tensor([[[[8.]]]])

        Activations/backprops:
        layerA 0 tensor([[[[0., 1.],
                           [1., 1.]]]])
        layerB 0 tensor([[[[1., 2.]]]])

        layerA 1 tensor([[[[2.],
                           [3.]]]])
        layerB 1 tensor([[[[1.]]]])

        layer 0 discrepancy: 0.6597963571548462
        layer 1 discrepancy: 0.0

     """
    u.seed_random(1)

    n, Xh, Xw = 1, 1, 4
    Kh, Kw = 1, 2
    dd = [1, 1, 1]
    o = dd[-1]

    model: u.SimpleModel = u.StridedConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True)
    data = torch.tensor([0, 1., 1, 1]).reshape((n, dd[0], Xh, Xw))

    model.layers[0].bias.data.zero_()
    model.layers[0].weight.data.copy_(torch.tensor([1, 2]))

    model.layers[1].bias.data.zero_()
    model.layers[1].weight.data.copy_(torch.tensor([1, 2]))

    sample_output = model(data)

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type='LeastSquares')
    autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron')
    autograd_lib.compute_hess(model, method='exact')
    autograd_lib.disable_hooks()

    for i in range(len(model.layers)):
        layer = model.layers[i]
        H = layer.weight.hess
        Hk = layer.weight.hess_kron
        Hk = Hk.expand()
        print(u.symsqrt_dist(H, Hk))
Пример #4
0
def test_kron_nano():
    u.seed_random(1)

    d = [1, 2]
    n = 1
    # torch.set_default_dtype(torch.float32)

    loss_type = 'CrossEntropy'
    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'DebugLeastSquares':
        loss_fn = u.debug_least_squares
    else:
        loss_fn = nn.CrossEntropyLoss()

    data = torch.randn(n, d[0])
    data = torch.ones(n, d[0])
    if loss_type.endswith('LeastSquares'):
        target = torch.randn(n, d[-1])
    elif loss_type == 'CrossEntropy':
        target = torch.LongTensor(n).random_(0, d[-1])
        target = torch.tensor([0])

    # Hessian computation, saves regular and Kronecker factored versions into .hess and .hess_factored attributes
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.compute_hess(model, method='kron')
    autograd_lib.compute_hess(model)
    autograd_lib.disable_hooks()

    for layer in model.layers:
        Hk: u.Kron = layer.weight.hess_factored
        Hk_bias: u.Kron = layer.bias.hess_factored
        Hk, Hk_bias = u.expand_hess(Hk, Hk_bias)   # kronecker multiply the factors

        # old approach, using direct computation
        H2, H_bias2 = layer.weight.hess, layer.bias.hess

        # compute Hessian through autograd
        model.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        H_autograd = u.hessian(loss, layer.weight)
        H_bias_autograd = u.hessian(loss, layer.bias)

        # compare autograd with direct approach
        u.check_close(H2, H_autograd.reshape(Hk.shape))
        u.check_close(H_bias2, H_bias_autograd)

        # compare factored with direct approach
        assert(u.symsqrt_dist(Hk, H2) < 1e-6)
Пример #5
0
def _test_kron_conv_golden():
    """Hardcoded error values to detect unexpected numeric changes."""
    u.seed_random(1)

    n, Xh, Xw = 2, 8, 8
    Kh, Kw = 2, 2
    dd = [3, 3, 3, 3]
    o = dd[-1]

    model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True)
    data = torch.randn((n, dd[0], Xh, Xw))

    # print(model)
    # print(data)

    loss_type = 'CrossEntropy'    #  loss_type = 'LeastSquares'
    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'DebugLeastSquares':
        loss_fn = u.debug_least_squares
    else:    # CrossEntropy
        loss_fn = nn.CrossEntropyLoss()

    sample_output = model(data)

    if loss_type.endswith('LeastSquares'):
        targets = torch.randn(sample_output.shape)
    elif loss_type == 'CrossEntropy':
        targets = torch.LongTensor(n).random_(0, o)

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron')
    autograd_lib.compute_hess(model, method='mean_kron', attr_name='hess_mean_kron')
    autograd_lib.compute_hess(model, method='exact')
    autograd_lib.disable_hooks()

    errors1 = []
    errors2 = []
    for i in range(len(model.layers)):
        layer = model.layers[i]

        # direct Hessian computation
        H = layer.weight.hess
        H_bias = layer.bias.hess

        # factored Hessian computation
        Hk = layer.weight.hess_kron
        Hk_bias = layer.bias.hess_kron
        Hk = Hk.expand()
        Hk_bias = Hk_bias.expand()

        Hk2 = layer.weight.hess_mean_kron
        Hk2_bias = layer.bias.hess_mean_kron
        Hk2 = Hk2.expand()
        Hk2_bias = Hk2_bias.expand()

        # autograd Hessian computation
        loss = loss_fn(output, targets)
        Ha = u.hessian(loss, layer.weight).reshape(H.shape)
        Ha_bias = u.hessian(loss, layer.bias)

        # compare direct against autograd
        Ha = Ha.reshape(H.shape)
        # rel_error = torch.max((H-Ha)/Ha)

        u.check_close(H, Ha, rtol=1e-5, atol=1e-7)
        u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7)

        errors1.extend([u.symsqrt_dist(H, Hk), u.symsqrt_dist(H_bias, Hk_bias)])
        errors2.extend([u.symsqrt_dist(H, Hk2), u.symsqrt_dist(H_bias, Hk2_bias)])

    errors1 = torch.tensor(errors1)
    errors2 = torch.tensor(errors2)
    golden_errors1 = torch.tensor([0.09458080679178238, 0.0, 0.13416489958763123, 0.0, 0.0003909761435352266, 0.0])
    golden_errors2 = torch.tensor([0.0945773795247078, 0.0, 0.13418318331241608, 0.0, 4.478318658129865e-07, 0.0])

    u.check_close(golden_errors1, errors1)
    u.check_close(golden_errors2, errors2)