示例#1
0
    def get_qv_func_fa(self, polf: Optional[PolicyActDictType]) -> QFType:
        ffs = self.qvf_fa.feature_funcs
        features = len(ffs)
        a_mat = np.zeros((features, features))
        b_vec = np.zeros(features)
        control = polf is None
        this_polf = polf if polf is not None else self.get_init_policy_func()

        for episode in range(self.num_episodes):
            if self.exploring_start:
                state, action = self.mdp_rep.init_state_action_gen()
            else:
                state = self.mdp_rep.init_state_gen()
                action = get_rv_gen_func_single(this_polf(state))()

            # print((episodes, max(self.qvf_fa.get_func_eval((state, a)) for a in
            #        self.mdp_rep.state_action_func(state))))
            # print(self.qvf_fa.params)

            steps = 0
            terminate = False

            while not terminate:
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_func(state, action)
                phi_s = np.array([f((state, action)) for f in ffs])
                next_action = get_rv_gen_func_single(this_polf(next_state))()
                if control:
                    next_act = max(
                        [(a, self.qvf_fa.get_func_eval((next_state, a)))
                         for a in self.state_action_func(next_state)],
                        key=itemgetter(1))[0]
                else:
                    next_act = next_action
                phi_sp = np.array([f((next_state, next_act)) for f in ffs])
                a_mat += np.outer(phi_s, phi_s - self.mdp_rep.gamma * phi_sp)
                b_vec += reward * phi_s

                steps += 1
                terminate = steps >= self.max_steps or \
                    self.mdp_rep.terminal_state_func(state)
                state = next_state
                action = next_action

            if control and (episode + 1) % self.batch_size == 0:
                self.qvf_fa.params = [np.linalg.inv(a_mat).dot(b_vec)]
                # print(self.qvf_fa.params)
                this_polf = get_soft_policy_func_from_qf(
                    self.qvf_fa.get_func_eval, self.state_action_func,
                    self.softmax, self.epsilon_func(episode))
                a_mat = np.zeros((features, features))
                b_vec = np.zeros(features)

        if not control:
            self.qvf_fa.params = [np.linalg.inv(a_mat).dot(b_vec)]

        return lambda st: lambda act, st=st: self.qvf_fa.get_func_eval(
            (st, act))
示例#2
0
    def get_qv_func_fa(self, polf: Optional[PolicyActDictType]) -> QFType:
        control = polf is None
        this_polf = polf if polf is not None else self.get_init_policy_func()
        episodes = 0

        while episodes < self.num_episodes:
            if self.exploring_start:
                state, action = self.mdp_rep.init_state_action_gen()
            else:
                state = self.mdp_rep.init_state_gen()
                action = get_rv_gen_func_single(this_polf(state))()

            # print((episodes, max(self.qvf_fa.get_func_eval((state, a)) for a in
            #        self.mdp_rep.state_action_func(state))))
            # print(self.qvf_fa.params)

            steps = 0
            terminate = False

            while not terminate:
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_func(state, action)
                next_action = get_rv_gen_func_single(this_polf(next_state))()
                if self.algorithm == TDAlgorithm.QLearning and control:
                    next_qv = max(
                        self.qvf_fa.get_func_eval((next_state, a))
                        for a in self.state_action_func(next_state))
                elif self.algorithm == TDAlgorithm.ExpectedSARSA and control:
                    # next_qv = sum(this_polf(next_state).get(a, 0.) *
                    #               self.qvf_fa.get_func_eval((next_state, a))
                    #               for a in self.state_action_func(next_state))
                    next_qv = get_expected_action_value(
                        {
                            a: self.qvf_fa.get_func_eval((next_state, a))
                            for a in self.state_action_func(next_state)
                        }, self.softmax, self.epsilon_func(episodes))
                else:
                    next_qv = self.qvf_fa.get_func_eval(
                        (next_state, next_action))

                target = reward + self.mdp_rep.gamma * next_qv
                # TD is online update and so, policy improves at every time step
                self.qvf_fa.update_params([(state, action)], [target])
                if control:
                    this_polf = get_soft_policy_func_from_qf(
                        self.qvf_fa.get_func_eval, self.state_action_func,
                        self.softmax, self.epsilon_func(episodes))
                steps += 1
                terminate = steps >= self.max_steps or \
                    self.mdp_rep.terminal_state_func(state)
                state = next_state
                action = next_action

            episodes += 1

        return lambda st: lambda act, st=st: self.qvf_fa.get_func_eval(
            (st, act))
