예제 #1
0
def test_kron_1x2_conv():
    """Minimal example of a 1x2 convolution whose Hessian/grad covariance doesn't factor as Kronecker.

    Two convolutional layers stacked on top of each other, followed by least squares loss.

        Outputs:
        0 tensor([[[[0., 1., 1., 1.]]]])
        1 tensor([[[[2., 3.]]]])
        2 tensor([[[[8.]]]])

        Activations/backprops:
        layerA 0 tensor([[[[0., 1.],
                           [1., 1.]]]])
        layerB 0 tensor([[[[1., 2.]]]])

        layerA 1 tensor([[[[2.],
                           [3.]]]])
        layerB 1 tensor([[[[1.]]]])

        layer 0 discrepancy: 0.6597963571548462
        layer 1 discrepancy: 0.0

     """
    u.seed_random(1)

    n, Xh, Xw = 1, 1, 4
    Kh, Kw = 1, 2
    dd = [1, 1, 1]
    o = dd[-1]

    model: u.SimpleModel = u.StridedConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True)
    data = torch.tensor([0, 1., 1, 1]).reshape((n, dd[0], Xh, Xw))

    model.layers[0].bias.data.zero_()
    model.layers[0].weight.data.copy_(torch.tensor([1, 2]))

    model.layers[1].bias.data.zero_()
    model.layers[1].weight.data.copy_(torch.tensor([1, 2]))

    sample_output = model(data)

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type='LeastSquares')
    autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron')
    autograd_lib.compute_hess(model, method='exact')
    autograd_lib.disable_hooks()

    for i in range(len(model.layers)):
        layer = model.layers[i]
        H = layer.weight.hess
        Hk = layer.weight.hess_kron
        Hk = Hk.expand()
        print(u.symsqrt_dist(H, Hk))
예제 #2
0
def test_main():

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb',
                        type=int,
                        default=1,
                        help='log to weights and biases')
    parser.add_argument('--autograd_check',
                        type=int,
                        default=0,
                        help='autograd correctness checks')
    parser.add_argument('--logdir',
                        type=str,
                        default='/temp/runs/curv_train_tiny/run')

    parser.add_argument('--train_batch_size', type=int, default=100)
    parser.add_argument('--stats_batch_size', type=int, default=60000)
    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps',
                        type=int,
                        default=100,
                        help="this many train steps between stat collection")
    parser.add_argument('--stats_steps',
                        type=int,
                        default=1000000,
                        help="total number of curvature stats collections")
    parser.add_argument('--nonlin',
                        type=int,
                        default=1,
                        help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--method',
                        type=str,
                        choices=['gradient', 'newton'],
                        default='gradient',
                        help="descent method, newton or gradient")
    parser.add_argument('--layer',
                        type=int,
                        default=-1,
                        help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument('--lmb', type=float, default=1e-3)
    parser.add_argument(
        '--hess_samples',
        type=int,
        default=1,
        help='number of samples when sub-sampling outputs, 0 for exact hessian'
    )
    parser.add_argument('--hess_kfac',
                        type=int,
                        default=0,
                        help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho',
                        type=int,
                        default=1,
                        help='use expensive method to compute rho')
    parser.add_argument('--skip_stats',
                        type=int,
                        default=0,
                        help='skip all stats collection')
    parser.add_argument('--full_batch',
                        type=int,
                        default=0,
                        help='do stats on the whole dataset')
    parser.add_argument('--weight_decay', type=float, default=1e-4)

    #args = parser.parse_args()
    args = AttrDict()
    args.lmb = 1e-3
    args.compute_rho = 1
    args.weight_decay = 1e-4
    args.method = 'gradient'
    args.logdir = '/tmp'
    args.data_width = 2
    args.targets_width = 2
    args.train_batch_size = 10
    args.full_batch = False
    args.skip_stats = False
    args.autograd_check = False

    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    #gl.event_writer = SummaryWriter(logdir)
    gl.event_writer = u.NoOp()
    # print(f"Logging to {run_name}")

    # small values for debugging
    # loss_type = 'LeastSquares'
    loss_type = 'CrossEntropy'

    args.wandb = 0
    args.stats_steps = 10
    args.train_steps = 10
    args.stats_batch_size = 10
    args.data_width = 2
    args.targets_width = 2
    args.nonlin = False
    d1 = args.data_width**2
    d2 = 2
    d3 = args.targets_width**2

    d1 = args.data_width**2
    assert args.data_width == args.targets_width
    o = d1
    n = args.stats_batch_size
    d = [d1, 30, 30, 30, 20, 30, 30, 30, d1]

    if loss_type == 'CrossEntropy':
        d3 = 10
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    dsize = max(args.train_batch_size, args.stats_batch_size) + 1

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    # optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=dsize,
                          original_targets=True)

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_iter = None
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.stats_batch_size,
            shuffle=False,
            drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               train=False,
                               dataset_size=dsize,
                               original_targets=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.train_batch_size,
                                              shuffle=False,
                                              drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.token_count = 0
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            # print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()  # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)

        for (i, layer) in enumerate(model.layers):

            # input/output layers are unreasonably expensive if not using Kronecker factoring
            if d[i] > 50 or d[i + 1] > 50:
                print(
                    f'layer {i} is too big ({d[i], d[i + 1]}), skipping stats')
                continue

            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            u.check_equal(G.reshape(layer.weight.grad1.shape),
                          layer.weight.grad1)

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel(
            )  # proportion of activations that are zero
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma) / s.sigma_l2

            lambda_regularizer = args.lmb * torch.eye(d[i + 1] * d[i]).to(
                gl.device)
            H = layer.weight.hess

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H + lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          (dd.flatten().norm()**2))

            with u.timeit(f"pinvH-{i}"):
                pinvH = H.pinverse()

            with u.timeit(f'curv-{i}'):
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre',
                        pinvH)  # save Newton preconditioner
                s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm(
                )  # frobenius norm of Newton update
                s.regret_newton = u.to_python_scalar(
                    g @ pinvH @ g.t() / 2)  # replace with "quadratic_form"
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_spectral(H, sigma)

                discrepancy = torch.max(abs(p_sigma - p_sigma.t()) / p_sigma)

                s.psigma_erank = u.sym_erank(p_sigma)
                s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 / n

                # Faster approaches for noise variance computation
                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                #            u.log_scalar(train_loss=loss.item())

                if args.method != 'newton':
                    optimizer.step()
                    if args.weight_decay:
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                param.data.mul_(1 - args.weight_decay)
                else:
                    for (layer_idx, layer) in enumerate(model.layers):
                        param: torch.nn.Parameter = layer.weight
                        param_data: torch.Tensor = param.data
                        param_data.copy_(param_data - 0.1 * param.grad)
                        if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                            continue
                        u.nan_check(layer.weight.pre)
                        u.nan_check(param.grad.flatten())
                        u.nan_check(
                            u.v2r(param.grad.flatten()) @ layer.weight.pre)
                        param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                            param.grad.flatten()) @ layer.weight.pre
                        u.nan_check(param_new_flat)
                        param_data.copy_(
                            param_new_flat.reshape(param_data.shape))

                gl.token_count += data.shape[0]

    gl.event_writer.close()

    assert val_losses[0] > 2.4  # 2.4828238487243652
    assert val_losses[-1] < 2.25  # 2.20609712600708
