method = 'VI'
        lr = 0.5
        vb_samples = run_inference(Y, ProteinGenModel, pr_ode_model, method, \
            iterations = args.iterations, num_samples = args.num_qsamples, \
            lr = lr, num_particles = 1, return_sites = ("ode_params","scale","_RETURN"))
        vb_params = \
        np.concatenate((vb_samples['ode_params'].detach().numpy().reshape((args.num_qsamples,6)), \
            vb_samples['scale'].detach().numpy().reshape((args.num_qsamples,1))),axis=1)

        plot_marginals(vb_params,
                       mc_params,
                       param_names,
                       real_params=real_params)
    else:
        print('Using VJP by Adjoint Sensitivity')
        pr_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 5, 6, \
            times, 1e-5, 1e-6, [1,0,1,0,0])

        method = 'NUTS'
        NUTS_samples = run_inference(Y, ProteinGenModel, pr_ode_model, method, \
            iterations = args.num_samples, warmup_steps = args.warmup_steps)
        mc_params = np.concatenate(
            (NUTS_samples['ode_params'], NUTS_samples['scale'][:, None]),
            axis=1)

        method = 'VI'
        lr = 0.5
        vb_samples = run_inference(Y, ProteinGenModel, pr_ode_model, method, \
            iterations = args.iterations, num_samples = args.num_qsamples, \
            lr = lr, num_particles = 1, return_sites = ("ode_params","scale","_RETURN"))
        vb_params = \
        np.concatenate((vb_samples['ode_params'].detach().numpy().reshape((args.num_qsamples,6)), \
Beispiel #2
0
        method = 'VI'
        lr = 0.5
        vb_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \
                                   iterations=args.iterations, num_samples=args.num_qsamples, \
                                   lr=lr, num_particles=1, return_sites=("ode_params1", "ode_params2", "ode_params3"))
        vb_params = np.concatenate(
            (vb_samples['ode_params1'][:, None].detach().numpy(),
             vb_samples['ode_params2'][:, None].detach().numpy(),
             vb_samples['ode_params3'][:, None].detach().numpy()),
            axis=1)

        # plot_marginals(vb_params, mc_params, param_names, rows=2)
        plot_marginals(vb_params, vb_params, param_names, rows=2)
    else:
        print('Using VJP by Adjoint Sensitivity')
        sir_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 3, 5, \
                                                   times, 1e-5, 1e-6, [0.9, 0.1, 0.0])
        sir_ode_model.set_unknown_y0()
        # method = 'NUTS'
        # NUTS_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \
        #     iterations = args.num_samples, warmup_steps = args.warmup_steps)
        # mc_params=np.concatenate((NUTS_samples['ode_params1'][:,None],
        #                     NUTS_samples['ode_params2'][:,None],
        #                     NUTS_samples['ode_params3'][:,None]
        #                     ),axis=1)

        method = 'VI'
        lr = 0.5
        vb_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \
                                   iterations=args.iterations, num_samples=args.num_qsamples, \
                                   lr=lr, num_particles=1, return_sites=("ode_params1", "ode_params2", "ode_params3"))
        vb_params = np.concatenate(
Beispiel #3
0
    print('Using VJP by Forward Sensitivity')
    lna_ode_model = ForwardSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 6, 3, \
        times, 1e-5, 1e-6, [100, 100, 0, 0 ,0 ,0])
    method = 'VI'
    lr = 0.5
    vb_samples = run_inference(Y, LNAGenModel, lna_ode_model, method, iterations=args.iterations, \
        lr = lr, num_particles = 1, num_samples = args.num_qsamples, \
            return_sites = ("ode_params1","ode_params2","ode_params3"))
    vb_params_for = np.concatenate(
        (vb_samples['ode_params1'][:, None].detach().numpy(),
         vb_samples['ode_params2'][:, None].detach().numpy(),
         vb_samples['ode_params3'][:, None].detach().numpy()),
        axis=1)

    print('Using VJP by Adjoint Sensitivity')
    lna_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 6, 3, \
    times, 1e-5, 1e-6, [100, 100, 0, 0 ,0 ,0])

    vb_samples = run_inference(Y, LNAGenModel, lna_ode_model, method, iterations=args.iterations, \
        lr = lr, num_particles = 1, num_samples = args.num_qsamples, \
            return_sites = ("ode_params1","ode_params2","ode_params3"))
    vb_params_adj = np.concatenate(
        (vb_samples['ode_params1'][:, None].detach().numpy(),
         vb_samples['ode_params2'][:, None].detach().numpy(),
         vb_samples['ode_params3'][:, None].detach().numpy()),
        axis=1)

    plot_marginals(vb_params_for,
                   vb_params_adj,
                   param_names,
                   real_params=real_params,
                   rows=2)
Beispiel #4
0

