示例#1
0
def test_full_hessian_xent_mnist_multilayer():
    """Test regular and diagonal hessian computation."""
    u.seed_random(1)

    data_width = 3
    batch_size = 2
    d = [data_width**2, 6, 10]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True)
    autograd_lib.register(model)
    dataset = u.TinyMNIST(dataset_size=batch_size,
                          data_width=data_width,
                          original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()

    hess = defaultdict(float)
    hess_diag = defaultdict(float)
    for train_step in range(train_steps):
        data, targets = next(train_iter)

        activations = {}

        def save_activations(layer, a, _):
            activations[layer] = a

        with autograd_lib.module_hook(save_activations):
            output = model(data)
            loss = loss_fn(output, targets)

        def compute_hess(layer, _, B):
            A = activations[layer]
            BA = torch.einsum("nl,ni->nli", B, A)
            hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)
            hess_diag[layer] += torch.einsum("ni,nj->ij", B * B, A * A)

        with autograd_lib.module_hook(compute_hess):
            autograd_lib.backward_hessian(output,
                                          loss='CrossEntropy',
                                          retain_graph=True)

        # compute Hessian through autograd
        H_autograd = u.hessian(loss, model.layers[0].weight)
        u.check_close(hess[model.layers[0]] / batch_size, H_autograd)
        diag_autograd = torch.einsum('lili->li', H_autograd)
        u.check_close(diag_autograd, hess_diag[model.layers[0]] / batch_size)

        H_autograd = u.hessian(loss, model.layers[1].weight)
        u.check_close(hess[model.layers[1]] / batch_size, H_autograd)
        diag_autograd = torch.einsum('lili->li', H_autograd)
        u.check_close(diag_autograd, hess_diag[model.layers[1]] / batch_size)
示例#2
0
def test_kron_mnist():
    u.seed_random(1)

    data_width = 3
    batch_size = 3
    d = [data_width**2, 10]
    o = d[-1]
    n = batch_size
    train_steps = 1

    # torch.set_default_dtype(torch.float64)

    model: u.SimpleModel2 = u.SimpleFullyConnected2(d, nonlin=False, bias=True)
    autograd_lib.add_hooks(model)

    dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()

    gl.token_count = 0
    for train_step in range(train_steps):
        data, targets = next(train_iter)

        # get gradient values
        u.clear_backprops(model)
        autograd_lib.enable_hooks()
        output = model(data)
        autograd_lib.backprop_hess(output, hess_type='CrossEntropy')

        i = 0
        layer = model.layers[i]
        autograd_lib.compute_hess(model, method='kron')
        autograd_lib.compute_hess(model)
        autograd_lib.disable_hooks()

        # direct Hessian computation
        H = layer.weight.hess
        H_bias = layer.bias.hess

        # factored Hessian computation
        H2 = layer.weight.hess_factored
        H2_bias = layer.bias.hess_factored
        H2, H2_bias = u.expand_hess(H2, H2_bias)

        # autograd Hessian computation
        loss = loss_fn(output, targets) # TODO: change to d[i+1]*d[i]
        H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1])
        H_bias_autograd = u.hessian(loss, layer.bias)

        # compare direct against autograd
        u.check_close(H, H_autograd)
        u.check_close(H_bias, H_bias_autograd)

        approx_error = u.symsqrt_dist(H, H2)
        assert approx_error < 1e-2, approx_error
示例#3
0
def test_kron_nano():
    u.seed_random(1)

    d = [1, 2]
    n = 1
    # torch.set_default_dtype(torch.float32)

    loss_type = 'CrossEntropy'
    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'DebugLeastSquares':
        loss_fn = u.debug_least_squares
    else:
        loss_fn = nn.CrossEntropyLoss()

    data = torch.randn(n, d[0])
    data = torch.ones(n, d[0])
    if loss_type.endswith('LeastSquares'):
        target = torch.randn(n, d[-1])
    elif loss_type == 'CrossEntropy':
        target = torch.LongTensor(n).random_(0, d[-1])
        target = torch.tensor([0])

    # Hessian computation, saves regular and Kronecker factored versions into .hess and .hess_factored attributes
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.compute_hess(model, method='kron')
    autograd_lib.compute_hess(model)
    autograd_lib.disable_hooks()

    for layer in model.layers:
        Hk: u.Kron = layer.weight.hess_factored
        Hk_bias: u.Kron = layer.bias.hess_factored
        Hk, Hk_bias = u.expand_hess(Hk, Hk_bias)   # kronecker multiply the factors

        # old approach, using direct computation
        H2, H_bias2 = layer.weight.hess, layer.bias.hess

        # compute Hessian through autograd
        model.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        H_autograd = u.hessian(loss, layer.weight)
        H_bias_autograd = u.hessian(loss, layer.bias)

        # compare autograd with direct approach
        u.check_close(H2, H_autograd.reshape(Hk.shape))
        u.check_close(H_bias2, H_bias_autograd)

        # compare factored with direct approach
        assert(u.symsqrt_dist(Hk, H2) < 1e-6)
示例#4
0
def test_diagonal_hessian():
    u.seed_random(1)
    A, model = create_toy_model()

    activations = {}

    def save_activations(layer, a, _):
        if layer != model.layers[0]:
            return
        activations[layer] = a

    with autograd_lib.module_hook(save_activations):
        Y = model(A.t())
        loss = torch.sum(Y * Y) / 2

    hess = [0]

    def compute_hess(layer, _, B):
        if layer != model.layers[0]:
            return
        A = activations[layer]
        hess[0] += torch.einsum("ni,nj->ij", B * B, A * A).reshape(-1)

    with autograd_lib.module_hook(compute_hess):
        autograd_lib.backprop_identity(Y, retain_graph=True)

    # check against autograd
    hess0 = u.hessian(loss, model.layers[0].weight).reshape([4, 4])
    u.check_equal(hess[0], torch.diag(hess0))

    # check against manual solution
    u.check_equal(hess[0], [425., 225., 680., 360.])
示例#5
0
def test_full_hessian_multibatch():
    A, model = create_toy_model()
    data = A.t()
    data = data.repeat(3, 1)
    n = float(len(data))

    activations = {}
    hess = defaultdict(float)

    def save_activations(layer, a, _):
        activations[layer] = a

    def compute_hessian(layer, _, B):
        A = activations[layer]
        BA = torch.einsum("nl,ni->nli", B, A)
        hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)

    for x in data:
        with autograd_lib.module_hook(save_activations):
            y = model(x)
            loss = torch.sum(y * y) / 2

        with autograd_lib.module_hook(compute_hessian):
            autograd_lib.backprop_identity(y)

    result = hess[model.layers[0]]

    # check result against autograd
    loss = u.least_squares(model(data), aggregation='sum')
    hess0 = u.hessian(loss, model.layers[0].weight)
    u.check_equal(hess0, result)
示例#6
0
def test_kfac_hessian():
    A, model = create_toy_model()
    data = A.t()
    data = data.repeat(7, 1)
    n = float(len(data))

    activations = {}
    hess = defaultdict(lambda: AttrDefault(float))

    def save_activations(layer, a, _):
        activations[layer] = a

    def compute_hessian(layer, _, B):
        A = activations[layer]
        hess[layer].AA += torch.einsum("ni,nj->ij", A, A)
        hess[layer].BB += torch.einsum("ni,nj->ij", B, B)

    for x in data:
        with autograd_lib.module_hook(save_activations):
            y = model(x)
            o = y.shape[1]
            loss = torch.sum(y * y) / 2

        with autograd_lib.module_hook(compute_hessian):
            autograd_lib.backprop_identity(y)

    hess0 = hess[model.layers[0]]
    result = u.kron(hess0.BB / n, hess0.AA / o)

    # check result against autograd
    loss = u.least_squares(model(data), aggregation='sum')
    hess0 = u.hessian(loss, model.layers[0].weight).reshape(4, 4)
    u.check_equal(hess0, result)
