Example #1
0
    def __call__(self, global_feat, labels):

        global_feat = l2_norm(global_feat)
        dist_mat = euclidean_dist(global_feat, global_feat)
        dist_ap, dist_an = hard_example_mining(dist_mat, labels)

        y = dist_an.new().resize_as_(dist_an).fill_(1)

        if self.margin is not None:
            loss = self.ranking_loss(dist_an, dist_ap, y)
        else:
            loss = self.ranking_loss(dist_an - dist_ap, y)

        return loss
Example #2
0
def main():
    # for kernel_size=1, mean kron factoring works for any image size
    main_vals = AttrDict(n=2,
                         kernel_size=1,
                         image_size=625,
                         num_channels=5,
                         num_layers=4,
                         loss='CrossEntropy',
                         nonlin=False)

    hess_list1 = compute_hess(method='exact', **main_vals)
    hess_list2 = compute_hess(method='kron', **main_vals)
    value_error = max(
        [u.symsqrt_dist(h1, h2) for h1, h2 in zip(hess_list1, hess_list2)])
    magnitude_error = max([
        u.l2_norm(h2) / u.l2_norm(h1)
        for h1, h2 in zip(hess_list1, hess_list2)
    ])
    print(value_error)
    print(magnitude_error)

    dimension_vals = dict(image_size=[2, 3, 4, 5, 6],
                          num_channels=range(2, 12),
                          kernel_size=[1, 2, 3, 4, 5])
    for method in ['mean_kron', 'kron']:  # , 'experimental_kfac']:
        print()
        print('=' * 10, method, '=' * 40)
        for dimension in ['image_size', 'num_channels', 'kernel_size']:
            value_errors = []
            magnitude_errors = []
            for val in dimension_vals[dimension]:
                vals = AttrDict(main_vals.copy())
                vals.method = method
                vals[dimension] = val
                vals.image_size = max(vals.image_size,
                                      vals.kernel_size**vals.num_layers)
                # print(vals)
                vals_exact = AttrDict(vals.copy())
                vals_exact.method = 'exact'
                hess_list1 = compute_hess(**vals_exact)
                hess_list2 = compute_hess(**vals)
                magnitude_error = max([
                    u.l2_norm(h2) / u.l2_norm(h1)
                    for h1, h2 in zip(hess_list1, hess_list2)
                ])
                hess_list1 = [h / u.l2_norm(h) for h in hess_list1]
                hess_list2 = [h / u.l2_norm(h) for h in hess_list2]

                value_error = max([
                    u.symsqrt_dist(h1, h2)
                    for h1, h2 in zip(hess_list1, hess_list2)
                ])
                value_errors.append(value_error)
                magnitude_errors.append(magnitude_error.item())
            print(dimension)
            print('   value: :', value_errors)
            print('   magnitude :', u.format_list(magnitude_errors))
Example #3
0
    def forward(self, img):
        results = []
        for model, w in zip(self.models, self.weights):
            if self.tta:
                tta_preds = []
                for trans_img in self.get_TTA(img):
                    feat, _ = model(trans_img)
                    tta_preds.append(feat)

                mean_pred = torch.mean(torch.stack(tta_preds), dim=0)
                results.append(l2_norm(mean_pred) * w)
            else:
                feat, _ = model(img)
                results.append(feat * w)

        if len(results) == 1:
            return results[0]

        if self.reduction in {'cat', 'concat'}:
            return l2_norm(torch.cat(results, dim=-1))
        elif self.reduction == 'mean':
            return l2_norm(torch.mean(torch.stack(results), dim=0))

        return torch.stack(results)