示例#3
0
    def get_mc_path(
        self,
        pol: Policy,
        start_state: S,
        start_action: Optional[A] = None,
    ) -> Sequence[Tuple[S, A, float, bool]]:

        res = []
        next_state = start_state
        steps = 0
        terminate = False
        occ_states = set()
        act_gen_dict = {
            s: get_rv_gen_func_single(pol.get_state_probabilities(s))
            for s in self.mdp_rep.state_action_dict.keys()
        }

        while not terminate:
            state = next_state
            first = state not in occ_states
            occ_states.add(state)
            action = act_gen_dict[state]()\
                if (steps > 0 or start_action is None) else start_action
            next_state, reward =\
                self.mdp_rep.state_reward_gen_dict[state][action]()
            res.append((state, action, reward, first))
            steps += 1
            terminate = steps >= self.max_steps or\
                state in self.mdp_rep.terminal_states
        return res
示例#4
0
    def get_value_func_dict(self, pol: Policy) -> VFType:
        sa_dict = self.mdp_rep.state_action_dict
        vf_dict = {s: 0.0 for s in sa_dict.keys()}
        act_gen_dict = {
            s: get_rv_gen_func_single(pol.get_state_probabilities(s))
            for s in sa_dict.keys()
        }
        episodes = 0

        while episodes < self.num_episodes:
            state = self.mdp_rep.init_state_gen()
            steps = 0
            terminate = False

            while not terminate:
                action = act_gen_dict[state]()
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_dict[state][action]()
                vf_dict[state] += self.alpha * \
                    (reward + self.mdp_rep.gamma * vf_dict[next_state] -
                     vf_dict[state])
                state = next_state
                steps += 1
                terminate = steps >= self.max_steps or \
                    state in self.mdp_rep.terminal_states

            episodes += 1

        return vf_dict
示例#5
0
    def get_value_func_dict(self, pol: Policy) -> VFType:
        sa_dict = self.mdp_rep.state_action_dict
        vf_dict = {s: 0. for s in sa_dict.keys()}
        act_gen_dict = {s: get_rv_gen_func_single(pol.get_state_probabilities(s))
                        for s in sa_dict.keys()}
        episodes = 0
        updates = 0

        while episodes < self.num_episodes:
            et_dict = {s: 0. for s in sa_dict.keys()}
            state = self.mdp_rep.init_state_gen()
            steps = 0
            terminate = False

            while not terminate:
                action = act_gen_dict[state]()
                next_state, reward =\
                    self.mdp_rep.state_reward_gen_dict[state][action]()
                delta = reward + self.mdp_rep.gamma * vf_dict[next_state] -\
                    vf_dict[state]
                et_dict[state] += 1
                alpha = self.learning_rate * (updates / self.learning_rate_decay
                                              + 1) ** -0.5
                for s in sa_dict.keys():
                    vf_dict[s] += alpha * delta * et_dict[s]
                    et_dict[s] *= self.gamma_lambda
                updates += 1
                steps += 1
                terminate = steps >= self.max_steps or\
                    state in self.mdp_rep.terminal_states
                state = next_state

            episodes += 1

        return vf_dict