示例#7
0
def test_cross_entropy_hessian_mnist():
    u.seed_random(1)

    data_width = 3
    batch_size = 2
    d = [data_width**2, 10]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True)

    dataset = u.TinyMNIST(dataset_size=batch_size, data_width=data_width, original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()
    loss_hessian = u.HessianExactCrossEntropyLoss()

    gl.token_count = 0
    for train_step in range(train_steps):
        data, targets = next(train_iter)

        # get gradient values
        u.clear_backprops(model)
        model.skip_forward_hooks = False
        model.skip_backward_hooks = False
        output = model(data)
        for bval in loss_hessian(output):
            output.backward(bval, retain_graph=True)
        i = 0
        layer = model.layers[i]
        H, Hbias = u.hessian_from_backprops(layer.activations,
                                            layer.backprops_list,
                                            bias=True)
        model.skip_forward_hooks = True
        model.skip_backward_hooks = True

        # compute Hessian through autograd
        model.zero_grad()
        output = model(data)
        loss = loss_fn(output, targets)
        H_autograd = u.hessian(loss, layer.weight).reshape(d[i] * d[i + 1], d[i] * d[i + 1])
        u.check_close(H, H_autograd)

        Hbias_autograd = u.hessian(loss, layer.bias)
        u.check_close(Hbias, Hbias_autograd)
示例#8
0
def _test_kfac_hessian_xent_mnist():
    u.seed_random(1)

    data_width = 3
    batch_size = 2
    d = [data_width**2, 10]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True)
    autograd_lib.register(model)
    dataset = u.TinyMNIST(dataset_size=batch_size,
                          data_width=data_width,
                          original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()

    activations = {}
    hess = defaultdict(lambda: AttrDefault(float))
    for train_step in range(train_steps):
        data, targets = next(train_iter)

        activations = {}

        def save_activations(layer, a, _):
            activations[layer] = a

        with autograd_lib.module_hook(save_activations):
            output = model(data)
            loss = loss_fn(output, targets)

        def compute_hess(layer, _, B):
            A = activations[layer]
            hess[layer].AA += torch.einsum("ni,nj->ij", A, A)
            hess[layer].BB += torch.einsum("ni,nj->ij", B, B)

        with autograd_lib.module_hook(compute_hess):
            autograd_lib.backward_hessian(output,
                                          loss='CrossEntropy',
                                          retain_graph=True)

        hess_factored = hess[model.layers[0]]
        hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n,
                             hess_factored.AA / o)  # hess for sum loss
        hess0 /= n  # hess for mean loss

        # compute Hessian through autograd
        H_autograd = u.hessian(loss, model.layers[0].weight)
        rel_error = torch.norm(
            (hess0 - H_autograd).flatten()) / torch.norm(H_autograd.flatten())
        assert rel_error < 0.01  # 0.0057
示例#9
0
def test_full_hessian_xent_kfac2():
    """Test with uneven layers."""
    u.seed_random(1)
    torch.set_default_dtype(torch.float64)

    batch_size = 1
    d = [3, 2]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=False)
    autograd_lib.register(model)
    loss_fn = torch.nn.CrossEntropyLoss()

    data = u.to_logits(torch.tensor([[0.7, 0.2, 0.1]]))
    targets = torch.tensor([0])

    data = data.repeat([3, 1])
    targets = targets.repeat([3])
    n = len(data)

    activations = {}
    hess = defaultdict(lambda: AttrDefault(float))

    for i in range(n):

        def save_activations(layer, A, _):
            activations[layer] = A
            hess[layer].AA += torch.einsum("ni,nj->ij", A, A)

        with autograd_lib.module_hook(save_activations):
            data_batch = data[i:i + 1]
            targets_batch = targets[i:i + 1]
            Y = model(data_batch)
            o = Y.shape[1]
            loss = loss_fn(Y, targets_batch)

        def compute_hess(layer, _, B):
            hess[layer].BB += torch.einsum("ni,nj->ij", B, B)

        with autograd_lib.module_hook(compute_hess):
            autograd_lib.backward_hessian(Y, loss='CrossEntropy')

    # expand
    hess_factored = hess[model.layers[0]]
    hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n,
                         hess_factored.AA / o)  # hess for sum loss
    hess0 /= n  # hess for mean loss

    # check against autograd
    # 0.1459
    Y = model(data)
    loss = loss_fn(Y, targets)
    hess_autograd = u.hessian(loss, model.layers[0].weight)
    u.check_equal(hess_autograd, hess0)
示例#10
0
def test_cross_entropy_hessian_tiny():
    u.seed_random(1)

    batch_size = 1
    d = [2, 2]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True)
    model.layers[0].weight.data.copy_(torch.eye(2))

    loss_fn = torch.nn.CrossEntropyLoss()
    loss_hessian = u.HessianExactCrossEntropyLoss()

    data = u.to_logits(torch.tensor([[0.7, 0.3]]))
    targets = torch.tensor([0])

    # get gradient values
    u.clear_backprops(model)
    model.skip_forward_hooks = False
    model.skip_backward_hooks = False
    output = model(data)

    for bval in loss_hessian(output):
        output.backward(bval, retain_graph=True)
    i = 0
    layer = model.layers[i]
    H, Hbias = u.hessian_from_backprops(layer.activations,
                                        layer.backprops_list,
                                        bias=True)
    model.skip_forward_hooks = True
    model.skip_backward_hooks = True

    # compute Hessian through autograd
    model.zero_grad()
    output = model(data)
    loss = loss_fn(output, targets)
    H_autograd = u.hessian(loss, layer.weight)
    u.check_close(H, H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1]))
    Hbias_autograd = u.hessian(loss, layer.bias)
    u.check_close(Hbias, Hbias_autograd)
示例#11
0
def test_full_hessian_xent_multibatch():
    u.seed_random(1)
    torch.set_default_dtype(torch.float64)

    batch_size = 1
    d = [2, 2]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=True)
    model.layers[0].weight.data.copy_(torch.eye(2))
    autograd_lib.register(model)
    loss_fn = torch.nn.CrossEntropyLoss()

    data = u.to_logits(torch.tensor([[0.7, 0.3]]))
    targets = torch.tensor([0])

    data = data.repeat([3, 1])
    targets = targets.repeat([3])
    n = len(data)

    activations = {}
    hess = defaultdict(float)

    def save_activations(layer, a, _):
        activations[layer] = a

    for i in range(n):
        with autograd_lib.module_hook(save_activations):
            data_batch = data[i:i + 1]
            targets_batch = targets[i:i + 1]
            Y = model(data_batch)
            loss = loss_fn(Y, targets_batch)

        def compute_hess(layer, _, B):
            A = activations[layer]
            BA = torch.einsum("nl,ni->nli", B, A)
            hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)

        with autograd_lib.module_hook(compute_hess):
            autograd_lib.backward_hessian(Y, loss='CrossEntropy')

    # check against autograd
    # 0.1459
    Y = model(data)
    loss = loss_fn(Y, targets)
    hess_autograd = u.hessian(loss, model.layers[0].weight)
    hess0 = hess[model.layers[0]] / n
    u.check_equal(hess_autograd, hess0)
示例#12
0
def test_autoencoder_newton():
    """Use Newton's method to train autoencoder."""

    image_size = 3
    batch_size = 64
    dataset = u.TinyMNIST(data_width=image_size, targets_width=image_size,
                          dataset_size=batch_size)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    d = image_size ** 2  # hidden layer size
    u.seed_random(1)
    model: u.SimpleModel = u.SimpleFullyConnected([d, d])
    model.disable_hooks()

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == batch_size
        return torch.sum(err * err) / 2 / len(data)

    for i in range(10):
        data, targets = next(iter(trainloader))
        optimizer.zero_grad()
        loss = loss_fn(model(data), targets)
        if i > 0:
            assert loss < 1e-9

        loss.backward()
        W = model.layers[0].weight
        grad = u.tvec(W.grad)

        loss = loss_fn(model(data), targets)
        H = u.hessian(loss, W)

        #  for col-major: H = H.transpose(0, 1).transpose(2, 3).reshape(d**2, d**2)
        H = H.reshape(d ** 2, d ** 2)

        #  For col-major: W1 = u.unvec(u.vec(W) - u.pinv(H) @ grad, d)
        # W1 = u.untvec(u.tvec(W) - grad @ u.pinv(H), d)
        W1 = u.untvec(u.tvec(W) - grad @ H.pinverse(), d)
        W.data.copy_(W1)
