class TrueOnlineSARSA: def __init__(self, nParams, nActions, alpha, gamma, lambd, approximationFunctionArgs, actionSelectionMethod="egreedy", epsilon=0.01): self.name = "True Online SARSA" self.nParams = nParams self.nActions = nActions self.alpha = alpha self.gamma = gamma self.lambd = lambd self.af_kwargs = approximationFunctionArgs self.af = getValueFromDict(self.af_kwargs, "af") self.ftf = getValueFromDict(self.af_kwargs, "ftf") self.w = np.zeros([self.nParams], dtype=np.float) self.z = np.zeros([self.nParams], dtype=np.float) self.q_old = 0.0 self.policy = FunctionApproximationPolicy( self.nParams, self.nActions, self.af_kwargs, actionSelectionMethod=actionSelectionMethod, epsilon=epsilon) def update(self, episode): t = len(episode) - 2 state = episode[t]["state"] action = episode[t]["action"] reward = episode[t + 1]["reward"] next_state = episode[t + 1]["state"] next_action = episode[t + 1]["action"] done = episode[t + 1]["done"] x = self.ftf(state, action, **self.af_kwargs) xx = self.ftf(next_state, next_action, **self.af_kwargs) q = self.getValue(state, action) q_next = self.getValue(next_state, next_action) td_error = reward + self.gamma * q_next - q self.z = self.gamma * self.lambd * self.z + ( 1 - self.alpha * self.gamma * self.lambd * np.dot(self.z, x)) * x self.w += self.alpha * (td_error + q - self.q_old ) * self.z - self.alpha * (q - self.q_old) * x self.policy.update(self.w) self.q_old = q_next if done: self.z *= 0.0 self.q_old = 0.0 def getValue(self, state, action=None): if action is None: return np.array([ self.af(self.w, state, action, **self.af_kwargs) for action in range(self.nActions) ]) else: return self.af(self.w, state, action, **self.af_kwargs) def selectAction(self, state): return self.policy.selectAction(state) def reset(self): self.w = np.zeros([self.nParams], dtype=np.float) self.z = np.zeros([self.nParams], dtype=np.float) self.q_old = 0.0 def getName(self): return self.name
class SARSALambda: def __init__(self, nParams, nActions, alpha, gamma, lambd, approximationFunctionArgs, doAccumulateTraces=False, doClearTraces=False, actionSelectionMethod="egreedy", epsilon=0.01): self.name = "SARSA(Lambda)" self.nParams = nParams self.nActions = nActions self.alpha = alpha self.gamma = gamma self.lambd = lambd self.af_kwargs = approximationFunctionArgs self.af = getValueFromDict(self.af_kwargs, "af") self.ftf = getValueFromDict(self.af_kwargs, "ftf") self.doAccumulateTraces = doAccumulateTraces self.doClearTraces = doClearTraces self.w = np.zeros([self.nParams], dtype=np.float) self.z = np.zeros([self.nParams], dtype=np.float) self.policy = FunctionApproximationPolicy( self.nParams, self.nActions, self.af_kwargs, actionSelectionMethod=actionSelectionMethod, epsilon=epsilon) def update(self, episode): t = len(episode) - 2 state = episode[t]["state"] action = episode[t]["action"] reward = episode[t + 1]["reward"] next_state = episode[t + 1]["state"] next_action = episode[t + 1]["action"] done = episode[t + 1]["done"] x = self.ftf(state, action, **self.af_kwargs) xx = self.ftf(next_state, next_action, **self.af_kwargs) td_error = reward for i in np.nonzero(x)[0]: td_error -= self.w[i] if self.doAccumulateTraces: self.z[i] += 1 else: self.z[i] = 1 if done: self.w += self.alpha * td_error * self.z self.policy.update(self.w) self.z *= 0.0 else: for i in np.nonzero(xx)[0]: td_error += self.gamma * self.w[i] self.w += self.alpha * td_error * self.z self.policy.update(self.w) self.z = self.gamma * self.lambd * self.z if self.doClearTraces: idxToClear = np.array(np.ones(self.nParams), dtype=int) idxToClear[np.nonzero(x)[0]] = 0 idxToClear[np.nonzero(xx)[0]] = 0 self.z[idxToClear] = 0.0 def getValue(self, state, action=None): if action is None: return np.array([ self.af(self.w, state, action, **self.af_kwargs) for action in range(self.nActions) ]) else: return self.af(self.w, state, action, **self.af_kwargs) def selectAction(self, state): return self.policy.selectAction(state) def reset(self): self.w = np.zeros([self.nParams], dtype=np.float) self.z = np.zeros([self.nParams], dtype=np.float) def getName(self): return self.name