Exemplo n.º 1
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method

    except Exception as e:
        print(f"wandb crash with {e}")

    #    data_width = 4
    #    targets_width = 2

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    def capture_activations(module, input, _output):
        if skip_forward_hooks:
            return
        assert gl.backward_idx == 0  # no need to forward-prop on Hessian computation
        assert not hasattr(
            module, 'activations'
        ), "Seeing activations from previous forward, call util.zero_grad to clear"
        assert len(input) == 1, "this works for single input layers only"
        setattr(module, "activations", input[0].detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    for step in range(args.stats_steps):
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        gl.backward_idx = 0
        u.zero_grad(model)
        output = model(data)
        loss = loss_fn(output, targets)
        loss.backward(retain_graph=True)

        print("loss", loss.item())

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o)

        u.log_scalars({'loss': loss.item()})

        # o = 0
        for out_idx in range(o):
            model.zero_grad()
            # backprop to get section of batch output jacobian for output at position out_idx
            output = model(
                data
            )  # opt: using autograd.grad means I don't have to zero_grad
            ei = id_mat[out_idx]
            bval = torch.stack([ei] * n)
            gl.backward_idx = out_idx + 1
            output.backward(bval)
        skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            # empirical Fisher
            efisher = G.t() @ G / n
            sigma = efisher - g.t() @ g
            # u.dump(sigma, f'/tmp/sigmas/{step}-{i}')
            s.sigma_l2 = u.l2_norm(sigma)

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)]
            Amat_t = torch.cat([A_t] * o, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * o, d[i])
            assert Bmat_t.shape == (n * o, d[i + 1])

            Jb = u.khatri_rao_t(Bmat_t,
                                Amat_t)  # batch Jacobian, in row-vec format
            H = Jb.t() @ Jb / n
            pinvH = u.pinv(H)

            s.hess_l2 = u.l2_norm(H)
            s.invhess_l2 = u.l2_norm(pinvH)

            s.hess_fro = H.flatten().norm()
            s.invhess_fro = pinvH.flatten().norm()

            s.jacobian_l2 = u.l2_norm(Jb)
            s.grad_fro = g.flatten().norm()
            s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          dd.flatten().norm()**2)

            s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
            s.grad_curv = curv_direction(g)
            ndir = g @ u.pinv(H)  # newton direction
            s.newton_curv = curv_direction(ndir)
            setattr(layer.weight, 'pre',
                    u.pinv(H))  # save Newton preconditioner
            s.step_openai = 1 / s.grad_curv if s.grad_curv else 999

            s.newton_fro = ndir.flatten().norm(
            )  # frobenius norm of Newton update
            s.regret_gradient = loss_direction(g, s.step_openai)

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        for i in range(args.train_steps):
            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
Exemplo n.º 2
0
    def compute_layer_stats(layer):
        refreeze = False
        if hasattr(layer, 'frozen') and layer.frozen:
            u.unfreeze(layer)
            refreeze = True

        s = AttrDefault(str, {})
        n = args.stats_batch_size
        param = u.get_param(layer)
        _d = len(param.flatten())  # dimensionality of parameters
        layer_idx = model.layers.index(layer)
        # TODO: get layer type, include it in name
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        _loss, _output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2
        s.grad_fro = g.flatten().norm()
        s.param_fro = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        s.a_sparsity = neg_activations.float() / (
            pos_activations + neg_activations)  # 1 sparsity means all 0's
        activation_size = len(layer.data_output.flatten())
        s.a_magnitude = torch.sum(layer.data_output) / activation_size

        _output = backprop_output()
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)  # batch output Jacobian
        H = J.t() @ J / n

        s.hessian_l2 = u.l2_norm(H)
        s.jacobian_l2 = u.l2_norm(J)
        J1 = J.sum(dim=0) / n  # single output Jacobian
        s.J1_l2 = J1.norm()

        # newton decrement
        def loss_direction(direction, eps):
            """loss improvement if we take step eps in direction dir"""
            return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 *
                                      eps**2 * direction @ H @ direction.t())

        s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)

        # TODO: gradient diversity is stuck at 1
        # TODO: newton/gradient angle
        # TODO: newton step magnitude
        s.grad_curvature = u.to_python_scalar(
            g @ H @ g.t())  # curvature in direction of g
        s.step_openai = u.to_python_scalar(
            s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999

        s.regret_gradient = loss_direction(g, s.step_openai)

        if refreeze:
            u.freeze(layer)
        return s