예제 #3
0
def main():

    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")

    d1 = args.data_width ** 2
    assert args.data_width == args.targets_width
    o = d1
    n = args.stats_batch_size
    d = [d1, 30, 30, 30, 20, 30, 30, 30, d1]

    # small values for debugging
    # loss_type = 'LeastSquares'
    loss_type = 'CrossEntropy'

    args.wandb = 0
    args.stats_steps = 10
    args.train_steps = 10
    args.stats_batch_size = 10
    args.data_width = 2
    args.targets_width = 2
    args.nonlin = False
    d1 = args.data_width ** 2
    d2 = 2
    d3 = args.targets_width ** 2

    if loss_type == 'CrossEntropy':
        d3 = 10
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    dsize = max(args.train_batch_size, args.stats_batch_size)+1

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_iter = None
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.token_count = 0
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)

        for (i, layer) in enumerate(model.layers):

            # input/output layers are unreasonably expensive if not using Kronecker factoring
            if d[i]>50 or d[i+1]>50:
                print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats')
                continue

            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1)

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()  # proportion of activations that are zero
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma)/s.sigma_l2

            lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device)
            H = layer.weight.hess

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H+lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2))

            with u.timeit(f"pinvH-{i}"):
                pinvH = u.pinv(H)

            with u.timeit(f'curv-{i}'):
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre', pinvH)  # save Newton preconditioner
                s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm()  # frobenius norm of Newton update
                s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)   # replace with "quadratic_form"
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_svd(H, sigma)
                if u.has_nan(p_sigma) and args.compute_rho:  # use expensive method
                    print('using expensive method')
                    import pdb; pdb.set_trace()
                    H0, sigma0 = u.to_numpys(H, sigma)
                    p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0)
                    p_sigma = torch.tensor(p_sigma).to(gl.device)

                if u.has_nan(p_sigma):
                    # import pdb; pdb.set_trace()
                    s.psigma_erank = H.shape[0]
                    s.rho = 1
                else:
                    s.psigma_erank = u.sym_erank(p_sigma)
                    s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n

                # Faster approaches for noise variance computation
                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                #            u.log_scalar(train_loss=loss.item())

                if args.method != 'newton':
                    optimizer.step()
                    if args.weight_decay:
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                param.data.mul_(1-args.weight_decay)
                else:
                    for (layer_idx, layer) in enumerate(model.layers):
                        param: torch.nn.Parameter = layer.weight
                        param_data: torch.Tensor = param.data
                        param_data.copy_(param_data - 0.1 * param.grad)
                        if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                            continue
                        u.nan_check(layer.weight.pre)
                        u.nan_check(param.grad.flatten())
                        u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                        param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre
                        u.nan_check(param_new_flat)
                        param_data.copy_(param_new_flat.reshape(param_data.shape))

                gl.token_count += data.shape[0]

    gl.event_writer.close()
