NormalVariable(new_my,
                       new_sy,
                       "muy_{}".format(t + 1),
                       is_observed=True))

    mx.append(new_mx)
    my.append(new_my)
    sx.append(new_sx)
    sy.append(new_sy)

model = ProbabilisticModel(w + r + mux + muy)
variational_filter = ProbabilisticModel(Qmux + Qmuy)

# Variational model

print(model.get_average_reward(10))

# Train control
num_itr = 3000
inference.perform_inference(model,
                            posterior_model=variational_filter,
                            number_iterations=num_itr,
                            number_samples=9,
                            optimizer="Adam",
                            lr=0.01)
reward_list = model.diagnostics[
    "reward curve"]  #TODO: Very important. Solve the trained determinant problem. (it should be possible to specify which parameter is trainable)

print(model.get_sample(20)[["r"]])

plt.plot(reward_list)
Esempio n. 2
0
def perform_inference(joint_model,
                      number_iterations,
                      number_samples=1,
                      optimizer='Adam',
                      input_values={},
                      inference_method=None,
                      posterior_model=None,
                      sampler_model=None,
                      pretraining_iterations=0,
                      VI_opt_params=None,
                      ML_opt_params=None,
                      sampler_opt_params=None,
                      RL_opt_params=None,
                      **opt_params):  #TODO: input values
    """
    Summary

    Parameters
    ---------
    """
    if isinstance(joint_model, StochasticProcess):
        posterior_submodel = joint_model.active_posterior_submodel
        joint_submodel = joint_model.active_submodel
        if joint_model.posterior_process is not None:
            joint_submodel.set_posterior_model(posterior_submodel)
        joint_model = joint_submodel
    if isinstance(joint_model, Variable):
        joint_model = ProbabilisticModel([joint_model])
    if not inference_method:
        warnings.warn(
            "The inference method was not specified, using the default reverse KL variational inference"
        )
        inference_method = ReverseKL()
    if posterior_model is None and joint_model.posterior_model is not None:
        posterior_model = joint_model.posterior_model
    if posterior_model is None:
        posterior_model = inference_method.construct_posterior_model(
            joint_model)
    if not sampler_model:  #TODO: clean up
        if not sampler_model:
            try:
                sampler_model = inference_method.sampler_model
            except AttributeError:
                try:
                    sampler_model = joint_model.posterior_sampler
                except AttributeError:
                    sampler_model = None

    joint_model.update_observed_submodel()

    def append_prob_optimizer(model, optimizer, **opt_params):
        prob_opt = ProbabilisticOptimizer(
            model, optimizer, **opt_params
        )  # TODO: this should be better! handling models with no params
        if prob_opt.optimizer:
            optimizers_list.append(prob_opt)

    if VI_opt_params is None:
        VI_opt_params = opt_params
    if ML_opt_params is None:
        ML_opt_params = opt_params
    if sampler_opt_params is None:
        sampler_opt_params = opt_params
    if RL_opt_params is None:
        RL_opt_params = opt_params

    optimizers_list = []
    if inference_method.learnable_posterior:
        append_prob_optimizer([
            var for var in posterior_model.variables
            if not var.is_policy and not var.is_reward
        ], optimizer, **VI_opt_params)
    if inference_method.learnable_model:
        append_prob_optimizer([
            var for var in joint_model.variables
            if not var.is_policy and not var.is_reward
        ], optimizer, **ML_opt_params)
    if inference_method.learnable_sampler:
        append_prob_optimizer([
            var for var in sampler_model.variables
            if not var.is_policy and not var.is_reward
        ], optimizer, **sampler_opt_params)

    policy_variables = [var for var in joint_model.variables if var.is_policy]
    if policy_variables:
        policy_optimizer = ProbabilisticOptimizer(policy_variables, optimizer,
                                                  **RL_opt_params)

    loss_list = []
    reward_list = []

    inference_method.check_model_compatibility(joint_model, posterior_model,
                                               sampler_model)

    for iteration in tqdm(range(number_iterations)):
        loss = inference_method.compute_loss(joint_model, posterior_model,
                                             sampler_model, number_samples)

        if torch.isfinite(loss.detach()).all().item():
            # Inference
            if optimizers_list:
                [opt.zero_grad() for opt in optimizers_list]
                loss.backward()
                inference_method.correct_gradient(joint_model, posterior_model,
                                                  sampler_model,
                                                  number_samples)
                optimizers_list[0].update()
                if iteration > pretraining_iterations:
                    [opt.update() for opt in optimizers_list[1:]]
                loss_list.append(loss.cpu().detach().numpy().flatten())

            # Control
            if policy_variables:
                [opt.zero_grad() for opt in optimizers_list]
                policy_optimizer.zero_grad()
                reward = joint_model.get_average_reward(number_samples)
                (-reward).backward()
                policy_optimizer.update()
                reward_list.append(reward.cpu().detach().numpy().flatten())
        else:
            warnings.warn("Numerical error, skipping sample")
        loss_list.append(loss.cpu().detach().numpy())
    joint_model.diagnostics.update({"loss curve": np.array(loss_list)})
    if policy_variables:
        joint_model.diagnostics.update({"reward curve": np.array(reward_list)})

    inference_method.post_process(joint_model)

    if joint_model.posterior_model is None and inference_method.learnable_posterior:
        inference_method.set_posterior_model_after_inference(
            joint_model, posterior_model, sampler_model)