Example #4
0
def inference(model, test_loader, tqdm=tqdm, normalize=False):
    embeds = []
    is_ensemble = isinstance(model, EnsembleModels)
    if not is_ensemble:
        model.eval()
    with torch.no_grad():
        for i, img in enumerate(tqdm(test_loader)):
            img = img.cuda()
            feat = model(img)
            if not is_ensemble:
                feat = feat[1]

            if normalize:
                feat = l2_norm(feat)
            image_embeddings = feat.detach().cpu().numpy()
            embeds.append(image_embeddings)

    gc.collect()
    return np.concatenate(embeds)
Example #5
0
def val_epoch(model, valid_loader, criterion, valid_df, args):

    model.eval()
    embeds = []
    bar = tqdm(valid_loader)

    with torch.no_grad():
        for image, target in bar:
            image, target = image.cuda(), target.cuda()
            feat, _ = model(image)
            feat = l2_norm(feat)

            embeds.append(feat.detach().cpu().numpy())

    embeds = np.concatenate(embeds)
    preds, _ = search_similiar_images(embeds, valid_df)
    _, val_f1_score = row_wise_f1_score(valid_df.target, preds)

    return val_f1_score
Example #6
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method

    except Exception as e:
        print(f"wandb crash with {e}")

    #    data_width = 4
    #    targets_width = 2

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    def capture_activations(module, input, _output):
        if skip_forward_hooks:
            return
        assert gl.backward_idx == 0  # no need to forward-prop on Hessian computation
        assert not hasattr(
            module, 'activations'
        ), "Seeing activations from previous forward, call util.zero_grad to clear"
        assert len(input) == 1, "this works for single input layers only"
        setattr(module, "activations", input[0].detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    for step in range(args.stats_steps):
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        gl.backward_idx = 0
        u.zero_grad(model)
        output = model(data)
        loss = loss_fn(output, targets)
        loss.backward(retain_graph=True)

        print("loss", loss.item())

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o)

        u.log_scalars({'loss': loss.item()})

        # o = 0
        for out_idx in range(o):
            model.zero_grad()
            # backprop to get section of batch output jacobian for output at position out_idx
            output = model(
                data
            )  # opt: using autograd.grad means I don't have to zero_grad
            ei = id_mat[out_idx]
            bval = torch.stack([ei] * n)
            gl.backward_idx = out_idx + 1
            output.backward(bval)
        skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            # empirical Fisher
            efisher = G.t() @ G / n
            sigma = efisher - g.t() @ g
            # u.dump(sigma, f'/tmp/sigmas/{step}-{i}')
            s.sigma_l2 = u.l2_norm(sigma)

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)]
            Amat_t = torch.cat([A_t] * o, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * o, d[i])
            assert Bmat_t.shape == (n * o, d[i + 1])

            Jb = u.khatri_rao_t(Bmat_t,
                                Amat_t)  # batch Jacobian, in row-vec format
            H = Jb.t() @ Jb / n
            pinvH = u.pinv(H)

            s.hess_l2 = u.l2_norm(H)
            s.invhess_l2 = u.l2_norm(pinvH)

            s.hess_fro = H.flatten().norm()
            s.invhess_fro = pinvH.flatten().norm()

            s.jacobian_l2 = u.l2_norm(Jb)
            s.grad_fro = g.flatten().norm()
            s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          dd.flatten().norm()**2)

            s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
            s.grad_curv = curv_direction(g)
            ndir = g @ u.pinv(H)  # newton direction
            s.newton_curv = curv_direction(ndir)
            setattr(layer.weight, 'pre',
                    u.pinv(H))  # save Newton preconditioner
            s.step_openai = 1 / s.grad_curv if s.grad_curv else 999

            s.newton_fro = ndir.flatten().norm(
            )  # frobenius norm of Newton update
            s.regret_gradient = loss_direction(g, s.step_openai)

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        for i in range(args.train_steps):
            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