예제 #4
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb',
                        type=int,
                        default=0,
                        help='log to weights and biases')
    parser.add_argument('--autograd_check',
                        type=int,
                        default=0,
                        help='autograd correctness checks')
    parser.add_argument('--logdir',
                        type=str,
                        default='/tmp/runs/curv_train_tiny/run')

    parser.add_argument('--nonlin',
                        type=int,
                        default=1,
                        help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--bias',
                        type=int,
                        default=1,
                        help="whether to add bias between layers")

    parser.add_argument('--layer',
                        type=int,
                        default=-1,
                        help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument(
        '--hess_samples',
        type=int,
        default=1,
        help='number of samples when sub-sampling outputs, 0 for exact hessian'
    )
    parser.add_argument('--hess_kfac',
                        type=int,
                        default=0,
                        help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho',
                        type=int,
                        default=0,
                        help='use expensive method to compute rho')
    parser.add_argument('--skip_stats',
                        type=int,
                        default=1,
                        help='skip all stats collection')

    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps',
                        type=int,
                        default=1000,
                        help="this many train steps between stat collection")
    parser.add_argument('--stats_steps',
                        type=int,
                        default=1000000,
                        help="total number of curvature stats collections")

    parser.add_argument('--full_batch',
                        type=int,
                        default=0,
                        help='do stats on the whole dataset')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=2e-5)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--dropout', type=int, default=0)
    parser.add_argument('--swa', type=int, default=0)
    parser.add_argument('--lmb', type=float, default=1e-3)

    parser.add_argument('--train_batch_size', type=int, default=64)
    parser.add_argument('--stats_batch_size', type=int, default=10000)
    parser.add_argument('--uniform',
                        type=int,
                        default=0,
                        help='use uniform architecture (all layers same size)')
    parser.add_argument('--run_name', type=str, default='noname')

    gl.args = parser.parse_args()
    args = gl.args
    u.seed_random(1)

    gl.project_name = 'train_ciresan'
    u.setup_logdir_and_event_writer(args.run_name)
    print(f"Logging to {gl.logdir}")

    d1 = 28 * 28
    if args.uniform:
        d = [784, 784, 784, 784, 784, 784, 10]
    else:
        d = [784, 2500, 2000, 1500, 1000, 500, 10]
    o = 10
    n = args.stats_batch_size
    model = u.SimpleFullyConnected2(d,
                                    nonlin=args.nonlin,
                                    bias=args.bias,
                                    dropout=args.dropout)
    model = model.to(gl.device)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)
    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          original_targets=True,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    assert not args.full_batch, "fixme: validation still uses stats_iter"
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.stats_batch_size,
            shuffle=False,
            drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)
    else:
        stats_iter = None

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               train=False,
                               original_targets=True,
                               dataset_size=args.dataset_size)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.stats_batch_size,
                                              shuffle=False,
                                              drop_last=False)

    loss_fn = torch.nn.CrossEntropyLoss()
    autograd_lib.add_hooks(model)
    autograd_lib.disable_hooks()

    gl.token_count = 0
    last_outer = 0
    for step in range(args.stats_steps):
        epoch = gl.token_count // 60000
        print(gl.token_count)
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        # compute validation loss
        if args.swa:
            model.eval()
            with u.timeit('swa'):
                base_opt = torch.optim.SGD(model.parameters(),
                                           lr=args.lr,
                                           momentum=args.momentum)
                opt = torchcontrib.optim.SWA(base_opt,
                                             swa_start=0,
                                             swa_freq=1,
                                             swa_lr=args.lr)
                for _ in range(100):
                    optimizer.zero_grad()
                    data, targets = next(train_iter)
                    model.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, targets)
                    loss.backward()
                    opt.step()
                opt.swap_swa_sgd()

        with u.timeit("validate"):
            val_accuracy, val_loss = validate(model, test_loader,
                                              f'test (epoch {epoch})')
            train_accuracy, train_loss = validate(model, stats_loader,
                                                  f'train (epoch {epoch})')

        # save log
        metrics = {
            'epoch': epoch,
            'val_accuracy': val_accuracy,
            'val_loss': val_loss,
            'train_loss': train_loss,
            'train_accuracy': train_accuracy,
            'lr': optimizer.param_groups[0]['lr'],
            'momentum': optimizer.param_groups[0].get('momentum', 0)
        }
        u.log_scalars(metrics)

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        if not args.skip_stats:
            autograd_lib.enable_hooks()
            autograd_lib.clear_backprops(model)
            autograd_lib.clear_hess_backprops(model)
            with u.timeit("backprop_g"):
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward(retain_graph=True)
            with u.timeit("backprop_H"):
                autograd_lib.backprop_hess(output, hess_type='CrossEntropy')
            autograd_lib.disable_hooks()  # TODO(y): use remove_hooks

            with u.timeit("compute_grad1"):
                autograd_lib.compute_grad1(model)
            with u.timeit("compute_hess"):
                autograd_lib.compute_hess(model,
                                          method='kron',
                                          attr_name='hess2')
            autograd_lib.compute_stats_factored(model)

        for (i, layer) in enumerate(model.layers):
            param_names = {layer.weight: "weight", layer.bias: "bias"}
            for param in [layer.weight, layer.bias]:

                if param is None:
                    continue

                if not hasattr(param, 'stats'):
                    continue
                s = param.stats
                param_name = param_names[param]
                u.log_scalars(u.nest_stats(f"{param_name}", s))

        # gradient steps
        model.train()
        last_inner = 0
        for i in range(args.train_steps):
            if last_inner:
                u.log_scalars(
                    {"time/inner": 1000 * (time.perf_counter() - last_inner)})
            last_inner = time.perf_counter()

            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            optimizer.step()
            if args.weight_decay:
                for group in optimizer.param_groups:
                    for param in group['params']:
                        param.data.mul_(1 - args.weight_decay)

            gl.token_count += data.shape[0]

    gl.event_writer.close()