示例#6
0
 def init_sa(init_state_gen=init_state_gen,
             state_action_func=state_action_func) -> Tuple[S, A]:
     s = init_state_gen()
     actions = state_action_func(s)
     a = get_rv_gen_func_single({a: 1. / len(actions)
                                 for a in actions})()
     return s, a
示例#7
0
    def get_value_func_fa(self, polf: PolicyActDictType) -> VFType:
        ffs = self.vf_fa.feature_funcs
        features = len(ffs)
        a_mat = np.zeros((features, features))
        b_vec = np.zeros(features)

        for _ in range(self.num_episodes):
            state = self.mdp_rep.init_state_gen()
            steps = 0
            terminate = False

            while not terminate:
                action = get_rv_gen_func_single(polf(state))()
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_func(state, action)
                phi_s = np.array([f(state) for f in ffs])
                phi_sp = np.array([f(next_state) for f in ffs])
                a_mat += np.outer(phi_s, phi_s - self.mdp_rep.gamma * phi_sp)
                b_vec += reward * phi_s
                steps += 1
                terminate = steps >= self.max_steps or \
                    self.mdp_rep.terminal_state_func(state)
                state = next_state

        self.vf_fa.params = [np.linalg.inv(a_mat).dot(b_vec)]

        return self.vf_fa.get_func_eval
示例#8
0
    def get_mc_path(
        self,
        polf: PolicyActDictType,
        start_state: S,
        start_action: Optional[A] = None
    ) -> Sequence[Tuple[S, A, float, bool]]:

        res = []
        state = start_state
        steps = 0
        terminate = False
        occ_states = set()

        while not terminate:
            first = state not in occ_states
            occ_states.add(state)
            action = get_rv_gen_func_single(polf(state))()\
                if (steps > 0 or start_action is None) else start_action
            next_state, reward =\
                self.mdp_rep.state_reward_gen_func(state, action)
            res.append((state, action, reward, first))
            steps += 1
            terminate = steps >= self.max_steps or\
                self.mdp_rep.terminal_state_func(state)
            state = next_state
        return res
示例#9
0
    def get_qv_func_dict(self, pol: Optional[Policy]) -> QFDictType:
        control = pol is None
        this_pol = pol if pol is not None else self.get_init_policy()
        sa_dict = self.mdp_rep.state_action_dict
        qf_dict = {s: {a: 0.0 for a in v} for s, v in sa_dict.items()}
        episodes = 0
        updates = 0

        while episodes < self.num_episodes:
            et_dict = {s: {a: 0.0 for a in v} for s, v in sa_dict.items()}
            state, action = self.mdp_rep.init_state_action_gen()
            steps = 0
            terminate = False

            while not terminate:
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_dict[state][action]()
                next_action = get_rv_gen_func_single(
                    this_pol.get_state_probabilities(next_state))()
                if self.algorithm == TDAlgorithm.QLearning and control:
                    next_qv = max(qf_dict[next_state][a]
                                  for a in qf_dict[next_state])
                elif self.algorithm == TDAlgorithm.ExpectedSARSA and control:
                    # next_qv = sum(this_pol.get_state_action_probability(
                    #     next_state,
                    #     a
                    # ) * qf_dict[next_state][a] for a in qf_dict[next_state])
                    next_qv = get_expected_action_value(
                        qf_dict[next_state], self.softmax,
                        self.epsilon_func(episodes))
                else:
                    next_qv = qf_dict[next_state][next_action]

                delta = reward + self.mdp_rep.gamma * next_qv -\
                    qf_dict[state][action]
                et_dict[state][action] += 1
                alpha = self.learning_rate * (
                    updates / self.learning_rate_decay + 1)**-0.5
                for s, a_set in sa_dict.items():
                    for a in a_set:
                        qf_dict[s][a] += alpha * delta * et_dict[s][a]
                        et_dict[s][a] *= self.gamma_lambda
                updates += 1
                if control:
                    if self.softmax:
                        this_pol.edit_state_action_to_softmax(
                            state, qf_dict[state])
                    else:
                        this_pol.edit_state_action_to_epsilon_greedy(
                            state, qf_dict[state], self.epsilon_func(episodes))
                steps += 1
                terminate = steps >= self.max_steps or \
                    state in self.mdp_rep.terminal_states
                state = next_state
                action = next_action

            episodes += 1

        return qf_dict
