Ejemplo n.º 1
0
def sample_action(q_pi, policies, n_control, sampling_type="marginal_action"):
    """
    Samples action from posterior over policies, using one of two methods. 
    Parameters
    ----------
    q_pi [1D numpy.ndarray or Categorical]:
        Posterior beliefs about (possibly multi-step) policies.
    policies [list of numpy ndarrays]:
        List of arrays that indicate the policies under consideration. Each element 
        within the list is a matrix that stores the 
        the indices of the actions  upon the separate hidden state factors, at 
        each timestep (n_step x n_control_factor)
    n_control [list of integers]:
        List of the dimensionalities of the different (controllable)) hidden state factors
    sampling_type [string, 'marginal_action' or 'posterior_sample']:
        Indicates whether the sampled action for a given hidden state factor is given by 
        the evidence for that action, marginalized across different policies ('marginal_action')
        or simply the action entailed by a sample from the posterior over policies
    Returns
    ----------
    selectedPolicy [1D numpy ndarray]:
        Numpy array containing the indices of the actions along each control factor
    """

    n_factors = len(n_control)

    if sampling_type == "marginal_action":

        if utils.is_distribution(q_pi):
            q_pi = utils.to_numpy(q_pi)

        action_marginals = np.empty(n_factors, dtype=object)
        for c_idx in range(n_factors):
            action_marginals[c_idx] = np.zeros(n_control[c_idx])

        # weight each action according to its integrated posterior probability over policies and timesteps
        for pol_idx, policy in enumerate(policies):
            for t in range(policy.shape[0]):
                for factor_i, action_i in enumerate(policy[t, :]):
                    action_marginals[factor_i][action_i] += q_pi[pol_idx]

        action_marginals = Categorical(values=action_marginals)
        action_marginals.normalize()
        selected_policy = np.array(action_marginals.sample())

    elif sampling_type == "posterior_sample":
        if utils.is_distribution(q_pi):
            policy_index = q_pi.sample()
            selected_policy = policies[policy_index]
        else:
            q_pi = Categorical(values=q_pi)
            policy_index = q_pi.sample()
            selected_policy = policies[policy_index]
    else:
        raise ValueError(f"{sampling_type} not supported")

    return selected_policy
Ejemplo n.º 2
0
def process_observations(obs, n_modalities, n_observations):
    """
    Helper function for formatting observations    

        Observations can either be `Categorical`, `int` (converted to one-hot)
        or `tuple` (obs for each modality)
    
    @TODO maybe provide error messaging about observation format
    """
    if utils.is_distribution(obs):
        obs = utils.to_numpy(obs)
        if n_modalities == 1:
            obs = obs.squeeze()
        else:
            for m in range(n_modalities):
                obs[m] = obs[m].squeeze()

    if isinstance(obs, (int, np.integer)):
        obs = np.eye(n_observations[0])[obs]

    if isinstance(obs, tuple):
        obs_arr_arr = np.empty(n_modalities, dtype=object)
        for m in range(n_modalities):
            obs_arr_arr[m] = np.eye(n_observations[m])[obs[m]]
        obs = obs_arr_arr

    return obs
Ejemplo n.º 3
0
def process_priors(prior, n_factors):
    """
    Helper function for formatting observations  
    
    @TODO
    """
    if utils.is_distribution(prior):
        prior_arr = np.empty(n_factors, dtype=object)
        if n_factors == 1:
            prior_arr[0] = prior.values.squeeze()
        else:
            for factor in range(n_factors):
                prior_arr[factor] = prior[factor].values.squeeze()
        prior = prior_arr

    elif not utils.is_arr_of_arr(prior):
        prior = utils.to_arr_of_arr(prior)

    return prior
Ejemplo n.º 4
0
def softmax(dist, return_numpy=True):
    """ Computes the softmax function on a set of values

    """
    if utils.is_distribution(dist):
        if dist.IS_AOA:
            output = []
            for i in range(len(dist.values)):
                output[i] = softmax(dist.values[i], return_numpy=True)
            output = utils.to_categorical(np.array(output))
        else:
            dist = np.copy(dist.values)

    output = dist - dist.max(axis=0)
    output = np.exp(output)
    output = output / np.sum(output, axis=0)
    if return_numpy:
        return output
    else:
        return utils.to_categorical(output)
