示例#1
0
def benchmark(method):

    start_time = time.time()
    times = []

    for i in range(1):
        if method == 'svd':
            _result = torch.svd(H)
            open('/dev/null', 'w').write(str(_result[0]))
        elif method == 'inv':
            _result = torch.inverse(H)
            open('/dev/null', 'w').write(str(_result[0]))
        elif method == 'pinv':
            _result = u.pinv(H)
            open('/dev/null', 'w').write(str(_result[0]))
        elif method == 'pinverse':
            _result = torch.pinverse(H)
            open('/dev/null', 'w').write(str(_result[0]))
        elif method == 'eig':
            _result = torch.symeig(H, eigenvectors=True)
            open('/dev/null', 'w').write(str(_result[0]))
        elif method == 'svd':
            _result = torch.svd(H)
            open('/dev/null', 'w').write(str(_result[0]))
        elif method == 'solve':
            _result = torch.solve(S, H)
            open('/dev/null', 'w').write(str(_result[0]))
        else:
            assert False
        new_time = time.time()
        elapsed_time = 1000 * (new_time - start_time)
        print(f"{elapsed_time:8.2f}   {method}")
        start_time = new_time
        times.append(elapsed_time)
示例#2
0
def atest_pinv():
    a = torch.tensor([[2., 7, 9], [1, 9, 8], [2, 7, 5]])
    b = torch.tensor([[6., 6, 1], [10, 7, 7], [7, 10, 10]])
    C = u.Kron(a, b)
    u.check_close(a.flatten().norm() * b.flatten().norm(), C.frobenius_norm())

    u.check_close(C.frobenius_norm(), 4 * math.sqrt(11635.))

    Ci = [[
        0, 5 / 102, -(7 / 204), 0, -(70 / 561), 49 / 561, 0, 125 / 1122,
        -(175 / 2244)
    ],
          [
              1 / 20, -(53 / 1020), 8 / 255, -(7 / 55), 371 / 2805,
              -(224 / 2805), 5 / 44, -(265 / 2244), 40 / 561
          ],
          [
              -(1 / 20), 3 / 170, 3 / 170, 7 / 55, -(42 / 935), -(42 / 935),
              -(5 / 44), 15 / 374, 15 / 374
          ],
          [
              0, -(5 / 102), 7 / 204, 0, 20 / 561, -(14 / 561), 0, 35 / 1122,
              -(49 / 2244)
          ],
          [
              -(1 / 20), 53 / 1020, -(8 / 255), 2 / 55, -(106 / 2805),
              64 / 2805, 7 / 220, -(371 / 11220), 56 / 2805
          ],
          [
              1 / 20, -(3 / 170), -(3 / 170), -(2 / 55), 12 / 935, 12 / 935,
              -(7 / 220), 21 / 1870, 21 / 1870
          ], [0, 5 / 102, -(7 / 204), 0, 0, 0, 0, -(5 / 102), 7 / 204],
          [
              1 / 20, -(53 / 1020), 8 / 255, 0, 0, 0, -(1 / 20), 53 / 1020,
              -(8 / 255)
          ],
          [
              -(1 / 20), 3 / 170, 3 / 170, 0, 0, 0, 1 / 20, -(3 / 170),
              -(3 / 170)
          ]]
    C = C.expand_vec()
    C0 = u.to_numpy(C)
    Ci = torch.tensor(Ci)
    u.check_close(C @ Ci @ C, C)

    u.check_close(linalg.pinv(C0), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(torch.pinverse(C), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(u.pinv(C), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(C.pinv(), Ci, rtol=1e-5, atol=1e-6)
示例#3
0
def main():

    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")

    d1 = args.data_width ** 2
    assert args.data_width == args.targets_width
    o = d1
    n = args.stats_batch_size
    d = [d1, 30, 30, 30, 20, 30, 30, 30, d1]

    # small values for debugging
    # loss_type = 'LeastSquares'
    loss_type = 'CrossEntropy'

    args.wandb = 0
    args.stats_steps = 10
    args.train_steps = 10
    args.stats_batch_size = 10
    args.data_width = 2
    args.targets_width = 2
    args.nonlin = False
    d1 = args.data_width ** 2
    d2 = 2
    d3 = args.targets_width ** 2

    if loss_type == 'CrossEntropy':
        d3 = 10
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    dsize = max(args.train_batch_size, args.stats_batch_size)+1

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_iter = None
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.token_count = 0
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)

        for (i, layer) in enumerate(model.layers):

            # input/output layers are unreasonably expensive if not using Kronecker factoring
            if d[i]>50 or d[i+1]>50:
                print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats')
                continue

            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1)

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()  # proportion of activations that are zero
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma)/s.sigma_l2

            lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device)
            H = layer.weight.hess

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H+lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2))

            with u.timeit(f"pinvH-{i}"):
                pinvH = u.pinv(H)

            with u.timeit(f'curv-{i}'):
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre', pinvH)  # save Newton preconditioner
                s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm()  # frobenius norm of Newton update
                s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)   # replace with "quadratic_form"
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_svd(H, sigma)
                if u.has_nan(p_sigma) and args.compute_rho:  # use expensive method
                    print('using expensive method')
                    import pdb; pdb.set_trace()
                    H0, sigma0 = u.to_numpys(H, sigma)
                    p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0)
                    p_sigma = torch.tensor(p_sigma).to(gl.device)

                if u.has_nan(p_sigma):
                    # import pdb; pdb.set_trace()
                    s.psigma_erank = H.shape[0]
                    s.rho = 1
                else:
                    s.psigma_erank = u.sym_erank(p_sigma)
                    s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n

                # Faster approaches for noise variance computation
                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                #            u.log_scalar(train_loss=loss.item())

                if args.method != 'newton':
                    optimizer.step()
                    if args.weight_decay:
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                param.data.mul_(1-args.weight_decay)
                else:
                    for (layer_idx, layer) in enumerate(model.layers):
                        param: torch.nn.Parameter = layer.weight
                        param_data: torch.Tensor = param.data
                        param_data.copy_(param_data - 0.1 * param.grad)
                        if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                            continue
                        u.nan_check(layer.weight.pre)
                        u.nan_check(param.grad.flatten())
                        u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                        param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre
                        u.nan_check(param_new_flat)
                        param_data.copy_(param_new_flat.reshape(param_data.shape))

                gl.token_count += data.shape[0]

    gl.event_writer.close()
