def input(self, state: int, t: int) -> int: if not self._solved: raise RuntimeError( 'Need to call DiscretePolicy.solve() before asking for inputs.' ) if self._policy_type == 'trv': return FiniteDist(self._input_given_trv[:, state, t]).sample() else: return FiniteDist(self._input_given_trv[:, :, t] @ self._trv_given_state[:, state, t]).sample()
def mutual_info(self, chan_input: dists.FiniteDist, base: float = 'e') -> float: """ Calculates the mutual information between X and Y. :param chan_input: A finite distribution over n elements representing the assumed distribution over X. :return: The mutual information. """ n, m = self._y_given_x.shape joint = self.joint(chan_input).pmf(shape=(n, m)) marginal = self.marginal(chan_input).pmf() denom = marginal.reshape((-1, 1)) * chan_input.pmf() inside_log = np.zeros(denom.shape) denom_nonzeros = denom != 0 joint_nonzeros = joint != 0 inside_log[ denom_nonzeros] = joint[denom_nonzeros] / denom[denom_nonzeros] pointwise = np.zeros((n, m)) if base == 'e': pointwise[joint_nonzeros] = joint[joint_nonzeros] * np.log( inside_log[joint_nonzeros]) elif base == 2: pointwise[joint_nonzeros] = joint[joint_nonzeros] * np.log2( inside_log[joint_nonzeros]) else: raise TypeError("Currently only handles base=2 or base='e'.") return pointwise.sum()
def joint(self, chan_input: dists.FiniteDist) -> dists.FiniteDist: """ Computes the joint distribution for (X, Y). :param chan_input: A finite distribution over a set of size n. :return: A finite distribution over a set of size n * m. The probability of the event {Y = i, X = j} can be accessed using a pmf method call with val=(i, j), shape=(m, n). """ return dists.FiniteDist((self._y_given_x * chan_input.pmf()).flatten())
def marginal(self, prior: dists.FiniteDist) -> dists.FiniteDist: """ Computes the distribution of Y resulting from a prior over X. :param prior: The assumed prior on X. :return: The marginal distribution of Y. """ return dists.FiniteDist(self._y_given_x @ prior.pmf())
def posterior(self, prior: dists.FiniteDist, output: int) -> dists.FiniteDist: """ Computes the posterior distribution over X given Y = y. :param prior: A finite distribution over n elements representing assumed prior distribution over X. :param output: The index of the observed value y. :return: A finite distribution over n elements representing the posterior distribution over X. """ return dists.FiniteDist((self._y_given_x[output, :] * prior.pmf()) / self.marginal(prior).pmf(output))
def process_update(self, belief: dists.FiniteDist, t: int) -> dists.FiniteDist: (n, _, m) = self._dynamics.shape belief_pmf = belief.pmf() next_belief_given_input = np.zeros(n, m) for i in range(m): next_belief_given_input[:, i] = self._dynamics.shape[:, :, i] @ belief_pmf input_dist = self._policy.input_channel(t).marginal( dists.FiniteDist(belief_pmf)) return dists.FiniteDist(next_belief_given_input @ input_dist.pmf())
import numpy as np from trcontrol.scenarios.lava import Lava from trcontrol.framework.control.discrete_policies import DiscreteTRVPolicy, DiscretePolicy from trcontrol.framework.prob.dists import FiniteDist np.random.seed(0) init_dist = FiniteDist(np.array([0.3, 0.4, 0, 0.3, 0])) lava = Lava(5, 2, init_dist, 5) dp = DiscretePolicy(lava) dp.solve() dtp = DiscreteTRVPolicy(lava, 5, 1) dtp.solve(5, 20, True) for t in range(5): print(dtp._trv_given_state[:, :, t].round(decimals=3))
def solve(self, horizon: int, iters: int = 100, verbose: bool = False, init_trv_given_state: Union[np.ndarray, None] = None, init_input_given_trv: Union[np.ndarray, None] = None): costs = self._problem.costs_tensor dynamics = self._problem.dynamics_tensor terminal_costs = self._problem.terminal_costs_tensor (n, _, m) = dynamics.shape p = self._trv_size # to be consistent with paper notation values = np.zeros((n, horizon + 1)) values[:, -1] = self._problem.terminal_costs_tensor state_dist = [ FiniteDist(np.concatenate((np.array([1]), np.zeros(n - 1)))) for t in range(horizon + 1) ] state_dist[0] = self._problem.init_dist trv_dist = [ FiniteDist(np.concatenate((np.array([1]), np.zeros(p - 1)))) for t in range(horizon) ] if init_trv_given_state is None: trv_given_state = np.random.rand(p, n, horizon) trv_given_state = trv_given_state / (trv_given_state / trv_given_state.sum(axis=0)) else: trv_given_state = init_trv_given_state.copy() if init_input_given_trv is None: input_given_trv = np.random.rand(m, p, horizon) input_given_trv = input_given_trv / (input_given_trv / input_given_trv.sum(axis=0)) else: input_given_trv = init_input_given_trv.copy() obj_hist = np.zeros(iters + 1) obj_hist[0] = _objective(dynamics, costs, terminal_costs, self._tradeoff, trv_given_state, input_given_trv, self._problem.init_dist) obj_val = obj_hist[0] self._trv_given_state = trv_given_state self._input_given_trv = input_given_trv transitions = np.zeros((n, n)) for iter in range(iters): if verbose: print(f'\t[{iter}] Objective:\t{obj_hist[iter]:.3}') # Forward Equations for t in range(horizon): input_given_state = input_given_trv[:, :, t] @ trv_given_state[:, :, t] transitions = _forward_eq(dynamics, input_given_state) state_dist[t + 1] = channels.DiscreteChannel(transitions).marginal( state_dist[t]) trv_dist[t] = channels.DiscreteChannel( trv_given_state[:, :, t]).marginal(state_dist[t]) # Backward Equations for t in range(horizon - 1, -1, -1): # TRV de Given State for i in range(n): for j in range(p): exponent = -self._tradeoff * ( (values[:, t + 1] @ dynamics[:, i, :] @ input_given_trv[:, j, t]) + (costs[i, :] @ input_given_trv[:, j, t])) trv_given_state[ j, i, t] = trv_dist[t].pmf(j) * np.exp(exponent) trv_given_state[:, :, t] = trv_given_state[:, :, t] / trv_given_state[:, :, t].sum( axis =0 ) # Input Given TRV policy = cvx.Variable(m, nonneg=True) c = cvx.Parameter(m) c.value = np.zeros(m) obj = cvx.Minimize(c @ policy) cstrs = [cvx.sum(policy) == 1] prob = cvx.Problem(obj, cstrs) for i in range(p): for j in range(m): c.value[j] = trv_given_state[i, :, t] @ (costs[:, j] * state_dist[t].pmf()) \ + (values[:, t + 1] @ dynamics[:, :, j]) @ (trv_given_state[i, :, t] * state_dist[t].pmf()) prob.solve() input_given_trv[:, i, t] = policy.value # Value Function for i in range(n): input_given_state = input_given_trv[:, :, t] @ trv_given_state[:, i, t] trv_dist[t] = channels.DiscreteChannel( trv_given_state[:, :, t]).marginal(state_dist[t]) values[i, t] = costs[i, :] @ input_given_state \ + values[:, t + 1] @ (dynamics[:, i, :] @ input_given_state) \ + (1 / self._tradeoff) * kl(FiniteDist(trv_given_state[:, i, t]), trv_dist[t]) obj_hist[iter + 1] = _objective(dynamics, costs, terminal_costs, self._tradeoff, trv_given_state, input_given_trv, self._problem.init_dist) if obj_hist[iter + 1] <= obj_val: obj_val = obj_hist[iter + 1] self._trv_given_state = trv_given_state self._input_given_trv = input_given_trv if verbose: print(f'\t[{horizon}] Objective:\t{obj_hist[horizon]:.3}') self._solved = True return obj_val