def stochastic_variational_inference(joint_model, number_iterations, number_samples, optimizer=chainer.optimizers.Adam(0.001), input_values={}): """ Summary Parameters --------- """ joint_model.update_observed_submodel() #TODO: Probably not here posterior_model = joint_model.posterior_model joint_optimizer = ProbabilisticOptimizer(joint_model, optimizer) posterior_optimizer = ProbabilisticOptimizer(posterior_model, optimizer) #TODO: These things should not be here, maybe they should be inherited loss_list = [] for iteration in tqdm(range(number_iterations)): loss = -joint_model.estimate_log_model_evidence(number_samples=number_samples, method="ELBO", input_values=input_values) if np.isfinite(loss.data).all(): posterior_optimizer.chain.cleargrads() joint_optimizer.chain.cleargrads() loss.backward() joint_optimizer.update() posterior_optimizer.update() loss_list.append(loss.data) else: warnings.warn("Numerical error, skipping sample") joint_model.diagnostics.update({"loss curve": np.array(loss_list)})
def append_prob_optimizer(model, optimizer, **opt_params): prob_opt = ProbabilisticOptimizer( model, optimizer, **opt_params ) # TODO: this should be better! handling models with no params if prob_opt.optimizer: optimizers_list.append(prob_opt)
def perform_inference(joint_model, number_iterations, number_samples=1, optimizer='Adam', input_values={}, inference_method=None, posterior_model=None, sampler_model=None, pretraining_iterations=0, VI_opt_params=None, ML_opt_params=None, sampler_opt_params=None, RL_opt_params=None, **opt_params): #TODO: input values """ Summary Parameters --------- """ if isinstance(joint_model, StochasticProcess): posterior_submodel = joint_model.active_posterior_submodel joint_submodel = joint_model.active_submodel if joint_model.posterior_process is not None: joint_submodel.set_posterior_model(posterior_submodel) joint_model = joint_submodel if isinstance(joint_model, Variable): joint_model = ProbabilisticModel([joint_model]) if not inference_method: warnings.warn( "The inference method was not specified, using the default reverse KL variational inference" ) inference_method = ReverseKL() if posterior_model is None and joint_model.posterior_model is not None: posterior_model = joint_model.posterior_model if posterior_model is None: posterior_model = inference_method.construct_posterior_model( joint_model) if not sampler_model: #TODO: clean up if not sampler_model: try: sampler_model = inference_method.sampler_model except AttributeError: try: sampler_model = joint_model.posterior_sampler except AttributeError: sampler_model = None joint_model.update_observed_submodel() def append_prob_optimizer(model, optimizer, **opt_params): prob_opt = ProbabilisticOptimizer( model, optimizer, **opt_params ) # TODO: this should be better! handling models with no params if prob_opt.optimizer: optimizers_list.append(prob_opt) if VI_opt_params is None: VI_opt_params = opt_params if ML_opt_params is None: ML_opt_params = opt_params if sampler_opt_params is None: sampler_opt_params = opt_params if RL_opt_params is None: RL_opt_params = opt_params optimizers_list = [] if inference_method.learnable_posterior: append_prob_optimizer([ var for var in posterior_model.variables if not var.is_policy and not var.is_reward ], optimizer, **VI_opt_params) if inference_method.learnable_model: append_prob_optimizer([ var for var in joint_model.variables if not var.is_policy and not var.is_reward ], optimizer, **ML_opt_params) if inference_method.learnable_sampler: append_prob_optimizer([ var for var in sampler_model.variables if not var.is_policy and not var.is_reward ], optimizer, **sampler_opt_params) policy_variables = [var for var in joint_model.variables if var.is_policy] if policy_variables: policy_optimizer = ProbabilisticOptimizer(policy_variables, optimizer, **RL_opt_params) loss_list = [] reward_list = [] inference_method.check_model_compatibility(joint_model, posterior_model, sampler_model) for iteration in tqdm(range(number_iterations)): loss = inference_method.compute_loss(joint_model, posterior_model, sampler_model, number_samples) if torch.isfinite(loss.detach()).all().item(): # Inference if optimizers_list: [opt.zero_grad() for opt in optimizers_list] loss.backward() inference_method.correct_gradient(joint_model, posterior_model, sampler_model, number_samples) optimizers_list[0].update() if iteration > pretraining_iterations: [opt.update() for opt in optimizers_list[1:]] loss_list.append(loss.cpu().detach().numpy().flatten()) # Control if policy_variables: [opt.zero_grad() for opt in optimizers_list] policy_optimizer.zero_grad() reward = joint_model.get_average_reward(number_samples) (-reward).backward() policy_optimizer.update() reward_list.append(reward.cpu().detach().numpy().flatten()) else: warnings.warn("Numerical error, skipping sample") loss_list.append(loss.cpu().detach().numpy()) joint_model.diagnostics.update({"loss curve": np.array(loss_list)}) if policy_variables: joint_model.diagnostics.update({"reward curve": np.array(reward_list)}) inference_method.post_process(joint_model) if joint_model.posterior_model is None and inference_method.learnable_posterior: inference_method.set_posterior_model_after_inference( joint_model, posterior_model, sampler_model)