Ejemplo n.º 5
0
    def cross(self, x=None, return_numpy=False, *args):
        """ Multi-dimensional outer product
            
            If no `x` argument is passed, the function returns the "auto-outer product" 
            of self. Otherwise, the function will recursively take the outer product 
            of the initial entry of `x` with `self` until it has depleted the possible 
            entries of `x` that it can outer-product

        Parameters
        ----------
        - `x` [np.ndarray || [Categorical] (optional)
            The values to perform the outer-product with
        - `args` [np.ndarray] || Categorical] (optional)
            Perform the outer product of the `args` with self
       
        Returns
        -------
        - `y` [np.ndarray || Categorical]
            The result of the outer-product
        """
        x = utils.to_numpy(x)

        if x is not None:
            if len(args) > 0 and utils.is_distribution(args[0]):
                arg_array = []
                for arg in args:
                    arg_array.append(arg.values)
                y = maths.spm_cross(self.values, x, *arg_array)
            else:
                y = maths.spm_cross(self.values, x, *args)
        else:
            y = maths.spm_cross(self.values)

        if return_numpy:
            return y
        else:
            return Categorical(values=y)
Ejemplo n.º 6
0
def update_transition_dirichlet(pB,
                                B,
                                actions,
                                qs,
                                qs_prev,
                                lr=1.0,
                                factors="all",
                                return_numpy=True):
    """
    Update Dirichlet parameters that parameterize the transition model of the generative model 
    (describing the probabilistic mapping between hidden states over time).

    Parameters
    -----------
   -  pB [numpy nd.array, array-of-arrays (with np.ndarray entries), or Dirichlet 
   (either single-modality or AoA)]:
        The prior Dirichlet parameters of the generative model, parameterizing the agent's 
        beliefs about the transition likelihood. 
    - B [numpy nd.array, object-like array of arrays, or Categorical (either single-modality or AoA)]:
        The transition likelihood of the generative model. 
    - actions [numpy 1D array]:
        A 1D numpy array of shape (num_control_factors,) containing the action(s) performed at 
        a given timestep.
    - qs [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), or Categorical 
    (either single-factor or AoA)]:
        Current marginal posterior beliefs about hidden state factors
    - qs_prev [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), or 
    Categorical (either single-factor or AoA)]:
        Past marginal posterior beliefs about hidden state factors
    - lr [float, optional]:
        Learning rate.
    - return_numpy [bool, optional]:
        Logical flag to determine whether output is a numpy array or a Dirichlet
    - factors [list, optional]:
        Indices (in terms of range(Nf)) of the hidden state factors to include in learning.
        Defaults to 'all', meaning that transition likelihood matrices for all hidden state factors
        are updated as a function of transitions in the different control factors (i.e. actions)
    """

    pB = utils.to_numpy(pB)
    B = utils.to_numpy(B)

    if utils.is_arr_of_arr(pB):
        n_factors = len(pB)
    else:
        n_factors = 1

    if return_numpy:
        pB_updated = copy.deepcopy(pB)
    else:
        pB_updated = utils.to_dirichlet(copy.deepcopy(pB))

    if not utils.is_distribution(qs):
        qs = utils.to_categorical(qs)

    if factors == "all":
        if n_factors == 1:
            dfdb = qs.cross(qs_prev, return_numpy=True)
            dfdb = dfdb * (B[:, :, actions[0]] > 0).astype("float")
            pB_updated[:, :,
                       actions[0]] = pB_updated[:, :, actions[0]] + (lr * dfdb)

        elif n_factors > 1:
            for factor in range(n_factors):
                dfdb = qs[factor].cross(qs_prev[factor], return_numpy=True)
                dfdb = dfdb * (B[factor][:, :, actions[factor]] >
                               0).astype("float")
                pB_updated[factor][:, :, actions[factor]] = pB_updated[
                    factor][:, :, actions[factor]] + (lr * dfdb)
    else:
        for factor in factors:
            dfdb = qs[factor].cross(qs_prev[factor], return_numpy=True)
            dfdb = dfdb * (B[factor][:, :, actions[factor]] >
                           0).astype("float")
            pB_updated[factor][:, :, actions[factor]] = pB_updated[
                factor][:, :, actions[factor]] + (lr * dfdb)

    return pB_updated