Beispiel #1
0
def solve_dare(A, B, Q, R):
    def _make_riccati_operator(params):
        A, B, Q, R = params

        def _riccati_operator(i, P):
            del i
            X = R + B.T @ P.T @ B
            Y = B.T @ P @ A
            return (A.T @ P @ A) - ((A.T @ P @ B) @ np.linalg.solve(X, Y)) + Q

        return _riccati_operator

    implicit_function = two_phase_solver(_make_riccati_operator)
    solution = implicit_function(np.eye(A.shape[0]), (A, B, Q, R))
    return solution.value
Beispiel #2
0
def lqr_evaluation(A, B, Q, R, K):
    def _make_lqr_eval_operator(params):
        A, B, Q, R, K = params

        def _lqr_eval_operator(i, P):
            del i
            return Q + (K.T @ R @ K) + ((A + B @ K).T @ P @ (A + B @ K))

        return _lqr_eval_operator

    implicit_function = two_phase_solver(_make_lqr_eval_operator)
    solution = implicit_function(np.eye(A.shape[0]), (A, B, Q, R, K))

    def _vf(x):
        return x.T @ solution.value @ x

    def _qf(x, u):
        return (x.T @ Q @ x) + (u.T @ R @ u) + _vf(A @ x + B @ u)

    return _vf, _qf
Beispiel #3
0
def make_differentiable_planner(true_discount, temperature):
    def param_func(params):
        transition, reward = params

        def _smooth_bellman_operator(i, qf):
            del i
            return reward + true_discount * np.einsum(
                'ast,t->sa', transition,
                temperature * logsumexp((1. / temperature) * qf, axis=1))

        return _smooth_bellman_operator

    smooth_value_iteration = two_phase_solver(param_func)

    def _planner(params):
        transition_hat, reward_hat = params
        q0 = np.zeros_like(reward_hat)
        solution = smooth_value_iteration(q0, (transition_hat, reward_hat))
        return softmax((1. / temperature) * solution.value)

    return _planner
Beispiel #4
0
def get_solvers(model_constructor,
                pdf_transform=False,
                default_rtol=1e-10,
                default_atol=1e-10,
                default_max_iter=int(1e7),
                learning_rate=0.01):

    adam_init, adam_update, adam_get_params = optimizers.adam(1e-6)

    def make_model(hyper_pars):
        constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]
        m, bonlypars = model_constructor(nn_pars)

        bounds = m.config.suggested_bounds
        constrained_mu = to_inf(constrained_mu,
                                bounds[0]) if pdf_transform else constrained_mu

        exp_bonly_data = expected_data(m, bonlypars, include_auxdata=True)

        def expected_logpdf(
                pars):  # maps pars to bounded space if pdf_transform = True

            return (logpdf(m, to_bounded_vec(pars, bounds), exp_bonly_data)
                    if pdf_transform else logpdf(m, pars, exp_bonly_data))

        def global_fit_objective(pars):  # NLL
            return -expected_logpdf(pars)[0]

        def constrained_fit_objective(nuis_par):  # NLL
            pars = jax.numpy.concatenate(
                [jax.numpy.asarray([constrained_mu]), nuis_par])
            return -expected_logpdf(pars)[0]

        return constrained_mu, global_fit_objective, constrained_fit_objective, bounds

    def global_bestfit_minimized(hyper_param):
        _, nll, _, _ = make_model(hyper_param)

        def bestfit_via_grad_descent(i, param):  # gradient descent
            g = jax.grad(nll)(param)
            # param = param - g * learning_rate
            param = adam_get_params(adam_update(i, g, adam_init(param)))
            return param

        return bestfit_via_grad_descent

    def constrained_bestfit_minimized(hyper_param):
        mu, nll, cnll, bounds = make_model(hyper_param)

        def bestfit_via_grad_descent(i, param):  # gradient descent
            _, np = param[0], param[1:]
            g = jax.grad(cnll)(np)
            np = adam_get_params(adam_update(i, g, adam_init(np)))
            param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])
            return param

        return bestfit_via_grad_descent

    global_solve = twophase.two_phase_solver(
        param_func=global_bestfit_minimized,
        default_rtol=default_rtol,
        default_atol=default_atol,
        default_max_iter=default_max_iter)
    constrained_solver = twophase.two_phase_solver(
        param_func=constrained_bestfit_minimized,
        default_rtol=default_rtol,
        default_atol=default_atol,
        default_max_iter=default_max_iter,
    )

    def g_fitter(init, hyper_pars):
        solve = global_solve(init, hyper_pars)
        return solve.value

    def c_fitter(init, hyper_pars):
        solve = constrained_solver(init, hyper_pars)
        return solve.value

    return g_fitter, c_fitter