예제 #5
0
def main():

    u.install_pdb_handler()
    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {logdir}")

    loss_type = 'CrossEntropy'

    d1 = args.data_width ** 2
    args.stats_batch_size = min(args.stats_batch_size, args.dataset_size)
    args.train_batch_size = min(args.train_batch_size, args.dataset_size)
    n = args.stats_batch_size
    o = 10
    d = [d1, 60, 60, 60, o]
    # dataset_size = args.dataset_size

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin, last_layer_linear=True)
    model = model.to(gl.device)
    u.mark_expensive(model.layers[0])    # to stop grad1/hess calculations on this layer
    print(model)

    try:
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name, dir='/tmp/wandb.runs')
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    #  optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)

    test_dataset = u.TinyMNIST(data_width=args.data_width, train=False, dataset_size=args.dataset_size, loss_type=loss_type)
    test_batch_size = min(args.dataset_size, 1000)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:   # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        with u.timeit("validate"):
            if loss_type == 'CrossEntropy':
                val_accuracy, val_loss = validate(model, test_loader, f'test (stats_step {step})')
                # train_accuracy, train_loss = validate(model, train_loader, f'train (stats_step {step})')

                metrics = {'stats_step': step, 'val_accuracy': val_accuracy, 'val_loss': val_loss}
                u.log_scalars(metrics)

        data, targets = stats_data, stats_targets

        if not args.skip_stats:
            # Capture Hessian and gradient stats
            autograd_lib.enable_hooks()
            autograd_lib.clear_backprops(model)
            autograd_lib.clear_hess_backprops(model)
            with u.timeit("backprop_g"):
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward(retain_graph=True)
            with u.timeit("backprop_H"):
                autograd_lib.backprop_hess(output, hess_type=loss_type)
            autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

            with u.timeit("compute_grad1"):
                autograd_lib.compute_grad1(model)
            with u.timeit("compute_hess"):
                autograd_lib.compute_hess(model)

            for (i, layer) in enumerate(model.layers):

                if hasattr(layer, 'expensive'):
                    continue

                param_names = {layer.weight: "weight", layer.bias: "bias"}
                for param in [layer.weight, layer.bias]:
                    # input/output layers are unreasonably expensive if not using Kronecker factoring
                    if d[i]*d[i+1] > 8000:
                        print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats')
                        continue

                    s = AttrDefault(str, {})  # dictionary-like object for layer stats

                    #############################
                    # Gradient stats
                    #############################
                    A_t = layer.activations
                    B_t = layer.backprops_list[0] * n
                    s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()  # proportion of activations that are zero
                    s.mean_activation = torch.mean(A_t)
                    s.mean_backprop = torch.mean(B_t)

                    # empirical Fisher
                    G = param.grad1.reshape((n, -1))
                    g = G.mean(dim=0, keepdim=True)

                    u.nan_check(G)
                    with u.timeit(f'sigma-{i}'):
                        efisher = G.t() @ G / n
                        sigma = efisher - g.t() @ g
                        # sigma_spectrum =
                        s.sigma_l2 = u.sym_l2_norm(sigma)
                        s.sigma_erank = torch.trace(sigma)/s.sigma_l2

                    H = param.hess
                    lambda_regularizer = args.lmb * torch.eye(H.shape[0]).to(gl.device)
                    u.nan_check(H)

                    with u.timeit(f"invH-{i}"):
                        invH = torch.cholesky_inverse(H+lambda_regularizer)

                    with u.timeit(f"H_l2-{i}"):
                        s.H_l2 = u.sym_l2_norm(H)
                        s.iH_l2 = u.sym_l2_norm(invH)

                    with u.timeit(f"norms-{i}"):
                        s.H_fro = H.flatten().norm()
                        s.iH_fro = invH.flatten().norm()
                        s.grad_fro = g.flatten().norm()
                        s.param_fro = param.data.flatten().norm()

                    def loss_direction(dd: torch.Tensor, eps):
                        """loss improvement if we take step eps in direction dd"""
                        return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t())

                    def curv_direction(dd: torch.Tensor):
                        """Curvature in direction dd"""
                        return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2))

                    with u.timeit(f"pinvH-{i}"):
                        pinvH = u.pinv(H)

                    with u.timeit(f'curv-{i}'):
                        s.grad_curv = curv_direction(g)  # curvature (eigenvalue) in direction g
                        ndir = g @ pinvH  # newton direction
                        s.newton_curv = curv_direction(ndir)
                        setattr(layer.weight, 'pre', pinvH)  # save Newton preconditioner
                        s.step_openai = 1 / s.grad_curv if s.grad_curv else 1234567
                        s.step_div_inf = 2 / s.H_l2         # divegent step size for batch_size=infinity
                        s.step_div_1 = torch.tensor(2) / torch.trace(H)   # divergent step for batch_size=1

                        s.newton_fro = ndir.flatten().norm()  # frobenius norm of Newton update
                        s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)   # replace with "quadratic_form"
                        s.regret_gradient = loss_direction(g, s.step_openai)

                    with u.timeit(f'rho-{i}'):
                        s.rho, s.lyap_erank, lyap_evals = u.truncated_lyapunov_rho(H, sigma)
                        s.step_div_1_adjusted = s.step_div_1/s.rho

                    with u.timeit(f"batch-{i}"):
                        s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                        s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n  # Gradient diversity / n
                        s.noise_variance_pinv = torch.trace(pinvH @ sigma)
                        s.H_erank = torch.trace(H) / s.H_l2
                        s.batch_jain_simple = 1 + s.H_erank
                        s.batch_jain_full = 1 + s.rho * s.H_erank

                    param_name = f"{layer.name}={param_names[param]}"
                    u.log_scalars(u.nest_stats(f"{param_name}", s))

                    H_evals = u.symeig_pos_evals(H)
                    sigma_evals = u.symeig_pos_evals(sigma)
                    u.log_spectrum(f'{param_name}/hess', H_evals)
                    u.log_spectrum(f'{param_name}/sigma', sigma_evals)
                    u.log_spectrum(f'{param_name}/lyap', lyap_evals)

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                optimizer.step()
                if args.weight_decay:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.mul_(1-args.weight_decay)

                gl.increment_global_step(data.shape[0])

    gl.event_writer.close()
