Esempio n. 1
0
    def test_update_pB_single_factor_no_actions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are no actions.
        """

        n_states = [3]
        n_control = [
            1
        ]  # this is how we encode the fact that there aren't any actions
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(
            values=np.random.rand(n_states[0], n_states[0], n_control[0]))
        B.normalize()
        pB = Dirichlet(values=np.ones_like(B.values))
        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = learning.update_transition_dirichlet(pB,
                                                          B,
                                                          action,
                                                          qs,
                                                          qs_prev,
                                                          lr=l_rate,
                                                          factors="all",
                                                          return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:, :, 0] += (l_rate *
                                   maths.spm_cross(qs.values, qs_prev.values) *
                                   (B[:, :, action[0]].values > 0))
        self.assertTrue(np.all(pB_updated == validation_pB.values))
Esempio n. 2
0
 def test_normalize_multi_factor(self):
     values_1 = np.random.rand(5)
     values_2 = np.random.rand(4, 3)
     values = np.array([values_1, values_2], dtype=object)
     d = Dirichlet(values=values)
     normed = Categorical(values=d.mean(return_numpy=True))
     self.assertTrue(normed.is_normalized())
Esempio n. 3
0
 def test_copy(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     d_copy = d.copy()
     self.assertTrue(np.array_equal(d_copy.values, d.values))
     d_copy.values = d_copy.values * 2
     self.assertFalse(np.array_equal(d_copy.values, d.values))
Esempio n. 4
0
    def test_update_pB_multi_factor_with_actions_all_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are actions. All factors are updated
        """

        n_states = [3, 4, 5]
        n_control = [3, 4, 5]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        B.normalize()
        pB = Dirichlet(values=construct_pB(n_states, n_control))
        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,
                                                      B,
                                                      action,
                                                      qs,
                                                      qs_prev,
                                                      lr=l_rate,
                                                      factors="all",
                                                      return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            validation_pB[:, :, action[factor]] += (
                l_rate *
                core.spm_cross(qs[factor].values, qs_prev[factor].values) *
                (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Esempio n. 5
0
    def test_update_pB_single_dactor_with_actions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are actions.
        """

        n_states = [3]
        n_control = [3]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB = Dirichlet(values=np.ones_like(B.values))
        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = core.update_transition_dirichlet(pB,
                                                      B,
                                                      action,
                                                      qs,
                                                      qs_prev,
                                                      lr=l_rate,
                                                      factors="all",
                                                      return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:, :, action[0]] += (
            l_rate * core.spm_cross(qs.values, qs_prev.values) *
            (B[:, :, action[0]].values > 0))
        self.assertTrue(np.all(pB_updated == validation_pB.values))
Esempio n. 6
0
    def test_pB_info_gain(self):
        """
        Test the pB_info_gain function. Demonstrates operation
        by manipulating shape of the Dirichlet priors over likelihood parameters
        (pB), which affects information gain for different states
        """
        n_states = [2]
        n_control = [2]
        qs = Categorical(values=np.eye(n_states[0])[0])
        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB_matrix = construct_pB(n_states, n_control)

        # create prior over dirichlets such that there is a skew
        # in the parameters about the likelihood mapping from the
        # hidden states to hidden states under the second action,
        # such that hidden state 0 is considered to be more likely than the other,
        # given the action in question
        # Therefore taking that action would yield an expected state that afford
        # high information gain about that part of the likelihood distribution.
        #
        pB_matrix[0, :, 1] = 2.0
        pB = Dirichlet(values=pB_matrix)

        # single timestep
        n_step = 1
        policies = core.construct_policies(n_states,
                                           n_control,
                                           policy_len=n_step)

        pB_info_gains = np.zeros(len(policies))
        for idx, policy in enumerate(policies):
            qs_pi = core.get_expected_states(qs, B, policy)
            pB_info_gains[idx] += core.calc_pB_info_gain(pB, qs_pi, qs, policy)
        self.assertGreater(pB_info_gains[1], pB_info_gains[0])
Esempio n. 7
0
    def test_update_pA_multi_factor_some_modalities(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that SOME observation modalities are updated and the generative model 
        has multiple hidden state factors
        """
        n_states = [2, 6]
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        # multiple observation modalities
        num_obs = [3, 4, 5]
        modalities_to_update = [0, 2]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = core.update_likelihood_dirichlet(
            pA,
            A,
            observation,
            qs,
            lr=l_rate,
            modalities=modalities_to_update,
            return_numpy=True)

        for modality, no in enumerate(num_obs):
            if modality in modalities_to_update:
                update = core.spm_cross(
                    np.eye(no)[observation[modality]], qs.values)
                validation_pA = pA[modality] + l_rate * update
            else:
                validation_pA = pA[modality]
            self.assertTrue(
                np.all(pA_updated[modality] == validation_pA.values))
Esempio n. 8
0
    def test_update_pA_single_factor_one_modality(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that ONE observation modalities is updated and the generative model 
        has a single hidden state factor
        """
        n_states = [3]
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        # multiple observation modalities
        num_obs = [3, 4]

        modality_to_update = [np.random.randint(len(num_obs))]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = learning.update_likelihood_dirichlet(
            pA,
            A,
            observation,
            qs,
            lr=l_rate,
            modalities=modality_to_update,
            return_numpy=True)

        for modality, no in enumerate(num_obs):
            if modality in modality_to_update:
                update = maths.spm_cross(
                    np.eye(no)[observation[modality]], qs.values)
                validation_pA = pA[modality] + l_rate * update
            else:
                validation_pA = pA[modality]
            self.assertTrue(
                np.all(pA_updated[modality] == validation_pA.values))
Esempio n. 9
0
    def test_update_pB_multi_factor_no_actions_one_factor(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are no actions. One factor is updated
        """
        n_states = [3, 4]
        n_control = [1, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [np.random.randint(len(n_states))]

        B = Categorical(values=np.array([
            np.random.rand(ns, ns, n_control[factor])
            for factor, ns in enumerate(n_states)
        ],
                                        dtype=object))
        B.normalize()
        pB = Dirichlet(values=np.array([
            np.ones_like(B[factor].values) for factor in range(len(n_states))
        ],
                                       dtype=object))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = learning.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate * maths.spm_cross(qs[factor].values,
                                             qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))

            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Esempio n. 10
0
 def test_multi_factor_init_values_expand(self):
     values_1 = np.random.rand(5)
     values_2 = np.random.rand(4)
     values = np.array([values_1, values_2], dtype=object)
     d = Dirichlet(values=values)
     self.assertEqual(d.shape, (2, ))
     self.assertEqual(d[0].shape, (5, 1))
     self.assertEqual(d[1].shape, (4, 1))
Esempio n. 11
0
    def test_update_pA_single_factor_all(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that all observation modalities are updated and the generative model 
        has a single hidden state factor
        """
        n_states = [3]
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        # single observation modality
        num_obs = [4]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))

        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = core.update_likelihood_dirichlet(pA,
                                                      A,
                                                      observation,
                                                      qs,
                                                      lr=l_rate,
                                                      modalities="all",
                                                      return_numpy=True)
        validation_pA = pA + l_rate * core.spm_cross(
            np.eye(*num_obs)[observation], qs.values)
        self.assertTrue(np.all(pA_updated == validation_pA.values))

        # multiple observation modalities
        num_obs = [3, 4]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = core.update_likelihood_dirichlet(pA,
                                                      A,
                                                      observation,
                                                      qs,
                                                      lr=l_rate,
                                                      modalities="all",
                                                      return_numpy=True)

        for modality, no in enumerate(num_obs):
            update = core.spm_cross(
                np.eye(no)[observation[modality]], qs.values)
            validation_pA = pA[modality] + l_rate * update
            self.assertTrue(
                np.all(pA_updated[modality] == validation_pA.values))
Esempio n. 12
0
 def test_multi_factor_init_values(self):
     values_1 = np.random.rand(5, 4)
     values_2 = np.random.rand(4, 3)
     values = np.array([values_1, values_2])
     d = Dirichlet(values=values)
     self.assertEqual(d.shape, (2, ))
     self.assertEqual(d[0].shape, (5, 4))
     self.assertEqual(d[1].shape, (4, 3))
Esempio n. 13
0
    def test_update_pB_multi_factor_some_controllable_some_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and some of them 
        are controllable. Some factors are updated.
        """

        n_states = [3, 4, 5]
        n_control = [1, 3, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [0, 1]
        B_values = np.empty(len(n_states), dtype=object)
        pB_values = np.empty(len(n_states), dtype=object)
        for factor, ns in enumerate(n_states):
            B_values[factor] = np.random.rand(ns, ns, n_control[factor])
            pB_values[factor] = np.ones((ns, ns, n_control[factor]))

        B = Categorical(values=B_values)
        B.normalize()
        pB = Dirichlet(values=pB_values)

        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = core.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate *
                    core.spm_cross(qs[factor].values, qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Esempio n. 14
0
 def test_contains_zeros(self):
     values = np.array([[1.0, 0.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     self.assertTrue(d.contains_zeros())
     values = np.array([[1.0, 1.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     self.assertFalse(d.contains_zeros())
Esempio n. 15
0
    def test_pA_info_gain(self):
        """
        Test the pA_info_gain function. Demonstrates operation
        by manipulating shape of the Dirichlet priors over likelihood parameters
        (pA), which affects information gain for different expected observations
        """
        n_states = [2]
        n_control = [2]

        qs = Categorical(values=np.eye(n_states[0])[0])

        B = Categorical(values=construct_generic_B(n_states, n_control))

        # single timestep
        n_step = 1
        policies = core.construct_policies(n_states,
                                           n_control,
                                           policy_len=n_step)

        # single observation modality
        num_obs = [2]

        # create noiseless identity A matrix
        A = Categorical(values=np.eye(num_obs[0]))

        # create prior over dirichlets such that there is a skew
        # in the parameters about the likelihood mapping from the
        # second hidden state (index 1) to observations, such that one
        # observation is considered to be more likely than the other conditioned on that state.
        # Therefore sampling that observation would afford high info gain
        # about parameters for that part of the likelhood distribution.

        pA_matrix = construct_pA(num_obs, n_states)
        pA_matrix[0, 1] = 2.0
        pA = Dirichlet(values=pA_matrix)

        pA_info_gains = np.zeros(len(policies))
        for idx, policy in enumerate(policies):
            qs_pi = core.get_expected_states(qs, B, policy)
            qo_pi = core.get_expected_obs(qs_pi, A)
            pA_info_gains[idx] += core.calc_pA_info_gain(pA, qo_pi, qs_pi)
        self.assertGreater(pA_info_gains[1], pA_info_gains[0])
Esempio n. 16
0
 def test_float_conversion(self):
     values = np.array([2, 3])
     self.assertEqual(values.dtype, np.int)
     d = Dirichlet(values=values)
     self.assertEqual(d.values.dtype, np.float64)
Esempio n. 17
0
# Prior
We initialise the agent's prior over hidden states at the start to be a flat distribution (i.e. the agent has no strong beliefs about what state it is starting in)

# Posterior (recognition density)
We initialise the posterior beliefs `qs` about hidden states (namely, beliefs about 'where I am') as a flat distribution over the possible states. This requires the
agent to first gather a proprioceptive observation from the environment (e.g. a bodily sensation of where it feels itself to be) before updating its posterior to be centered
on the true, evidence-supported location.
"""

likelihood_matrix = env.get_likelihood_dist()
A = Categorical(values=likelihood_matrix)
A.remove_zeros()
plot_likelihood(A, 'Observation likelihood')

b = Dirichlet(values=np.ones((n_states, n_states)))
B = b.mean()
plot_likelihood(B, 'Initial transition likelihood')

D = Categorical(values=np.ones(n_states))
D.normalize()

qs = Categorical(dims=[env.n_states])
qs.normalize()
"""
Run the dynamics of the environment and inference. Start by eliciting one observation from the environment
"""

# reset environment
first_state = env.reset(init_state=s[0])
Esempio n. 18
0
 def test_log(self):
     values = np.random.rand(3, 2)
     log_values = np.log(values)
     d = Dirichlet(values=values)
     self.assertTrue(np.array_equal(d.log(return_numpy=True), log_values))
Esempio n. 19
0
 def test_ndim(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     self.assertEqual(d.ndim, d.values.ndim)
Esempio n. 20
0
    def test_multistep_multi_factor_posterior_policies(self):
        """
        Test for computing posterior over policies (and associated expected free energies)
        in the case of a posterior over hidden states with multiple hidden state factors. 
        This version tests using a policy horizon of 3 steps ahead
        """
        n_states = [3, 4]
        n_control = [3, 4]

        qs = Categorical(values=construct_init_qs(n_states))
        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB = Dirichlet(values=construct_pB(n_states, n_control))

        # Single timestep
        n_step = 3
        policies = core.construct_policies(n_states,
                                           n_control,
                                           policy_len=n_step)

        # Single observation modality
        num_obs = [4]

        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        C = Categorical(values=construct_generic_C(num_obs))

        q_pi, efe = core.update_posterior_policies(
            qs,
            A,
            B,
            C,
            policies,
            use_utility=True,
            use_states_info_gain=True,
            use_param_info_gain=True,
            pA=pA,
            pB=pB,
            gamma=16.0,
            return_numpy=True,
        )

        self.assertEqual(len(q_pi), len(policies))  # type: ignore
        self.assertEqual(len(efe), len(policies))

        # Multiple observation modalities
        num_obs = [3, 2]

        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        C = Categorical(values=construct_generic_C(num_obs))

        q_pi, efe = core.update_posterior_policies(
            qs,
            A,
            B,
            C,
            policies,
            use_utility=True,
            use_states_info_gain=True,
            use_param_info_gain=True,
            pA=pA,
            pB=pB,
            gamma=16.0,
            return_numpy=True,
        )

        self.assertEqual(len(q_pi), len(policies))  # type: ignore
        self.assertEqual(len(efe), len(policies))
Esempio n. 21
0
 def test_normalize_two_dim(self):
     values = np.array([[1.0, 1.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     expected_values = np.array([[0.5, 0.5], [0.5, 0.5]])
     self.assertTrue(
         np.array_equal(d.mean(return_numpy=True), expected_values))
Esempio n. 22
0
 def test_remove_zeros(self):
     values = np.array([[1.0, 0.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     self.assertTrue((d.values == 0.0).any())
     d.remove_zeros()
     self.assertFalse((d.values == 0.0).any())
Esempio n. 23
0
 def test_shape(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     self.assertEqual(d.shape, (3, 2))
Esempio n. 24
0
 def test_init_empty(self):
     d = Dirichlet()
     self.assertEqual(d.ndim, 2)
Esempio n. 25
0
 def test_multi_factor_init_dims(self):
     d = Dirichlet(dims=[[5, 4], [4, 3]])
     self.assertEqual(d.shape, (2, ))
     self.assertEqual(d[0].shape, (5, 4))
     self.assertEqual(d[1].shape, (4, 3))
Esempio n. 26
0
def to_dirichlet(values):
    return Dirichlet(values=values)
Esempio n. 27
0
 def test_init_overload(self):
     with self.assertRaises(ValueError):
         values = np.random.rand(3, 2)
         _ = Dirichlet(dims=2, values=values)
Esempio n. 28
0
 def test_init_dims_int_expand(self):
     d = Dirichlet(dims=5)
     self.assertEqual(d.shape, (5, 1))
Esempio n. 29
0
    def __init__(self,
                 A=None,
                 pA=None,
                 B=None,
                 pB=None,
                 C=None,
                 D=None,
                 n_states=None,
                 n_observations=None,
                 n_controls=None,
                 policy_len=1,
                 inference_horizon=1,
                 control_fac_idx=None,
                 policies=None,
                 gamma=16.0,
                 use_utility=True,
                 use_states_info_gain=True,
                 use_param_info_gain=False,
                 action_sampling="marginal_action",
                 inference_algo="VANILLA",
                 inference_params=None,
                 modalities_to_learn="all",
                 lr_pA=1.0,
                 factors_to_learn="all",
                 lr_pB=1.0,
                 use_BMA=True,
                 policy_sep_prior=False):

        ### Constant parameters ###

        # policy parameters
        self.policy_len = policy_len
        self.gamma = gamma
        self.action_sampling = action_sampling
        self.use_utility = use_utility
        self.use_states_info_gain = use_states_info_gain
        self.use_param_info_gain = use_param_info_gain

        # learning parameters
        self.modalities_to_learn = modalities_to_learn
        self.lr_pA = lr_pA
        self.factors_to_learn = factors_to_learn
        self.lr_pB = lr_pB
        """ Initialise observation model (A matrices) """
        if A is not None:
            # Create `Categorical`
            if not isinstance(A, Categorical):
                self.A = Categorical(values=A)
            else:
                self.A = A

            # Determine number of modalities and observations
            if self.A.IS_AOA:
                self.n_modalities = self.A.shape[0]
                self.n_observations = [
                    self.A[modality].shape[0]
                    for modality in range(self.n_modalities)
                ]
            else:
                self.n_modalities = 1
                self.n_observations = [self.A.shape[0]]
            construct_A_flag = False
        else:

            # If A is none, we randomly initialise the matrix. This requires some information
            if n_observations is None:
                raise ValueError(
                    "Must provide either `A` or `n_observations` to `Agent` constructor"
                )
            self.n_observations = n_observations
            self.n_modalities = len(self.n_observations)
            construct_A_flag = True
        """ Initialise prior Dirichlet parameters on observation model (pA matrices) """
        if pA is not None:
            if not isinstance(pA, Dirichlet):
                self.pA = Dirichlet(values=pA)
            else:
                self.pA = pA
        else:
            self.pA = None
        """ Initialise transition model (B matrices) """
        if B is not None:
            if not isinstance(B, Categorical):
                self.B = Categorical(values=B)
            else:
                self.B = B

            # Same logic as before, but here we need number of factors and states per factor
            if self.B.IS_AOA:
                self.n_factors = self.B.shape[0]
                self.n_states = [
                    self.B[f].shape[0] for f in range(self.n_factors)
                ]
            else:
                self.n_factors = 1
                self.n_states = [self.B.shape[0]]
            construct_B_flag = False
        else:
            if n_states is None:
                raise ValueError(
                    "Must provide either `B` or `n_states` to `Agent` constructor"
                )
            self.n_states = n_states
            self.n_factors = len(self.n_factors)  #type: ignore
            construct_B_flag = True
        """ Initialise prior Dirichlet parameters on transition model (pB matrices) """
        if pB is not None:
            if not isinstance(pB, Dirichlet):
                self.pB = Dirichlet(values=pA)
            else:
                self.pB = pB
        else:
            self.pB = None

        # Users have the option to make only certain factors controllable.
        # default behaviour is to make all hidden state factors controllable
        # (i.e. self.n_states == self.n_controls)
        if control_fac_idx is None:
            self.control_fac_idx = list(range(self.n_factors))
        else:
            self.control_fac_idx = control_fac_idx

        # The user can specify the number of control states
        # However, given the controllable factors, this can be inferred
        if n_controls is None:
            _, self.n_controls = self._construct_n_controls()
        else:
            self.n_controls = n_controls

        # Again, the use can specify a set of possible policies, or
        # all possible combinations of actions and timesteps will be considered
        if policies is None:
            self.policies, _ = self._construct_n_controls()
        else:
            self.policies = policies

        # Construct prior preferences (uniform if not specified)
        if C is not None:
            if isinstance(C, Categorical):
                self.C = C
            else:
                self.C = Categorical(values=C)
        else:
            self.C = self._construct_C_prior()

        # Construct initial beliefs (uniform if not specified)
        if D is not None:
            if isinstance(D, Categorical):
                self.D = D
            else:
                self.D = Categorical(values=D)
        else:
            self.D = self._construct_D_prior()

        # Build model
        if construct_A_flag:
            self.A = self._construct_A_distribution()
        if construct_B_flag:
            self.B = self._construct_B_distribution()

        self.edge_handling_params = {}
        self.edge_handling_params['use_BMA'] = use_BMA
        self.edge_handling_params['policy_sep_prior'] = policy_sep_prior

        if inference_algo is None:
            self.inference_algo = "VANILLA"
            self.inference_params = self._get_default_params()
            if inference_horizon > 1:
                print(
                    "WARNING: if `inference_algo` is VANILLA, then inference_horizon must be 1\n. \
                    Setting inference_horizon to default value of 1...\n")
            else:
                self.inference_horizon = 1
        else:
            self.inference_algo = inference_algo
            self.inference_params = self._get_default_params()
            self.inference_horizon = inference_horizon

        self.prev_obs = []
        self.reset()

        self.action = None
        self.prev_actions = None