def test_update_pB_single_factor_no_actions(self): """ Test for updating prior Dirichlet parameters over transition likelihood (pB) in the case that the one and only hidden state factor is updated, and there are no actions. """ n_states = [3] n_control = [ 1 ] # this is how we encode the fact that there aren't any actions qs_prev = Categorical(values=construct_init_qs(n_states)) qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 B = Categorical( values=np.random.rand(n_states[0], n_states[0], n_control[0])) B.normalize() pB = Dirichlet(values=np.ones_like(B.values)) action = np.array([np.random.randint(nc) for nc in n_control]) pB_updated = learning.update_transition_dirichlet(pB, B, action, qs, qs_prev, lr=l_rate, factors="all", return_numpy=True) validation_pB = pB.copy() validation_pB[:, :, 0] += (l_rate * maths.spm_cross(qs.values, qs_prev.values) * (B[:, :, action[0]].values > 0)) self.assertTrue(np.all(pB_updated == validation_pB.values))
def test_normalize_multi_factor(self): values_1 = np.random.rand(5) values_2 = np.random.rand(4, 3) values = np.array([values_1, values_2], dtype=object) d = Dirichlet(values=values) normed = Categorical(values=d.mean(return_numpy=True)) self.assertTrue(normed.is_normalized())
def test_copy(self): values = np.random.rand(3, 2) d = Dirichlet(values=values) d_copy = d.copy() self.assertTrue(np.array_equal(d_copy.values, d.values)) d_copy.values = d_copy.values * 2 self.assertFalse(np.array_equal(d_copy.values, d.values))
def test_update_pB_multi_factor_with_actions_all_factors(self): """ Test for updating prior Dirichlet parameters over transition likelihood (pB) in the case that there are mulitple hidden state factors, and there are actions. All factors are updated """ n_states = [3, 4, 5] n_control = [3, 4, 5] qs_prev = Categorical(values=construct_init_qs(n_states)) qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 B = Categorical(values=construct_generic_B(n_states, n_control)) B.normalize() pB = Dirichlet(values=construct_pB(n_states, n_control)) action = np.array([np.random.randint(nc) for nc in n_control]) pB_updated = core.update_transition_dirichlet(pB, B, action, qs, qs_prev, lr=l_rate, factors="all", return_numpy=True) validation_pB = pB.copy() for factor, _ in enumerate(n_control): validation_pB = pB[factor].copy() validation_pB[:, :, action[factor]] += ( l_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)) self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
def test_update_pB_single_dactor_with_actions(self): """ Test for updating prior Dirichlet parameters over transition likelihood (pB) in the case that the one and only hidden state factor is updated, and there are actions. """ n_states = [3] n_control = [3] qs_prev = Categorical(values=construct_init_qs(n_states)) qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 B = Categorical(values=construct_generic_B(n_states, n_control)) pB = Dirichlet(values=np.ones_like(B.values)) action = np.array([np.random.randint(nc) for nc in n_control]) pB_updated = core.update_transition_dirichlet(pB, B, action, qs, qs_prev, lr=l_rate, factors="all", return_numpy=True) validation_pB = pB.copy() validation_pB[:, :, action[0]] += ( l_rate * core.spm_cross(qs.values, qs_prev.values) * (B[:, :, action[0]].values > 0)) self.assertTrue(np.all(pB_updated == validation_pB.values))
def test_pB_info_gain(self): """ Test the pB_info_gain function. Demonstrates operation by manipulating shape of the Dirichlet priors over likelihood parameters (pB), which affects information gain for different states """ n_states = [2] n_control = [2] qs = Categorical(values=np.eye(n_states[0])[0]) B = Categorical(values=construct_generic_B(n_states, n_control)) pB_matrix = construct_pB(n_states, n_control) # create prior over dirichlets such that there is a skew # in the parameters about the likelihood mapping from the # hidden states to hidden states under the second action, # such that hidden state 0 is considered to be more likely than the other, # given the action in question # Therefore taking that action would yield an expected state that afford # high information gain about that part of the likelihood distribution. # pB_matrix[0, :, 1] = 2.0 pB = Dirichlet(values=pB_matrix) # single timestep n_step = 1 policies = core.construct_policies(n_states, n_control, policy_len=n_step) pB_info_gains = np.zeros(len(policies)) for idx, policy in enumerate(policies): qs_pi = core.get_expected_states(qs, B, policy) pB_info_gains[idx] += core.calc_pB_info_gain(pB, qs_pi, qs, policy) self.assertGreater(pB_info_gains[1], pB_info_gains[0])
def test_update_pA_multi_factor_some_modalities(self): """ Test for updating prior Dirichlet parameters over sensory likelihood (pA) in the case that SOME observation modalities are updated and the generative model has multiple hidden state factors """ n_states = [2, 6] qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 # multiple observation modalities num_obs = [3, 4, 5] modalities_to_update = [0, 2] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) observation = A.dot(qs, return_numpy=False).sample() pA_updated = core.update_likelihood_dirichlet( pA, A, observation, qs, lr=l_rate, modalities=modalities_to_update, return_numpy=True) for modality, no in enumerate(num_obs): if modality in modalities_to_update: update = core.spm_cross( np.eye(no)[observation[modality]], qs.values) validation_pA = pA[modality] + l_rate * update else: validation_pA = pA[modality] self.assertTrue( np.all(pA_updated[modality] == validation_pA.values))
def test_update_pA_single_factor_one_modality(self): """ Test for updating prior Dirichlet parameters over sensory likelihood (pA) in the case that ONE observation modalities is updated and the generative model has a single hidden state factor """ n_states = [3] qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 # multiple observation modalities num_obs = [3, 4] modality_to_update = [np.random.randint(len(num_obs))] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) observation = A.dot(qs, return_numpy=False).sample() pA_updated = learning.update_likelihood_dirichlet( pA, A, observation, qs, lr=l_rate, modalities=modality_to_update, return_numpy=True) for modality, no in enumerate(num_obs): if modality in modality_to_update: update = maths.spm_cross( np.eye(no)[observation[modality]], qs.values) validation_pA = pA[modality] + l_rate * update else: validation_pA = pA[modality] self.assertTrue( np.all(pA_updated[modality] == validation_pA.values))
def test_update_pB_multi_factor_no_actions_one_factor(self): """ Test for updating prior Dirichlet parameters over transition likelihood (pB) in the case that there are mulitple hidden state factors, and there are no actions. One factor is updated """ n_states = [3, 4] n_control = [1, 1] qs_prev = Categorical(values=construct_init_qs(n_states)) qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 factors_to_update = [np.random.randint(len(n_states))] B = Categorical(values=np.array([ np.random.rand(ns, ns, n_control[factor]) for factor, ns in enumerate(n_states) ], dtype=object)) B.normalize() pB = Dirichlet(values=np.array([ np.ones_like(B[factor].values) for factor in range(len(n_states)) ], dtype=object)) action = np.array([np.random.randint(nc) for nc in n_control]) pB_updated = learning.update_transition_dirichlet( pB, B, action, qs, qs_prev, lr=l_rate, factors=factors_to_update, return_numpy=True) validation_pB = pB.copy() for factor, _ in enumerate(n_control): validation_pB = pB[factor].copy() if factor in factors_to_update: validation_pB[:, :, action[factor]] += ( l_rate * maths.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)) self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
def test_multi_factor_init_values_expand(self): values_1 = np.random.rand(5) values_2 = np.random.rand(4) values = np.array([values_1, values_2], dtype=object) d = Dirichlet(values=values) self.assertEqual(d.shape, (2, )) self.assertEqual(d[0].shape, (5, 1)) self.assertEqual(d[1].shape, (4, 1))
def test_update_pA_single_factor_all(self): """ Test for updating prior Dirichlet parameters over sensory likelihood (pA) in the case that all observation modalities are updated and the generative model has a single hidden state factor """ n_states = [3] qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 # single observation modality num_obs = [4] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) observation = A.dot(qs, return_numpy=False).sample() pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=l_rate, modalities="all", return_numpy=True) validation_pA = pA + l_rate * core.spm_cross( np.eye(*num_obs)[observation], qs.values) self.assertTrue(np.all(pA_updated == validation_pA.values)) # multiple observation modalities num_obs = [3, 4] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) observation = A.dot(qs, return_numpy=False).sample() pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=l_rate, modalities="all", return_numpy=True) for modality, no in enumerate(num_obs): update = core.spm_cross( np.eye(no)[observation[modality]], qs.values) validation_pA = pA[modality] + l_rate * update self.assertTrue( np.all(pA_updated[modality] == validation_pA.values))
def test_multi_factor_init_values(self): values_1 = np.random.rand(5, 4) values_2 = np.random.rand(4, 3) values = np.array([values_1, values_2]) d = Dirichlet(values=values) self.assertEqual(d.shape, (2, )) self.assertEqual(d[0].shape, (5, 4)) self.assertEqual(d[1].shape, (4, 3))
def test_update_pB_multi_factor_some_controllable_some_factors(self): """ Test for updating prior Dirichlet parameters over transition likelihood (pB) in the case that there are mulitple hidden state factors, and some of them are controllable. Some factors are updated. """ n_states = [3, 4, 5] n_control = [1, 3, 1] qs_prev = Categorical(values=construct_init_qs(n_states)) qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 factors_to_update = [0, 1] B_values = np.empty(len(n_states), dtype=object) pB_values = np.empty(len(n_states), dtype=object) for factor, ns in enumerate(n_states): B_values[factor] = np.random.rand(ns, ns, n_control[factor]) pB_values[factor] = np.ones((ns, ns, n_control[factor])) B = Categorical(values=B_values) B.normalize() pB = Dirichlet(values=pB_values) action = np.array([np.random.randint(nc) for nc in n_control]) pB_updated = core.update_transition_dirichlet( pB, B, action, qs, qs_prev, lr=l_rate, factors=factors_to_update, return_numpy=True) validation_pB = pB.copy() for factor, _ in enumerate(n_control): validation_pB = pB[factor].copy() if factor in factors_to_update: validation_pB[:, :, action[factor]] += ( l_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)) self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
def test_contains_zeros(self): values = np.array([[1.0, 0.0], [1.0, 1.0]]) d = Dirichlet(values=values) self.assertTrue(d.contains_zeros()) values = np.array([[1.0, 1.0], [1.0, 1.0]]) d = Dirichlet(values=values) self.assertFalse(d.contains_zeros())
def test_pA_info_gain(self): """ Test the pA_info_gain function. Demonstrates operation by manipulating shape of the Dirichlet priors over likelihood parameters (pA), which affects information gain for different expected observations """ n_states = [2] n_control = [2] qs = Categorical(values=np.eye(n_states[0])[0]) B = Categorical(values=construct_generic_B(n_states, n_control)) # single timestep n_step = 1 policies = core.construct_policies(n_states, n_control, policy_len=n_step) # single observation modality num_obs = [2] # create noiseless identity A matrix A = Categorical(values=np.eye(num_obs[0])) # create prior over dirichlets such that there is a skew # in the parameters about the likelihood mapping from the # second hidden state (index 1) to observations, such that one # observation is considered to be more likely than the other conditioned on that state. # Therefore sampling that observation would afford high info gain # about parameters for that part of the likelhood distribution. pA_matrix = construct_pA(num_obs, n_states) pA_matrix[0, 1] = 2.0 pA = Dirichlet(values=pA_matrix) pA_info_gains = np.zeros(len(policies)) for idx, policy in enumerate(policies): qs_pi = core.get_expected_states(qs, B, policy) qo_pi = core.get_expected_obs(qs_pi, A) pA_info_gains[idx] += core.calc_pA_info_gain(pA, qo_pi, qs_pi) self.assertGreater(pA_info_gains[1], pA_info_gains[0])
def test_float_conversion(self): values = np.array([2, 3]) self.assertEqual(values.dtype, np.int) d = Dirichlet(values=values) self.assertEqual(d.values.dtype, np.float64)
# Prior We initialise the agent's prior over hidden states at the start to be a flat distribution (i.e. the agent has no strong beliefs about what state it is starting in) # Posterior (recognition density) We initialise the posterior beliefs `qs` about hidden states (namely, beliefs about 'where I am') as a flat distribution over the possible states. This requires the agent to first gather a proprioceptive observation from the environment (e.g. a bodily sensation of where it feels itself to be) before updating its posterior to be centered on the true, evidence-supported location. """ likelihood_matrix = env.get_likelihood_dist() A = Categorical(values=likelihood_matrix) A.remove_zeros() plot_likelihood(A, 'Observation likelihood') b = Dirichlet(values=np.ones((n_states, n_states))) B = b.mean() plot_likelihood(B, 'Initial transition likelihood') D = Categorical(values=np.ones(n_states)) D.normalize() qs = Categorical(dims=[env.n_states]) qs.normalize() """ Run the dynamics of the environment and inference. Start by eliciting one observation from the environment """ # reset environment first_state = env.reset(init_state=s[0])
def test_log(self): values = np.random.rand(3, 2) log_values = np.log(values) d = Dirichlet(values=values) self.assertTrue(np.array_equal(d.log(return_numpy=True), log_values))
def test_ndim(self): values = np.random.rand(3, 2) d = Dirichlet(values=values) self.assertEqual(d.ndim, d.values.ndim)
def test_multistep_multi_factor_posterior_policies(self): """ Test for computing posterior over policies (and associated expected free energies) in the case of a posterior over hidden states with multiple hidden state factors. This version tests using a policy horizon of 3 steps ahead """ n_states = [3, 4] n_control = [3, 4] qs = Categorical(values=construct_init_qs(n_states)) B = Categorical(values=construct_generic_B(n_states, n_control)) pB = Dirichlet(values=construct_pB(n_states, n_control)) # Single timestep n_step = 3 policies = core.construct_policies(n_states, n_control, policy_len=n_step) # Single observation modality num_obs = [4] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) C = Categorical(values=construct_generic_C(num_obs)) q_pi, efe = core.update_posterior_policies( qs, A, B, C, policies, use_utility=True, use_states_info_gain=True, use_param_info_gain=True, pA=pA, pB=pB, gamma=16.0, return_numpy=True, ) self.assertEqual(len(q_pi), len(policies)) # type: ignore self.assertEqual(len(efe), len(policies)) # Multiple observation modalities num_obs = [3, 2] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) C = Categorical(values=construct_generic_C(num_obs)) q_pi, efe = core.update_posterior_policies( qs, A, B, C, policies, use_utility=True, use_states_info_gain=True, use_param_info_gain=True, pA=pA, pB=pB, gamma=16.0, return_numpy=True, ) self.assertEqual(len(q_pi), len(policies)) # type: ignore self.assertEqual(len(efe), len(policies))
def test_normalize_two_dim(self): values = np.array([[1.0, 1.0], [1.0, 1.0]]) d = Dirichlet(values=values) expected_values = np.array([[0.5, 0.5], [0.5, 0.5]]) self.assertTrue( np.array_equal(d.mean(return_numpy=True), expected_values))
def test_remove_zeros(self): values = np.array([[1.0, 0.0], [1.0, 1.0]]) d = Dirichlet(values=values) self.assertTrue((d.values == 0.0).any()) d.remove_zeros() self.assertFalse((d.values == 0.0).any())
def test_shape(self): values = np.random.rand(3, 2) d = Dirichlet(values=values) self.assertEqual(d.shape, (3, 2))
def test_init_empty(self): d = Dirichlet() self.assertEqual(d.ndim, 2)
def test_multi_factor_init_dims(self): d = Dirichlet(dims=[[5, 4], [4, 3]]) self.assertEqual(d.shape, (2, )) self.assertEqual(d[0].shape, (5, 4)) self.assertEqual(d[1].shape, (4, 3))
def to_dirichlet(values): return Dirichlet(values=values)
def test_init_overload(self): with self.assertRaises(ValueError): values = np.random.rand(3, 2) _ = Dirichlet(dims=2, values=values)
def test_init_dims_int_expand(self): d = Dirichlet(dims=5) self.assertEqual(d.shape, (5, 1))
def __init__(self, A=None, pA=None, B=None, pB=None, C=None, D=None, n_states=None, n_observations=None, n_controls=None, policy_len=1, inference_horizon=1, control_fac_idx=None, policies=None, gamma=16.0, use_utility=True, use_states_info_gain=True, use_param_info_gain=False, action_sampling="marginal_action", inference_algo="VANILLA", inference_params=None, modalities_to_learn="all", lr_pA=1.0, factors_to_learn="all", lr_pB=1.0, use_BMA=True, policy_sep_prior=False): ### Constant parameters ### # policy parameters self.policy_len = policy_len self.gamma = gamma self.action_sampling = action_sampling self.use_utility = use_utility self.use_states_info_gain = use_states_info_gain self.use_param_info_gain = use_param_info_gain # learning parameters self.modalities_to_learn = modalities_to_learn self.lr_pA = lr_pA self.factors_to_learn = factors_to_learn self.lr_pB = lr_pB """ Initialise observation model (A matrices) """ if A is not None: # Create `Categorical` if not isinstance(A, Categorical): self.A = Categorical(values=A) else: self.A = A # Determine number of modalities and observations if self.A.IS_AOA: self.n_modalities = self.A.shape[0] self.n_observations = [ self.A[modality].shape[0] for modality in range(self.n_modalities) ] else: self.n_modalities = 1 self.n_observations = [self.A.shape[0]] construct_A_flag = False else: # If A is none, we randomly initialise the matrix. This requires some information if n_observations is None: raise ValueError( "Must provide either `A` or `n_observations` to `Agent` constructor" ) self.n_observations = n_observations self.n_modalities = len(self.n_observations) construct_A_flag = True """ Initialise prior Dirichlet parameters on observation model (pA matrices) """ if pA is not None: if not isinstance(pA, Dirichlet): self.pA = Dirichlet(values=pA) else: self.pA = pA else: self.pA = None """ Initialise transition model (B matrices) """ if B is not None: if not isinstance(B, Categorical): self.B = Categorical(values=B) else: self.B = B # Same logic as before, but here we need number of factors and states per factor if self.B.IS_AOA: self.n_factors = self.B.shape[0] self.n_states = [ self.B[f].shape[0] for f in range(self.n_factors) ] else: self.n_factors = 1 self.n_states = [self.B.shape[0]] construct_B_flag = False else: if n_states is None: raise ValueError( "Must provide either `B` or `n_states` to `Agent` constructor" ) self.n_states = n_states self.n_factors = len(self.n_factors) #type: ignore construct_B_flag = True """ Initialise prior Dirichlet parameters on transition model (pB matrices) """ if pB is not None: if not isinstance(pB, Dirichlet): self.pB = Dirichlet(values=pA) else: self.pB = pB else: self.pB = None # Users have the option to make only certain factors controllable. # default behaviour is to make all hidden state factors controllable # (i.e. self.n_states == self.n_controls) if control_fac_idx is None: self.control_fac_idx = list(range(self.n_factors)) else: self.control_fac_idx = control_fac_idx # The user can specify the number of control states # However, given the controllable factors, this can be inferred if n_controls is None: _, self.n_controls = self._construct_n_controls() else: self.n_controls = n_controls # Again, the use can specify a set of possible policies, or # all possible combinations of actions and timesteps will be considered if policies is None: self.policies, _ = self._construct_n_controls() else: self.policies = policies # Construct prior preferences (uniform if not specified) if C is not None: if isinstance(C, Categorical): self.C = C else: self.C = Categorical(values=C) else: self.C = self._construct_C_prior() # Construct initial beliefs (uniform if not specified) if D is not None: if isinstance(D, Categorical): self.D = D else: self.D = Categorical(values=D) else: self.D = self._construct_D_prior() # Build model if construct_A_flag: self.A = self._construct_A_distribution() if construct_B_flag: self.B = self._construct_B_distribution() self.edge_handling_params = {} self.edge_handling_params['use_BMA'] = use_BMA self.edge_handling_params['policy_sep_prior'] = policy_sep_prior if inference_algo is None: self.inference_algo = "VANILLA" self.inference_params = self._get_default_params() if inference_horizon > 1: print( "WARNING: if `inference_algo` is VANILLA, then inference_horizon must be 1\n. \ Setting inference_horizon to default value of 1...\n") else: self.inference_horizon = 1 else: self.inference_algo = inference_algo self.inference_params = self._get_default_params() self.inference_horizon = inference_horizon self.prev_obs = [] self.reset() self.action = None self.prev_actions = None