예제 #6
0
def _test_kron_conv_golden():
    """Hardcoded error values to detect unexpected numeric changes."""
    u.seed_random(1)

    n, Xh, Xw = 2, 8, 8
    Kh, Kw = 2, 2
    dd = [3, 3, 3, 3]
    o = dd[-1]

    model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True)
    data = torch.randn((n, dd[0], Xh, Xw))

    # print(model)
    # print(data)

    loss_type = 'CrossEntropy'    #  loss_type = 'LeastSquares'
    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'DebugLeastSquares':
        loss_fn = u.debug_least_squares
    else:    # CrossEntropy
        loss_fn = nn.CrossEntropyLoss()

    sample_output = model(data)

    if loss_type.endswith('LeastSquares'):
        targets = torch.randn(sample_output.shape)
    elif loss_type == 'CrossEntropy':
        targets = torch.LongTensor(n).random_(0, o)

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.compute_hess(model, method='kron', attr_name='hess_kron')
    autograd_lib.compute_hess(model, method='mean_kron', attr_name='hess_mean_kron')
    autograd_lib.compute_hess(model, method='exact')
    autograd_lib.disable_hooks()

    errors1 = []
    errors2 = []
    for i in range(len(model.layers)):
        layer = model.layers[i]

        # direct Hessian computation
        H = layer.weight.hess
        H_bias = layer.bias.hess

        # factored Hessian computation
        Hk = layer.weight.hess_kron
        Hk_bias = layer.bias.hess_kron
        Hk = Hk.expand()
        Hk_bias = Hk_bias.expand()

        Hk2 = layer.weight.hess_mean_kron
        Hk2_bias = layer.bias.hess_mean_kron
        Hk2 = Hk2.expand()
        Hk2_bias = Hk2_bias.expand()

        # autograd Hessian computation
        loss = loss_fn(output, targets)
        Ha = u.hessian(loss, layer.weight).reshape(H.shape)
        Ha_bias = u.hessian(loss, layer.bias)

        # compare direct against autograd
        Ha = Ha.reshape(H.shape)
        # rel_error = torch.max((H-Ha)/Ha)

        u.check_close(H, Ha, rtol=1e-5, atol=1e-7)
        u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7)

        errors1.extend([u.symsqrt_dist(H, Hk), u.symsqrt_dist(H_bias, Hk_bias)])
        errors2.extend([u.symsqrt_dist(H, Hk2), u.symsqrt_dist(H_bias, Hk2_bias)])

    errors1 = torch.tensor(errors1)
    errors2 = torch.tensor(errors2)
    golden_errors1 = torch.tensor([0.09458080679178238, 0.0, 0.13416489958763123, 0.0, 0.0003909761435352266, 0.0])
    golden_errors2 = torch.tensor([0.0945773795247078, 0.0, 0.13418318331241608, 0.0, 4.478318658129865e-07, 0.0])

    u.check_close(golden_errors1, errors1)
    u.check_close(golden_errors2, errors2)