示例#13
0
    def fit(self, x, y):
        """Run Newton's Method to minimize J(theta) for logistic regression.

        Args:
            x: Training example inputs. Shape (m, n).
            y: Training example labels. Shape (m,).
        """
        # *** START CODE HERE ***
        if self.theta is None:
            self.theta = np.zeros((x.shape[1], 1))
        y = y.reshape((y.shape[0], 1))
        error = 1e9
        numIters = 0
        while error > self.eps and numIters < self.max_iter:
            hess = util.hessian(x, self.theta)
            Jprime = util.gradient(x, self.theta, y)
            hessInv = np.linalg.inv(hess)
            theta_new = self.theta - hessInv.dot(Jprime)
            error = np.sum(np.abs(self.theta - theta_new))
            self.theta = theta_new.copy()
            numIters += 1
示例#14
0
def subtest_hess_type(hess_type):
    torch.manual_seed(1)
    model = TinyNet()

    def least_squares_loss(data_, targets_):
       assert len(data_) == len(targets_)
       err = data_ - targets_
       return torch.sum(err * err) / 2 / len(data_)

    n = 3
    data = torch.rand(n, 1, 28, 28)

    autograd_lib.add_hooks(model)
    output = model(data)

    if hess_type == 'LeastSquares':
        targets = torch.rand(output.shape)
        loss_fn = least_squares_loss
    else:  # hess_type == 'CrossEntropy':
        targets = torch.LongTensor(n).random_(0, 10)
        loss_fn = nn.CrossEntropyLoss()

    # Dummy backprop to make sure multiple backprops don't invalidate each other
    autograd_lib.backprop_hess(output, hess_type=hess_type)
    autograd_lib.clear_hess_backprops(model)

    autograd_lib.backprop_hess(output, hess_type=hess_type)

    autograd_lib.compute_hess(model)
    autograd_lib.disable_hooks()

    for layer in model.modules():
        if not autograd_lib.is_supported(layer):
            continue
        for param in layer.parameters():
            loss = loss_fn(output, targets)
            hess_autograd = u.hessian(loss, param)
            hess = param.hess
            assert torch.allclose(hess, hess_autograd.reshape(hess.shape))
示例#15
0
def test_full_hessian():
    u.seed_random(1)
    A, model = create_toy_model()
    data = A.t()
    #    data = data.repeat(3, 1)
    activations = {}

    hess = defaultdict(float)

    def save_activations(layer, a, _):
        activations[layer] = a

    with autograd_lib.module_hook(save_activations):
        Y = model(A.t())
        loss = torch.sum(Y * Y) / 2

    def compute_hess(layer, _, B):
        A = activations[layer]
        n = A.shape[0]

        di = A.shape[1]
        do = B.shape[1]

        BA = torch.einsum("nl,ni->nli", B, A)
        hess[layer] += torch.einsum('nli,nkj->likj', BA, BA)

    with autograd_lib.module_hook(compute_hess):
        autograd_lib.backprop_identity(Y, retain_graph=True)

    # check against autograd
    hess_autograd = u.hessian(loss, model.layers[0].weight)
    hess0 = hess[model.layers[0]]
    u.check_equal(hess_autograd, hess0)

    # check against manual solution
    u.check_equal(hess0.reshape(4, 4),
                  [[425, -75, 170, -30], [-75, 225, -30, 90],
                   [170, -30, 680, -120], [-30, 90, -120, 360]])
示例#16
0
def test_cross_entropy_soft():
    torch.set_default_dtype(torch.float32)

    q = torch.tensor([0.4, 0.6]).unsqueeze(0).float()
    p = torch.tensor([0.7, 0.3]).unsqueeze(0).float()
    observed_logit = u.to_logits(p)

    # Compare against other loss functions
    # https://www.wolframcloud.com/obj/user-eac9ee2d-7714-42da-8f84-bec1603944d5/newton/logistic-hessian.nb

    loss1 = F.binary_cross_entropy(p[0], q[0])
    u.check_close(loss1, 0.865054)

    loss_fn = u.CrossEntropySoft()
    loss2 = loss_fn(observed_logit, q)
    u.check_close(loss2, loss1)

    loss3 = F.cross_entropy(observed_logit, torch.tensor([0]))
    u.check_close(loss3, loss_fn(observed_logit, torch.tensor([[1, 0.]])))

    # check gradient
    observed_logit.requires_grad = True
    grad = torch.autograd.grad(loss_fn(observed_logit, target=q),
                               observed_logit)
    u.check_close(p - q, grad[0])

    # check Hessian
    observed_logit = u.to_logits(p)
    observed_logit.zero_()
    observed_logit.requires_grad = True
    hessian_autograd = u.hessian(loss_fn(observed_logit, target=q),
                                 observed_logit)
    hessian_autograd = hessian_autograd.reshape((p.numel(), p.numel()))
    p = F.softmax(observed_logit, dim=1)
    hessian_manual = torch.diag(p[0]) - p.t() @ p
    u.check_close(hessian_autograd, hessian_manual)
示例#17
0
def test_explicit_hessian():
    """Check computation of hessian of loss(B'WA) from https://github.com/yaroslavvb/kfac_pytorch/blob/master/derivation.pdf


    """

    torch.set_default_dtype(torch.float64)
    A = torch.tensor([[-1., 4], [3, 0]])
    B = torch.tensor([[-4., 3], [2, 6]])
    X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True)

    Y = B.t() @ X @ A
    u.check_equal(Y, [[-52, 64], [-81, -108]])
    loss = torch.sum(Y * Y) / 2
    hess0 = u.hessian(loss, X).reshape([4, 4])
    hess1 = u.Kron(A @ A.t(), B @ B.t())

    u.check_equal(loss, 12512.5)

    # PyTorch autograd computes Hessian with respect to row-vectorized parameters, whereas
    # autograd_lib uses math convention and does column-vectorized.
    # Commuting order of Kronecker product switches between two representations
    u.check_equal(hess1.commute(), hess0)

    # Do a test using Linear layers instead of matrix multiplies
    model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False)
    model.layers[0].weight.data.copy_(X)

    # Transpose to match previous results, layers treat dim0 as batch dimension
    u.check_equal(model.layers[0](A.t()).t(), [[5, -20], [-16, -8]])  # XA = (A'X0)'

    model.layers[1].weight.data.copy_(B.t())
    u.check_equal(model(A.t()).t(), Y)

    Y = model(A.t()).t()    # transpose to data-dimension=columns
    loss = torch.sum(Y * Y) / 2
    loss.backward()

    u.check_equal(model.layers[0].weight.grad, [[-2285, -105], [-1490, -1770]])
    G = B @ Y @ A.t()
    u.check_equal(model.layers[0].weight.grad, G)

    u.check_equal(hess0, u.Kron(B @ B.t(), A @ A.t()))

    # compute newton step
    u.check_equal(u.Kron([email protected](), [email protected]()).pinv() @ u.vec(G), u.v2c([-5, -2, 0, -6]))

    # compute Newton step using factored representation
    autograd_lib.add_hooks(model)

    Y = model(A.t())
    n = 2
    loss = torch.sum(Y * Y) / 2
    autograd_lib.backprop_hess(Y, hess_type='LeastSquares')
    autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron', vecr_order=False, loss_aggregation='sum')
    param = model.layers[0].weight

    hess2 = param.hess_kron
    print(hess2)

    u.check_equal(hess2, [[425, 170, -75, -30], [170, 680, -30, -120], [-75, -30, 225, 90], [-30, -120, 90, 360]])

    # Gradient test
    model.zero_grad()
    loss.backward()
    u.check_close(u.vec(G).flatten(), u.Vec(param.grad))

    # Newton step test
    # Method 0: PyTorch native autograd
    newton_step0 = param.grad.flatten() @ torch.pinverse(hess0)
    newton_step0 = newton_step0.reshape(param.shape)
    u.check_equal(newton_step0, [[-5, 0], [-2, -6]])

    # Method 1: colummn major order
    ihess2 = hess2.pinv()
    u.check_equal(ihess2.LL, [[1/16, 1/48], [1/48, 17/144]])
    u.check_equal(ihess2.RR, [[2/45, -(1/90)], [-(1/90), 1/36]])
    u.check_equal(torch.flatten(hess2.pinv() @ u.vec(G)), [-5, -2, 0, -6])
    newton_step1 = (ihess2 @ u.Vec(param.grad)).matrix_form()

    # Method2: row major order
    ihess2_rowmajor = ihess2.commute()
    newton_step2 = ihess2_rowmajor @ u.Vecr(param.grad)
    newton_step2 = newton_step2.matrix_form()

    u.check_equal(newton_step0, newton_step1)
    u.check_equal(newton_step0, newton_step2)