示例#10
0
 def __init__(self, mdp_ref_obj: MDPRefined) -> None:
     self.state_action_dict: Mapping[S,
                                     Set[A]] = mdp_ref_obj.state_action_dict
     self.terminal_states: Set[S] = mdp_ref_obj.terminal_states
     self.state_reward_gen_dict: Type1 = get_state_reward_gen_dict(
         mdp_ref_obj.rewards_refined, mdp_ref_obj.transitions)
     super().__init__(
         state_action_func=lambda x: self.state_action_dict[x],
         gamma=mdp_ref_obj.gamma,
         terminal_state_func=lambda x: x in self.terminal_states,
         state_reward_gen_func=lambda x, y: self.state_reward_gen_dict[x][y]
         (),
         init_state_gen=get_rv_gen_func_single({
             s: 1. / len(self.state_action_dict)
             for s in self.state_action_dict.keys()
         }),
         init_state_action_gen=get_rv_gen_func_single({
             (s, a):
             1. / sum(len(v) for v in self.state_action_dict.values())
             for s, v1 in self.state_action_dict.items() for a in v1
         }))
示例#11
0
 def __init__(self, state_action_dict: Mapping[S, Set[A]],
              terminal_states: Set[S], state_reward_gen_dict: Type1,
              gamma: float) -> None:
     self.state_action_dict: Mapping[S, Set[A]] = state_action_dict
     self.terminal_states: Set[S] = terminal_states
     self.state_reward_gen_dict: Type1 = state_reward_gen_dict
     super().__init__(
         state_action_func=lambda x: self.state_action_dict[x],
         gamma=gamma,
         terminal_state_func=lambda x: x in self.terminal_states,
         state_reward_gen_func=lambda x, y: self.state_reward_gen_dict[x][y]
         (),
         init_state_gen=get_rv_gen_func_single({
             s: 1. / len(self.state_action_dict)
             for s in self.state_action_dict.keys()
         }),
         init_state_action_gen=get_rv_gen_func_single({
             (s, a):
             1. / sum(len(v) for v in self.state_action_dict.values())
             for s, v1 in self.state_action_dict.items() for a in v1
         }))
示例#12
0
 def get_mdp_rep_for_rl_pg(self) -> MDPRepForRLPG:
     return MDPRepForRLPG(
         gamma=self.gamma,
         init_state_gen_func=get_rv_gen_func_single({
             s: 1. / len(self.state_action_dict)
             for s in self.state_action_dict.keys()
         }),
         state_reward_gen_func=lambda s, a: get_state_reward_gen_func(
             self.transitions[s][a],
             self.rewards_refined[s][a],
         )(),
         terminal_state_func=lambda s: s in self.terminal_states,
     )
示例#13
0
    def get_value_func_fa(self, polf: PolicyActDictType) -> VFType:
        episodes = 0

        while episodes < self.num_episodes:
            et = [np.zeros_like(p) for p in self.vf_fa.params]
            state = self.mdp_rep.init_state_gen()
            steps = 0
            terminate = False

            states = []
            targets = []
            while not terminate:
                action = get_rv_gen_func_single(polf(state))()
                next_state, reward =\
                    self.mdp_rep.state_reward_gen_func(state, action)
                target = reward + self.mdp_rep.gamma *\
                    self.vf_fa.get_func_eval(next_state)
                delta = target - self.vf_fa.get_func_eval(state)
                if self.offline:
                    states.append(state)
                    targets.append(target)
                else:
                    et = [et[i] * self.gamma_lambda + g for i, g in
                          enumerate(self.vf_fa.get_sum_objective_gradient(
                              [state],
                              np.ones(1)
                          )
                          )]
                    self.vf_fa.update_params_from_gradient(
                        [-e * delta for e in et]
                    )
                steps += 1
                terminate = steps >= self.max_steps or\
                    self.mdp_rep.terminal_state_func(state)
                state = next_state

            if self.offline:
                avg_grad = [g / len(states) for g in
                            self.vf_fa.get_el_tr_sum_loss_gradient(
                                states,
                                targets,
                                self.gamma_lambda
                            )]
                self.vf_fa.update_params_from_gradient(avg_grad)
            episodes += 1

        return self.vf_fa.get_func_eval