예제 #7
0
def test_kron_conv_exact():
    """Test per-example gradient computation for conv layer.


    Kronecker factoring is exact for 1x1 convolutions and linear activations.

    """
    u.seed_random(1)

    n, Xh, Xw = 2, 2, 2
    Kh, Kw = 1, 1
    dd = [2, 2, 2]
    o = dd[-1]

    model: u.SimpleModel = u.PooledConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=False, bias=True)
    data = torch.randn((n, dd[0], Xh, Xw))

    #print(model)
    #print(data)

    loss_type = 'CrossEntropy'    #  loss_type = 'LeastSquares'
    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'DebugLeastSquares':
        loss_fn = u.debug_least_squares
    else:    # CrossEntropy
        loss_fn = nn.CrossEntropyLoss()

    sample_output = model(data)

    if loss_type.endswith('LeastSquares'):
        targets = torch.randn(sample_output.shape)
    elif loss_type == 'CrossEntropy':
        targets = torch.LongTensor(n).random_(0, o)

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.compute_hess(model, method='mean_kron')
    autograd_lib.compute_hess(model, method='exact')
    autograd_lib.disable_hooks()

    for i in range(len(model.layers)):
        layer = model.layers[i]

        # direct Hessian computation
        H = layer.weight.hess
        H_bias = layer.bias.hess

        # factored Hessian computation
        Hk = layer.weight.hess_factored
        Hk_bias = layer.bias.hess_factored
        Hk = Hk.expand()
        Hk_bias = Hk_bias.expand()

        # autograd Hessian computation
        loss = loss_fn(output, targets)
        Ha = u.hessian(loss, layer.weight).reshape(H.shape)
        Ha_bias = u.hessian(loss, layer.bias)

        # compare direct against autograd
        Ha = Ha.reshape(H.shape)
        # rel_error = torch.max((H-Ha)/Ha)

        u.check_close(H, Ha, rtol=1e-5, atol=1e-7)
        u.check_close(Ha_bias, H_bias, rtol=1e-5, atol=1e-7)

        u.check_close(H_bias, Hk_bias)
        u.check_close(H, Hk)
