def test_zero_one_loss(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    opt_weights = model.get_model_weights()
    loss_val = model.loss_val_zero_one(model.train_set, model.train_labels,
                                       opt_weights)
    loss = np.zeros(model.train_labels.size)
    assert np.array_equal(loss, loss_val)
def test_hess_diag(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    opt_weights = model.get_model_weights()
    grad_data = Subset(model.dataset, np.array([0, 1]))
    true_labels = model.labels[:2]
    hess_diag = model.g_diag(grad_data, true_labels, opt_weights)
    assert hess_diag.size == opt_weights.size
    assert np.all(hess_diag >= 0)
def test_grad_model(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    opt_weights = model.get_model_weights()
    grad_data = Subset(model.dataset, np.array([0, 1]))
    grad, out = model.grad_model_out_weights(grad_data, opt_weights)
    out_vals = model.model_eval(opt_weights, grad_data)
    assert np.max(np.abs(out - out_vals)) < 1e-4
    assert grad.shape == (2, opt_weights.size)
def test_model_eval(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    opt_weights = model.get_model_weights()
    out = model.model_eval(opt_weights, model.train_set)
    pred = np.argmax(out, axis=1)
    assert np.array_equal(pred, model.train_labels)
def test_set_weights(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    weights = model.get_model_weights()
    new_weights = np.repeat(1, 3131364)
    model.set_model_weights(new_weights)
    assert np.max(np.abs(model.get_model_weights() - new_weights)) < 1e-4
def test_get_weights(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    weights = model.get_model_weights()
    assert weights.size == 3131364
def main(
    checkpoint,
    hessian_loc,
    sample_loc,
    iters,
    checkpoint_loc,
    sigma_prior,
    projection_dim,
    gpus,
    results_loc,
):
    if not os.path.exists(sample_loc):
        log.info(f"Creating sample save location at {sample_loc}")
        os.mkdir(sample_loc)

    if not os.path.exists(checkpoint_loc):
        log.info(f"Creating checkpoint save location at {checkpoint_loc}")
        os.mkdir(checkpoint_loc)

    log.info(
        f"Creating model for checkpoint {checkpoint} and getting optimal weights"
    )
    model = Model(checkpoint)
    opt_weights = model.get_model_weights()

    if os.path.exists(hessian_loc):
        log.info(f"Loading the saved hessian at {hessian_loc}")
        hess = np.load(hessian_loc)
    else:
        log.info(f"No hessian found, calculating the hessian")
        hess = model.g_diag(model.train_set, model.train_labels, opt_weights)
        np.save(hessian_loc, hess)
        log.info(f"Hessian saved to {hessian_loc}")

    sigma_post = find_posterior(sigma_prior, hess)

    def loss(prms):
        log.debug(f"Params shape {prms.shape}")
        loss = np.hstack([
            model.loss_val_zero_one(model.dataset, model.labels,
                                    prms[i, :])[:, np.newaxis]
            for i in range(prms.shape[0])
        ])
        loss = np.reshape(loss, (loss.shape[0], loss.shape[1], 1))
        log.debug(f"Loss shape {loss.shape}")
        return loss

    log.info("Starting to sample loss from the posterior")
    post_sample = sample_posterior(
        opt_weights,
        sigma_post,
        projection_dim,
        loss,
        sample_loc,
        gpus,
        model.labels.size,
    )
    log.info("Finished sampling loss from the posterior")

    prob = np.repeat(1 / model.labels.size, model.labels.size)
    temp_loc = sample_loc + "/temp"
    if not os.path.exists(temp_loc):
        os.mkdir(temp_loc)

    log.info("Starting coreset calculation with Frank-Wolfe algorithm")
    checkpoint_its = 50
    (
        coreset_sizes,
        norm_difference_full,
        norm_coreset_full,
        norm_loss_full,
        w_minus_p_norm_full,
        w_norm_full,
        gen_err_full,
    ) = calc_coreset(iters, post_sample, prob, checkpoint_its, checkpoint_loc,
                     temp_loc)

    log.info("Starting computing constants for bound")
    params = get_bound_params(post_sample, prob)

    log.info(f"Coreset calculation completed, saving results to {results_loc}")
    np.savez(
        results_loc,
        iters=iters,
        coresize=coreset_sizes,
        genErr=gen_err_full,
        wminuspnorm_full=w_minus_p_norm_full,
        lnorm_prior=norm_loss_full,
        llwnorm=norm_difference_full,
        lwnorm=norm_coreset_full,
        wnorm=w_norm_full,
        params=params,
    )