class SemiGradientTDControl:
    def __init__(self,
                 nParams,
                 nActions,
                 alpha,
                 approximationFunctionArgs,
                 actionSelectionMethod="egreedy",
                 epsilon=0.01):
        self.name = "Generic SemiGradient TD Control Class"
        self.nParams = nParams
        self.nActions = nActions
        self.alpha = alpha
        self.af_kwargs = approximationFunctionArgs
        self.af = getValueFromDict(self.af_kwargs, "af")
        self.afd = getValueFromDict(self.af_kwargs, "afd")
        self.w = np.zeros([self.nParams], dtype=float)
        self.policy = FunctionApproximationPolicy(
            self.nParams,
            self.nActions,
            self.af_kwargs,
            actionSelectionMethod=actionSelectionMethod,
            epsilon=epsilon)

    def selectAction(self, state):
        return self.policy.selectAction(state)

    def getValue(self, state, action=None):
        if action is None:
            return np.array([
                self.af(self.w, state, a, **self.af_kwargs)
                for a in range(self.nActions)
            ])
        else:
            return self.af(self.w, state, action, **self.af_kwargs)

    def getName(self):
        return self.name

    def reset(self):
        self.w = np.zeros([self.nParams], dtype=float)
        self.policy.reset()

    def getGreedyAction(self, state, actionsAvailable=None):
        q = np.array([
            self.af(self.w, state, a, **self.af_kwargs)
            for a in range(self.nActions)
        ])
        if (actionsAvailable is None):
            actionValues = q[:]
            actionList = np.array(range(self.nActions))
        else:
            actionValues = q[actionsAvailable]
            actionList = np.array(actionsAvailable)
        actionIdx = selectAction_greedy(actionValues)
        return actionList[actionIdx]
class TrueOnlineSARSA:
    def __init__(self,
                 nParams,
                 nActions,
                 alpha,
                 gamma,
                 lambd,
                 approximationFunctionArgs,
                 actionSelectionMethod="egreedy",
                 epsilon=0.01):
        self.name = "True Online SARSA"
        self.nParams = nParams
        self.nActions = nActions
        self.alpha = alpha
        self.gamma = gamma
        self.lambd = lambd
        self.af_kwargs = approximationFunctionArgs
        self.af = getValueFromDict(self.af_kwargs, "af")
        self.ftf = getValueFromDict(self.af_kwargs, "ftf")
        self.w = np.zeros([self.nParams], dtype=np.float)
        self.z = np.zeros([self.nParams], dtype=np.float)
        self.q_old = 0.0
        self.policy = FunctionApproximationPolicy(
            self.nParams,
            self.nActions,
            self.af_kwargs,
            actionSelectionMethod=actionSelectionMethod,
            epsilon=epsilon)

    def update(self, episode):
        t = len(episode) - 2
        state = episode[t]["state"]
        action = episode[t]["action"]
        reward = episode[t + 1]["reward"]
        next_state = episode[t + 1]["state"]
        next_action = episode[t + 1]["action"]
        done = episode[t + 1]["done"]
        x = self.ftf(state, action, **self.af_kwargs)
        xx = self.ftf(next_state, next_action, **self.af_kwargs)
        q = self.getValue(state, action)
        q_next = self.getValue(next_state, next_action)
        td_error = reward + self.gamma * q_next - q
        self.z = self.gamma * self.lambd * self.z + (
            1 - self.alpha * self.gamma * self.lambd * np.dot(self.z, x)) * x
        self.w += self.alpha * (td_error + q - self.q_old
                                ) * self.z - self.alpha * (q - self.q_old) * x
        self.policy.update(self.w)
        self.q_old = q_next
        if done:
            self.z *= 0.0
            self.q_old = 0.0

    def getValue(self, state, action=None):
        if action is None:
            return np.array([
                self.af(self.w, state, action, **self.af_kwargs)
                for action in range(self.nActions)
            ])
        else:
            return self.af(self.w, state, action, **self.af_kwargs)

    def selectAction(self, state):
        return self.policy.selectAction(state)

    def reset(self):
        self.w = np.zeros([self.nParams], dtype=np.float)
        self.z = np.zeros([self.nParams], dtype=np.float)
        self.q_old = 0.0

    def getName(self):
        return self.name
class SARSALambda:
    def __init__(self,
                 nParams,
                 nActions,
                 alpha,
                 gamma,
                 lambd,
                 approximationFunctionArgs,
                 doAccumulateTraces=False,
                 doClearTraces=False,
                 actionSelectionMethod="egreedy",
                 epsilon=0.01):
        self.name = "SARSA(Lambda)"
        self.nParams = nParams
        self.nActions = nActions
        self.alpha = alpha
        self.gamma = gamma
        self.lambd = lambd
        self.af_kwargs = approximationFunctionArgs
        self.af = getValueFromDict(self.af_kwargs, "af")
        self.ftf = getValueFromDict(self.af_kwargs, "ftf")
        self.doAccumulateTraces = doAccumulateTraces
        self.doClearTraces = doClearTraces
        self.w = np.zeros([self.nParams], dtype=np.float)
        self.z = np.zeros([self.nParams], dtype=np.float)
        self.policy = FunctionApproximationPolicy(
            self.nParams,
            self.nActions,
            self.af_kwargs,
            actionSelectionMethod=actionSelectionMethod,
            epsilon=epsilon)

    def update(self, episode):
        t = len(episode) - 2
        state = episode[t]["state"]
        action = episode[t]["action"]
        reward = episode[t + 1]["reward"]
        next_state = episode[t + 1]["state"]
        next_action = episode[t + 1]["action"]
        done = episode[t + 1]["done"]
        x = self.ftf(state, action, **self.af_kwargs)
        xx = self.ftf(next_state, next_action, **self.af_kwargs)
        td_error = reward
        for i in np.nonzero(x)[0]:
            td_error -= self.w[i]
            if self.doAccumulateTraces:
                self.z[i] += 1
            else:
                self.z[i] = 1
        if done:
            self.w += self.alpha * td_error * self.z
            self.policy.update(self.w)
            self.z *= 0.0
        else:
            for i in np.nonzero(xx)[0]:
                td_error += self.gamma * self.w[i]
            self.w += self.alpha * td_error * self.z
            self.policy.update(self.w)
            self.z = self.gamma * self.lambd * self.z
        if self.doClearTraces:
            idxToClear = np.array(np.ones(self.nParams), dtype=int)
            idxToClear[np.nonzero(x)[0]] = 0
            idxToClear[np.nonzero(xx)[0]] = 0
            self.z[idxToClear] = 0.0

    def getValue(self, state, action=None):
        if action is None:
            return np.array([
                self.af(self.w, state, action, **self.af_kwargs)
                for action in range(self.nActions)
            ])
        else:
            return self.af(self.w, state, action, **self.af_kwargs)

    def selectAction(self, state):
        return self.policy.selectAction(state)

    def reset(self):
        self.w = np.zeros([self.nParams], dtype=np.float)
        self.z = np.zeros([self.nParams], dtype=np.float)

    def getName(self):
        return self.name