def __init__(self, task, vfunc = None, tol=1e-3): self.task = task self.num_states = task.get_num_states() self.gamma = task.gamma self.tol = tol if vfunc: self.vfunc = vfunc else: self.vfunc = TabularVfunc(self.num_states)
class ValueIterationSolver(object): ''' Vanilla value iteration for tabular environment ''' def __init__(self, task, vfunc = None, tol=1e-3): self.task = task self.num_states = task.get_num_states() self.gamma = task.gamma self.tol = tol if vfunc: self.vfunc = vfunc else: self.vfunc = TabularVfunc(self.num_states) def get_action(self, state): '''Returns the greedy action with respect to the current policy''' poss_actions = self.task.get_allowed_actions(state) # compute a^* = \argmax_{a} Q(s, a) best_action = None best_val = -float('inf') for action in poss_actions: ns_dist = self.task.next_state_distribution(state, action) val = 0. for ns, prob in ns_dist: val += prob * self.gamma * self.vfunc(ns) if val > best_val: best_action = action best_val = val elif val == best_val and random.random() < 0.5: best_action = action best_val = val return best_action def learn(self): ''' Performs value iteration on the MDP until convergence ''' while True: # repeatedly perform the Bellman backup on each state # V_{i+1}(s) = \max_{a} \sum_{s' \in NS} T(s, a, s')[R(s, a, s') + \gamma V(s')] max_diff = 0. # TODO: Add priority sweeping for state in xrange(self.num_states): for state in self.task.env.get_valid_states(): poss_actions = self.task.get_allowed_actions(state) best_val = 0. for idx, action in enumerate(poss_actions): val = 0. ns_dist = self.task.next_state_distribution(state, action) for ns, prob in ns_dist: val += prob * (self.task.get_reward(state, action, ns) + self.gamma * self.vfunc(ns)) if(idx == 0 or val > best_val): best_val = val diff = abs(self.vfunc(state) - best_val) self.vfunc.update(state, best_val) if diff > max_diff: max_diff = diff if max_diff < self.tol: break