Example #1
0
obs = env.reset()

msg = """ === Starting experiment === \n True scene: {} Initial observation {} """
print(msg.format(env.true_scene, obs))
prior = env.get_uniform_posterior()

for t in range(T):

    qs = core.update_posterior_states(A, obs, prior, return_numpy=False)

    msg = """[{}] Inference [location {} / scene {}] 
     Observation [location {} / feature {}] """
    print(
        msg.format(t, np.argmax(qs[0].values), np.argmax(qs[1].values), obs[0],
                   obs[1]))

    q_pi, efe = core.update_posterior_policies(qs, A, B, C, policies)

    action = core.sample_action(q_pi,
                                policies,
                                env.n_control,
                                sampling_type="marginal_action")

    obs = env.step(action)

    msg = """[{}] Action [Saccade to location {}]"""
    print(msg.format(t, action[0]))

    prior = core.get_expected_states(qs, B.log(), action.reshape(1, -1))
Example #2
0
_, possible_policies = core.construct_policies(env.n_states, env.n_factors, [0], 1)

obs = env.reset()

msg = """ === Starting experiment === \n True scene: {} Initial observation {} """
print(msg.format(env.true_scene, obs))
prior = env.get_uniform_posterior()

for t in range(T):

    Qs = core.update_posterior_states(A, obs, prior, return_numpy=False)

    msg = """[{}] Inference [location {} / scene {}] 
     Observation [location {} / feature {}] """
    print(msg.format(t, Qs[0].sample(), Qs[1].sample(), obs[0], obs[1]))

    Q_pi, _ = core.update_posterior_policies(Qs, A, B, C, possible_policies)

    action = core.sample_action(
        Q_pi, possible_policies, env.n_control, sampling_type="marginal_action"
    )

    obs = env.step(action)

    msg = """[{}] Action [Saccade to location {}]"""
    print(msg.format(t, action[0]))

    prior = core.get_expected_states(Qs, B.log(), action)

    def test_multistep_multifac_posteriorPolicies(self):
        """
        Test for computing posterior over policies (and associated expected free energies)
        in the case of a posterior over hidden states with multiple hidden state factors. 
        This version tests using a policy horizon of 3 steps ahead
        """

        n_states = [3, 4]
        n_control = [3, 4]

        qs = Categorical(values=construct_init_qs(n_states))
        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB = Dirichlet(values=construct_pB(n_states, n_control))

        # single timestep
        n_step = 3
        policies = core.construct_policies(n_states,
                                           n_control,
                                           policy_len=n_step)

        # single observation modality
        num_obs = [4]

        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        C = Categorical(values=construct_generic_C(num_obs))

        q_pi, efe = core.update_posterior_policies(qs,
                                                   A,
                                                   B,
                                                   C,
                                                   policies,
                                                   use_utility=True,
                                                   use_states_info_gain=True,
                                                   use_param_info_gain=True,
                                                   pA=pA,
                                                   pB=pB,
                                                   gamma=16.0,
                                                   return_numpy=True)

        self.assertEqual(len(q_pi), len(policies))
        self.assertEqual(len(efe), len(policies))

        # multiple observation modalities
        num_obs = [3, 2]

        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        C = Categorical(values=construct_generic_C(num_obs))

        q_pi, efe = core.update_posterior_policies(qs,
                                                   A,
                                                   B,
                                                   C,
                                                   policies,
                                                   use_utility=True,
                                                   use_states_info_gain=True,
                                                   use_param_info_gain=True,
                                                   pA=pA,
                                                   pB=pB,
                                                   gamma=16.0,
                                                   return_numpy=True)

        self.assertEqual(len(q_pi), len(policies))
        self.assertEqual(len(efe), len(policies))