Beispiel #5
0
def get_solvers(model_constructor,
                pdf_transform=False,
                default_rtol=1e-10,
                default_atol=1e-10,
                default_max_iter=int(1e7),
                learning_rate=0.01):
    '''
    Wraps a series of functions that perform maximum likelihood fitting in the
    `two_phase_solver` method found in the `fax` python module. This allows for
    the calculation of gradients of the best-fit parameters with respect to upstream
    parameters that control the underlying model, i.e. the event yields (which are
    then parameterized by weights or similar).

    Args:
            model_constructor: Function that takes in the parameters of the observable,
            and returns a model object (and background-only parameters)
    Returns:
            g_fitter, c_fitter: Callable functions that perform global and constrained fits
            respectively. Differentiable :)
    '''

    adam_init, adam_update, adam_get_params = optimizers.adam(1e-6)

    def make_model(hyper_pars):
        constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]
        m, bonlypars = model_constructor(nn_pars)

        bounds = m.config.suggested_bounds()
        constrained_mu = to_inf(constrained_mu,
                                bounds[0]) if pdf_transform else constrained_mu

        exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True)

        def expected_logpdf(
                pars):  # maps pars to bounded space if pdf_transform = True

            return (m.logpdf(to_bounded_vec(pars, bounds), exp_bonly_data)
                    if pdf_transform else m.logpdf(pars, exp_bonly_data))

        def global_fit_objective(pars):  # NLL
            return -expected_logpdf(pars)[0]

        def constrained_fit_objective(nuis_par):  # NLL
            pars = jax.numpy.concatenate(
                [jax.numpy.asarray([constrained_mu]), nuis_par])
            return -expected_logpdf(pars)[0]

        return constrained_mu, global_fit_objective, constrained_fit_objective, bounds

    def global_bestfit_minimized(hyper_param):
        _, nll, _, _ = make_model(hyper_param)

        def bestfit_via_grad_descent(i, param):  # gradient descent
            g = jax.grad(nll)(param)
            # param = param - g * learning_rate
            param = adam_get_params(adam_update(i, g, adam_init(param)))
            return param

        return bestfit_via_grad_descent

    def constrained_bestfit_minimized(hyper_param):
        mu, nll, cnll, bounds = make_model(hyper_param)

        def bestfit_via_grad_descent(i, param):  # gradient descent
            _, np = param[0], param[1:]
            g = jax.grad(cnll)(np)
            np = adam_get_params(adam_update(i, g, adam_init(np)))
            param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])
            return param

        return bestfit_via_grad_descent

    global_solve = twophase.two_phase_solver(
        param_func=global_bestfit_minimized,
        default_rtol=default_rtol,
        default_atol=default_atol,
        default_max_iter=default_max_iter)
    constrained_solver = twophase.two_phase_solver(
        param_func=constrained_bestfit_minimized,
        default_rtol=default_rtol,
        default_atol=default_atol,
        default_max_iter=default_max_iter,
    )

    def g_fitter(init, hyper_pars):
        solve = global_solve(init, hyper_pars)
        return solve.value

    def c_fitter(init, hyper_pars):
        solve = constrained_solver(init, hyper_pars)
        return solve.value

    return g_fitter, c_fitter
Beispiel #6
0
def global_fit(
        model_constructor,
        pdf_transform=False,
        default_rtol=1e-10,
        default_atol=1e-10,
        default_max_iter=int(1e7),
        learning_rate=1e-6,
):
    """
    Wraps a series of functions that perform maximum likelihood fitting in the
    `two_phase_solver` method found in the `fax` python module. This allows for
    the calculation of gradients of the best-fit parameters with respect to upstream
    parameters that control the underlying model, i.e. the event yields (which are
    then parameterized by weights or similar).

    Args:
            model_constructor: Function that takes in the parameters of the observable,
            and returns a model object (and background-only parameters)
    Returns:
            global_fitter: Callable function that performs global fits.
            Differentiable :)
    """

    adam_init, adam_update, adam_get_params = optimizers.adam(learning_rate)

    def make_model(model_pars):
        m, bonlypars = model_constructor(model_pars)

        bounds = m.config.suggested_bounds()
        constrained_mu = (to_inf(constrained_mu, bounds[0])
                          if pdf_transform else constrained_mu)

        exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True)

        def expected_logpdf(
                pars):  # maps pars to bounded space if pdf_transform = True

            return (m.logpdf(to_bounded_vec(pars, bounds), exp_bonly_data)
                    if pdf_transform else m.logpdf(pars, exp_bonly_data))

        def global_fit_objective(pars):  # NLL
            return -expected_logpdf(pars)[0]

        return global_fit_objective

    def global_bestfit_minimized(hyper_param):
        nll = make_model(hyper_param)

        def bestfit_via_grad_descent(i, param):  # gradient descent
            g = jax.grad(nll)(param)
            # param = param - g * learning_rate
            param = adam_get_params(adam_update(i, g, adam_init(param)))
            return param

        return bestfit_via_grad_descent

    global_solve = twophase.two_phase_solver(
        param_func=global_bestfit_minimized,
        default_rtol=default_rtol,
        default_atol=default_atol,
        default_max_iter=default_max_iter,
    )

    def global_fitter(init, hyper_pars):
        solve = global_solve(init, hyper_pars)
        return solve.value

    return global_fitter