示例#18
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method

    except Exception as e:
        print(f"wandb crash with {e}")

    #    data_width = 4
    #    targets_width = 2

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    def capture_activations(module, input, _output):
        if skip_forward_hooks:
            return
        assert gl.backward_idx == 0  # no need to forward-prop on Hessian computation
        assert not hasattr(
            module, 'activations'
        ), "Seeing activations from previous forward, call util.zero_grad to clear"
        assert len(input) == 1, "this works for single input layers only"
        setattr(module, "activations", input[0].detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    for step in range(args.stats_steps):
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        gl.backward_idx = 0
        u.zero_grad(model)
        output = model(data)
        loss = loss_fn(output, targets)
        loss.backward(retain_graph=True)

        print("loss", loss.item())

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o)

        u.log_scalars({'loss': loss.item()})

        # o = 0
        for out_idx in range(o):
            model.zero_grad()
            # backprop to get section of batch output jacobian for output at position out_idx
            output = model(
                data
            )  # opt: using autograd.grad means I don't have to zero_grad
            ei = id_mat[out_idx]
            bval = torch.stack([ei] * n)
            gl.backward_idx = out_idx + 1
            output.backward(bval)
        skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            # empirical Fisher
            efisher = G.t() @ G / n
            sigma = efisher - g.t() @ g
            # u.dump(sigma, f'/tmp/sigmas/{step}-{i}')
            s.sigma_l2 = u.l2_norm(sigma)

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)]
            Amat_t = torch.cat([A_t] * o, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * o, d[i])
            assert Bmat_t.shape == (n * o, d[i + 1])

            Jb = u.khatri_rao_t(Bmat_t,
                                Amat_t)  # batch Jacobian, in row-vec format
            H = Jb.t() @ Jb / n
            pinvH = u.pinv(H)

            s.hess_l2 = u.l2_norm(H)
            s.invhess_l2 = u.l2_norm(pinvH)

            s.hess_fro = H.flatten().norm()
            s.invhess_fro = pinvH.flatten().norm()

            s.jacobian_l2 = u.l2_norm(Jb)
            s.grad_fro = g.flatten().norm()
            s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          dd.flatten().norm()**2)

            s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
            s.grad_curv = curv_direction(g)
            ndir = g @ u.pinv(H)  # newton direction
            s.newton_curv = curv_direction(ndir)
            setattr(layer.weight, 'pre',
                    u.pinv(H))  # save Newton preconditioner
            s.step_openai = 1 / s.grad_curv if s.grad_curv else 999

            s.newton_fro = ndir.flatten().norm(
            )  # frobenius norm of Newton update
            s.regret_gradient = loss_direction(g, s.step_openai)

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        for i in range(args.train_steps):
            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
示例#19
0
def test_kron_conv_exact():
    """Test per-example gradient computation for conv layer.


    Kronecker factoring is exact for 1x1 convolutions and linear activations.

    """
    u.seed_random(1)

    n, Xh, Xw = 2, 2, 2
    Kh, Kw = 1, 1
    dd = [2, 2, 2]
    o = dd[-1]

    model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True)
    data = torch.randn((n, dd[0], Xh, Xw))

    #print(model)
    #print(data)

    loss_type = 'CrossEntropy'    #  loss_type = 'LeastSquares'
    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'DebugLeastSquares':
        loss_fn = u.debug_least_squares
    else:    # CrossEntropy
        loss_fn = nn.CrossEntropyLoss()

    sample_output = model(data)

    if loss_type.endswith('LeastSquares'):
        targets = torch.randn(sample_output.shape)
    elif loss_type == 'CrossEntropy':
        targets = torch.LongTensor(n).random_(0, o)

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.compute_hess(model, method='mean_kron')
    autograd_lib.compute_hess(model, method='exact')
    autograd_lib.disable_hooks()

    for i in range(len(model.layers)):
        layer = model.layers[i]

        # direct Hessian computation
        H = layer.weight.hess
        H_bias = layer.bias.hess

        # factored Hessian computation
        Hk = layer.weight.hess_factored
        Hk_bias = layer.bias.hess_factored
        Hk = Hk.expand()
        Hk_bias = Hk_bias.expand()

        # autograd Hessian computation
        loss = loss_fn(output, targets)
        Ha = u.hessian(loss, layer.weight).reshape(H.shape)
        Ha_bias = u.hessian(loss, layer.bias)

        # compare direct against autograd
        Ha = Ha.reshape(H.shape)
        # rel_error = torch.max((H-Ha)/Ha)

        u.check_close(H, Ha, rtol=1e-5, atol=1e-7)
        u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7)

        u.check_close(H_bias, Hk_bias)
        u.check_close(H, Hk)
    def compute_layer_stats(layer):
        stats = AttrDefault(str, {})
        n = stats_batch_size
        param = u.get_param(layer)
        d = len(param.flatten())
        layer_idx = model.layers.index(layer)
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        loss, output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        stats.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2

        stats.gradient_norm = g.flatten().norm()
        stats.parameter_norm = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        stats.sparsity = pos_activations.float() / (pos_activations +
                                                    neg_activations)

        output = backprop_output()
        At2 = layer.data_input
        u.check_close(At, At2)
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)
        H = J.t() @ J / n

        model.zero_grad()
        output = model(stats_data)  # use last saved data batch for backprop
        loss = compute_loss(output, stats_targets)
        hess = u.hessian(loss, param)

        hess = hess.transpose(2, 3).transpose(0, 1).reshape(d, d)
        u.check_close(hess, H)
        u.check_close(hess, H)

        stats.hessian_norm = u.l2_norm(H)
        stats.jacobian_norm = u.l2_norm(J)
        Joutput = J.sum(dim=0) / n
        stats.jacobian_sensitivity = Joutput.norm()

        # newton decrement
        stats.loss_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
        u.check_close(stats.loss_newton, loss)

        # do line-search to find optimal step
        def line_search(directionv, start, end, steps=10):
            """Takes steps between start and end, returns steps+1 loss entries"""
            param0 = param.data.clone()
            param0v = u.vec(param0).t()
            losses = []
            for i in range(steps + 1):
                output = model(
                    stats_data)  # use last saved data batch for backprop
                loss = compute_loss(output, stats_targets)
                losses.append(loss)
                offset = start + i * ((end - start) / steps)
                param1v = param0v + offset * directionv

                param1 = u.unvec(param1v.t(), param.data.shape[0])
                param.data.copy_(param1)

            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            losses.append(loss)

            param.data.copy_(param0)
            return losses

        # try to take a newton step
        gradv = g
        line_losses = line_search(-gradv @ u.pinv(H), 0, 2, steps=10)
        u.check_equal(line_losses[0], loss)
        u.check_equal(line_losses[6], 0)
        assert line_losses[5] > line_losses[6]
        assert line_losses[7] > line_losses[6]
        return stats
