def solve_dare(A, B, Q, R): def _make_riccati_operator(params): A, B, Q, R = params def _riccati_operator(i, P): del i X = R + B.T @ P.T @ B Y = B.T @ P @ A return (A.T @ P @ A) - ((A.T @ P @ B) @ np.linalg.solve(X, Y)) + Q return _riccati_operator implicit_function = two_phase_solver(_make_riccati_operator) solution = implicit_function(np.eye(A.shape[0]), (A, B, Q, R)) return solution.value
def lqr_evaluation(A, B, Q, R, K): def _make_lqr_eval_operator(params): A, B, Q, R, K = params def _lqr_eval_operator(i, P): del i return Q + (K.T @ R @ K) + ((A + B @ K).T @ P @ (A + B @ K)) return _lqr_eval_operator implicit_function = two_phase_solver(_make_lqr_eval_operator) solution = implicit_function(np.eye(A.shape[0]), (A, B, Q, R, K)) def _vf(x): return x.T @ solution.value @ x def _qf(x, u): return (x.T @ Q @ x) + (u.T @ R @ u) + _vf(A @ x + B @ u) return _vf, _qf
def make_differentiable_planner(true_discount, temperature): def param_func(params): transition, reward = params def _smooth_bellman_operator(i, qf): del i return reward + true_discount * np.einsum( 'ast,t->sa', transition, temperature * logsumexp((1. / temperature) * qf, axis=1)) return _smooth_bellman_operator smooth_value_iteration = two_phase_solver(param_func) def _planner(params): transition_hat, reward_hat = params q0 = np.zeros_like(reward_hat) solution = smooth_value_iteration(q0, (transition_hat, reward_hat)) return softmax((1. / temperature) * solution.value) return _planner
def get_solvers(model_constructor, pdf_transform=False, default_rtol=1e-10, default_atol=1e-10, default_max_iter=int(1e7), learning_rate=0.01): adam_init, adam_update, adam_get_params = optimizers.adam(1e-6) def make_model(hyper_pars): constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1] m, bonlypars = model_constructor(nn_pars) bounds = m.config.suggested_bounds constrained_mu = to_inf(constrained_mu, bounds[0]) if pdf_transform else constrained_mu exp_bonly_data = expected_data(m, bonlypars, include_auxdata=True) def expected_logpdf( pars): # maps pars to bounded space if pdf_transform = True return (logpdf(m, to_bounded_vec(pars, bounds), exp_bonly_data) if pdf_transform else logpdf(m, pars, exp_bonly_data)) def global_fit_objective(pars): # NLL return -expected_logpdf(pars)[0] def constrained_fit_objective(nuis_par): # NLL pars = jax.numpy.concatenate( [jax.numpy.asarray([constrained_mu]), nuis_par]) return -expected_logpdf(pars)[0] return constrained_mu, global_fit_objective, constrained_fit_objective, bounds def global_bestfit_minimized(hyper_param): _, nll, _, _ = make_model(hyper_param) def bestfit_via_grad_descent(i, param): # gradient descent g = jax.grad(nll)(param) # param = param - g * learning_rate param = adam_get_params(adam_update(i, g, adam_init(param))) return param return bestfit_via_grad_descent def constrained_bestfit_minimized(hyper_param): mu, nll, cnll, bounds = make_model(hyper_param) def bestfit_via_grad_descent(i, param): # gradient descent _, np = param[0], param[1:] g = jax.grad(cnll)(np) np = adam_get_params(adam_update(i, g, adam_init(np))) param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np]) return param return bestfit_via_grad_descent global_solve = twophase.two_phase_solver( param_func=global_bestfit_minimized, default_rtol=default_rtol, default_atol=default_atol, default_max_iter=default_max_iter) constrained_solver = twophase.two_phase_solver( param_func=constrained_bestfit_minimized, default_rtol=default_rtol, default_atol=default_atol, default_max_iter=default_max_iter, ) def g_fitter(init, hyper_pars): solve = global_solve(init, hyper_pars) return solve.value def c_fitter(init, hyper_pars): solve = constrained_solver(init, hyper_pars) return solve.value return g_fitter, c_fitter
def get_solvers(model_constructor, pdf_transform=False, default_rtol=1e-10, default_atol=1e-10, default_max_iter=int(1e7), learning_rate=0.01): ''' Wraps a series of functions that perform maximum likelihood fitting in the `two_phase_solver` method found in the `fax` python module. This allows for the calculation of gradients of the best-fit parameters with respect to upstream parameters that control the underlying model, i.e. the event yields (which are then parameterized by weights or similar). Args: model_constructor: Function that takes in the parameters of the observable, and returns a model object (and background-only parameters) Returns: g_fitter, c_fitter: Callable functions that perform global and constrained fits respectively. Differentiable :) ''' adam_init, adam_update, adam_get_params = optimizers.adam(1e-6) def make_model(hyper_pars): constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1] m, bonlypars = model_constructor(nn_pars) bounds = m.config.suggested_bounds() constrained_mu = to_inf(constrained_mu, bounds[0]) if pdf_transform else constrained_mu exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True) def expected_logpdf( pars): # maps pars to bounded space if pdf_transform = True return (m.logpdf(to_bounded_vec(pars, bounds), exp_bonly_data) if pdf_transform else m.logpdf(pars, exp_bonly_data)) def global_fit_objective(pars): # NLL return -expected_logpdf(pars)[0] def constrained_fit_objective(nuis_par): # NLL pars = jax.numpy.concatenate( [jax.numpy.asarray([constrained_mu]), nuis_par]) return -expected_logpdf(pars)[0] return constrained_mu, global_fit_objective, constrained_fit_objective, bounds def global_bestfit_minimized(hyper_param): _, nll, _, _ = make_model(hyper_param) def bestfit_via_grad_descent(i, param): # gradient descent g = jax.grad(nll)(param) # param = param - g * learning_rate param = adam_get_params(adam_update(i, g, adam_init(param))) return param return bestfit_via_grad_descent def constrained_bestfit_minimized(hyper_param): mu, nll, cnll, bounds = make_model(hyper_param) def bestfit_via_grad_descent(i, param): # gradient descent _, np = param[0], param[1:] g = jax.grad(cnll)(np) np = adam_get_params(adam_update(i, g, adam_init(np))) param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np]) return param return bestfit_via_grad_descent global_solve = twophase.two_phase_solver( param_func=global_bestfit_minimized, default_rtol=default_rtol, default_atol=default_atol, default_max_iter=default_max_iter) constrained_solver = twophase.two_phase_solver( param_func=constrained_bestfit_minimized, default_rtol=default_rtol, default_atol=default_atol, default_max_iter=default_max_iter, ) def g_fitter(init, hyper_pars): solve = global_solve(init, hyper_pars) return solve.value def c_fitter(init, hyper_pars): solve = constrained_solver(init, hyper_pars) return solve.value return g_fitter, c_fitter
def global_fit( model_constructor, pdf_transform=False, default_rtol=1e-10, default_atol=1e-10, default_max_iter=int(1e7), learning_rate=1e-6, ): """ Wraps a series of functions that perform maximum likelihood fitting in the `two_phase_solver` method found in the `fax` python module. This allows for the calculation of gradients of the best-fit parameters with respect to upstream parameters that control the underlying model, i.e. the event yields (which are then parameterized by weights or similar). Args: model_constructor: Function that takes in the parameters of the observable, and returns a model object (and background-only parameters) Returns: global_fitter: Callable function that performs global fits. Differentiable :) """ adam_init, adam_update, adam_get_params = optimizers.adam(learning_rate) def make_model(model_pars): m, bonlypars = model_constructor(model_pars) bounds = m.config.suggested_bounds() constrained_mu = (to_inf(constrained_mu, bounds[0]) if pdf_transform else constrained_mu) exp_bonly_data = m.expected_data(bonlypars, include_auxdata=True) def expected_logpdf( pars): # maps pars to bounded space if pdf_transform = True return (m.logpdf(to_bounded_vec(pars, bounds), exp_bonly_data) if pdf_transform else m.logpdf(pars, exp_bonly_data)) def global_fit_objective(pars): # NLL return -expected_logpdf(pars)[0] return global_fit_objective def global_bestfit_minimized(hyper_param): nll = make_model(hyper_param) def bestfit_via_grad_descent(i, param): # gradient descent g = jax.grad(nll)(param) # param = param - g * learning_rate param = adam_get_params(adam_update(i, g, adam_init(param))) return param return bestfit_via_grad_descent global_solve = twophase.two_phase_solver( param_func=global_bestfit_minimized, default_rtol=default_rtol, default_atol=default_atol, default_max_iter=default_max_iter, ) def global_fitter(init, hyper_pars): solve = global_solve(init, hyper_pars) return solve.value return global_fitter