示例#14
0
    def get_qv_func_dict(self, pol: Optional[Policy]) -> QVFType:
        control = pol is None
        this_pol = pol if pol is not None else self.get_init_policy()
        sa_dict = self.mdp_rep.state_action_dict
        qf_dict = {s: {a: 0.0 for a in v} for s, v in sa_dict.items()}
        episodes = 0

        while episodes < self.num_episodes:
            state, action = self.mdp_rep.init_state_action_gen()
            steps = 0
            terminate = False

            while not terminate:
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_dict[state][action]()
                next_action = get_rv_gen_func_single(
                    this_pol.get_state_probabilities(next_state))()
                if self.algorithm == TDAlgorithm.QLearning and control:
                    next_qv = max(qf_dict[next_state][a]
                                  for a in qf_dict[next_state])
                elif self.algorithm == TDAlgorithm.ExpectedSARSA and control:
                    next_qv = sum(
                        this_pol.get_state_action_probability(next_state, a) *
                        qf_dict[next_state][a] for a in qf_dict[next_state])
                else:
                    next_qv = qf_dict[next_state][next_action]

                qf_dict[state][action] += self.alpha * \
                    (reward + self.mdp_rep.gamma * next_qv -
                     qf_dict[state][action])
                if control:
                    if self.softmax:
                        this_pol.edit_state_action_to_softmax(
                            state, qf_dict[state])
                    else:
                        this_pol.edit_state_action_to_epsilon_greedy(
                            state, qf_dict[state], self.epsilon)
                state = next_state
                action = next_action
                steps += 1
                terminate = steps >= self.max_steps or \
                    state in self.mdp_rep.terminal_states

            episodes += 1

        return qf_dict
示例#15
0
    def get_value_func_fa(self, polf: PolicyActDictType) -> VFType:
        episodes = 0
        updates = 0

        while episodes < self.num_episodes:
            et = np.zeros(self.vf_fa.num_features)
            state = self.mdp_rep.init_state_gen()
            features = self.vf_fa.get_feature_vals(state)
            old_vf_fa = 0.
            steps = 0
            terminate = False

            while not terminate:
                action = get_rv_gen_func_single(polf(state))()
                next_state, reward =\
                    self.mdp_rep.state_reward_gen_func(state, action)
                next_features = self.vf_fa.get_feature_vals(next_state)
                vf_fa = features.dot(self.vf_w)
                next_vf_fa = next_features.dot(self.vf_w)
                target = reward + self.mdp_rep.gamma * next_vf_fa
                delta = target - vf_fa
                alpha = self.vf_fa.learning_rate *\
                    (updates / self.learning_rate_decay + 1) ** -0.5
                et = et * self.gamma_lambda + features *\
                    (1 - alpha * self.gamma_lambda * et.dot(features))
                self.vf_w += alpha * (et *
                                      (delta + vf_fa - old_vf_fa) - features *
                                      (vf_fa - old_vf_fa))
                updates += 1
                steps += 1
                terminate = steps >= self.max_steps or\
                    self.mdp_rep.terminal_state_func(state)
                old_vf_fa = next_vf_fa
                state = next_state
                features = next_features

            episodes += 1

        return lambda x: self.vf_fa.get_feature_vals(x).dot(self.vf_w)