Example #7
0
def test_hessian():
    """Tests of Hessian computation."""
    u.seed_random(1)
    batch_size = 500

    data_width = 4
    targets_width = 4

    d1 = data_width ** 2
    d2 = 10
    d3 = targets_width ** 2
    o = d3
    N = batch_size
    d = [d1, d2, d3]

    dataset = u.TinyMNIST(data_width=data_width, targets_width=targets_width, dataset_size=batch_size)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    train_iter = iter(trainloader)
    data, targets = next(train_iter)

    def loss_fn(data, targets):
        assert len(data) == len(targets)
        err = data - targets.view(-1, data.shape[1])
        return torch.sum(err * err) / 2 / len(data)

    u.seed_random(1)
    model: u.SimpleModel = u.SimpleFullyConnected(d, nonlin=False, bias=True)

    # backprop hessian and compare against autograd
    hessian_backprop = u.HessianExactSqrLoss()
    output = model(data)
    for bval in hessian_backprop(output):
        output.backward(bval, retain_graph=True)

    i, layer = next(enumerate(model.layers))
    A_t = layer.activations
    Bh_t = layer.backprops_list
    H, Hb = u.hessian_from_backprops(A_t, Bh_t, bias=True)

    model.disable_hooks()
    H_autograd = u.hessian(loss_fn(model(data), targets), layer.weight)
    u.check_close(H, H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i]),
                  rtol=1e-4, atol=1e-7)
    Hb_autograd = u.hessian(loss_fn(model(data), targets), layer.bias)
    u.check_close(Hb, Hb_autograd, rtol=1e-4, atol=1e-7)

    # check first few per-example Hessians
    Hi, Hb_i = u.per_example_hess(A_t, Bh_t, bias=True)
    u.check_close(H, Hi.mean(dim=0))
    u.check_close(Hb, Hb_i.mean(dim=0), atol=2e-6, rtol=1e-5)

    for xi in range(5):
        loss = loss_fn(model(data[xi:xi + 1, ...]), targets[xi:xi + 1])
        H_autograd = u.hessian(loss, layer.weight)
        u.check_close(Hi[xi], H_autograd.reshape(d[i + 1] * d[i], d[i + 1] * d[i]))
        Hbias_autograd = u.hessian(loss, layer.bias)
        u.check_close(Hb_i[i], Hbias_autograd)

    # get subsampled Hessian
    u.seed_random(1)
    model = u.SimpleFullyConnected(d, nonlin=False)
    hessian_backprop = u.HessianSampledSqrLoss(num_samples=1)

    output = model(data)
    for bval in hessian_backprop(output):
        output.backward(bval, retain_graph=True)
    model.disable_hooks()
    i, layer = next(enumerate(model.layers))
    H_approx1 = u.hessian_from_backprops(layer.activations, layer.backprops_list)

    # get subsampled Hessian with more samples
    u.seed_random(1)
    model = u.SimpleFullyConnected(d, nonlin=False)

    hessian_backprop = u.HessianSampledSqrLoss(num_samples=o)
    output = model(data)
    for bval in hessian_backprop(output):
        output.backward(bval, retain_graph=True)
    model.disable_hooks()
    i, layer = next(enumerate(model.layers))
    H_approx2 = u.hessian_from_backprops(layer.activations, layer.backprops_list)

    assert abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1) < 0.08, abs(u.l2_norm(H) / u.l2_norm(H_approx1) - 1)  # 0.0612
    assert abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1) < 0.03, abs(u.l2_norm(H) / u.l2_norm(H_approx2) - 1)  # 0.0239
    assert u.kl_div_cov(H_approx1, H) < 0.3, u.kl_div_cov(H_approx1, H)  # 0.222
    assert u.kl_div_cov(H_approx2, H) < 0.2, u.kl_div_cov(H_approx2, H)  # 0.1233