예제 #8
0
def test_hessian_multibatch():
    """Test that Kronecker-factored computations still work when splitting work over batches."""

    u.seed_random(1)

    # torch.set_default_dtype(torch.float64)

    gl.project_name = 'test'
    gl.logdir_base = '/tmp/runs'
    run_name = 'test_hessian_multibatch'
    u.setup_logdir_and_event_writer(run_name=run_name)

    loss_type = 'CrossEntropy'
    data_width = 2
    n = 4
    d1 = data_width ** 2
    o = 10
    d = [d1, o]

    model = u.SimpleFullyConnected2(d, bias=False, nonlin=False)
    model = model.to(gl.device)

    dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type)
    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False)
    stats_iter = u.infinite_iter(stats_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:  # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0

    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)
    data, targets = stats_data, stats_targets

    # Capture Hessian and gradient stats
    autograd_lib.enable_hooks()
    autograd_lib.clear_backprops(model)

    output = model(data)
    loss = loss_fn(output, targets)
    loss.backward(retain_graph=True)
    layer = model.layers[0]

    autograd_lib.clear_hess_backprops(model)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.disable_hooks()

    # compute Hessian using direct method, compare against PyTorch autograd
    hess0 = u.hessian(loss, layer.weight)
    autograd_lib.compute_hess(model)
    hess1 = layer.weight.hess
    u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-8, rtol=1e-6)

    # compute Hessian using factored method. Because Hessian depends on examples for cross entropy, factoring is not exact, raise tolerance
    autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True)
    hess2 = layer.weight.hess2
    u.check_close(hess1, hess2, atol=1e-3, rtol=1e-1)

    # compute Hessian using multibatch
    # restart iterators
    dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type)
    assert n % 2 == 0
    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n//2, shuffle=False)
    stats_iter = u.infinite_iter(stats_loader)
    autograd_lib.compute_cov(model, loss_fn, stats_iter, batch_size=n//2, steps=2)

    cov: autograd_lib.LayerCov = layer.cov
    hess2: u.Kron = hess2.commute()    # get back into AA x BB order
    u.check_close(cov.H.value(), hess2)
예제 #9
0
def test_factored_hessian():
    """"Simple test to ensure Hessian computation is working.

    In a linear neural network with squared loss, Newton step will converge in one step.
    Compute stats after minimizing, pass sanity checks.
    """

    u.seed_random(1)
    loss_type = 'LeastSquares'

    data_width = 2
    n = 5
    d1 = data_width ** 2
    o = 10
    d = [d1, o]

    model = u.SimpleFullyConnected2(d, bias=False, nonlin=False)
    model = model.to(gl.device)
    print(model)

    dataset = u.TinyMNIST(data_width=data_width, dataset_size=n, loss_type=loss_type)
    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=n, shuffle=False)
    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:  # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0

    data, targets = stats_data, stats_targets

    # Capture Hessian and gradient stats
    autograd_lib.enable_hooks()
    autograd_lib.clear_backprops(model)

    output = model(data)
    loss = loss_fn(output, targets)
    print(loss)
    loss.backward(retain_graph=True)
    layer = model.layers[0]

    autograd_lib.clear_hess_backprops(model)
    autograd_lib.backprop_hess(output, hess_type=loss_type)
    autograd_lib.disable_hooks()

    # compute Hessian using direct method, compare against PyTorch autograd
    hess0 = u.hessian(loss, layer.weight)
    autograd_lib.compute_hess(model)
    hess1 = layer.weight.hess
    print(hess1)
    u.check_close(hess0.reshape(hess1.shape), hess1, atol=1e-9, rtol=1e-6)

    # compute Hessian using factored method
    autograd_lib.compute_hess(model, method='kron', attr_name='hess2', vecr_order=True)
    # s.regret_newton = vecG.t() @ pinvH.commute() @ vecG.t() / 2  # TODO(y): figure out why needed transposes

    hess2 = layer.weight.hess2
    u.check_close(hess1, hess2, atol=1e-9, rtol=1e-6)

    # Newton step in regular notation
    g1 = layer.weight.grad.flatten()
    newton1 = hess1 @ g1

    g2 = u.Vecr(layer.weight.grad)
    newton2 = g2 @ hess2

    u.check_close(newton1, newton2, atol=1e-9, rtol=1e-6)

    # compute regret in factored notation, compare against actual drop in loss
    regret1 = g1 @ hess1.pinverse() @ g1 / 2
    regret2 = g2 @ hess2.pinv() @ g2 / 2
    u.check_close(regret1, regret2)

    current_weight = layer.weight.detach().clone()
    param: torch.nn.Parameter = layer.weight
    # param.data.sub_((hess1.pinverse() @ g1).reshape(param.shape))
    # output = model(data)
    # loss = loss_fn(output, targets)
    # print("result 1", loss)

    # param.data.sub_((hess1.pinverse() @ u.vec(layer.weight.grad)).reshape(param.shape))
    # output = model(data)
    # loss = loss_fn(output, targets)
    # print("result 2", loss)

    # param.data.sub_((u.vec(layer.weight.grad).t() @ hess1.pinverse()).reshape(param.shape))
    # output = model(data)
    # loss = loss_fn(output, targets)
    # print("result 3", loss)
    #

    del layer.weight.grad
    output = model(data)
    loss = loss_fn(output, targets)
    loss.backward()
    param.data.sub_(u.unvec(hess1.pinverse() @ u.vec(layer.weight.grad), layer.weight.shape[0]))
    output = model(data)
    loss = loss_fn(output, targets)
    print("result 4", loss)

    # param.data.sub_((g1 @ hess1.pinverse() @ g1).reshape(param.shape))

    print(loss)
예제 #10
0
def test_factored_stats_golden_values():
    """Test stats from values generated by non-factored version"""
    u.seed_random(1)
    u.install_pdb_handler()
    torch.set_default_dtype(torch.float32)

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    args = parser.parse_args()

    logdir = u.create_local_logdir('/temp/runs/factored_test')
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print('logging to ', logdir)

    loss_type = 'LeastSquares'

    args.data_width = 2
    args.dataset_size = 5
    args.stats_batch_size = 5
    d1 = args.data_width**2
    args.stats_batch_size = args.dataset_size
    args.stats_steps = 1

    n = args.stats_batch_size
    o = 10
    d = [d1, o]

    model = u.SimpleFullyConnected2(d, bias=False, nonlin=0)
    model = model.to(gl.device)
    print(model)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          dataset_size=args.dataset_size,
                          loss_type=loss_type)
    stats_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.stats_batch_size, shuffle=False)
    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:  # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        data, targets = stats_data, stats_targets

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)

        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()  # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)
            autograd_lib.compute_hess(model, method='kron', attr_name='hess2')

        autograd_lib.compute_stats_factored(model)

        params = list(model.parameters())
        assert len(params) == 1
        new_values = params[0].stats
        golden_values = torch.load('test/factored.pt')

        for valname in new_values:
            print("Checking ", valname)
            if valname == 'sigma_l2':
                u.check_close(new_values[valname],
                              golden_values[valname],
                              atol=1e-2)  # sigma is approximate
            elif valname == 'sigma_erank':
                u.check_close(new_values[valname],
                              golden_values[valname],
                              atol=0.11)  # 1.0 vs 1.1
            elif valname in ['rho', 'step_div_1_adjusted', 'batch_jain_full']:
                continue  # lyapunov stats weren't computed correctly in golden set
            elif valname in ['batch_openai']:
                continue  # batch sizes depend on sigma which is approximate
            elif valname in ['noise_variance_pinv']:
                pass  # went from 0.22 to 0.014 after kron factoring (0.01 with full centering, 0.3 with no centering)
            elif valname in ['sparsity']:
                pass  # had a bug in old calc (using integer arithmetic)
            else:
                u.check_close(new_values[valname],
                              golden_values[valname],
                              rtol=1e-4,
                              atol=1e-6,
                              label=valname)

    gl.event_writer.close()
