Exemple #1
0
 def test_copy(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     d_copy = d.copy()
     self.assertTrue(np.array_equal(d_copy.values, d.values))
     d_copy.values = d_copy.values * 2
     self.assertFalse(np.array_equal(d_copy.values, d.values))
Exemple #2
0
 def test_normalize_multi_factor(self):
     values_1 = np.random.rand(5)
     values_2 = np.random.rand(4, 3)
     values = np.array([values_1, values_2])
     d = Dirichlet(values=values)
     normed = Categorical(values=d.mean(return_numpy=True))
     self.assertTrue(normed.is_normalized())
    def test_update_pB_multiFactor_withActions_someFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are actions. Some factors are updated
        """

        n_states = [3, 4, 5]
        n_control = [3, 4, 5]
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        factors_to_update = [0,1]

        B = Categorical(values = construct_generic_B(n_states,n_control))
        B.normalize()
        pB = Dirichlet(values = construct_pB(n_states,n_control))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors=factors_to_update,return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):

            validation_pB = pB[factor].copy()

            if factor in factors_to_update:
                validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
                
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
    def test_update_pB_multiFactor_noActions_allFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are no actions. All factors are updated
        """

        n_states = [3, 4]
        n_control = [1, 1] 
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        B = Categorical(values = np.array([np.random.rand(ns, ns, n_control[factor]) for factor, ns in enumerate(n_states)]))
        B.normalize()
        pB = Dirichlet(values = np.array([np.ones_like(B[factor].values) for factor in range(len(n_states))]))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors="all",return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
Exemple #5
0
    def test_expectation_single_factor(self):
        """ tests implementation of expect_log method against matlab version (single factor)
        """

        array_path = os.path.join(os.getcwd(), "tests/data/wnorm_a.mat")
        mat_contents = loadmat(file_name=array_path)
        result = mat_contents["result"]

        d = Dirichlet(values=mat_contents["A"])
        result_py = d.expectation_of_log(return_numpy=True)
        self.assertTrue(np.isclose(result, result_py).all())
Exemple #6
0
    def test_expectation_multi_factor(self):
        """ tests implementation of expect_log method against matlab version (multi factor)
        """

        array_path = os.path.join(os.getcwd(), "tests/data/wnorm_b.mat")
        mat_contents = loadmat(file_name=array_path)
        result_1 = mat_contents["result_1"]
        result_2 = mat_contents["result_2"]

        d = Dirichlet(values=mat_contents["A"][0])
        result_py = d.expectation_of_log(return_numpy=True)

        self.assertTrue(
            np.isclose(result_1, result_py[0]).all()
            and np.isclose(result_2, result_py[1]).all())
    def test_update_pA_multiFactor_somemodalities(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that SOME observation modalities are updated and the generative model 
        has multiple hidden state factors
        """

        n_states = [2, 6]
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0
    
        # multiple observation modalities
        num_obs = [3,4,5]

        modalities_to_update = [0, 2]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities=modalities_to_update,return_numpy=True)

        for modality, no in enumerate(num_obs):
            
            if modality in modalities_to_update:
                validation_pA = pA[modality] + learning_rate * core.spm_cross(np.eye(no)[observation[modality]], qs.values)
            else:
                validation_pA = pA[modality]
            self.assertTrue(np.all(pA_updated[modality]==validation_pA.values))
Exemple #8
0
 def test_multi_factor_init_values_expand(self):
     values_1 = np.random.rand(5)
     values_2 = np.random.rand(4)
     values = np.array([values_1, values_2])
     d = Dirichlet(values=values)
     self.assertEqual(d.shape, (2, ))
     self.assertEqual(d[0].shape, (5, 1))
     self.assertEqual(d[1].shape, (4, 1))
Exemple #9
0
 def test_contains_zeros(self):
     values = np.array([[1.0, 0.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     self.assertTrue(d.contains_zeros())
     values = np.array([[1.0, 1.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     self.assertFalse(d.contains_zeros())
    def test_update_pB_multiFactor_someControllable_someFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and some of them 
        are controllable. Some factors are updated.
        """

        n_states = [3, 4, 5]
        n_control = [1, 3, 1]
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        factors_to_update = [0,1]

        B_values = np.empty(len(n_states),dtype=object)
        pB_values = np.empty(len(n_states),dtype=object)
        for factor, ns in enumerate(n_states):
            B_values[factor] = np.random.rand(ns, ns, n_control[factor])
            pB_values[factor] = np.ones( (ns, ns, n_control[factor]) )

        B = Categorical(values = B_values)
        B.normalize()
        pB = Dirichlet(values = pB_values)

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors=factors_to_update,return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):

            validation_pB = pB[factor].copy()

            if factor in factors_to_update:
                validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
                
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
    def test_update_pA_singleFactor_all(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that all observation modalities are updated and the generative model 
        has a single hidden state factor
        """

        n_states = [3]
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0
    
        # single observation modality
        num_obs = [4]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities="all",return_numpy=True)

        validation_pA = pA + learning_rate * core.spm_cross(np.eye(*num_obs)[observation], qs.values)
        self.assertTrue(np.all(pA_updated==validation_pA.values))

        # multiple observation modalities
        num_obs = [3,4]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities="all",return_numpy=True)

        for modality, no in enumerate(num_obs):

            validation_pA = pA[modality] + learning_rate * core.spm_cross(np.eye(no)[observation[modality]], qs.values)
            self.assertTrue(np.all(pA_updated[modality]==validation_pA.values))
    def test_update_pB_singleFactor_withActions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are actions.
        """

        n_states = [3]
        n_control = [3] 
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        B = Categorical(values = construct_generic_B(n_states, n_control))
        pB = Dirichlet(values = np.ones_like(B.values))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors="all",return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:,:,action[0]] += learning_rate * core.spm_cross(qs.values, qs_prev.values) * (B[:, :, action[0]].values > 0)
        self.assertTrue(np.all(pB_updated==validation_pB.values))
    def test_pA_info_gain(self):
        """
        Test the pA_info_gain function. Demonstrates operation
        by manipulating shape of the Dirichlet priors over likelihood parameters
        (pA), which affects information gain for different expected observations
        """

        n_states = [2]
        n_control = [2]

        qs = Categorical(values=np.eye(n_states[0])[0])

        B = Categorical(values=construct_generic_B(n_states, n_control))

        # single timestep
        n_step = 1
        policies = core.construct_policies(n_states,
                                           n_control,
                                           policy_len=n_step)

        # single observation modality
        num_obs = [2]

        # create noiseless identity A matrix
        A = Categorical(values=np.eye(num_obs[0]))

        # create prior over dirichlets such that there is a skew
        # in the parameters about the likelihood mapping from the
        # second hidden state (index 1) to observations, such that one
        # observation is considered to be more likely than the other conditioned on that state.
        # Therefore sampling that observation would afford high info gain
        # about parameters for that part of the likelhood distribution.

        pA_matrix = construct_pA(num_obs, n_states)
        pA_matrix[0, 1] = 2.0
        pA = Dirichlet(values=pA_matrix)

        pA_info_gains = np.zeros(len(policies))

        for idx, policy in enumerate(policies):

            qs_pi = core.get_expected_states(qs, B, policy)
            qo_pi = core.get_expected_obs(qs_pi, A)

            pA_info_gains[idx] += core.calc_pA_info_gain(pA, qo_pi, qs_pi)

        self.assertGreater(pA_info_gains[1], pA_info_gains[0])
    def test_pB_info_gain(self):
        """
        Test the pB_info_gain function. Demonstrates operation
        by manipulating shape of the Dirichlet priors over likelihood parameters
        (pB), which affects information gain for different states
        """

        n_states = [2]
        n_control = [2]

        qs = Categorical(values=np.eye(n_states[0])[0])

        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB_matrix = construct_pB(n_states, n_control)

        # create prior over dirichlets such that there is a skew
        # in the parameters about the likelihood mapping from the
        # hidden states to hidden states under the second action,
        # such that hidden state 0 is considered to be more likely than the other,
        # given the action in question
        # Therefore taking that action would yield an expected state that afford
        # high information gain about that part of the likelihood distribution.
        #
        pB_matrix[0, :, 1] = 2.0
        pB = Dirichlet(values=pB_matrix)

        # single timestep
        n_step = 1
        policies = core.construct_policies(n_states,
                                           n_control,
                                           policy_len=n_step)

        pB_info_gains = np.zeros(len(policies))

        for idx, policy in enumerate(policies):

            qs_pi = core.get_expected_states(qs, B, policy)

            pB_info_gains[idx] += core.calc_pB_info_gain(pB, qs_pi, qs, policy)

        self.assertGreater(pB_info_gains[1], pB_info_gains[0])
Exemple #15
0
 def test_ndim(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     self.assertEqual(d.ndim, d.values.ndim)
Exemple #16
0
    def __init__(self,
                 A=None,
                 pA=None,
                 B=None,
                 pB=None,
                 C=None,
                 D=None,
                 n_states=None,
                 n_observations=None,
                 n_controls=None,
                 policy_len=1,
                 control_fac_idx=None,
                 policies=None,
                 gamma=16.0,
                 use_utility=True,
                 use_states_info_gain=True,
                 use_param_info_gain=False,
                 action_sampling="marginal_action",
                 inference_algo="FPI",
                 inference_params=None,
                 modalities_to_learn="all",
                 lr_pA=1.0,
                 factors_to_learn="all",
                 lr_pB=1.0):

        ### Constant parameters ###

        # policy parameters
        self.policy_len = policy_len
        self.gamma = gamma
        self.action_sampling = action_sampling
        self.use_utility = use_utility
        self.use_states_info_gain = use_states_info_gain
        self.use_param_info_gain = use_param_info_gain

        # learning parameters
        self.modalities_to_learn = modalities_to_learn
        self.lr_pA = lr_pA
        self.factors_to_learn = factors_to_learn
        self.lr_pB = lr_pB
        """ Initialise observation model (A matrices) """
        if A is not None:
            # Create `Categorical`
            if not isinstance(A, Categorical):
                self.A = Categorical(values=A)
            else:
                self.A = A

            # Determine number of modalities and observations
            if self.A.IS_AOA:
                self.n_modalities = self.A.shape[0]
                self.n_observations = [
                    self.A[modality].shape[0]
                    for modality in range(self.n_modalities)
                ]
            else:
                self.n_modalities = 1
                self.n_observations = [self.A.shape[0]]
            construct_A_flag = False
        else:

            # If A is none, we randomly initialise the matrix. This requires some information
            if n_observations is None:
                raise ValueError(
                    "Must provide either `A` or `n_observations` to `Agent` constructor"
                )
            self.n_observations = n_observations
            self.n_modalities = len(self.n_observations)
            construct_A_flag = True
        """ Initialise prior Dirichlet parameters on observation model (pA matrices) """
        if pA is not None:
            if not isinstance(pA, Dirichlet):
                self.pA = Dirichlet(values=pA)
            else:
                self.pA = pA
        else:
            self.pA = None
        """ Initialise transition model (B matrices) """
        if B is not None:
            if not isinstance(B, Categorical):
                self.B = Categorical(values=B)
            else:
                self.B = B

            # Same logic as before, but here we need number of factors and states per factor
            if self.B.IS_AOA:
                self.n_factors = self.B.shape[0]
                self.n_states = [
                    self.B[f].shape[0] for f in range(self.n_factors)
                ]
            else:
                self.n_factors = 1
                self.n_states = [self.B.shape[0]]
            construct_B_flag = False
        else:
            if n_states is None:
                raise ValueError(
                    "Must provide either `B` or `n_states` to `Agent` constructor"
                )
            self.n_states = n_states
            self.n_factors = len(self.n_factors)
            construct_B_flag = True
        """ Initialise prior Dirichlet parameters on transition model (pB matrices) """
        if pB is not None:
            if not isinstance(pB, Dirichlet):
                self.pB = Dirichlet(values=pA)
            else:
                self.pB = pB
        else:
            self.pB = None

        # Users have the option to make only certain factors controllable.
        # default behaviour is to make all hidden state factors controllable
        # (i.e. self.n_states == self.n_controls)
        if control_fac_idx is None:
            self.control_fac_idx = list(range(self.n_factors))
        else:
            self.control_fac_idx = control_fac_idx

        # The user can specify the number of control states
        # However, given the controllable factors, this can be inferred
        if n_controls is None:
            _, self.n_controls = self._construct_n_controls()
        else:
            self.n_controls = n_controls

        # Again, the use can specify a set of possible policies, or
        # all possible combinations of actions and timesteps will be considered
        if policies is None:
            self.policies, _ = self._construct_n_controls()
        else:
            self.policies = policies

        # Construct prior preferences (uniform if not specified)
        if C is not None:
            if isinstance(C, Categorical):
                self.C = C
            else:
                self.C = Categorical(values=C)
        else:
            self.C = self._construct_C_prior()

        # Construct initial beliefs (uniform if not specified)
        if D is not None:
            if isinstance(D, Categorical):
                self.D = D
            else:
                self.D = Categorical(values=D)
        else:
            self.D = self._construct_D_prior()

        # Build model
        if construct_A_flag:
            self.A = self._construct_A_distribution()
        if construct_B_flag:
            self.B = self._construct_B_distribution()

        if inference_algo is None:
            self.inference_algo = "FPI"
            self.inference_params = self._get_default_params()
        else:
            self.inference_algo = inference_algo
            self.inference_params = self._get_default_params()

        self.qs = self.D
        self.action = None
Exemple #17
0
 def test_multi_factor_init_dims(self):
     d = Dirichlet(dims=[[5, 4], [4, 3]])
     self.assertEqual(d.shape, (2, ))
     self.assertEqual(d[0].shape, (5, 4))
     self.assertEqual(d[1].shape, (4, 3))
Exemple #18
0
 def test_init_dims_int_expand(self):
     d = Dirichlet(dims=5)
     self.assertEqual(d.shape, (5, 1))
Exemple #19
0
 def test_float_conversion(self):
     values = np.array([2, 3])
     self.assertEqual(values.dtype, np.int)
     d = Dirichlet(values=values)
     self.assertEqual(d.values.dtype, np.float64)
Exemple #20
0
 def test_init_overload(self):
     with self.assertRaises(ValueError):
         values = np.random.rand(3, 2)
         _ = Dirichlet(dims=2, values=values)
Exemple #21
0
 def test_init_empty(self):
     d = Dirichlet()
     self.assertEqual(d.ndim, 2)
Exemple #22
0
 def test_log(self):
     values = np.random.rand(3, 2)
     log_values = np.log(values)
     d = Dirichlet(values=values)
     self.assertTrue(np.array_equal(d.log(return_numpy=True), log_values))
Exemple #23
0
 def test_normalize_two_dim(self):
     values = np.array([[1.0, 1.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     expected_values = np.array([[0.5, 0.5], [0.5, 0.5]])
     self.assertTrue(
         np.array_equal(d.mean(return_numpy=True), expected_values))
    def test_multistep_multifac_posteriorPolicies(self):
        """
        Test for computing posterior over policies (and associated expected free energies)
        in the case of a posterior over hidden states with multiple hidden state factors. 
        This version tests using a policy horizon of 3 steps ahead
        """

        n_states = [3, 4]
        n_control = [3, 4]

        qs = Categorical(values=construct_init_qs(n_states))
        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB = Dirichlet(values=construct_pB(n_states, n_control))

        # single timestep
        n_step = 3
        policies = core.construct_policies(n_states,
                                           n_control,
                                           policy_len=n_step)

        # single observation modality
        num_obs = [4]

        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        C = Categorical(values=construct_generic_C(num_obs))

        q_pi, efe = core.update_posterior_policies(qs,
                                                   A,
                                                   B,
                                                   C,
                                                   policies,
                                                   use_utility=True,
                                                   use_states_info_gain=True,
                                                   use_param_info_gain=True,
                                                   pA=pA,
                                                   pB=pB,
                                                   gamma=16.0,
                                                   return_numpy=True)

        self.assertEqual(len(q_pi), len(policies))
        self.assertEqual(len(efe), len(policies))

        # multiple observation modalities
        num_obs = [3, 2]

        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        C = Categorical(values=construct_generic_C(num_obs))

        q_pi, efe = core.update_posterior_policies(qs,
                                                   A,
                                                   B,
                                                   C,
                                                   policies,
                                                   use_utility=True,
                                                   use_states_info_gain=True,
                                                   use_param_info_gain=True,
                                                   pA=pA,
                                                   pB=pB,
                                                   gamma=16.0,
                                                   return_numpy=True)

        self.assertEqual(len(q_pi), len(policies))
        self.assertEqual(len(efe), len(policies))
Exemple #25
0
 def test_remove_zeros(self):
     values = np.array([[1.0, 0.0], [1.0, 1.0]])
     d = Dirichlet(values=values)
     self.assertTrue((d.values == 0.0).any())
     d.remove_zeros()
     self.assertFalse((d.values == 0.0).any())
Exemple #26
0
 def test_shape(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     self.assertEqual(d.shape, (3, 2))
Exemple #27
0
def to_dirichlet(values):
    return Dirichlet(values=values)
# Prior
We initialise the agent's prior over hidden states at the start to be a flat distribution (i.e. the agent has no strong beliefs about what state it is starting in)

# Posterior (recognition density)
We initialise the posterior beliefs `qs` about hidden states (namely, beliefs about 'where I am') as a flat distribution over the possible states. This requires the
agent to first gather a proprioceptive observation from the environment (e.g. a bodily sensation of where it feels itself to be) before updating its posterior to be centered
on the true, evidence-supported location.
"""

likelihood_matrix = env.get_likelihood_dist()
A = Categorical(values=likelihood_matrix)
A.remove_zeros()
plot_likelihood(A, 'Observation likelihood')

b = Dirichlet(values=np.ones((n_states, n_states)))
B = b.mean()
plot_likelihood(B, 'Initial transition likelihood')

D = Categorical(values=np.ones(n_states))
D.normalize()

qs = Categorical(dims=[env.n_states])
qs.normalize()
"""
Run the dynamics of the environment and inference. Start by eliciting one observation from the environment
"""

# reset environment
first_state = env.reset(init_state=s[0])