示例#4
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method

    except Exception as e:
        print(f"wandb crash with {e}")

    #    data_width = 4
    #    targets_width = 2

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    def capture_activations(module, input, _output):
        if skip_forward_hooks:
            return
        assert gl.backward_idx == 0  # no need to forward-prop on Hessian computation
        assert not hasattr(
            module, 'activations'
        ), "Seeing activations from previous forward, call util.zero_grad to clear"
        assert len(input) == 1, "this works for single input layers only"
        setattr(module, "activations", input[0].detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    for step in range(args.stats_steps):
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        gl.backward_idx = 0
        u.zero_grad(model)
        output = model(data)
        loss = loss_fn(output, targets)
        loss.backward(retain_graph=True)

        print("loss", loss.item())

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o)

        u.log_scalars({'loss': loss.item()})

        # o = 0
        for out_idx in range(o):
            model.zero_grad()
            # backprop to get section of batch output jacobian for output at position out_idx
            output = model(
                data
            )  # opt: using autograd.grad means I don't have to zero_grad
            ei = id_mat[out_idx]
            bval = torch.stack([ei] * n)
            gl.backward_idx = out_idx + 1
            output.backward(bval)
        skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            # empirical Fisher
            efisher = G.t() @ G / n
            sigma = efisher - g.t() @ g
            # u.dump(sigma, f'/tmp/sigmas/{step}-{i}')
            s.sigma_l2 = u.l2_norm(sigma)

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)]
            Amat_t = torch.cat([A_t] * o, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * o, d[i])
            assert Bmat_t.shape == (n * o, d[i + 1])

            Jb = u.khatri_rao_t(Bmat_t,
                                Amat_t)  # batch Jacobian, in row-vec format
            H = Jb.t() @ Jb / n
            pinvH = u.pinv(H)

            s.hess_l2 = u.l2_norm(H)
            s.invhess_l2 = u.l2_norm(pinvH)

            s.hess_fro = H.flatten().norm()
            s.invhess_fro = pinvH.flatten().norm()

            s.jacobian_l2 = u.l2_norm(Jb)
            s.grad_fro = g.flatten().norm()
            s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          dd.flatten().norm()**2)

            s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
            s.grad_curv = curv_direction(g)
            ndir = g @ u.pinv(H)  # newton direction
            s.newton_curv = curv_direction(ndir)
            setattr(layer.weight, 'pre',
                    u.pinv(H))  # save Newton preconditioner
            s.step_openai = 1 / s.grad_curv if s.grad_curv else 999

            s.newton_fro = ndir.flatten().norm(
            )  # frobenius norm of Newton update
            s.regret_gradient = loss_direction(g, s.step_openai)

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        for i in range(args.train_steps):
            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