示例#16
0
    def get_value_func_fa(self, polf: PolicyType) -> Type1:
        episodes = 0

        while episodes < self.num_episodes:
            state = self.mdp_rep.init_state_gen()
            steps = 0
            terminate = False

            while not terminate:
                action = get_rv_gen_func_single(polf(state))()
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_func(state, action)
                target = reward + self.mdp_rep.gamma *\
                    self.vf_fa.get_func_eval(next_state)
                self.vf_fa.update_params([state], [target])
                steps += 1
                terminate = steps >= self.max_steps or \
                    self.mdp_rep.terminal_state_func(state)
                state = next_state

            episodes += 1

        return self.vf_fa.get_func_eval
示例#17
0
    def get_qv_func_fa(self, polf: Optional[PolicyActDictType]) -> QFType:
        control = polf is None
        this_polf = polf if polf is not None else self.get_init_policy_func()
        episodes = 0

        while episodes < self.num_episodes:
            et = [np.zeros_like(p) for p in self.qvf_fa.params]
            if self.exploring_start:
                state, action = self.mdp_rep.init_state_action_gen()
            else:
                state = self.mdp_rep.init_state_gen()
                action = get_rv_gen_func_single(this_polf(state))()

            # print((episodes, max(self.qvf_fa.get_func_eval((state, a)) for a in
            #        self.mdp_rep.state_action_func(state))))
            # print(self.qvf_fa.params)

            steps = 0
            terminate = False

            states_actions = []
            targets = []
            while not terminate:
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_func(state, action)
                next_action = get_rv_gen_func_single(this_polf(next_state))()
                if self.algorithm == TDAlgorithm.QLearning and control:
                    next_qv = max(self.qvf_fa.get_func_eval((next_state, a)) for a in
                                  self.state_action_func(next_state))
                elif self.algorithm == TDAlgorithm.ExpectedSARSA and control:
                    # next_qv = sum(this_polf(next_state).get(a, 0.) *
                    #               self.qvf_fa.get_func_eval((next_state, a))
                    #               for a in self.state_action_func(next_state))
                    next_qv = get_expected_action_value(
                        {a: self.qvf_fa.get_func_eval((next_state, a)) for a in
                         self.state_action_func(next_state)},
                        self.softmax,
                        self.epsilon_func(episodes)
                    )
                else:
                    next_qv = self.qvf_fa.get_func_eval((next_state, next_action))

                target = reward + self.mdp_rep.gamma * next_qv
                delta = target - self.qvf_fa.get_func_eval((state, action))

                if self.offline:
                    states_actions.append((state, action))
                    targets.append(target)
                else:
                    et = [et[i] * self.gamma_lambda + g for i, g in
                          enumerate(self.qvf_fa.get_sum_objective_gradient(
                              [(state, action)],
                              np.ones(1)
                          )
                          )]
                    self.qvf_fa.update_params_from_gradient(
                        [-e * delta for e in et]
                    )
                if control and self.batch_size == 0:
                    this_polf = get_soft_policy_func_from_qf(
                        self.qvf_fa.get_func_eval,
                        self.state_action_func,
                        self.softmax,
                        self.epsilon_func(episodes)
                    )
                steps += 1
                terminate = steps >= self.max_steps or \
                    self.mdp_rep.terminal_state_func(state)

                state = next_state
                action = next_action

            if self.offline:
                avg_grad = [g / len(states_actions) for g in
                            self.qvf_fa.get_el_tr_sum_loss_gradient(
                                states_actions,
                                targets,
                                self.gamma_lambda
                            )]
                self.qvf_fa.update_params_from_gradient(avg_grad)

            episodes += 1

            if control and self.batch_size != 0 and\
                    episodes % self.batch_size == 0:
                this_polf = get_soft_policy_func_from_qf(
                    self.qvf_fa.get_func_eval,
                    self.state_action_func,
                    self.softmax,
                    self.epsilon_func(episodes - 1)
                )

        return lambda st: lambda act, st=st: self.qvf_fa.get_func_eval((st, act))