예제 #11
0
def compute_hess(n: int = 1,
                 image_size: int = 1,
                 kernel_size: int = 1,
                 num_channels: int = 1,
                 num_layers: int = 1,
                 nonlin: bool = False,
                 loss: str = 'CrossEntropy',
                 method='exact',
                 param_name='weight') -> List[torch.Tensor]:
    """

    Compute Hessians for all layers for given architecture

    Args:
        param_name: which parameter to compute ('weight' or 'bias')
        n: number of examples
        image_size:  width of image (square image)
        kernel_size: kernel size
        num_channels:
        num_layers:
        nonlin
        loss: LeastSquares or CrossEntropy
        method: 'kron', 'mean_kron'
        num_layers: number of layers in the network

    Returns:
        list of num_layers Hessian matrices.
    """

    assert param_name in ['weight', 'bias']
    assert loss in autograd_lib._supported_losses
    assert method in autograd_lib._supported_methods

    u.seed_random(1)

    Xh, Xw = 1, image_size
    Kh, Kw = 1, kernel_size
    dd = [num_channels] * (num_layers + 1)

    model: u.SimpleModel2 = u.PooledConvolutional2(dd,
                                                   kernel_size=(Kh, Kw),
                                                   nonlin=nonlin,
                                                   bias=True)
    # model: u.SimpleModel2 = u.StridedConvolutional2(dd, kernel_size=(Kh, Kw), nonlin=nonlin, bias=True)
    data = torch.randn((n, dd[0], Xh, Xw))

    autograd_lib.clear_backprops(model)
    autograd_lib.add_hooks(model)
    output = model(data)
    autograd_lib.backprop_hess(output, hess_type=loss)
    autograd_lib.compute_hess(model, method=method)
    autograd_lib.disable_hooks()

    result = []
    for i in range(len(model.layers)):
        param = getattr(model.layers[i], param_name)
        if method == 'exact' or method == 'autograd':
            result.append(param.hess)
        else:
            result.append(param.hess_factored.expand())
    return result
def main():
    u.install_pdb_handler()
    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {logdir}")

    loss_type = 'CrossEntropy'

    d1 = args.data_width ** 2
    args.stats_batch_size = min(args.stats_batch_size, args.dataset_size)
    args.train_batch_size = min(args.train_batch_size, args.dataset_size)
    n = args.stats_batch_size
    o = 10
    d = [d1, 60, 60, 60, o]
    # dataset_size = args.dataset_size

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin, last_layer_linear=True)
    model = model.to(gl.device)
    u.mark_expensive(model.layers[0])    # to stop grad1/hess calculations on this layer
    print(model)

    try:
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name, dir='/tmp/wandb.runs')
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    #  optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)

    test_dataset = u.TinyMNIST(data_width=args.data_width, train=False, dataset_size=args.dataset_size, loss_type=loss_type)
    test_batch_size = min(args.dataset_size, 1000)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:   # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        with u.timeit("validate"):
            if loss_type == 'CrossEntropy':
                val_accuracy, val_loss = validate(model, test_loader, f'test (stats_step {step})')
                # train_accuracy, train_loss = validate(model, train_loader, f'train (stats_step {step})')

                metrics = {'stats_step': step, 'val_accuracy': val_accuracy, 'val_loss': val_loss}
                u.log_scalars(metrics)

        data, targets = stats_data, stats_targets

        if not args.skip_stats:
            # Capture Hessian and gradient stats
            autograd_lib.enable_hooks()
            autograd_lib.clear_backprops(model)
            autograd_lib.clear_hess_backprops(model)
            with u.timeit("backprop_g"):
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward(retain_graph=True)
            with u.timeit("backprop_H"):
                autograd_lib.backprop_hess(output, hess_type='CrossEntropy')
            autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

            with u.timeit("compute_grad1"):
                autograd_lib.compute_grad1(model)
            with u.timeit("compute_hess"):
                autograd_lib.compute_hess(model, method='kron', attr_name='hess2')

            autograd_lib.compute_stats_factored(model)

            for (i, layer) in enumerate(model.layers):
                param_names = {layer.weight: "weight", layer.bias: "bias"}
                for param in [layer.weight, layer.bias]:

                    if param is None:
                        continue

                    if not hasattr(param, 'stats'):
                        continue
                    s = param.stats
                    param_name = param_names[param]
                    u.log_scalars(u.nest_stats(f"{param_name}", s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                optimizer.step()
                if args.weight_decay:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.mul_(1-args.weight_decay)

                gl.increment_global_step(data.shape[0])

    gl.event_writer.close()