示例#21
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['d1'] = d1
            wandb.config['d2'] = d2
            wandb.config['d3'] = d3
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               dataset_size=args.dataset_size,
                               train=False)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.stats_batch_size,
                                              shuffle=True,
                                              drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    skip_forward_hooks = False
    skip_backward_hooks = False

    def capture_activations(module: nn.Module, input: List[torch.Tensor],
                            output: torch.Tensor):
        if skip_forward_hooks:
            return
        assert not hasattr(
            module, 'activations'
        ), "Seeing results of previous autograd, call util.zero_grad to clear"
        assert len(input) == 1, "this was tested for single input layers only"
        setattr(module, "activations", input[0].detach())
        setattr(module, "output", output.detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    last_outer = 0
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()
        # compute validation loss
        skip_forward_hooks = True
        skip_backward_hooks = True
        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        with u.timeit("backprop_g"):
            gl.backward_idx = 0
            u.zero_grad(model)
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o).to(gl.device)

        u.log_scalar(loss=loss.item())

        with u.timeit("backprop_H"):
            # optionally use randomized low-rank approximation of Hessian
            hess_rank = args.hess_samples if args.hess_samples else o

            for out_idx in range(hess_rank):
                model.zero_grad()
                # backprop to get section of batch output jacobian for output at position out_idx
                output = model(
                    data
                )  # opt: using autograd.grad means I don't have to zero_grad
                if args.hess_samples:
                    bval = torch.LongTensor(n, o).to(gl.device).random_(
                        0, 2) * 2 - 1
                    bval = bval.float()
                else:
                    ei = id_mat[out_idx]
                    bval = torch.stack([ei] * n)
                gl.backward_idx = out_idx + 1
                output.backward(bval)
            skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma) / s.sigma_l2

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [
                layer.backprops[out_idx + 1] for out_idx in range(hess_rank)
            ]
            Amat_t = torch.cat([A_t] * hess_rank, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * hess_rank, d[i])
            assert Bmat_t.shape == (n * hess_rank, d[i + 1])

            lambda_regularizer = args.lmb * torch.eye(d[i] * d[i + 1]).to(
                gl.device)
            with u.timeit(f"khatri_H-{i}"):
                Jb = u.khatri_rao_t(
                    Bmat_t, Amat_t)  # batch Jacobian, in row-vec format

            with u.timeit(f"H-{i}"):
                H = Jb.t() @ Jb / n

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H + lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.jacobian_fro = Jb.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          (dd.flatten().norm()**2))

            with u.timeit("pinvH"):
                pinvH = u.pinv(H)

            with u.timeit(f'curv-{i}'):
                s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre',
                        pinvH)  # save Newton preconditioner
                s.step_openai = 1 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / u.sym_l2_norm(H)
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm(
                )  # frobenius norm of Newton update
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_svd(H, sigma)
                if u.has_nan(
                        p_sigma) and args.compute_rho:  # use expensive method
                    H0 = H.cpu().detach().numpy()
                    sigma0 = sigma.cpu().detach().numpy()
                    p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0)
                    p_sigma = torch.tensor(p_sigma).to(gl.device)

                if u.has_nan(p_sigma):
                    s.psigma_erank = H.shape[0]
                    s.rho = 1
                else:
                    s.psigma_erank = u.sym_erank(p_sigma)
                    s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                print('openai batch', s.batch_openai)
                s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2

                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        last_inner = 0
        for i in range(args.train_steps):
            if last_inner:
                u.log_scalars(
                    {"time/inner": 1000 * (time.perf_counter() - last_inner)})
            last_inner = time.perf_counter()

            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
示例#22
0
def test_conv_hessian():
    """Test per-example gradient computation for conv layer."""
    u.seed_random(1)
    n, Xc, Xh, Xw = 3, 2, 3, 7
    dd = [Xc, 2]

    Kh, Kw = 2, 3
    Oh, Ow = Xh - Kh + 1, Xw - Kw + 1
    model: u.SimpleModel = u.ReshapedConvolutional(dd, kernel_size=(Kh, Kw), bias=True)
    weight_buffer = model.layers[0].weight.data

    assert (Kh, Kw) == model.layers[0].kernel_size

    data = torch.randn((n, Xc, Xh, Xw))

    # output channels, input channels, height, width
    assert weight_buffer.shape == (dd[1], dd[0], Kh, Kw)

    def loss_fn(data):
        err = data.reshape(len(data), -1)
        return torch.sum(err * err) / 2 / len(data)

    loss_hessian = u.HessianExactSqrLoss()
    # o = Oh * Ow * dd[1]

    output = model(data)
    o = output.shape[1]
    for bval in loss_hessian(output):
        output.backward(bval, retain_graph=True)
    assert loss_hessian.num_samples == o

    i, layer = next(enumerate(model.layers))

    At = unfold(layer.activations, (Kh, Kw))    # -> n, Xc * Kh * Kw, Oh * Ow
    assert At.shape == (n, dd[0] * Kh * Kw, Oh*Ow)

    #  o, n, dd[1], Oh, Ow -> o, n, dd[1], Oh*Ow
    Bh_t = torch.stack([Bt.reshape(n, dd[1], Oh*Ow) for Bt in layer.backprops_list])
    assert Bh_t.shape == (o, n, dd[1], Oh*Ow)
    Ah_t = torch.stack([At]*o)
    assert Ah_t.shape == (o, n, dd[0] * Kh * Kw, Oh*Ow)

    # sum out the output patch dimension
    Jb = torch.einsum('onij,onkj->onik', Bh_t, Ah_t)  # => o, n, dd[1], dd[0] * Kh * Kw
    Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb)     # => n, dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw
    Jb_bias = torch.einsum('onij->oni', Bh_t)
    Hb_i = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias)
    H = Hi.mean(dim=0)
    Hb = Hb_i.mean(dim=0)

    model.disable_hooks()
    loss = loss_fn(model(data))
    H_autograd = u.hessian(loss, layer.weight)
    assert H_autograd.shape == (dd[1], dd[0], Kh, Kw, dd[1], dd[0], Kh, Kw)
    assert H.shape == (dd[1], dd[0]*Kh*Kw, dd[1], dd[0]*Kh*Kw)
    u.check_close(H, H_autograd.reshape(H.shape), rtol=1e-4, atol=1e-7)

    Hb_autograd = u.hessian(loss, layer.bias)
    assert Hb_autograd.shape == (dd[1], dd[1])
    u.check_close(Hb, Hb_autograd)

    assert len(Bh_t) == loss_hessian.num_samples == o
    for xi in range(n):
        loss = loss_fn(model(data[xi:xi + 1, ...]))
        H_autograd = u.hessian(loss, layer.weight)
        u.check_close(Hi[xi], H_autograd.reshape(H.shape))
        Hb_autograd = u.hessian(loss, layer.bias)
        u.check_close(Hb_i[xi], Hb_autograd)
        assert Hb_i[xi, 0, 0] == Oh*Ow   # each output has curvature 1, bias term adds up Oh*Ow of them
示例#23
0
def test_factored_hessian():
    """"Simple test to ensure Hessian computation is working.

    In a linear neural network with squared loss, Newton step will converge in one step.
    Compute stats after minimizing, pass sanity checks.
    """

    u.seed_random(1)
    loss_type = 'LeastSquares'

    data_width = 2
    n = 5
    d1 = data_width ** 2
    o = 10
    d = [d1, o]

    model = u.SimpleFullyConnected2(d, bias=False, nonlin=False)
    model = model.to(gl.device)
    print(model)

    dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type)
    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False)
    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:  # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0

    data, targets = stats_data, stats_targets

    # Capture Hessian and gradient stats
    autograd_lib.enable_hooks()
    autograd_lib.clear_backprops(model)

    output = model(data)
    loss = loss_fn(output, targets)
    print(loss)
    loss.backward(retain_graph=True)
    layer = model.layers[0]

    autograd_lib.clear_hess_backprops(model)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.disable_hooks()

    # compute Hessian using direct method, compare against PyTorch autograd
    hess0 = u.hessian(loss, layer.weight)
    autograd_lib.compute_hess(model)
    hess1 = layer.weight.hess
    print(hess1)
    u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-9, rtol=1e-6)

    # compute Hessian using factored method
    autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True)
    # s.regret_newton = vecG.t() @ pinvH.commute() @ vecG.t() / 2  # TODO(y): figure out why needed transposes

    hess2 = layer.weight.hess2
    u.check_close(hess1, hess2, atol=1e-9, rtol=1e-6)

    # Newton step in regular notation
    g1 = layer.weight.grad.flatten()
    newton1 = hess1 @ g1

    g2 = u.Vecr(layer.weight.grad)
    newton2 = g2 @ hess2

    u.check_close(newton1, newton2, atol=1e-9, rtol=1e-6)

    # compute regret in factored notation, compare against actual drop in loss
    regret1 = g1 @ hess1.pinverse() @ g1 / 2
    regret2 = g2 @ hess2.pinv() @ g2 / 2
    u.check_close(regret1, regret2)

    current_weight = layer.weight.detach().clone()
    param: torch.nn.Parameter = layer.weight
    # param.data.sub_((hess1.pinverse() @ g1).reshape(param.shape))
    # output = model(data)
    # loss = loss_fn(output, targets)
    # print("result 1", loss)

    # param.data.sub_((hess1.pinverse() @ u.vec(layer.weight.grad)).reshape(param.shape))
    # output = model(data)
    # loss = loss_fn(output, targets)
    # print("result 2", loss)

    # param.data.sub_((u.vec(layer.weight.grad).t() @ hess1.pinverse()).reshape(param.shape))
    # output = model(data)
    # loss = loss_fn(output, targets)
    # print("result 3", loss)
    #

    del layer.weight.grad
    output = model(data)
    loss = loss_fn(output, targets)
    loss.backward()
    param.data.sub_(u.unvec(hess1.pinverse() @ u.vec(layer.weight.grad), layer.weight.shape[0]))
    output = model(data)
    loss = loss_fn(output, targets)
    print("result 4", loss)

    # param.data.sub_((g1 @ hess1.pinverse() @ g1).reshape(param.shape))

    print(loss)
