Exemplo n.º 1
0
    def get_qv_func_dict(self, pol: Optional[Policy]) -> QVFType:
        control = pol is None
        this_pol = pol if pol is not None else self.get_init_policy()
        sa_dict = self.mdp_rep.state_action_dict
        counts_dict = {s: {a: 0 for a in v} for s, v in sa_dict.items()}
        qf_dict = {s: {a: 0.0 for a in v} for s, v in sa_dict.items()}
        episodes = 0

        while episodes < self.num_episodes:
            start_state, start_action = self.mdp_rep.init_state_action_gen()
            mc_path = self.get_mc_path(this_pol, start_state, start_action)
            rew_arr = np.array([x for _, _, x, _ in mc_path])
            if mc_path[-1][0] in self.mdp_rep.terminal_states:
                returns = get_returns_from_rewards_terminating(
                    rew_arr, self.mdp_rep.gamma)
            else:
                returns = get_returns_from_rewards_non_terminating(
                    rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps)
            for i, r in enumerate(returns):
                s, a, _, f = mc_path[i]
                if not self.first_visit or f:
                    counts_dict[s][a] += 1
                    c = counts_dict[s][a]
                    qf_dict[s][a] = (qf_dict[s][a] * (c - 1) + r) / c
            if control:
                this_pol = get_soft_policy_from_qf_dict(
                    qf_dict, self.softmax, self.epsilon)
            episodes += 1

        return qf_dict
Exemplo n.º 2
0
    def get_qv_func_fa(self, polf: Optional[PolicyActDictType]) -> QFType:
        control = polf is None
        this_polf = polf if polf is not None else self.get_init_policy_func()
        episodes = 0

        while episodes < self.num_episodes:
            start_state, start_action = self.mdp_rep.init_state_action_gen()
            mc_path = self.get_mc_path(this_polf, start_state, start_action)
            rew_arr = np.array([x for _, _, x, _ in mc_path])
            if self.mdp_rep.terminal_state_func(mc_path[-1][0]):
                returns = get_returns_from_rewards_terminating(
                    rew_arr, self.mdp_rep.gamma)
            else:
                returns = get_returns_from_rewards_non_terminating(
                    rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps)

            sgd_pts = [((mc_path[i][0], mc_path[i][1]), r)
                       for i, r in enumerate(returns)
                       if not self.first_visit or mc_path[i][3]]
            # MC is offline update and so, policy improves after each episode
            self.qvf_fa.update_params(*zip(*sgd_pts))

            if control:
                this_polf = get_soft_policy_func_from_qf(
                    self.qvf_fa.get_func_eval, self.state_action_func,
                    self.softmax, self.epsilon_func(episodes))
            episodes += 1

        return lambda st: lambda act, st=st: self.qvf_fa.get_func_eval(
            (st, act))
Exemplo n.º 3
0
    def get_value_func_dict(self, pol: Policy) -> VFType:
        sa_dict = self.mdp_rep.state_action_dict
        counts_dict = {s: 0 for s in sa_dict.keys()}
        vf_dict = {s: 0.0 for s in sa_dict.keys()}
        episodes = 0

        while episodes < self.num_episodes:
            start_state = self.mdp_rep.init_state_gen()
            mc_path = self.get_mc_path(pol, start_state, start_action=None)

            rew_arr = np.array([x for _, _, x, _ in mc_path])
            if mc_path[-1][0] in self.mdp_rep.terminal_states:
                returns = get_returns_from_rewards_terminating(
                    rew_arr, self.mdp_rep.gamma)
            else:
                returns = get_returns_from_rewards_non_terminating(
                    rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps)
            for i, r in enumerate(returns):
                s, _, _, f = mc_path[i]
                if not self.first_visit or f:
                    counts_dict[s] += 1
                    c = counts_dict[s]
                    vf_dict[s] = (vf_dict[s] * (c - 1) + r) / c
            episodes += 1

        return vf_dict
Exemplo n.º 4
0
    def get_value_func_fa(self, polf: PolicyActDictType) -> VFType:
        episodes = 0

        while episodes < self.num_episodes:
            start_state = self.mdp_rep.init_state_gen()
            mc_path = self.get_mc_path(polf, start_state, start_action=None)

            rew_arr = np.array([x for _, _, x in mc_path])
            if self.mdp_rep.terminal_state_func(mc_path[-1][0]):
                returns = get_returns_from_rewards_terminating(
                    rew_arr, self.mdp_rep.gamma)
            else:
                returns = get_returns_from_rewards_non_terminating(
                    rew_arr, self.mdp_rep.gamma, self.nt_return_eval_steps)

            sgd_pts = [(mc_path[i][0], r) for i, r in enumerate(returns)]
            self.vf_fa.update_params(*zip(*sgd_pts))

            episodes += 1

        return self.vf_fa.get_func_eval