def get_qv_func_dict(self, pol: Optional[Policy]) -> QVFType: control = pol is None this_pol = pol if pol is not None else self.get_init_policy() sa_dict = self.mdp_rep.state_action_dict counts_dict = {s: {a: 0 for a in v} for s, v in sa_dict.items()} qf_dict = {s: {a: 0.0 for a in v} for s, v in sa_dict.items()} episodes = 0 while episodes < self.num_episodes: start_state, start_action = self.mdp_rep.init_state_action_gen() mc_path = self.get_mc_path(this_pol, start_state, start_action) rew_arr = np.array([x for _, _, x, _ in mc_path]) if mc_path[-1][0] in self.mdp_rep.terminal_states: returns = get_returns_from_rewards_terminating( rew_arr, self.mdp_rep.gamma) else: returns = get_returns_from_rewards_non_terminating( rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps) for i, r in enumerate(returns): s, a, _, f = mc_path[i] if not self.first_visit or f: counts_dict[s][a] += 1 c = counts_dict[s][a] qf_dict[s][a] = (qf_dict[s][a] * (c - 1) + r) / c if control: this_pol = get_soft_policy_from_qf_dict( qf_dict, self.softmax, self.epsilon) episodes += 1 return qf_dict
def get_qv_func_fa(self, polf: Optional[PolicyActDictType]) -> QFType: control = polf is None this_polf = polf if polf is not None else self.get_init_policy_func() episodes = 0 while episodes < self.num_episodes: start_state, start_action = self.mdp_rep.init_state_action_gen() mc_path = self.get_mc_path(this_polf, start_state, start_action) rew_arr = np.array([x for _, _, x, _ in mc_path]) if self.mdp_rep.terminal_state_func(mc_path[-1][0]): returns = get_returns_from_rewards_terminating( rew_arr, self.mdp_rep.gamma) else: returns = get_returns_from_rewards_non_terminating( rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps) sgd_pts = [((mc_path[i][0], mc_path[i][1]), r) for i, r in enumerate(returns) if not self.first_visit or mc_path[i][3]] # MC is offline update and so, policy improves after each episode self.qvf_fa.update_params(*zip(*sgd_pts)) if control: this_polf = get_soft_policy_func_from_qf( self.qvf_fa.get_func_eval, self.state_action_func, self.softmax, self.epsilon_func(episodes)) episodes += 1 return lambda st: lambda act, st=st: self.qvf_fa.get_func_eval( (st, act))
def get_value_func_dict(self, pol: Policy) -> VFType: sa_dict = self.mdp_rep.state_action_dict counts_dict = {s: 0 for s in sa_dict.keys()} vf_dict = {s: 0.0 for s in sa_dict.keys()} episodes = 0 while episodes < self.num_episodes: start_state = self.mdp_rep.init_state_gen() mc_path = self.get_mc_path(pol, start_state, start_action=None) rew_arr = np.array([x for _, _, x, _ in mc_path]) if mc_path[-1][0] in self.mdp_rep.terminal_states: returns = get_returns_from_rewards_terminating( rew_arr, self.mdp_rep.gamma) else: returns = get_returns_from_rewards_non_terminating( rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps) for i, r in enumerate(returns): s, _, _, f = mc_path[i] if not self.first_visit or f: counts_dict[s] += 1 c = counts_dict[s] vf_dict[s] = (vf_dict[s] * (c - 1) + r) / c episodes += 1 return vf_dict
def get_value_func_fa(self, polf: PolicyActDictType) -> VFType: episodes = 0 while episodes < self.num_episodes: start_state = self.mdp_rep.init_state_gen() mc_path = self.get_mc_path(polf, start_state, start_action=None) rew_arr = np.array([x for _, _, x in mc_path]) if self.mdp_rep.terminal_state_func(mc_path[-1][0]): returns = get_returns_from_rewards_terminating( rew_arr, self.mdp_rep.gamma) else: returns = get_returns_from_rewards_non_terminating( rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps) sgd_pts = [(mc_path[i][0], r) for i, r in enumerate(returns)] self.vf_fa.update_params(*zip(*sgd_pts)) episodes += 1 return self.vf_fa.get_func_eval