示例#24
0
def test_hessian():
    """Tests of Hessian computation."""
    u.seed_random(1)
    batch_size = 500

    data_width = 4
    targets_width = 4

    d1 = data_width ** 2
    d2 = 10
    d3 = targets_width ** 2
    o = d3
    N = batch_size
    d = [d1, d2, d3]

    dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    train_iter = iter(trainloader)
    data, targets = next(train_iter)

    def loss_fn(data, targets):
        assert len(data) == len(targets)
        err = data - targets.view(-1, data.shape[1])
        return torch.sum(err * err) / 2 / len(data)

    u.seed_random(1)
    model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True)

    # backprop hessian and compare against autograd
    hessian_backprop = u.HessianExactSqrLoss()
    output = model(data)
    for bval in hessian_backprop(output):
        output.backward(bval, retain_graph=True)

    i, layer = next(enumerate(model.layers))
    A_t = layer.activations
    Bh_t = layer.backprops_list
    H, Hb = u.hessian_from_backprops(A_t, Bh_t, bias=True)

    model.disable_hooks()
    H_autograd = u.hessian(loss_fn(model(data), targets), layer.weight)
    u.check_close(H, H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i]),
                  rtol=1e-4, atol=1e-7)
    Hb_autograd = u.hessian(loss_fn(model(data), targets), layer.bias)
    u.check_close(Hb, Hb_autograd, rtol=1e-4, atol=1e-7)

    # check first few per-example Hessians
    Hi, Hb_i = u.per_example_hess(A_t, Bh_t, bias=True)
    u.check_close(H, Hi.mean(dim=0))
    u.check_close(Hb, Hb_i.mean(dim=0), atol=2e-6, rtol=1e-5)

    for xi in range(5):
        loss = loss_fn(model(data[xi:xi + 1, ...]), targets[xi:xi + 1])
        H_autograd = u.hessian(loss, layer.weight)
        u.check_close(Hi[xi], H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i]))
        Hbias_autograd = u.hessian(loss, layer.bias)
        u.check_close(Hb_i[i], Hbias_autograd)

    # get subsampled Hessian
    u.seed_random(1)
    model = u.SimpleFullyConnected(d, nonlin=False)
    hessian_backprop = u.HessianSampledSqrLoss(num_samples=1)

    output = model(data)
    for bval in hessian_backprop(output):
        output.backward(bval, retain_graph=True)
    model.disable_hooks()
    i, layer = next(enumerate(model.layers))
    H_approx1 = u.hessian_from_backprops(layer.activations, layer.backprops_list)

    # get subsampled Hessian with more samples
    u.seed_random(1)
    model = u.SimpleFullyConnected(d, nonlin=False)

    hessian_backprop = u.HessianSampledSqrLoss(num_samples=o)
    output = model(data)
    for bval in hessian_backprop(output):
        output.backward(bval, retain_graph=True)
    model.disable_hooks()
    i, layer = next(enumerate(model.layers))
    H_approx2 = u.hessian_from_backprops(layer.activations, layer.backprops_list)

    assert abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) < 0.08, abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1)  # 0.0612
    assert abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) < 0.03, abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1)  # 0.0239
    assert u.kl_div_cov(H_approx1, H) < 0.3, u.kl_div_cov(H_approx1, H)  # 0.222
    assert u.kl_div_cov(H_approx2, H) < 0.2, u.kl_div_cov(H_approx2, H)  # 0.1233
