def test_grad_model(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    opt_weights = model.get_model_weights()
    grad_data = Subset(model.dataset, np.array([0, 1]))
    grad, out = model.grad_model_out_weights(grad_data, opt_weights)
    out_vals = model.model_eval(opt_weights, grad_data)
    assert np.max(np.abs(out - out_vals)) < 1e-4
    assert grad.shape == (2, opt_weights.size)
def test_model_eval(checkpoint):
    model = IsiCifar(checkpoint, "../../data")
    opt_weights = model.get_model_weights()
    out = model.model_eval(opt_weights, model.train_set)
    pred = np.argmax(out, axis=1)
    assert np.array_equal(pred, model.train_labels)