示例#5
0
def main():

    u.install_pdb_handler()
    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {logdir}")

    loss_type = 'CrossEntropy'

    d1 = args.data_width ** 2
    args.stats_batch_size = min(args.stats_batch_size, args.dataset_size)
    args.train_batch_size = min(args.train_batch_size, args.dataset_size)
    n = args.stats_batch_size
    o = 10
    d = [d1, 60, 60, 60, o]
    # dataset_size = args.dataset_size

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin, last_layer_linear=True)
    model = model.to(gl.device)
    u.mark_expensive(model.layers[0])    # to stop grad1/hess calculations on this layer
    print(model)

    try:
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name, dir='/tmp/wandb.runs')
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    #  optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)

    test_dataset = u.TinyMNIST(data_width=args.data_width, train=False, dataset_size=args.dataset_size, loss_type=loss_type)
    test_batch_size = min(args.dataset_size, 1000)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:   # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        with u.timeit("validate"):
            if loss_type == 'CrossEntropy':
                val_accuracy, val_loss = validate(model, test_loader, f'test (stats_step {step})')
                # train_accuracy, train_loss = validate(model, train_loader, f'train (stats_step {step})')

                metrics = {'stats_step': step, 'val_accuracy': val_accuracy, 'val_loss': val_loss}
                u.log_scalars(metrics)

        data, targets = stats_data, stats_targets

        if not args.skip_stats:
            # Capture Hessian and gradient stats
            autograd_lib.enable_hooks()
            autograd_lib.clear_backprops(model)
            autograd_lib.clear_hess_backprops(model)
            with u.timeit("backprop_g"):
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward(retain_graph=True)
            with u.timeit("backprop_H"):
                autograd_lib.backprop_hess(output, hess_type=loss_type)
            autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

            with u.timeit("compute_grad1"):
                autograd_lib.compute_grad1(model)
            with u.timeit("compute_hess"):
                autograd_lib.compute_hess(model)

            for (i, layer) in enumerate(model.layers):

                if hasattr(layer, 'expensive'):
                    continue

                param_names = {layer.weight: "weight", layer.bias: "bias"}
                for param in [layer.weight, layer.bias]:
                    # input/output layers are unreasonably expensive if not using Kronecker factoring
                    if d[i]*d[i+1] > 8000:
                        print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats')
                        continue

                    s = AttrDefault(str, {})  # dictionary-like object for layer stats

                    #############################
                    # Gradient stats
                    #############################
                    A_t = layer.activations
                    B_t = layer.backprops_list[0] * n
                    s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()  # proportion of activations that are zero
                    s.mean_activation = torch.mean(A_t)
                    s.mean_backprop = torch.mean(B_t)

                    # empirical Fisher
                    G = param.grad1.reshape((n, -1))
                    g = G.mean(dim=0, keepdim=True)

                    u.nan_check(G)
                    with u.timeit(f'sigma-{i}'):
                        efisher = G.t() @ G / n
                        sigma = efisher - g.t() @ g
                        # sigma_spectrum =
                        s.sigma_l2 = u.sym_l2_norm(sigma)
                        s.sigma_erank = torch.trace(sigma)/s.sigma_l2

                    H = param.hess
                    lambda_regularizer = args.lmb * torch.eye(H.shape[0]).to(gl.device)
                    u.nan_check(H)

                    with u.timeit(f"invH-{i}"):
                        invH = torch.cholesky_inverse(H+lambda_regularizer)

                    with u.timeit(f"H_l2-{i}"):
                        s.H_l2 = u.sym_l2_norm(H)
                        s.iH_l2 = u.sym_l2_norm(invH)

                    with u.timeit(f"norms-{i}"):
                        s.H_fro = H.flatten().norm()
                        s.iH_fro = invH.flatten().norm()
                        s.grad_fro = g.flatten().norm()
                        s.param_fro = param.data.flatten().norm()

                    def loss_direction(dd: torch.Tensor, eps):
                        """loss improvement if we take step eps in direction dd"""
                        return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t())

                    def curv_direction(dd: torch.Tensor):
                        """Curvature in direction dd"""
                        return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2))

                    with u.timeit(f"pinvH-{i}"):
                        pinvH = u.pinv(H)

                    with u.timeit(f'curv-{i}'):
                        s.grad_curv = curv_direction(g)  # curvature (eigenvalue) in direction g
                        ndir = g @ pinvH  # newton direction
                        s.newton_curv = curv_direction(ndir)
                        setattr(layer.weight, 'pre', pinvH)  # save Newton preconditioner
                        s.step_openai = 1 / s.grad_curv if s.grad_curv else 1234567
                        s.step_div_inf = 2 / s.H_l2         # divegent step size for batch_size=infinity
                        s.step_div_1 = torch.tensor(2) / torch.trace(H)   # divergent step for batch_size=1

                        s.newton_fro = ndir.flatten().norm()  # frobenius norm of Newton update
                        s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)   # replace with "quadratic_form"
                        s.regret_gradient = loss_direction(g, s.step_openai)

                    with u.timeit(f'rho-{i}'):
                        s.rho, s.lyap_erank, lyap_evals = u.truncated_lyapunov_rho(H, sigma)
                        s.step_div_1_adjusted = s.step_div_1/s.rho

                    with u.timeit(f"batch-{i}"):
                        s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                        s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n  # Gradient diversity / n
                        s.noise_variance_pinv = torch.trace(pinvH @ sigma)
                        s.H_erank = torch.trace(H) / s.H_l2
                        s.batch_jain_simple = 1 + s.H_erank
                        s.batch_jain_full = 1 + s.rho * s.H_erank

                    param_name = f"{layer.name}={param_names[param]}"
                    u.log_scalars(u.nest_stats(f"{param_name}", s))

                    H_evals = u.symeig_pos_evals(H)
                    sigma_evals = u.symeig_pos_evals(sigma)
                    u.log_spectrum(f'{param_name}/hess', H_evals)
                    u.log_spectrum(f'{param_name}/sigma', sigma_evals)
                    u.log_spectrum(f'{param_name}/lyap', lyap_evals)

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                optimizer.step()
                if args.weight_decay:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.mul_(1-args.weight_decay)

                gl.increment_global_step(data.shape[0])

    gl.event_writer.close()