示例#25
0
def test_hessian_multibatch():
    """Test that Kronecker-factored computations still work when splitting work over batches."""

    u.seed_random(1)

    # torch.set_default_dtype(torch.float64)

    gl.project_name = 'test'
    gl.logdir_base = '/tmp/runs'
    run_name = 'test_hessian_multibatch'
    u.setup_logdir_and_event_writer(run_name=run_name)

    loss_type = 'CrossEntropy'
    data_width = 2
    n = 4
    d1 = data_width ** 2
    o = 10
    d = [d1, o]

    model = u.SimpleFullyConnected2(d, bias=False, nonlin=False)
    model = model.to(gl.device)

    dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type)
    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False)
    stats_iter = u.infinite_iter(stats_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:  # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0

    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)
    data, targets = stats_data, stats_targets

    # Capture Hessian and gradient stats
    autograd_lib.enable_hooks()
    autograd_lib.clear_backprops(model)

    output = model(data)
    loss = loss_fn(output, targets)
    loss.backward(retain_graph=True)
    layer = model.layers[0]

    autograd_lib.clear_hess_backprops(model)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.disable_hooks()

    # compute Hessian using direct method, compare against PyTorch autograd
    hess0 = u.hessian(loss, layer.weight)
    autograd_lib.compute_hess(model)
    hess1 = layer.weight.hess
    u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-8, rtol=1e-6)

    # compute Hessian using factored method. Because Hessian depends on examples for cross entropy, factoring is not exact, raise tolerance
    autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True)
    hess2 = layer.weight.hess2
    u.check_close(hess1, hess2, atol=1e-3, rtol=1e-1)

    # compute Hessian using multibatch
    # restart iterators
    dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type)
    assert n % 2 == 0
    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n//2, shuffle=False)
    stats_iter = u.infinite_iter(stats_loader)
    autograd_lib.compute_cov(model, loss_fn, stats_iter, batch_size=n//2, steps=2)

    cov: autograd_lib.LayerCov = layer.cov
    hess2: u.Kron = hess2.commute()    # get back into AA x BB order
    u.check_close(cov.H.value(), hess2)
示例#26
0
def _test_explicit_hessian_refactored():

    """Check computation of hessian of loss(B'WA) from https://github.com/yaroslavvb/kfac_pytorch/blob/master/derivation.pdf


    """

    torch.set_default_dtype(torch.float64)
    A = torch.tensor([[-1., 4], [3, 0]])
    B = torch.tensor([[-4., 3], [2, 6]])
    X = torch.tensor([[-5., 0], [-2, -6]], requires_grad=True)

    Y = B.t() @ X @ A
    u.check_equal(Y, [[-52, 64], [-81, -108]])
    loss = torch.sum(Y * Y) / 2
    hess0 = u.hessian(loss, X).reshape([4, 4])
    hess1 = u.Kron(A @ A.t(), B @ B.t())

    u.check_equal(loss, 12512.5)

    # Do a test using Linear layers instead of matrix multiplies
    model: u.SimpleFullyConnected2 = u.SimpleFullyConnected2([2, 2, 2], bias=False)
    model.layers[0].weight.data.copy_(X)

    # Transpose to match previous results, layers treat dim0 as batch dimension
    u.check_equal(model.layers[0](A.t()).t(), [[5, -20], [-16, -8]])  # XA = (A'X0)'

    model.layers[1].weight.data.copy_(B.t())
    u.check_equal(model(A.t()).t(), Y)

    Y = model(A.t()).t()    # transpose to data-dimension=columns
    loss = torch.sum(Y * Y) / 2
    loss.backward()

    u.check_equal(model.layers[0].weight.grad, [[-2285, -105], [-1490, -1770]])
    G = B @ Y @ A.t()
    u.check_equal(model.layers[0].weight.grad, G)

    autograd_lib.register(model)
    activations_dict = autograd_lib.ModuleDict()  # todo(y): make save_activations ctx manager automatically create A
    with autograd_lib.save_activations(activations_dict):
        Y = model(A.t())

    Acov = autograd_lib.ModuleDict(autograd_lib.SecondOrderCov)
    for layer, activations in activations_dict.items():
        print(layer, activations)
        Acov[layer].accumulate(activations, activations)
    autograd_lib.set_default_activations(activations_dict)
    autograd_lib.set_default_Acov(Acov)

    B = autograd_lib.ModuleDict(autograd_lib.SymmetricFourthOrderCov)
    autograd_lib.backward_accum(Y, "identity", B, retain_graph=False)

    print(B[model.layers[0]])

    autograd_lib.backprop_hess(Y, hess_type='LeastSquares')
    autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron', vecr_order=False, loss_aggregation='sum')
    param = model.layers[0].weight

    hess2 = param.hess_kron
    print(hess2)

    u.check_equal(hess2, [[425, 170, -75, -30], [170, 680, -30, -120], [-75, -30, 225, 90], [-30, -120, 90, 360]])

    # Gradient test
    model.zero_grad()
    loss.backward()
    u.check_close(u.vec(G).flatten(), u.Vec(param.grad))

    # Newton step test
    # Method 0: PyTorch native autograd
    newton_step0 = param.grad.flatten() @ torch.pinverse(hess0)
    newton_step0 = newton_step0.reshape(param.shape)
    u.check_equal(newton_step0, [[-5, 0], [-2, -6]])

    # Method 1: colummn major order
    ihess2 = hess2.pinv()
    u.check_equal(ihess2.LL, [[1/16, 1/48], [1/48, 17/144]])
    u.check_equal(ihess2.RR, [[2/45, -(1/90)], [-(1/90), 1/36]])
    u.check_equal(torch.flatten(hess2.pinv() @ u.vec(G)), [-5, -2, 0, -6])
    newton_step1 = (ihess2 @ u.Vec(param.grad)).matrix_form()

    # Method2: row major order
    ihess2_rowmajor = ihess2.commute()
    newton_step2 = ihess2_rowmajor @ u.Vecr(param.grad)
    newton_step2 = newton_step2.matrix_form()

    u.check_equal(newton_step0, newton_step1)
    u.check_equal(newton_step0, newton_step2)
示例#27
0
def test_main_autograd():
    u.seed_random(1)
    log_wandb = False
    autograd_check = True
    use_double = False

    logdir = u.get_unique_logdir('/tmp/autoencoder_test/run')

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)

    batch_size = 5

    try:
        if log_wandb:
            wandb.init(project='test-autograd_test', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['batch'] = batch_size
    except Exception as e:
        print(f"wandb crash with {e}")

    data_width = 4
    targets_width = 2

    d1 = data_width ** 2
    d2 = 10
    d3 = targets_width ** 2
    o = d3
    n = batch_size
    d = [d1, d2, d3]
    model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=True, bias=True)
    if use_double:
        model = model.double()
    train_steps = 3

    dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width,
                          dataset_size=batch_size * train_steps)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    train_iter = iter(trainloader)

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == batch_size
        return torch.sum(err * err) / 2 / len(data)

    loss_hessian = u.HessianExactSqrLoss()

    gl.token_count = 0
    for train_step in range(train_steps):
        data, targets = next(train_iter)
        if use_double:
            data, targets = data.double(), targets.double()

        # get gradient values
        model.skip_backward_hooks = False
        model.skip_forward_hooks = False
        u.clear_backprops(model)
        output = model(data)
        loss = loss_fn(output, targets)
        loss.backward(retain_graph=True)
        model.skip_forward_hooks = True

        output = model(data)
        for bval in loss_hessian(output):
            if use_double:
                bval = bval.double()
            output.backward(bval, retain_graph=True)

        model.skip_backward_hooks = True

        for (i, layer) in enumerate(model.layers):

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            # per example gradients
            G = u.khatri_rao_t(B_t, A_t)
            assert G.shape == (n, d[i+1] * d[i])
            Gbias = B_t
            assert Gbias.shape == (n, d[i + 1])

            # average gradient
            g = G.sum(dim=0, keepdim=True) / n
            gb = Gbias.sum(dim=0, keepdim=True) / n
            assert g.shape == (1, d[i] * d[i + 1])
            assert gb.shape == (1, d[i + 1])

            if autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad)
                u.check_close(torch.einsum('nj->j', B_t) / n, layer.bias.saved_grad)
                u.check_close(torch.mean(B_t, dim=0), layer.bias.saved_grad)
                u.check_close(torch.einsum('ni,nj->ij', B_t, A_t)/n, layer.weight.saved_grad)

            # empirical Fisher
            efisher = G.t() @ G / n
            _sigma = efisher - g.t() @ g

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [layer.backprops_list[out_idx + 1] for out_idx in range(o)]
            Amat_t = torch.cat([A_t] * o, dim=0)  # todo: can instead replace with a khatri-rao loop
            Bmat_t = torch.cat(Bh_t, dim=0)
            Amat_t2 = torch.stack([A_t]*o, dim=0)  # o, n, in_dim
            Bmat_t2 = torch.stack(Bh_t, dim=0)  # o, n, out_dim

            assert Amat_t.shape == (n * o, d[i])
            assert Bmat_t.shape == (n * o, d[i + 1])

            Jb = u.khatri_rao_t(Bmat_t, Amat_t)  # batch output Jacobian
            H = Jb.t() @ Jb / n
            Jb2 = torch.einsum('oni,onj->onij', Bmat_t2, Amat_t2)
            u.check_close(H.reshape(d[i+1], d[i], d[i+1], d[i]), torch.einsum('onij,onkl->ijkl', Jb2, Jb2)/n)

            Hbias = Bmat_t.t() @ Bmat_t / n
            u.check_close(Hbias, torch.einsum('ni,nj->ij', Bmat_t, Bmat_t) / n)

            if autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                Hbias_autograd = u.hessian(loss, layer.bias)
                u.check_close(H, H_autograd.reshape(d[i+1] * d[i], d[i+1] * d[i]))
                u.check_close(Hbias, Hbias_autograd)
示例#28
0
def main():

    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")

    d1 = args.data_width ** 2
    assert args.data_width == args.targets_width
    o = d1
    n = args.stats_batch_size
    d = [d1, 30, 30, 30, 20, 30, 30, 30, d1]

    # small values for debugging
    # loss_type = 'LeastSquares'
    loss_type = 'CrossEntropy'

    args.wandb = 0
    args.stats_steps = 10
    args.train_steps = 10
    args.stats_batch_size = 10
    args.data_width = 2
    args.targets_width = 2
    args.nonlin = False
    d1 = args.data_width ** 2
    d2 = 2
    d3 = args.targets_width ** 2

    if loss_type == 'CrossEntropy':
        d3 = 10
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    dsize = max(args.train_batch_size, args.stats_batch_size)+1

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_iter = None
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.token_count = 0
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)

        for (i, layer) in enumerate(model.layers):

            # input/output layers are unreasonably expensive if not using Kronecker factoring
            if d[i]>50 or d[i+1]>50:
                print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats')
                continue

            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1)

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()  # proportion of activations that are zero
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma)/s.sigma_l2

            lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device)
            H = layer.weight.hess

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H+lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2))

            with u.timeit(f"pinvH-{i}"):
                pinvH = u.pinv(H)

            with u.timeit(f'curv-{i}'):
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre', pinvH)  # save Newton preconditioner
                s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm()  # frobenius norm of Newton update
                s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)   # replace with "quadratic_form"
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_svd(H, sigma)
                if u.has_nan(p_sigma) and args.compute_rho:  # use expensive method
                    print('using expensive method')
                    import pdb; pdb.set_trace()
                    H0, sigma0 = u.to_numpys(H, sigma)
                    p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0)
                    p_sigma = torch.tensor(p_sigma).to(gl.device)

                if u.has_nan(p_sigma):
                    # import pdb; pdb.set_trace()
                    s.psigma_erank = H.shape[0]
                    s.rho = 1
                else:
                    s.psigma_erank = u.sym_erank(p_sigma)
                    s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n

                # Faster approaches for noise variance computation
                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                #            u.log_scalar(train_loss=loss.item())

                if args.method != 'newton':
                    optimizer.step()
                    if args.weight_decay:
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                param.data.mul_(1-args.weight_decay)
                else:
                    for (layer_idx, layer) in enumerate(model.layers):
                        param: torch.nn.Parameter = layer.weight
                        param_data: torch.Tensor = param.data
                        param_data.copy_(param_data - 0.1 * param.grad)
                        if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                            continue
                        u.nan_check(layer.weight.pre)
                        u.nan_check(param.grad.flatten())
                        u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                        param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre
                        u.nan_check(param_new_flat)
                        param_data.copy_(param_new_flat.reshape(param_data.shape))

                gl.token_count += data.shape[0]

    gl.event_writer.close()
