method = 'VI' lr = 0.5 vb_samples = run_inference(Y, ProteinGenModel, pr_ode_model, method, \ iterations = args.iterations, num_samples = args.num_qsamples, \ lr = lr, num_particles = 1, return_sites = ("ode_params","scale","_RETURN")) vb_params = \ np.concatenate((vb_samples['ode_params'].detach().numpy().reshape((args.num_qsamples,6)), \ vb_samples['scale'].detach().numpy().reshape((args.num_qsamples,1))),axis=1) plot_marginals(vb_params, mc_params, param_names, real_params=real_params) else: print('Using VJP by Adjoint Sensitivity') pr_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 5, 6, \ times, 1e-5, 1e-6, [1,0,1,0,0]) method = 'NUTS' NUTS_samples = run_inference(Y, ProteinGenModel, pr_ode_model, method, \ iterations = args.num_samples, warmup_steps = args.warmup_steps) mc_params = np.concatenate( (NUTS_samples['ode_params'], NUTS_samples['scale'][:, None]), axis=1) method = 'VI' lr = 0.5 vb_samples = run_inference(Y, ProteinGenModel, pr_ode_model, method, \ iterations = args.iterations, num_samples = args.num_qsamples, \ lr = lr, num_particles = 1, return_sites = ("ode_params","scale","_RETURN")) vb_params = \ np.concatenate((vb_samples['ode_params'].detach().numpy().reshape((args.num_qsamples,6)), \
print('Using VJP by Forward Sensitivity') lna_ode_model = ForwardSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 6, 3, \ times, 1e-5, 1e-6, [100, 100, 0, 0 ,0 ,0]) method = 'VI' lr = 0.5 vb_samples = run_inference(Y, LNAGenModel, lna_ode_model, method, iterations=args.iterations, \ lr = lr, num_particles = 1, num_samples = args.num_qsamples, \ return_sites = ("ode_params1","ode_params2","ode_params3")) vb_params_for = np.concatenate( (vb_samples['ode_params1'][:, None].detach().numpy(), vb_samples['ode_params2'][:, None].detach().numpy(), vb_samples['ode_params3'][:, None].detach().numpy()), axis=1) print('Using VJP by Adjoint Sensitivity') lna_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 6, 3, \ times, 1e-5, 1e-6, [100, 100, 0, 0 ,0 ,0]) vb_samples = run_inference(Y, LNAGenModel, lna_ode_model, method, iterations=args.iterations, \ lr = lr, num_particles = 1, num_samples = args.num_qsamples, \ return_sites = ("ode_params1","ode_params2","ode_params3")) vb_params_adj = np.concatenate( (vb_samples['ode_params1'][:, None].detach().numpy(), vb_samples['ode_params2'][:, None].detach().numpy(), vb_samples['ode_params3'][:, None].detach().numpy()), axis=1) plot_marginals(vb_params_for, vb_params_adj, param_names, real_params=real_params, rows=2)
method = 'VI' lr = 0.5 vb_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \ iterations=args.iterations, num_samples=args.num_qsamples, \ lr=lr, num_particles=1, return_sites=("ode_params1", "ode_params2", "ode_params3")) vb_params = np.concatenate( (vb_samples['ode_params1'][:, None].detach().numpy(), vb_samples['ode_params2'][:, None].detach().numpy(), vb_samples['ode_params3'][:, None].detach().numpy()), axis=1) # plot_marginals(vb_params, mc_params, param_names, rows=2) plot_marginals(vb_params, vb_params, param_names, rows=2) else: print('Using VJP by Adjoint Sensitivity') sir_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 3, 5, \ times, 1e-5, 1e-6, [0.9, 0.1, 0.0]) sir_ode_model.set_unknown_y0() # method = 'NUTS' # NUTS_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \ # iterations = args.num_samples, warmup_steps = args.warmup_steps) # mc_params=np.concatenate((NUTS_samples['ode_params1'][:,None], # NUTS_samples['ode_params2'][:,None], # NUTS_samples['ode_params3'][:,None] # ),axis=1) method = 'VI' lr = 0.5 vb_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \ iterations=args.iterations, num_samples=args.num_qsamples, \ lr=lr, num_particles=1, return_sites=("ode_params1", "ode_params2", "ode_params3")) vb_params = np.concatenate(
iterations=args.iterations, num_samples=args.num_qsamples, lr=lr, num_particles=num_particles, return_sites=("ode_params1", "ode_params2")) vb_params = np.concatenate( (vb_samples['ode_params1'][:, None].detach().numpy(), vb_samples['ode_params2'][:, None].detach().numpy()), axis=1) # plot_marginals(vb_params, mc_params, param_names, rows=2) plot_marginals(vb_params, vb_params, param_names, rows=2) else: print('Using VJP by Adjoint Sensitivity') plant_ode_model = AdjointSensManualJacobians(rhs_f, jac_x_f, jac_p_f, 2, 2, times, 1e-5, 1e-6, [0.237939, 0.021049]) # plant_ode_model.set_unknown_y0() # method = 'NUTS' # NUTS_samples = run_inference(Y, SIRGenModel, sir_ode_model, method, \ # iterations = args.num_samples, warmup_steps = args.warmup_steps) # mc_params=np.concatenate((NUTS_samples['ode_params1'][:,None], # NUTS_samples['ode_params2'][:,None], # NUTS_samples['ode_params3'][:,None] # ),axis=1) method = 'VI' lr = 0.5 num_particles = 1 vb_samples = run_inference(Y,