# Define variables to be traced

    trace_func = utils.construct_trace_func(generate_params,
                                            data,
                                            dim_u,
                                            dim_v=dim_y)

    # Run experiment

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="robust_gp_regression",
        dir_prefix=f"{args.dataset}_data_subsampled_by_{args.data_subsample}",
        var_names=list(prior_specifications.keys()),
        var_trace_func=trace_func,
        posterior_neg_log_dens=posterior_neg_log_dens,
        extended_prior_neg_log_dens=extended_prior_neg_log_dens,
        constrained_system_class=GeneralGaussianProcessModelSystem,
        constrained_system_kwargs={
            "covar_func": covar_func,
            "noise_scale_func": noise_scale_func,
            "noise_transform_func": noise_transform_func,
            "data": data,
            "dim_u": dim_u,
        },
        sample_initial_states=sample_initial_states,
    )
    # Define variables to be traced

    jitted_generate_params = api.jit(api.partial(generate_params, data=data))

    def trace_func(state):
        u = state.pos[:dim_u]
        params = jitted_generate_params(u)
        return {**params, "u": u}

    # Run experiment

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="fitzhugh_nagumo",
        dir_prefix=f"σ_{args.obs_noise_std:.0e}",
        var_names=list(prior_specifications.keys()),
        var_trace_func=trace_func,
        posterior_neg_log_dens=posterior_neg_log_dens,
        extended_prior_neg_log_dens=extended_prior_neg_log_dens,
        constrained_system_class=IndependentAdditiveNoiseModelSystem,
        constrained_system_kwargs={
            "generate_y": generate_y,
            "data": data,
            "dim_u": dim_u,
        },
        sample_initial_states=sample_initial_states,
    )
    data["y_obs"] = np.concatenate(data["obs_vals_g_Na"])
    dim_u = compute_dim_u(data)

    # Set up seeded random number generator

    rng = onp.random.default_rng(args.seed)

    # Define variables to be traced

    trace_func = utils.construct_trace_func(generate_params, data, dim_u)

    # Run experiment

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="hh_voltage_clamp_sodium",
        var_names=list(prior_specifications.keys()),
        var_trace_func=trace_func,
        posterior_neg_log_dens=posterior_neg_log_dens,
        extended_prior_neg_log_dens=extended_prior_neg_log_dens,
        constrained_system_class=IndependentAdditiveNoiseModelSystem,
        constrained_system_kwargs={
            "generate_y": generate_y,
            "data": data,
            "dim_u": dim_u,
        },
        sample_initial_states=sample_initial_states,
    )
Exemplo n.º 4
0
    # Define variables to be traced

    jitted_generate_from_model = api.jit(
        api.partial(generate_from_model, data=data))

    def trace_func(state):
        u, v = state.pos[:dim_u], state.pos[dim_u:dim_u + dim_y]
        params, x = jitted_generate_from_model(u, v)
        return {**params, "x": x, "u": u, "v": v}

    # Run experiment

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="eight_schools",
        var_names=list(prior_specifications.keys()) + ["x"],
        var_trace_func=trace_func,
        posterior_neg_log_dens=posterior_neg_log_dens,
        extended_prior_neg_log_dens=extended_prior_neg_log_dens,
        constrained_system_class=HierarchicalLatentVariableModelSystem,
        constrained_system_kwargs={
            "generate_y": generate_y,
            "data": data,
            "dim_u": dim_u,
        },
        sample_initial_states=sample_initial_states,
    )
Exemplo n.º 5
0
        generate_σ=generate_σ,
        grad_generate_σ=grad_generate_σ,
        hessian_generate_σ=hessian_generate_σ,
    )

    euclidean_system_kwargs = pde.construct_euclidean_system_kwargs(
        y_obs=data["y_obs"],
        forward_func=forward_func,
        vjp_forward_func=vjp_forward_func,
        dim_y=dim_y,
        dim_z=dim_z,
        generate_σ=generate_σ,
        grad_generate_σ=grad_generate_σ,
    )

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="poisson",
        dir_prefix=f"σ_{args.obs_noise_std:.0e}",
        var_names=["σ", "z_mean", "z_std"],
        var_trace_func=trace_func,
        constrained_system_class=mici.systems.DenseConstrainedEuclideanMetricSystem,
        constrained_system_kwargs=constrained_system_kwargs,
        euclidean_system_class=mici.systems.EuclideanMetricSystem,
        euclidean_system_kwargs=euclidean_system_kwargs,
        sample_initial_states=sample_initial_states,
        precompile_jax_functions=False
    )
    # Run experiment

    (
        constrained_system_class,
        constrained_system_kwargs,
    ) = utils.get_ssm_constrained_system_class_and_kwargs(
        args.use_manual_constraint_and_jacobian,
        generate_params,
        generate_x_0,
        forward_func,
        inverse_observation_func,
        constr_split,
        jacob_constr_split_blocks,
    )
    constrained_system_kwargs.update(data=data, dim_u=dim_u)

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="stochastic_volatility",
        var_names=list(prior_specifications.keys()),
        var_trace_func=trace_func,
        posterior_neg_log_dens=posterior_neg_log_dens,
        extended_prior_neg_log_dens=extended_prior_neg_log_dens,
        constrained_system_class=constrained_system_class,
        constrained_system_kwargs=constrained_system_kwargs,
        sample_initial_states=sample_initial_states,
    )

Exemplo n.º 7
0
    # Run experiment

    (
        constrained_system_class,
        constrained_system_kwargs,
    ) = utils.get_ssm_constrained_system_class_and_kwargs(
        args.use_manual_constraint_and_jacobian,
        generate_params,
        generate_x_0,
        forward_func,
        inverse_observation_func,
        constr_split,
        jacob_constr_split_blocks,
    )
    constrained_system_kwargs.update(data=data, dim_u=dim_u)

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="nonlinear_ssm",
        dir_prefix=f"σ_{args.obs_noise_std:.0e}",
        var_names=list(prior_specifications.keys()),
        var_trace_func=trace_func,
        posterior_neg_log_dens=posterior_neg_log_dens,
        extended_prior_neg_log_dens=extended_prior_neg_log_dens,
        constrained_system_class=constrained_system_class,
        constrained_system_kwargs=constrained_system_kwargs,
        sample_initial_states=sample_initial_states,
    )
Exemplo n.º 8
0
    # Define variables to be traced

    jitted_generate_from_model = api.jit(
        api.partial(generate_from_model, data=data))

    def trace_func(state):
        u = state.pos[:dim_u]
        params, x = jitted_generate_from_model(u)
        return {**params, "x": x, "u": u}

    # Run experiment

    final_states, traces, stats, summary_dict, sampler = utils.run_experiment(
        args=args,
        data=data,
        rng=rng,
        experiment_name="garch",
        var_names=["μ", "α_0", "α_1", "β_1"],
        var_trace_func=trace_func,
        posterior_neg_log_dens=posterior_neg_log_dens,
        extended_prior_neg_log_dens=extended_prior_neg_log_dens,
        constrained_system_class=IndependentAdditiveNoiseModelSystem,
        constrained_system_kwargs={
            "generate_y": generate_y,
            "data": data,
            "dim_u": dim_u,
        },
        sample_initial_states=sample_initial_states,
    )