def to_minimize(flat_theta): theta = reconstruct(flat_theta, summary, jnp.reshape) theta = apply_transformation(theta, "log_", jnp.exp, "") lik = calculate_likelihood( theta, species_ids, fg_covs, bg_covs, fg_covs_thin, bg_covs_thin, quad_weights, counts, n_s, n_fg, ) kl = calculate_kl(theta) prior = jnp.sum( gamma.logpdf(theta["w_prior_var"], 0.5, scale=1.0 / n_c)) prior = prior + jnp.sum( norm.logpdf(theta["w_prior_mean"], 0.0, scale=jnp.sqrt(1.0 / n_c))) return -(lik - kl + prior)
def annotated_with_grad(flat_theta, summary): flat_theta = jnp.array(flat_theta) obj, grad = with_grad(flat_theta) print(obj, jnp.linalg.norm(grad)) if jnp.isnan(obj) or jnp.isinf(obj) or jnp.any(jnp.isnan(grad)): import ipdb problem = reconstruct(flat_theta, summary, jnp.reshape) ipdb.set_trace() return np.array(obj).astype(np.float64), np.array(grad).astype( np.float64)
def to_minimize(flat_theta): theta = reconstruct(flat_theta, summary, jnp.reshape) kern_fn = get_kernel_fun(kernel_currier, theta, transformation_fun) spec = sv.SVGPSpec( m=theta["mu"], L_elts=theta["L_elts"], Z=theta["Z"], kern_fn=kern_fn ) pred_mean, pred_var = sv.project_to_x(spec, X) kl = sv.calculate_kl(spec) lik = jnp.sum(expectation_1d(likelihood_fun, pred_mean, pred_var)) prior = prior_fun(transformation_fun(theta)) return -(lik - kl + prior)
def fit( fg_covs, bg_covs, species_ids, quad_weights, counts, fg_covs_thin=None, bg_covs_thin=None, ): n_c = fg_covs.shape[1] n_s = len(np.unique(species_ids)) n_c_thin = 0 if fg_covs_thin is None else fg_covs_thin.shape[1] n_fg = fg_covs.shape[0] n_bg = bg_covs.shape[0] fg_covs_thin = fg_covs_thin if fg_covs_thin is not None else jnp.zeros( (n_fg, 0)) bg_covs_thin = bg_covs_thin if bg_covs_thin is not None else jnp.zeros( (n_bg, 0)) init_theta = { "w_means": jnp.zeros((n_c, n_s)), "log_w_vars": jnp.log(jnp.tile(1.0 / n_c, (n_c, n_s))) - 5, "intercept_means": jnp.zeros(n_s) - 5, "log_intercept_vars": jnp.zeros(n_s) - 5, "w_prior_mean": jnp.zeros((n_c, 1)), "log_w_prior_var": jnp.log(jnp.tile(1.0 / n_c, (n_c, 1))) - 5, # Thinning is assumed constant across species "w_means_thin": jnp.zeros((n_c_thin, 1)), "log_w_vars_thin": jnp.zeros((n_c_thin, 1)) - 10 if n_c_thin == 0 else jnp.log(jnp.tile(1.0 / (n_c_thin), (n_c_thin, 1))) - 10, "log_w_prior_var_thin": jnp.array(0.0) if n_c_thin == 0 else jnp.log(1.0 / (n_c_thin)) - 10, } flat_theta, summary = flatten_and_summarise(**init_theta) def to_minimize(flat_theta): theta = reconstruct(flat_theta, summary, jnp.reshape) theta = apply_transformation(theta, "log_", jnp.exp, "") lik = calculate_likelihood( theta, species_ids, fg_covs, bg_covs, fg_covs_thin, bg_covs_thin, quad_weights, counts, n_s, n_fg, ) kl = calculate_kl(theta) prior = jnp.sum( gamma.logpdf(theta["w_prior_var"], 0.5, scale=1.0 / n_c)) prior = prior + jnp.sum( norm.logpdf(theta["w_prior_mean"], 0.0, scale=jnp.sqrt(1.0 / n_c))) return -(lik - kl + prior) with_grad = jit(value_and_grad(to_minimize)) def annotated_with_grad(flat_theta, summary): flat_theta = jnp.array(flat_theta) obj, grad = with_grad(flat_theta) print(obj, jnp.linalg.norm(grad)) if jnp.isnan(obj) or jnp.isinf(obj) or jnp.any(jnp.isnan(grad)): import ipdb problem = reconstruct(flat_theta, summary, jnp.reshape) ipdb.set_trace() return np.array(obj).astype(np.float64), np.array(grad).astype( np.float64) result = minimize( partial(annotated_with_grad, summary=summary), flat_theta, method="L-BFGS-B", jac=True, ) final_theta = reconstruct(result.x, summary, jnp.reshape) final_theta = apply_transformation(final_theta, "log_", jnp.exp, "") return final_theta
def fit( X, init_kernel_params, kernel_currier, likelihood_fun, prior_fun, transformation_fun=constrain_positive, n_inducing=100, verbose=False, Z=None, ): if Z is None: Z = find_starting_z(X, n_inducing) init_kern_fn = get_kernel_fun( kernel_currier, init_kernel_params, transformation_fun ) init_spec = sv.initialise_using_kernel_fun(init_kern_fn, Z) theta = { "mu": init_spec.m, "L_elts": init_spec.L_elts, "Z": jnp.array(Z), **init_kernel_params, } flat_theta, summary = flatten_and_summarise(**theta) def to_minimize(flat_theta): theta = reconstruct(flat_theta, summary, jnp.reshape) kern_fn = get_kernel_fun(kernel_currier, theta, transformation_fun) spec = sv.SVGPSpec( m=theta["mu"], L_elts=theta["L_elts"], Z=theta["Z"], kern_fn=kern_fn ) pred_mean, pred_var = sv.project_to_x(spec, X) kl = sv.calculate_kl(spec) lik = jnp.sum(expectation_1d(likelihood_fun, pred_mean, pred_var)) prior = prior_fun(transformation_fun(theta)) return -(lik - kl + prior) with_grad = partial(convert_decorator, verbose=verbose)( jit(value_and_grad(to_minimize)) ) result = minimize(with_grad, flat_theta, method="L-BFGS-B", jac=True) final_theta = reconstruct(result.x, summary, jnp.reshape) kern = get_kernel_fun(kernel_currier, final_theta, transformation_fun) final_spec = sv.SVGPSpec( m=final_theta["mu"], L_elts=final_theta["L_elts"], Z=final_theta["Z"], kern_fn=kern, ) return final_spec, transformation_fun(final_theta)