param_names = [r"$\beta$",r"$\gamma$", r"$s_0$"]
plot_marginals(vb_for, mc_for, param_names, './figures/ppc_sir/sir_marginals_for')
plot_marginals(vb_adj, mc_adj, param_names, './figures/ppc_sir/sir_marginals_adj')


pairwise(vb_for, parameter_names=param_names, saveto='./figures/ppc_sir/sir_pairwise_vb_for.png', nbins=100)
#pairwise(vb_adj, parameter_names=param_names, saveto='./figures/ppc_sir/sir_pairwise_vb_adj.png', nbins=100)
pairwise(mc_for, parameter_names=param_names, saveto='./figures/ppc_sir/sir_pairwise_mc_for.png', nbins=100)
#pairwise(mc_adj, parameter_names=param_names, saveto='./figures/ppc_sir/sir_pairwise_mc_adj.png', nbins=100)


sir_ode_model_for = ForwardSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 3, 5, \
        times, 1e-5, 1e-6, [0.9,0.1,0.0])
sir_ode_model_adj = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 3, 5, \
        times, 1e-5, 1e-6, [0.9,0.1,0.0])
sir_ode_model_adj.set_checkpointed()
mc_for_ppc = []
vb_for_ppc = []
mc_adj_ppc = []
vb_adj_ppc = []
for i in range(1000):
    sir_ode_model_for.set_y0([mc_for[i,2],1-mc_for[i,2],0])
    mc_for_ppc.append(sir_ode_model_for.solve(mc_for[i,:2])[:,1]*300)
    sir_ode_model_for.set_y0([vb_for[i,2],1-vb_for[i,2],0])
    vb_for_ppc.append(sir_ode_model_for.solve(vb_for[i,:2])[:,1]*300)

    sir_ode_model_adj.set_y0([mc_adj[i,2],1-mc_adj[i,2],0])
    mc_adj_ppc.append(sir_ode_model_adj.solve(mc_adj[i,:2])[:,1]*300)
    sir_ode_model_adj.set_y0([vb_adj[i,2],1-vb_adj[i,2],0])
    vb_adj_ppc.append(sir_ode_model_adj.solve(vb_adj[i,:2])[:,1]*300)    
Beispiel #5
0
                       ncol=2,
                       fontsize=18)
    plt.subplots_adjust(hspace=0.7)
    plt.tight_layout()
    plt.savefig(plot_name + '.eps')
    plt.close()


_rhs = r
_y, _p = sym.symbols('y:5'), sym.symbols('p:6')
rhs_f, jac_x_f, jac_p_f = prepare_symbolic(_rhs, _y, _p)
times = np.array([0, 1, 2, 4, 5, 7, 10, 15, 20, 30, 40, 50, 60, 80, 100])

pr_ode_model_for = ForwardSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 5, 6, \
    times, 1e-5, 1e-6, [1,0,1,0,0])
pr_ode_model_adj = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 5, 6, \
    times, 1e-5, 1e-6, [1,0,1,0,0])
pr_ode_model_adj.set_checkpointed()

sigma = 0.01
real_params = [0.07, 0.6, 0.05, 0.3, 0.017, 0.3]
sol = pr_ode_model_for.solve(real_params)
np.random.seed(121)
Y = sol + np.random.randn(len(times), 5) * sigma

param_filename = './results/pr_vi_for.p'
vb_for = pickle.load(open(param_filename, "rb"))
param_filename = './results/pr_vi_adj.p'
vb_adj = pickle.load(open(param_filename, "rb"))
param_filename = './results/pr_hmc_for.p'
mc_for = pickle.load(open(param_filename, "rb"))[::2, :]
param_filename = './results/pr_hmc_adj.p'
Beispiel #6
0
                                   iterations=args.iterations,
                                   num_samples=args.num_qsamples,
                                   lr=lr,
                                   num_particles=num_particles,
                                   return_sites=("ode_params1", "ode_params2"))
        vb_params = np.concatenate(
            (vb_samples['ode_params1'][:, None].detach().numpy(),
             vb_samples['ode_params2'][:, None].detach().numpy()),
            axis=1)

        # plot_marginals(vb_params, mc_params, param_names, rows=2)
        plot_marginals(vb_params, vb_params, param_names, rows=2)
    else:
        print('Using VJP by Adjoint Sensitivity')
        plant_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f,
                                                     2, 2, times, 1e-5, 1e-6,
                                                     [0.237939, 0.021049])

        # plant_ode_model.set_unknown_y0()
        # method = 'NUTS'
        # NUTS_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \
        #     iterations = args.num_samples, warmup_steps = args.warmup_steps)
        # mc_params=np.concatenate((NUTS_samples['ode_params1'][:,None],
        #                     NUTS_samples['ode_params2'][:,None],
        #                     NUTS_samples['ode_params3'][:,None]
        #                     ),axis=1)

        method = 'VI'
        lr = 0.5
        num_particles = 1
        vb_samples = run_inference(Y,