示例#18
0
    def get_qv_func_fa(self, polf: Optional[PolicyActDictType]) -> QFType:
        control = polf is None
        this_polf = polf if polf is not None else self.get_init_policy_func()
        episodes = 0
        updates = 0

        while episodes < self.num_episodes:
            et = np.zeros(self.qvf_fa.num_features)
            if self.exploring_start:
                state, action = self.mdp_rep.init_state_action_gen()
            else:
                state = self.mdp_rep.init_state_gen()
                action = get_rv_gen_func_single(this_polf(state))()
            features = self.qvf_fa.get_feature_vals((state, action))

            # print((episodes, max(self.qvf_fa.get_feature_vals((state, a)).dot(self.qvf_w)
            #                      for a in self.mdp_rep.state_action_func(state))))
            # print(self.qvf_w)

            old_qvf_fa = 0.
            steps = 0
            terminate = False

            while not terminate:
                next_state, reward = \
                    self.mdp_rep.state_reward_gen_func(state, action)
                next_action = get_rv_gen_func_single(this_polf(next_state))()
                next_features = self.qvf_fa.get_feature_vals(
                    (next_state, next_action))
                qvf_fa = features.dot(self.qvf_w)
                if self.algorithm == TDAlgorithm.QLearning and control:
                    next_qvf_fa = max(
                        self.qvf_fa.get_feature_vals((next_state,
                                                      a)).dot(self.qvf_w)
                        for a in self.state_action_func(next_state))
                elif self.algorithm == TDAlgorithm.ExpectedSARSA and control:
                    # next_qvf_fa = sum(this_polf(next_state).get(a, 0.) *
                    #               self.qvf_fa.get_feature_vals((next_state, a)).dot(self.qvf_w)
                    #               for a in self.state_action_func(next_state))
                    next_qvf_fa = get_expected_action_value(
                        {
                            a: self.qvf_fa.get_feature_vals(
                                (next_state, a)).dot(self.qvf_w)
                            for a in self.state_action_func(next_state)
                        }, self.softmax, self.epsilon_func(episodes))
                else:
                    next_qvf_fa = next_features.dot(self.qvf_w)

                target = reward + self.mdp_rep.gamma * next_qvf_fa
                delta = target - qvf_fa
                alpha = self.vf_fa.learning_rate * \
                    (updates / self.learning_rate_decay + 1) ** -0.5
                et = et * self.gamma_lambda + features * \
                    (1 - alpha * self.gamma_lambda * et.dot(features))
                self.qvf_w += alpha * (et * (delta + qvf_fa - old_qvf_fa) -
                                       features * (qvf_fa - old_qvf_fa))

                if control and self.batch_size == 0:
                    this_polf = get_soft_policy_func_from_qf(
                        lambda sa: self.qvf_fa.get_feature_vals(sa).dot(
                            self.qvf_w), self.state_action_func, self.softmax,
                        self.epsilon_func(episodes))
                updates += 1
                steps += 1
                terminate = steps >= self.max_steps or \
                    self.mdp_rep.terminal_state_func(state)
                old_qvf_fa = next_qvf_fa
                state = next_state
                action = next_action
                features = next_features

            episodes += 1

            if control and self.batch_size != 0 and\
                    episodes % self.batch_size == 0:
                this_polf = get_soft_policy_func_from_qf(
                    self.qvf_fa.get_func_eval, self.state_action_func,
                    self.softmax, self.epsilon_func(episodes - 1))

        return lambda st: lambda act, st=st: self.qvf_fa.get_feature_vals(
            (st, act)).dot(self.qvf_w)