def mutual_info(self,
                    chan_input: dists.FiniteDist,
                    base: float = 'e') -> float:
        """
        Calculates the mutual information between X and Y.

        :param chan_input: A finite distribution over n elements representing the assumed distribution over X.
        :return: The mutual information.
        """
        n, m = self._y_given_x.shape
        joint = self.joint(chan_input).pmf(shape=(n, m))
        marginal = self.marginal(chan_input).pmf()

        denom = marginal.reshape((-1, 1)) * chan_input.pmf()
        inside_log = np.zeros(denom.shape)

        denom_nonzeros = denom != 0
        joint_nonzeros = joint != 0

        inside_log[
            denom_nonzeros] = joint[denom_nonzeros] / denom[denom_nonzeros]

        pointwise = np.zeros((n, m))

        if base == 'e':
            pointwise[joint_nonzeros] = joint[joint_nonzeros] * np.log(
                inside_log[joint_nonzeros])
        elif base == 2:
            pointwise[joint_nonzeros] = joint[joint_nonzeros] * np.log2(
                inside_log[joint_nonzeros])
        else:
            raise TypeError("Currently only handles base=2 or base='e'.")

        return pointwise.sum()
    def joint(self, chan_input: dists.FiniteDist) -> dists.FiniteDist:
        """
        Computes the joint distribution for (X, Y).

        :param chan_input: A finite distribution over a set of size n.
        :return: A finite distribution over a set of size n * m. The probability of the event {Y = i, X = j} can be
        accessed using a pmf method call with val=(i, j), shape=(m, n).
        """
        return dists.FiniteDist((self._y_given_x * chan_input.pmf()).flatten())
    def marginal(self, prior: dists.FiniteDist) -> dists.FiniteDist:
        """
        Computes the distribution of Y resulting from a prior over X.

        :param prior: The assumed prior on X.
        :return: The marginal distribution of Y.
        """

        return dists.FiniteDist(self._y_given_x @ prior.pmf())
    def posterior(self, prior: dists.FiniteDist,
                  output: int) -> dists.FiniteDist:
        """
        Computes the posterior distribution over X given Y = y.

        :param prior: A finite distribution over n elements representing assumed prior distribution over X.
        :param output: The index of the observed value y.
        :return: A finite distribution over n elements representing the posterior distribution over X.
        """

        return dists.FiniteDist((self._y_given_x[output, :] * prior.pmf()) /
                                self.marginal(prior).pmf(output))
    def process_update(self, belief: dists.FiniteDist,
                       t: int) -> dists.FiniteDist:
        (n, _, m) = self._dynamics.shape
        belief_pmf = belief.pmf()
        next_belief_given_input = np.zeros(n, m)

        for i in range(m):
            next_belief_given_input[:,
                                    i] = self._dynamics.shape[:, :,
                                                              i] @ belief_pmf

        input_dist = self._policy.input_channel(t).marginal(
            dists.FiniteDist(belief_pmf))

        return dists.FiniteDist(next_belief_given_input @ input_dist.pmf())