Example #8
0
    def compute_layer_stats(layer):
        refreeze = False
        if hasattr(layer, 'frozen') and layer.frozen:
            u.unfreeze(layer)
            refreeze = True

        s = AttrDefault(str, {})
        n = args.stats_batch_size
        param = u.get_param(layer)
        _d = len(param.flatten())  # dimensionality of parameters
        layer_idx = model.layers.index(layer)
        # TODO: get layer type, include it in name
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        _loss, _output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2
        s.grad_fro = g.flatten().norm()
        s.param_fro = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        s.a_sparsity = neg_activations.float() / (
            pos_activations + neg_activations)  # 1 sparsity means all 0's
        activation_size = len(layer.data_output.flatten())
        s.a_magnitude = torch.sum(layer.data_output) / activation_size

        _output = backprop_output()
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)  # batch output Jacobian
        H = J.t() @ J / n

        s.hessian_l2 = u.l2_norm(H)
        s.jacobian_l2 = u.l2_norm(J)
        J1 = J.sum(dim=0) / n  # single output Jacobian
        s.J1_l2 = J1.norm()

        # newton decrement
        def loss_direction(direction, eps):
            """loss improvement if we take step eps in direction dir"""
            return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 *
                                      eps**2 * direction @ H @ direction.t())

        s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)

        # TODO: gradient diversity is stuck at 1
        # TODO: newton/gradient angle
        # TODO: newton step magnitude
        s.grad_curvature = u.to_python_scalar(
            g @ H @ g.t())  # curvature in direction of g
        s.step_openai = u.to_python_scalar(
            s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999

        s.regret_gradient = loss_direction(g, s.step_openai)

        if refreeze:
            u.freeze(layer)
        return s
Example #9
0
def test_l2_norm():
    mat = torch.tensor([[1, 1], [0, 1]]).float()
    u.check_equal(u.l2_norm(mat), 0.5 * (1 + math.sqrt(5)))
    ii = torch.eye(5)
    u.check_equal(u.l2_norm(ii), 1)
    def compute_layer_stats(layer):
        stats = AttrDefault(str, {})
        n = stats_batch_size
        param = u.get_param(layer)
        d = len(param.flatten())
        layer_idx = model.layers.index(layer)
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        loss, output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        stats.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2

        stats.gradient_norm = g.flatten().norm()
        stats.parameter_norm = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        stats.sparsity = pos_activations.float() / (pos_activations +
                                                    neg_activations)

        output = backprop_output()
        At2 = layer.data_input
        u.check_close(At, At2)
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)
        H = J.t() @ J / n

        model.zero_grad()
        output = model(stats_data)  # use last saved data batch for backprop
        loss = compute_loss(output, stats_targets)
        hess = u.hessian(loss, param)

        hess = hess.transpose(2, 3).transpose(0, 1).reshape(d, d)
        u.check_close(hess, H)
        u.check_close(hess, H)

        stats.hessian_norm = u.l2_norm(H)
        stats.jacobian_norm = u.l2_norm(J)
        Joutput = J.sum(dim=0) / n
        stats.jacobian_sensitivity = Joutput.norm()

        # newton decrement
        stats.loss_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
        u.check_close(stats.loss_newton, loss)

        # do line-search to find optimal step
        def line_search(directionv, start, end, steps=10):
            """Takes steps between start and end, returns steps+1 loss entries"""
            param0 = param.data.clone()
            param0v = u.vec(param0).t()
            losses = []
            for i in range(steps + 1):
                output = model(
                    stats_data)  # use last saved data batch for backprop
                loss = compute_loss(output, stats_targets)
                losses.append(loss)
                offset = start + i * ((end - start) / steps)
                param1v = param0v + offset * directionv

                param1 = u.unvec(param1v.t(), param.data.shape[0])
                param.data.copy_(param1)

            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            losses.append(loss)

            param.data.copy_(param0)
            return losses

        # try to take a newton step
        gradv = g
        line_losses = line_search(-gradv @ u.pinv(H), 0, 2, steps=10)
        u.check_equal(line_losses[0], loss)
        u.check_equal(line_losses[6], 0)
        assert line_losses[5] > line_losses[6]
        assert line_losses[7] > line_losses[6]
        return stats