示例#6
0
    def compute_layer_stats(layer):
        refreeze = False
        if hasattr(layer, 'frozen') and layer.frozen:
            u.unfreeze(layer)
            refreeze = True

        s = AttrDefault(str, {})
        n = args.stats_batch_size
        param = u.get_param(layer)
        _d = len(param.flatten())  # dimensionality of parameters
        layer_idx = model.layers.index(layer)
        # TODO: get layer type, include it in name
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        _loss, _output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2
        s.grad_fro = g.flatten().norm()
        s.param_fro = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        s.a_sparsity = neg_activations.float() / (
            pos_activations + neg_activations)  # 1 sparsity means all 0's
        activation_size = len(layer.data_output.flatten())
        s.a_magnitude = torch.sum(layer.data_output) / activation_size

        _output = backprop_output()
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)  # batch output Jacobian
        H = J.t() @ J / n

        s.hessian_l2 = u.l2_norm(H)
        s.jacobian_l2 = u.l2_norm(J)
        J1 = J.sum(dim=0) / n  # single output Jacobian
        s.J1_l2 = J1.norm()

        # newton decrement
        def loss_direction(direction, eps):
            """loss improvement if we take step eps in direction dir"""
            return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 *
                                      eps**2 * direction @ H @ direction.t())

        s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)

        # TODO: gradient diversity is stuck at 1
        # TODO: newton/gradient angle
        # TODO: newton step magnitude
        s.grad_curvature = u.to_python_scalar(
            g @ H @ g.t())  # curvature in direction of g
        s.step_openai = u.to_python_scalar(
            s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999

        s.regret_gradient = loss_direction(g, s.step_openai)

        if refreeze:
            u.freeze(layer)
        return s
示例#7
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases')
    parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks')
    parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run')

    parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers")

    parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument('--hess_samples', type=int, default=1,
                        help='number of samples when sub-sampling outputs, 0 for exact hessian')
    parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho')
    parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection')

    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection")
    parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections")

    parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--dropout', type=int, default=0)
    parser.add_argument('--swa', type=int, default=0)
    parser.add_argument('--lmb', type=float, default=1e-3)

    parser.add_argument('--train_batch_size', type=int, default=64)
    parser.add_argument('--stats_batch_size', type=int, default=10000)
    parser.add_argument('--stats_num_batches', type=int, default=1)
    parser.add_argument('--run_name', type=str, default='noname')
    parser.add_argument('--launch_blocking', type=int, default=0)
    parser.add_argument('--sampled', type=int, default=0)
    parser.add_argument('--curv', type=str, default='kfac',
                        help='decomposition to use for curvature estimates: zero_order, kfac, isserlis or full')
    parser.add_argument('--log_spectra', type=int, default=0)

    u.seed_random(1)
    gl.args = parser.parse_args()
    args = gl.args
    u.seed_random(1)

    gl.project_name = 'train_ciresan'
    u.setup_logdir_and_event_writer(args.run_name)
    print(f"Logging to {gl.logdir}")

    d1 = 28 * 28
    d = [784, 2500, 2000, 1500, 1000, 500, 10]

    # number of samples per datapoint. Used to normalize kfac
    model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout)
    model = model.to(gl.device)
    autograd_lib.register(model)

    assert args.dataset_size >= args.stats_batch_size
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    assert not args.full_batch, "fixme: validation still uses stats_iter"
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=True,
                                                   drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)
    else:
        stats_iter = None

    test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False,
                               original_targets=True,
                               dataset_size=args.dataset_size)
    test_eval_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False,
                                                   drop_last=False)
    train_eval_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False,
                                                    drop_last=False)

    loss_fn = torch.nn.CrossEntropyLoss()
    autograd_lib.add_hooks(model)
    autograd_lib.disable_hooks()

    gl.token_count = 0
    last_outer = 0

    for step in range(args.stats_steps):
        epoch = gl.token_count // 60000
        lr = optimizer.param_groups[0]['lr']
        print('token_count', gl.token_count)
        if last_outer:
            u.log_scalars({"time/outer": 1000 * (time.perf_counter() - last_outer)})
            print(f'time: {time.perf_counter() - last_outer:.2f}')
        last_outer = time.perf_counter()

        with u.timeit("validate"):
            val_accuracy, val_loss = validate(model, test_eval_loader, f'test (epoch {epoch})')
            train_accuracy, train_loss = validate(model, train_eval_loader, f'train (epoch {epoch})')

        # save log
        metrics = {'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss,
                   'train_loss': train_loss, 'train_accuracy': train_accuracy,
                   'lr': optimizer.param_groups[0]['lr'],
                   'momentum': optimizer.param_groups[0].get('momentum', 0)}
        u.log_scalars(metrics)

        def mom_update(buffer, val):
            buffer *= 0.9
            buffer += val * 0.1

        if not args.skip_stats:
            # number of samples passed through
            n = args.stats_batch_size * args.stats_num_batches

            # quanti
            forward_stats = defaultdict(lambda: AttrDefault(float))

            hessians = defaultdict(lambda: AttrDefault(float))
            jacobians = defaultdict(lambda: AttrDefault(float))
            fishers = defaultdict(lambda: AttrDefault(float))  # empirical fisher/gradient
            quad_fishers = defaultdict(lambda: AttrDefault(float))  # gradient statistics that depend on fisher (4th order moments)
            train_regrets = defaultdict(list)
            test_regrets1 = defaultdict(list)
            test_regrets2 = defaultdict(list)
            train_regrets_opt = defaultdict(list)
            test_regrets_opt = defaultdict(list)
            cosines = defaultdict(list)
            dot_products = defaultdict(list)
            hessians_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            jacobians_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            quad_fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList))

            current = None
            current_histograms = None

            for i in range(args.stats_num_batches):
                activations = {}
                backprops = {}

                def save_activations(layer, A, _):
                    activations[layer] = A
                    forward_stats[layer].AA += torch.einsum("ni,nj->ij", A, A)

                print('forward')
                with u.timeit("stats_forward"):
                    with autograd_lib.module_hook(save_activations):
                        data, targets = next(stats_iter)
                        output = model(data)
                        loss = loss_fn(output, targets) * len(output)

                def compute_stats(layer, _, B):
                    A = activations[layer]
                    if current == fishers:
                        backprops[layer] = B

                    # about 27ms per layer
                    with u.timeit('compute_stats'):
                        current[layer].BB += torch.einsum("ni,nj->ij", B, B)  # TODO(y): index consistency
                        current[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A)
                        current[layer].BA += torch.einsum("ni,nj->ij", B, A)
                        current[layer].a += torch.einsum("ni->i", A)
                        current[layer].b += torch.einsum("nk->k", B)
                        current[layer].norm2 += ((A * A).sum(dim=1) * (B * B).sum(dim=1)).sum()

                        # compute curvatures in direction of all gradiennts
                        if current is fishers:
                            assert args.stats_num_batches == 1, "not tested on more than one stats step, currently reusing aggregated moments"
                            hess = hessians[layer]
                            jac = jacobians[layer]
                            Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n
                            Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n
                            norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1))

                            current[layer].min_norm2 = min(norms)
                            current[layer].median_norm2 = torch.median(norms)
                            current[layer].max_norm2 = max(norms)

                            norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1))
                            norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1))

                            current[layer].norm += norms.sum()
                            current_histograms[layer].norms.extend(torch.sqrt(norms))
                            current[layer].curv_hess += (skip_nans(norms2_hess / norms)).sum()
                            current_histograms[layer].curv_hess.extend(skip_nans(norms2_hess / norms))
                            current[layer].curv_hess_max += (skip_nans(norms2_hess / norms)).max()
                            current[layer].curv_hess_median += (skip_nans(norms2_hess / norms)).median()

                            current_histograms[layer].curv_jac.extend(skip_nans(norms2_jac / norms))
                            current[layer].curv_jac += (skip_nans(norms2_jac / norms)).sum()
                            current[layer].curv_jac_max += (skip_nans(norms2_jac / norms)).max()
                            current[layer].curv_jac_median += (skip_nans(norms2_jac / norms)).median()

                            current[layer].a_sparsity += torch.sum(A <= 0).float() / A.numel()
                            current[layer].b_sparsity += torch.sum(B <= 0).float() / B.numel()

                            current[layer].mean_activation += torch.mean(A)
                            current[layer].mean_activation2 += torch.mean(A*A)
                            current[layer].mean_backprop = torch.mean(B)
                            current[layer].mean_backprop2 = torch.mean(B*B)

                            current[layer].norms_hess += torch.sqrt(norms2_hess).sum()
                            current_histograms[layer].norms_hess.extend(torch.sqrt(norms2_hess))
                            current[layer].norms_jac += norms2_jac.sum()
                            current_histograms[layer].norms_jac.extend(torch.sqrt(norms2_jac))

                            normalized_moments = copy.copy(hessians[layer])
                            normalized_moments.AA = forward_stats[layer].AA
                            normalized_moments = u.divide_attributes(normalized_moments, n)

                            train_regrets_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=0, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets1_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=1, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets2_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=2, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=2,
                                                                           m=normalized_moments, approx=args.curv)
                            train_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=0,
                                                                            m=normalized_moments, approx=args.curv)
                            cosines_ = autograd_lib.offset_cosines(A, B)
                            train_regrets[layer].extend(train_regrets_)
                            test_regrets1[layer].extend(test_regrets1_)
                            test_regrets2[layer].extend(test_regrets2_)
                            train_regrets_opt[layer].extend(train_regrets_opt_)
                            test_regrets_opt[layer].extend(test_regrets_opt_)
                            cosines[layer].extend(cosines_)
                            dot_products[layer].extend(autograd_lib.offset_dotprod(A, B))

                        # statistics of the form g.Sigma.g
                        elif current == quad_fishers:
                            hess = hessians[layer]
                            sigma = fishers[layer]
                            jac = jacobians[layer]
                            Bs, As = B @ sigma.BB / n, A @ forward_stats[layer].AA / n
                            Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n
                            Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n

                            norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1))
                            norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1))
                            norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1))
                            norms_sigma = ((As * A).sum(dim=1) * (Bs * B).sum(dim=1))

                            current[layer].norm += norms.sum()  # TODO(y) remove, redundant with norm2 above
                            current[layer].curv_sigma += (skip_nans(norms_sigma / norms)).sum()
                            current[layer].curv_sigma_max = skip_nans(norms_sigma / norms).max()
                            current[layer].curv_sigma_median = skip_nans(norms_sigma / norms).median()
                            current[layer].curv_hess += skip_nans(norms2_hess / norms).sum()
                            current[layer].curv_hess_max += skip_nans(norms2_hess / norms).max()
                            current[layer].lyap_hess_mean += skip_nans(norms_sigma / norms2_hess).mean()
                            current[layer].lyap_hess_max = max(skip_nans(norms_sigma/norms2_hess))
                            current[layer].lyap_jac_mean += skip_nans(norms_sigma / norms2_jac).mean()
                            current[layer].lyap_jac_max = max(skip_nans(norms_sigma/norms2_jac))

                print('backward')
                with u.timeit("backprop_H"):
                    with autograd_lib.module_hook(compute_stats):
                        current = hessians
                        current_histograms = hessians_histograms
                        autograd_lib.backward_hessian(output, loss='CrossEntropy', sampled=args.sampled,
                                                      retain_graph=True)  # 600 ms
                        current = jacobians
                        current_histograms = jacobians_histograms
                        autograd_lib.backward_jacobian(output, sampled=args.sampled, retain_graph=True)  # 600 ms
                        current = fishers
                        current_histograms = fishers_histograms
                        model.zero_grad()
                        loss.backward(retain_graph=True)  # 60 ms
                        current = quad_fishers
                        current_histograms = quad_fishers_histograms
                        model.zero_grad()
                        loss.backward()  # 60 ms

            print('summarize')
            for (i, layer) in enumerate(model.layers):
                stats_dict = {'hessian': hessians, 'jacobian': jacobians, 'fisher': fishers}

                # evaluate stats from
                # https://app.wandb.ai/yaroslavvb/train_ciresan/runs/425pu650?workspace=user-yaroslavvb
                for stats_name in stats_dict:
                    s = AttrDict()
                    stats = stats_dict[stats_name][layer]

                    for key in forward_stats[layer]:
                        # print(f'copying {key} in {stats_name}, {layer}')
                        try:
                            assert stats[key] == float()
                        except:
                            f"Trying to overwrite {key} in {stats_name}, {layer}"
                        stats[key] = forward_stats[layer][key]

                    diag: torch.Tensor = stats.diag / n

                    # jacobian:
                    # curv in direction of gradient goes down to roughly 0.3-1
                    # maximum curvature goes up to 1000-2000
                    #
                    # Hessian:
                    # max curv goes down to 1, in direction of gradient 0.0001

                    s.diag_l2 = torch.max(diag)  # 40 - 3000 smaller than kfac l2 for jac
                    s.diag_fro = torch.norm(
                        diag)  # jacobian grows to 0.5-1.5, rest falls, layer-5 has phase transition, layer-4 also
                    s.diag_trace = diag.sum()  # jacobian grows 0-1000 (first), 0-150 (last). Almost same as kfac_trace (771 vs 810 kfac). Jacobian has up/down phase transition
                    s.diag_average = diag.mean()

                    # normalize for mean loss
                    BB = stats.BB / n
                    AA = stats.AA / n
                    # A_evals, _ = torch.symeig(AA)   # averaging 120ms per hit, 90 hits
                    # B_evals, _ = torch.symeig(BB)

                    # s.kfac_l2 = torch.max(A_evals) * torch.max(B_evals)    # 60x larger than diag_l2. layer0/hess has down/up phase transition. layer5/jacobian has up/down phase transition
                    s.kfac_trace = torch.trace(AA) * torch.trace(BB)  # 0/hess down/up tr, 5/jac sharp phase transition
                    s.kfac_fro = torch.norm(stats.AA) * torch.norm(
                        stats.BB)  # 0/hess has down/up tr, 5/jac up/down transition
                    # s.kfac_erank = s.kfac_trace / s.kfac_l2   # first layer has 25, rest 15, all layers go down except last, last noisy
                    # s.kfac_erank_fro = s.kfac_trace / s.kfac_fro / max(stats.BA.shape)

                    s.diversity = (stats.norm2 / n) / u.norm_squared(
                        stats.BA / n)  # gradient diversity. Goes up 3x. Bottom layer has most diversity. Jacobian diversity much less noisy than everythingelse

                    # discrepancy of KFAC based on exact values of diagonal approximation
                    # average difference normalized by average diagonal magnitude
                    diag_kfac = torch.einsum('ll,ii->li', BB, AA)
                    s.kfac_error = (torch.abs(diag_kfac - diag)).mean() / torch.mean(diag.abs())
                    u.log_scalars(u.nest_stats(f'layer-{i}/{stats_name}', s))

                # openai batch size stat
                s = AttrDict()
                hess = hessians[layer]
                jac = jacobians[layer]
                fish = fishers[layer]
                quad_fish = quad_fishers[layer]

                # the following check passes, but is expensive
                # if args.stats_num_batches == 1:
                #    u.check_close(fisher[layer].BA, layer.weight.grad)

                def trsum(A, B):
                    return (A * B).sum()  # computes tr(AB')

                grad = fishers[layer].BA / n
                s.grad_fro = torch.norm(grad)

                # get norms
                s.lyap_hess_max = quad_fish.lyap_hess_max
                s.lyap_hess_ave = quad_fish.lyap_hess_sum / n
                s.lyap_jac_max = quad_fish.lyap_jac_max
                s.lyap_jac_ave = quad_fish.lyap_jac_sum / n
                s.hess_trace = hess.diag.sum() / n
                s.jac_trace = jac.diag.sum() / n

                # Version 1 of Jain stochastic rates, use Hessian for curvature
                b = args.train_batch_size

                s.hess_curv = trsum((hess.BB / n) @ grad @ (hess.AA / n), grad) / trsum(grad, grad)
                s.jac_curv = trsum((jac.BB / n) @ grad @ (jac.AA / n), grad) / trsum(grad, grad)

                # compute gradient noise statistics
                # fish.BB has /n factor twice, hence don't need extra /n on fish.AA
                # after sampling, hess_noise,jac_noise became 100x smaller, but normalized is unaffected
                s.hess_noise = (trsum(hess.AA / n, fish.AA / n) * trsum(hess.BB / n, fish.BB / n))
                s.jac_noise = (trsum(jac.AA / n, fish.AA / n) * trsum(jac.BB / n, fish.BB / n))
                s.hess_noise_centered = s.hess_noise - trsum(hess.BB / n @ grad, grad @ hess.AA / n)
                s.jac_noise_centered = s.jac_noise - trsum(jac.BB / n @ grad, grad @ jac.AA / n)
                s.openai_gradient_noise = (fish.norms_hess / n) / trsum(hess.BB / n @ grad, grad @ hess.AA / n)

                s.mean_norm = torch.sqrt(fish.norm2) / n
                s.min_norm = torch.sqrt(fish.min_norm2)
                s.median_norm = torch.sqrt(fish.median_norm2)
                s.max_norm = torch.sqrt(fish.max_norm2)
                s.enorms = u.norm_squared(grad)
                s.a_sparsity = fish.a_sparsity
                s.b_sparsity = fish.b_sparsity
                s.mean_activation = fish.mean_activation
                s.msr_activation = torch.sqrt(fish.mean_activation2)
                s.mean_backprop = fish.mean_backprop
                s.msr_backprop = torch.sqrt(fish.mean_backprop2)

                s.norms_centered = fish.norm2 / n - u.norm_squared(grad)
                s.norms_hess = fish.norms_hess / n
                s.norms_jac = fish.norms_jac / n

                s.hess_curv_grad = fish.curv_hess / n  # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.hess_curv_grad_max = fish.curv_hess_max   # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.hess_curv_grad_median = fish.curv_hess_median   # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.sigma_curv_grad = quad_fish.curv_sigma / n
                s.sigma_curv_grad_max = quad_fish.curv_sigma_max
                s.sigma_curv_grad_median = quad_fish.curv_sigma_median
                s.band_bottou = 0.5 * lr * s.sigma_curv_grad / s.hess_curv_grad
                s.band_bottou_stoch = 0.5 * lr * quad_fish.curv_ratio / n
                s.band_yaida = 0.25 * lr * s.mean_norm**2
                s.band_yaida_centered = 0.25 * lr * s.norms_centered

                s.jac_curv_grad = fish.curv_jac / n  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.
                s.jac_curv_grad_max = fish.curv_jac_max  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.
                s.jac_curv_grad_median = fish.curv_jac_median  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.

                # OpenAI gradient noise statistics
                s.hess_noise_normalized = s.hess_noise_centered / (fish.norms_hess / n)
                s.jac_noise_normalized = s.jac_noise / (fish.norms_jac / n)

                train_regrets_, test_regrets1_, test_regrets2_, train_regrets_opt_, test_regrets_opt_, cosines_, dot_products_ = (torch.stack(r[layer]) for r in (train_regrets, test_regrets1, test_regrets2, train_regrets_opt, test_regrets_opt, cosines, dot_products))
                s.train_regret = train_regrets_.median()  # use median because outliers make it hard to see the trend
                s.test_regret1 = test_regrets1_.median()
                s.test_regret2 = test_regrets2_.median()
                s.test_regret_opt = test_regrets_opt_.median()
                s.train_regret_opt = train_regrets_opt_.median()
                s.mean_dot_product = torch.mean(dot_products_)
                s.median_dot_product = torch.median(dot_products_)
                a = [1, 2, 3]

                s.median_cosine = cosines_.median()
                s.mean_cosine = cosines_.mean()

                # get learning rates
                L1 = s.hess_curv_grad / n
                L2 = s.jac_curv_grad / n
                diversity = (fish.norm2 / n) / u.norm_squared(grad)
                robust_diversity = (fish.norm2 / n) / fish.median_norm2
                dotprod_diversity = fish.median_norm2 / s.median_dot_product
                s.lr1 = 2 / (L1 * diversity)
                s.lr2 = 2 / (L2 * diversity)
                s.lr3 = 2 / (L2 * robust_diversity)
                s.lr4 = 2 / (L2 * dotprod_diversity)

                hess_A = u.symeig_pos_evals(hess.AA / n)
                hess_B = u.symeig_pos_evals(hess.BB / n)
                fish_A = u.symeig_pos_evals(fish.AA / n)
                fish_B = u.symeig_pos_evals(fish.BB / n)
                jac_A = u.symeig_pos_evals(jac.AA / n)
                jac_B = u.symeig_pos_evals(jac.BB / n)
                u.log_scalars({f'layer-{i}/hessA_erank': erank(hess_A)})
                u.log_scalars({f'layer-{i}/hessB_erank': erank(hess_B)})
                u.log_scalars({f'layer-{i}/fishA_erank': erank(fish_A)})
                u.log_scalars({f'layer-{i}/fishB_erank': erank(fish_B)})
                u.log_scalars({f'layer-{i}/jacA_erank': erank(jac_A)})
                u.log_scalars({f'layer-{i}/jacB_erank': erank(jac_B)})
                gl.event_writer.add_histogram(f'layer-{i}/hist_hess_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_fish_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_jac_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())

                s.hess_l2 = max(hess_A) * max(hess_B)
                s.jac_l2 = max(jac_A) * max(jac_B)
                s.fish_l2 = max(fish_A) * max(fish_B)
                s.hess_trace = hess.diag.sum() / n

                s.jain1_sto = 1/(s.hess_trace + 2 * s.hess_l2)
                s.jain1_det = 1/s.hess_l2

                s.jain1_lr = (1 / b) * (1/s.jain1_sto) + (b - 1) / b * (1/s.jain1_det)
                s.jain1_lr = 2 / s.jain1_lr

                s.regret_ratio = (
                            train_regrets_opt_ / test_regrets_opt_).median()  # ratio between train and test regret, large means overfitting
                u.log_scalars(u.nest_stats(f'layer-{i}', s))

                # compute stats that would let you bound rho
                if i == 0:  # only compute this once, for output layer
                    hhh = hessians[model.layers[-1]].BB / n
                    fff = fishers[model.layers[-1]].BB / n
                    d = fff.shape[0]
                    L = u.lyapunov_spectral(hhh, 2 * fff, cond=1e-8)
                    L_evals = u.symeig_pos_evals(L)
                    Lcheap = fff @ u.pinv(hhh, cond=1e-8)
                    Lcheap_evals = u.eig_real(Lcheap)

                    u.log_scalars({f'mismatch/rho': d/erank(L_evals)})
                    u.log_scalars({f'mismatch/rho_cheap': d/erank(Lcheap_evals)})
                    u.log_scalars({f'mismatch/diagonalizability': erank(L_evals)/erank(Lcheap_evals)})  # 1 means diagonalizable
                    u.log_spectrum(f'mismatch/sigma', u.symeig_pos_evals(fff), loglog=False)
                    u.log_spectrum(f'mismatch/hess', u.symeig_pos_evals(hhh), loglog=False)
                    u.log_spectrum(f'mismatch/lyapunov', L_evals, loglog=True)
                    u.log_spectrum(f'mismatch/lyapunov_cheap', Lcheap_evals, loglog=True)

                gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms', u.to_numpy(fishers_histograms[layer].norms.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms_hess', u.to_numpy(fishers_histograms[layer].norms_hess.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_curv_jac', u.to_numpy(fishers_histograms[layer].curv_jac.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_curv_hess', u.to_numpy(fishers_histograms[layer].curv_hess.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_cosines', u.to_numpy(cosines[layer]), gl.get_global_step())

                if args.log_spectra:
                    with u.timeit('spectrum'):
                        # 2/alpha
                        # s.jain1_lr = (1 / b) * s.jain1_sto + (b - 1) / b * s.jain1_det
                        # s.jain1_lr = 1 / s.jain1_lr

                        # hess.diag_trace, jac.diag_trace

                        # Version 2 of Jain stochastic rates, use Jacobian squared for curvature
                        s.jain2_sto = s.lyap_jac_max * s.jac_trace / s.lyap_jac_ave
                        s.jain2_det = s.jac_l2
                        s.jain2_lr = (1 / b) * s.jain2_sto + (b - 1) / b * s.jain2_det
                        s.jain2_lr = 1 / s.jain2_lr

                        u.log_spectrum(f'layer-{i}/hess_A', hess_A)
                        u.log_spectrum(f'layer-{i}/hess_B', hess_B)
                        u.log_spectrum(f'layer-{i}/hess_AB', u.outer(hess_A, hess_B).flatten())
                        u.log_spectrum(f'layer-{i}/jac_A', jac_A)
                        u.log_spectrum(f'layer-{i}/jac_B', jac_B)
                        u.log_spectrum(f'layer-{i}/fish_A', fish_A)
                        u.log_spectrum(f'layer-{i}/fish_B', fish_B)

                        u.log_scalars({f'layer-{i}/trace_ratio': fish_B.sum()/hess_B.sum()})

                        L = torch.eig(u.lyapunov_spectral(hess.BB, 2*fish.BB, cond=1e-8))[0]
                        L = L[:, 0]  # extract real part
                        L = L.sort()[0]
                        L = torch.flip(L, [0])

                        L_cheap = torch.eig(fish.BB @ u.pinv(hess.BB, cond=1e-8))[0]
                        L_cheap = L_cheap[:, 0]  # extract real part
                        L_cheap = L_cheap.sort()[0]
                        L_cheap = torch.flip(L_cheap, [0])

                        d = len(hess_B)
                        u.log_spectrum(f'layer-{i}/Lyap', L)
                        u.log_spectrum(f'layer-{i}/Lyap_cheap', L_cheap)

                        u.log_scalars({f'layer-{i}/dims': d})
                        u.log_scalars({f'layer-{i}/L_erank': erank(L)})
                        u.log_scalars({f'layer-{i}/L_cheap_erank': erank(L_cheap)})

                        u.log_scalars({f'layer-{i}/rho': d/erank(L)})
                        u.log_scalars({f'layer-{i}/rho_cheap': d/erank(L_cheap)})

        model.train()
        with u.timeit('train'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                optimizer.step()
                if args.weight_decay:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.mul_(1 - args.weight_decay)

                gl.token_count += data.shape[0]

    gl.event_writer.close()
示例#8
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['d1'] = d1
            wandb.config['d2'] = d2
            wandb.config['d3'] = d3
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               dataset_size=args.dataset_size,
                               train=False)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.stats_batch_size,
                                              shuffle=True,
                                              drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    skip_forward_hooks = False
    skip_backward_hooks = False

    def capture_activations(module: nn.Module, input: List[torch.Tensor],
                            output: torch.Tensor):
        if skip_forward_hooks:
            return
        assert not hasattr(
            module, 'activations'
        ), "Seeing results of previous autograd, call util.zero_grad to clear"
        assert len(input) == 1, "this was tested for single input layers only"
        setattr(module, "activations", input[0].detach())
        setattr(module, "output", output.detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    last_outer = 0
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()
        # compute validation loss
        skip_forward_hooks = True
        skip_backward_hooks = True
        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        with u.timeit("backprop_g"):
            gl.backward_idx = 0
            u.zero_grad(model)
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o).to(gl.device)

        u.log_scalar(loss=loss.item())

        with u.timeit("backprop_H"):
            # optionally use randomized low-rank approximation of Hessian
            hess_rank = args.hess_samples if args.hess_samples else o

            for out_idx in range(hess_rank):
                model.zero_grad()
                # backprop to get section of batch output jacobian for output at position out_idx
                output = model(
                    data
                )  # opt: using autograd.grad means I don't have to zero_grad
                if args.hess_samples:
                    bval = torch.LongTensor(n, o).to(gl.device).random_(
                        0, 2) * 2 - 1
                    bval = bval.float()
                else:
                    ei = id_mat[out_idx]
                    bval = torch.stack([ei] * n)
                gl.backward_idx = out_idx + 1
                output.backward(bval)
            skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma) / s.sigma_l2

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [
                layer.backprops[out_idx + 1] for out_idx in range(hess_rank)
            ]
            Amat_t = torch.cat([A_t] * hess_rank, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * hess_rank, d[i])
            assert Bmat_t.shape == (n * hess_rank, d[i + 1])

            lambda_regularizer = args.lmb * torch.eye(d[i] * d[i + 1]).to(
                gl.device)
            with u.timeit(f"khatri_H-{i}"):
                Jb = u.khatri_rao_t(
                    Bmat_t, Amat_t)  # batch Jacobian, in row-vec format

            with u.timeit(f"H-{i}"):
                H = Jb.t() @ Jb / n

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H + lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.jacobian_fro = Jb.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          (dd.flatten().norm()**2))

            with u.timeit("pinvH"):
                pinvH = u.pinv(H)

            with u.timeit(f'curv-{i}'):
                s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre',
                        pinvH)  # save Newton preconditioner
                s.step_openai = 1 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / u.sym_l2_norm(H)
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm(
                )  # frobenius norm of Newton update
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_svd(H, sigma)
                if u.has_nan(
                        p_sigma) and args.compute_rho:  # use expensive method
                    H0 = H.cpu().detach().numpy()
                    sigma0 = sigma.cpu().detach().numpy()
                    p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0)
                    p_sigma = torch.tensor(p_sigma).to(gl.device)

                if u.has_nan(p_sigma):
                    s.psigma_erank = H.shape[0]
                    s.rho = 1
                else:
                    s.psigma_erank = u.sym_erank(p_sigma)
                    s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                print('openai batch', s.batch_openai)
                s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2

                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        last_inner = 0
        for i in range(args.train_steps):
            if last_inner:
                u.log_scalars(
                    {"time/inner": 1000 * (time.perf_counter() - last_inner)})
            last_inner = time.perf_counter()

            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
    def compute_layer_stats(layer):
        stats = AttrDefault(str, {})
        n = stats_batch_size
        param = u.get_param(layer)
        d = len(param.flatten())
        layer_idx = model.layers.index(layer)
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        loss, output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        stats.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2

        stats.gradient_norm = g.flatten().norm()
        stats.parameter_norm = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        stats.sparsity = pos_activations.float() / (pos_activations +
                                                    neg_activations)

        output = backprop_output()
        At2 = layer.data_input
        u.check_close(At, At2)
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)
        H = J.t() @ J / n

        model.zero_grad()
        output = model(stats_data)  # use last saved data batch for backprop
        loss = compute_loss(output, stats_targets)
        hess = u.hessian(loss, param)

        hess = hess.transpose(2, 3).transpose(0, 1).reshape(d, d)
        u.check_close(hess, H)
        u.check_close(hess, H)

        stats.hessian_norm = u.l2_norm(H)
        stats.jacobian_norm = u.l2_norm(J)
        Joutput = J.sum(dim=0) / n
        stats.jacobian_sensitivity = Joutput.norm()

        # newton decrement
        stats.loss_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
        u.check_close(stats.loss_newton, loss)

        # do line-search to find optimal step
        def line_search(directionv, start, end, steps=10):
            """Takes steps between start and end, returns steps+1 loss entries"""
            param0 = param.data.clone()
            param0v = u.vec(param0).t()
            losses = []
            for i in range(steps + 1):
                output = model(
                    stats_data)  # use last saved data batch for backprop
                loss = compute_loss(output, stats_targets)
                losses.append(loss)
                offset = start + i * ((end - start) / steps)
                param1v = param0v + offset * directionv

                param1 = u.unvec(param1v.t(), param.data.shape[0])
                param.data.copy_(param1)

            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            losses.append(loss)

            param.data.copy_(param0)
            return losses

        # try to take a newton step
        gradv = g
        line_losses = line_search(-gradv @ u.pinv(H), 0, 2, steps=10)
        u.check_equal(line_losses[0], loss)
        u.check_equal(line_losses[6], 0)
        assert line_losses[5] > line_losses[6]
        assert line_losses[7] > line_losses[6]
        return stats