示例#29
0
def test_main():

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb',
                        type=int,
                        default=1,
                        help='log to weights and biases')
    parser.add_argument('--autograd_check',
                        type=int,
                        default=0,
                        help='autograd correctness checks')
    parser.add_argument('--logdir',
                        type=str,
                        default='/temp/runs/curv_train_tiny/run')

    parser.add_argument('--train_batch_size', type=int, default=100)
    parser.add_argument('--stats_batch_size', type=int, default=60000)
    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps',
                        type=int,
                        default=100,
                        help="this many train steps between stat collection")
    parser.add_argument('--stats_steps',
                        type=int,
                        default=1000000,
                        help="total number of curvature stats collections")
    parser.add_argument('--nonlin',
                        type=int,
                        default=1,
                        help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--method',
                        type=str,
                        choices=['gradient', 'newton'],
                        default='gradient',
                        help="descent method, newton or gradient")
    parser.add_argument('--layer',
                        type=int,
                        default=-1,
                        help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument('--lmb', type=float, default=1e-3)
    parser.add_argument(
        '--hess_samples',
        type=int,
        default=1,
        help='number of samples when sub-sampling outputs, 0 for exact hessian'
    )
    parser.add_argument('--hess_kfac',
                        type=int,
                        default=0,
                        help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho',
                        type=int,
                        default=1,
                        help='use expensive method to compute rho')
    parser.add_argument('--skip_stats',
                        type=int,
                        default=0,
                        help='skip all stats collection')
    parser.add_argument('--full_batch',
                        type=int,
                        default=0,
                        help='do stats on the whole dataset')
    parser.add_argument('--weight_decay', type=float, default=1e-4)

    #args = parser.parse_args()
    args = AttrDict()
    args.lmb = 1e-3
    args.compute_rho = 1
    args.weight_decay = 1e-4
    args.method = 'gradient'
    args.logdir = '/tmp'
    args.data_width = 2
    args.targets_width = 2
    args.train_batch_size = 10
    args.full_batch = False
    args.skip_stats = False
    args.autograd_check = False

    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    #gl.event_writer = SummaryWriter(logdir)
    gl.event_writer = u.NoOp()
    # print(f"Logging to {run_name}")

    # small values for debugging
    # loss_type = 'LeastSquares'
    loss_type = 'CrossEntropy'

    args.wandb = 0
    args.stats_steps = 10
    args.train_steps = 10
    args.stats_batch_size = 10
    args.data_width = 2
    args.targets_width = 2
    args.nonlin = False
    d1 = args.data_width**2
    d2 = 2
    d3 = args.targets_width**2

    d1 = args.data_width**2
    assert args.data_width == args.targets_width
    o = d1
    n = args.stats_batch_size
    d = [d1, 30, 30, 30, 20, 30, 30, 30, d1]

    if loss_type == 'CrossEntropy':
        d3 = 10
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    dsize = max(args.train_batch_size, args.stats_batch_size) + 1

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    # optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=dsize,
                          original_targets=True)

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_iter = None
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.stats_batch_size,
            shuffle=False,
            drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               train=False,
                               dataset_size=dsize,
                               original_targets=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.train_batch_size,
                                              shuffle=False,
                                              drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.token_count = 0
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            # print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()  # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)

        for (i, layer) in enumerate(model.layers):

            # input/output layers are unreasonably expensive if not using Kronecker factoring
            if d[i] > 50 or d[i + 1] > 50:
                print(
                    f'layer {i} is too big ({d[i], d[i + 1]}), skipping stats')
                continue

            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            u.check_equal(G.reshape(layer.weight.grad1.shape),
                          layer.weight.grad1)

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel(
            )  # proportion of activations that are zero
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma) / s.sigma_l2

            lambda_regularizer = args.lmb * torch.eye(d[i + 1] * d[i]).to(
                gl.device)
            H = layer.weight.hess

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H + lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          (dd.flatten().norm()**2))

            with u.timeit(f"pinvH-{i}"):
                pinvH = H.pinverse()

            with u.timeit(f'curv-{i}'):
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre',
                        pinvH)  # save Newton preconditioner
                s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm(
                )  # frobenius norm of Newton update
                s.regret_newton = u.to_python_scalar(
                    g @ pinvH @ g.t() / 2)  # replace with "quadratic_form"
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_spectral(H, sigma)

                discrepancy = torch.max(abs(p_sigma - p_sigma.t()) / p_sigma)

                s.psigma_erank = u.sym_erank(p_sigma)
                s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 / n

                # Faster approaches for noise variance computation
                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                #            u.log_scalar(train_loss=loss.item())

                if args.method != 'newton':
                    optimizer.step()
                    if args.weight_decay:
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                param.data.mul_(1 - args.weight_decay)
                else:
                    for (layer_idx, layer) in enumerate(model.layers):
                        param: torch.nn.Parameter = layer.weight
                        param_data: torch.Tensor = param.data
                        param_data.copy_(param_data - 0.1 * param.grad)
                        if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                            continue
                        u.nan_check(layer.weight.pre)
                        u.nan_check(param.grad.flatten())
                        u.nan_check(
                            u.v2r(param.grad.flatten()) @ layer.weight.pre)
                        param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                            param.grad.flatten()) @ layer.weight.pre
                        u.nan_check(param_new_flat)
                        param_data.copy_(
                            param_new_flat.reshape(param_data.shape))

                gl.token_count += data.shape[0]

    gl.event_writer.close()

    assert val_losses[0] > 2.4  # 2.4828238487243652
    assert val_losses[-1] < 2.25  # 2.20609712600708
示例#30
0
def _test_kron_conv_golden():
    """Hardcoded error values to detect unexpected numeric changes."""
    u.seed_random(1)

    n, Xh, Xw = 2, 8, 8
    Kh, Kw = 2, 2
    dd = [3, 3, 3, 3]
    o = dd[-1]

    model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True)
    data = torch.randn((n, dd[0], Xh, Xw))

    # print(model)
    # print(data)

    loss_type = 'CrossEntropy'    #  loss_type = 'LeastSquares'
    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'DebugLeastSquares':
        loss_fn = u.debug_least_squares
    else:    # CrossEntropy
        loss_fn = nn.CrossEntropyLoss()

    sample_output = model(data)

    if loss_type.endswith('LeastSquares'):
        targets = torch.randn(sample_output.shape)
    elif loss_type == 'CrossEntropy':
        targets = torch.LongTensor(n).random_(0, o)

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron')
    autograd_lib.compute_hess(model, method='mean_kron', attr_name='hess_mean_kron')
    autograd_lib.compute_hess(model, method='exact')
    autograd_lib.disable_hooks()

    errors1 = []
    errors2 = []
    for i in range(len(model.layers)):
        layer = model.layers[i]

        # direct Hessian computation
        H = layer.weight.hess
        H_bias = layer.bias.hess

        # factored Hessian computation
        Hk = layer.weight.hess_kron
        Hk_bias = layer.bias.hess_kron
        Hk = Hk.expand()
        Hk_bias = Hk_bias.expand()

        Hk2 = layer.weight.hess_mean_kron
        Hk2_bias = layer.bias.hess_mean_kron
        Hk2 = Hk2.expand()
        Hk2_bias = Hk2_bias.expand()

        # autograd Hessian computation
        loss = loss_fn(output, targets)
        Ha = u.hessian(loss, layer.weight).reshape(H.shape)
        Ha_bias = u.hessian(loss, layer.bias)

        # compare direct against autograd
        Ha = Ha.reshape(H.shape)
        # rel_error = torch.max((H-Ha)/Ha)

        u.check_close(H, Ha, rtol=1e-5, atol=1e-7)
        u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7)

        errors1.extend([u.symsqrt_dist(H, Hk), u.symsqrt_dist(H_bias, Hk_bias)])
        errors2.extend([u.symsqrt_dist(H, Hk2), u.symsqrt_dist(H_bias, Hk2_bias)])

    errors1 = torch.tensor(errors1)
    errors2 = torch.tensor(errors2)
    golden_errors1 = torch.tensor([0.09458080679178238, 0.0, 0.13416489958763123, 0.0, 0.0003909761435352266, 0.0])
    golden_errors2 = torch.tensor([0.0945773795247078, 0.0, 0.13418318331241608, 0.0, 4.478318658129865e-07, 0.0])

    u.check_close(golden_errors1, errors1)
    u.check_close(golden_errors2, errors2)