Example #1
0
class nStepOffPolicySARSA(nStepTDControlAgent):
    def __init__(self,
                 nStates,
                 nActions,
                 alpha,
                 gamma,
                 n,
                 policyUpdateMethod="esoft",
                 epsilon=0.1,
                 tieBreakingMethod="arbitrary",
                 valueInit="zeros"):
        super().__init__(nStates,
                         nActions,
                         alpha,
                         gamma,
                         n,
                         valueInit=valueInit)
        self.name = "n-step off-policy SARSA"
        self.policy = StochasticPolicy(self.nStates,
                                       self.nActions,
                                       policyUpdateMethod=policyUpdateMethod,
                                       epsilon=epsilon,
                                       tieBreakingMethod=tieBreakingMethod)

    def sweepBuffer(self, tau_start, tau_stop, t, T, behaviour_policy):
        for tau in range(tau_start, tau_stop):
            state = self.bufferExperience[tau]['state']
            action = self.bufferExperience[tau]['action']
            rewards = np.array([
                self.bufferExperience[i]['reward']
                for i in range(tau + 1,
                               min(tau + self.n, t + 1) + 1)
            ])
            gammas = np.array(
                [self.gamma**i for i in range(min(self.n, t + 1 - tau))])
            l = min(tau + self.n, t + 1) + 1
            p = [
                self.policy.getProbability(self.bufferExperience[i]['state'],
                                           self.bufferExperience[i]['action'])
                for i in range(tau + 1, l)
            ]
            b = [
                behaviour_policy.getProbability(
                    self.bufferExperience[i]['state'],
                    self.bufferExperience[i]['action'])
                for i in range(tau + 1, l)
            ]
            W = np.prod(np.array(p) / np.array(b))
            G = np.sum(rewards * gammas)
            if (tau + self.n) <= t + 1:
                G += self.gamma**(self.n) * self.actionValueTable[
                    self.bufferExperience[tau + self.n]['state'],
                    self.bufferExperience[tau + self.n]['action']]
            td_error = G - self.actionValueTable[state, action]
            self.actionValueTable[state, action] = self.actionValueTable[
                state, action] + self.alpha * W * td_error
            self.policy.update(state, self.actionValueTable[state, :])

    def selectAction(self, state, actionsAvailable=None):
        return self.policy.sampleAction(state, actionsAvailable)
Example #2
0
 def __init__(self,
              nStates,
              nActions,
              alpha,
              gamma,
              n,
              sigma,
              policyUpdateMethod="esoft",
              epsilon=0.1,
              tieBreakingMethod="arbitrary",
              valueInit="zeros"):
     super().__init__(nStates,
                      nActions,
                      alpha,
                      gamma,
                      n,
                      valueInit=valueInit)
     self.name = "n-step Q-sigma"
     self.sigma = sigma
     self.policy = StochasticPolicy(
         self.nStates,
         self.nActions,
         policyUpdateMethod=policyUpdateMethod,
         epsilon=epsilon,
         tieBreakingMethod=tieBreakingMethod)  # TODO
Example #3
0
 def __init__(self, nStates, nActions, gamma, policyUpdateMethod="greedy", epsilon=0.0, tieBreakingMethod="arbitrary"):
   self.name = "Generic Monte Carlo Control Agent"
   self.nStates = nStates
   self.nActions = nActions
   self.gamma = gamma
   self.actionValueTable = np.zeros([self.nStates, self.nActions], dtype=float)
   self.policy = StochasticPolicy(self.nStates, self.nActions, policyUpdateMethod=policyUpdateMethod,
     epsilon=epsilon, tieBreakingMethod=tieBreakingMethod)
Example #4
0
class nStepTreeBackup(nStepTDControlAgent):
    def __init__(self,
                 nStates,
                 nActions,
                 alpha,
                 gamma,
                 n,
                 policyUpdateMethod="esoft",
                 epsilon=0.1,
                 tieBreakingMethod="arbitrary",
                 valueInit="zeros"):
        super().__init__(nStates,
                         nActions,
                         alpha,
                         gamma,
                         n,
                         valueInit=valueInit)
        self.name = "n-step Tree Backup"
        self.policy = StochasticPolicy(self.nStates,
                                       self.nActions,
                                       policyUpdateMethod=policyUpdateMethod,
                                       epsilon=epsilon,
                                       tieBreakingMethod=tieBreakingMethod)

    def sweepBuffer(self, tau_start, tau_stop, t, T, behaviour_policy=None):
        for tau in range(tau_start, tau_stop):
            state = self.bufferExperience[tau]['state']
            action = self.bufferExperience[tau]['action']
            if (t + 1) >= T:
                G = self.bufferExperience[T]['reward']
            else:
                last_state = self.bufferExperience[t + 1]['state']
                last_reward = self.bufferExperience[t + 1]['reward']
                G = last_reward + self.gamma * np.dot(
                    self.policy.getProbability(last_state),
                    self.actionValueTable[last_state, :])
            for k in range(min(t, T - 1), tau, -1):
                sweeping_state = self.bufferExperience[k]['state']
                sweeping_action = self.bufferExperience[k]['action']
                sweeping_reward = self.bufferExperience[k]['reward']
                probActions = np.array(
                    self.policy.getProbability(sweeping_state))
                probAction = probActions[sweeping_action]
                probActions[sweeping_action] = 0.0
                G = sweeping_reward + self.gamma * np.dot(
                    probActions, self.actionValueTable[
                        sweeping_state, :]) + self.gamma * probAction * G
            td_error = G - self.actionValueTable[state, action]
            self.actionValueTable[state, action] = self.actionValueTable[
                state, action] + self.alpha * td_error
            self.policy.update(state, self.actionValueTable[state, :])

    def selectAction(self, state, actionsAvailable=None):
        return self.policy.sampleAction(state, actionsAvailable)
 def __init__(self, nStates, nActions, alpha, doUseBaseline=True):
     self.nStates = nStates
     self.nActions = nActions
     self.alpha = alpha
     self.doUseBaseline = doUseBaseline
     self.preferencesTable = np.zeros([self.nStates, self.nActions],
                                      dtype=float) + 0.0001
     self.policy = StochasticPolicy(self.nStates,
                                    self.nActions,
                                    policyUpdateMethod="softmax",
                                    tieBreakingMethod="consistent")
     self.count = 0
     self.avgReward = 0.0
Example #6
0
 def __init__(self,
              nStates,
              nActions,
              alpha,
              gamma,
              actionSelectionMethod="esoft",
              epsilon=0.01,
              tieBreakingMethod="arbitrary",
              valueInit="zeros"):
     super().__init__(nStates, nActions, alpha, gamma, valueInit=valueInit)
     self.name = "Expected SARSA"
     self.policy = StochasticPolicy(self.nStates,
                                    self.nActions,
                                    policyUpdateMethod="esoft",
                                    epsilon=epsilon,
                                    tieBreakingMethod=tieBreakingMethod)
Example #7
0
 def __init__(self,
              nStates,
              nActions,
              alpha,
              gamma,
              n,
              valueInit="zeros",
              policyUpdateMethod="greedy",
              epsilon=0.0,
              tieBreakingMethod="consistent"):
     super().__init__(nStates, alpha, gamma, n, valueInit=valueInit)
     self.name = "n-step Per-Decision TD Prediction"
     self.nActions = nActions
     self.policy = StochasticPolicy(self.nStates,
                                    self.nActions,
                                    policyUpdateMethod=policyUpdateMethod,
                                    epsilon=epsilon,
                                    tieBreakingMethod=tieBreakingMethod)
Example #8
0
class nStepPerDecisionTDPrediction(nStepTDPredictionAgent):
    def __init__(self,
                 nStates,
                 nActions,
                 alpha,
                 gamma,
                 n,
                 valueInit="zeros",
                 policyUpdateMethod="greedy",
                 epsilon=0.0,
                 tieBreakingMethod="consistent"):
        super().__init__(nStates, alpha, gamma, n, valueInit=valueInit)
        self.name = "n-step Per-Decision TD Prediction"
        self.nActions = nActions
        self.policy = StochasticPolicy(self.nStates,
                                       self.nActions,
                                       policyUpdateMethod=policyUpdateMethod,
                                       epsilon=epsilon,
                                       tieBreakingMethod=tieBreakingMethod)

    def sweepBuffer(self, tau_start, tau_stop, t, T, behaviour_policy):
        for tau in range(tau_start, tau_stop):
            state = self.bufferExperience[tau]['state']
            l = min(T + 1, t + 1)
            G = self.valueTable[self.bufferExperience[l]['state']]
            for k in range(l - 1, tau - 1, -1):
                sweeping_state = self.bufferExperience[k]['state']
                sweeping_action = self.bufferExperience[k]['action']
                sweeping_reward = self.bufferExperience[k + 1]['reward']
                p = self.policy.getProbability(sweeping_state, sweeping_action)
                b = behaviour_policy.getProbability(sweeping_state,
                                                    sweeping_action)
                W = p / b
                G = W * (sweeping_reward + self.gamma * G) + (
                    1.0 - W) * self.valueTable[sweeping_state]
            td_error = G - self.valueTable[state]
            self.valueTable[
                state] = self.valueTable[state] + self.alpha * td_error

    def reset(self):
        super().reset()
        self.policy.reset()
class BanditGradient():
    def __init__(self, nStates, nActions, alpha, doUseBaseline=True):
        self.nStates = nStates
        self.nActions = nActions
        self.alpha = alpha
        self.doUseBaseline = doUseBaseline
        self.preferencesTable = np.zeros([self.nStates, self.nActions],
                                         dtype=float) + 0.0001
        self.policy = StochasticPolicy(self.nStates,
                                       self.nActions,
                                       policyUpdateMethod="softmax",
                                       tieBreakingMethod="consistent")
        self.count = 0
        self.avgReward = 0.0

    def update(self, state, action, reward):
        if self.doUseBaseline:
            baseline = self.avgReward
        else:
            baseline = 0.0
        for a in range(self.nActions):
            if (a == action):
                self.preferencesTable[state, a] += self.alpha * (
                    reward - baseline) * (1.0 -
                                          self.policy.getProbability(state, a))
            else:
                self.preferencesTable[state, a] -= self.alpha * (
                    reward - baseline) * self.policy.getProbability(state, a)
        self.policy.update(state, self.preferencesTable)
        self.count += 1
        self.avgReward = self.avgReward + (1.0 / self.count) * (reward -
                                                                self.avgReward)

    def selectAction(self, state):
        return self.policy.sampleAction(state)

    def reset(self):
        self.preferencesTable = np.zeros([self.nStates, self.nActions],
                                         dtype=float) + 0.0001
        self.count = 0
        self.avgReward = 0.0
Example #10
0
class ExpectedSARSA(TDControlAgent):
    def __init__(self,
                 nStates,
                 nActions,
                 alpha,
                 gamma,
                 actionSelectionMethod="esoft",
                 epsilon=0.01,
                 tieBreakingMethod="arbitrary",
                 valueInit="zeros"):
        super().__init__(nStates, nActions, alpha, gamma, valueInit=valueInit)
        self.name = "Expected SARSA"
        self.policy = StochasticPolicy(self.nStates,
                                       self.nActions,
                                       policyUpdateMethod="esoft",
                                       epsilon=epsilon,
                                       tieBreakingMethod=tieBreakingMethod)

    def update(self, episode):
        T = len(episode)
        for t in range(0, T - 1):
            state = episode[t]["state"]
            action = episode[t]["action"]
            reward = episode[t + 1]["reward"]
            next_state = episode[t + 1]["state"]
            if ("allowedActions" in episode[t + 1].keys()):
                allowedActions = episode[t + 1]["allowedActions"]
                pdist = Numeric.normalize_sum(
                    self.policy.getProbability(next_state)[allowedActions])
            else:
                allowedActions = np.array(range(self.nActions))
                pdist = self.policy.getProbability(next_state)
            expectedVal = np.dot(
                pdist, self.actionValueTable[next_state, allowedActions])
            td_error = reward + self.gamma * expectedVal - self.actionValueTable[
                state, action]
            self.actionValueTable[state, action] += self.alpha * td_error
            self.policy.update(state, self.actionValueTable[state, :])

    def selectAction(self, state, actionsAvailable=None):
        return self.policy.sampleAction(state, actionsAvailable)
Example #11
0
class MCControlAgent:

  def __init__(self, nStates, nActions, gamma, policyUpdateMethod="greedy", epsilon=0.0, tieBreakingMethod="arbitrary"):
    self.name = "Generic Monte Carlo Control Agent"
    self.nStates = nStates
    self.nActions = nActions
    self.gamma = gamma
    self.actionValueTable = np.zeros([self.nStates, self.nActions], dtype=float)
    self.policy = StochasticPolicy(self.nStates, self.nActions, policyUpdateMethod=policyUpdateMethod,
      epsilon=epsilon, tieBreakingMethod=tieBreakingMethod)

  def selectAction(self, state, actionsAvailable=None):
    return self.policy.sampleAction(state, actionsAvailable)
    
  def getGreedyAction(self, state, actionsAvailable=None):
    if(actionsAvailable is None):
      actionValues = self.actionValueTable[state,:]
      actionList = np.array(range(self.nActions))
    else:
      actionValues = self.actionValueTable[state, actionsAvailable]
      actionList = np.array(actionsAvailable)
    actionIdx = selectAction_greedy(actionValues)
    return actionList[actionIdx]
    
  def getValue(self, state):
    return np.dot(self.policy.getProbability(state), self.actionValueTable[state,:])
    
  def getActionValue(self, state, action):
    return self.actionValueTable[state,action]

  def getName(self):
    return self.name
    
  def reset(self):
    self.actionValueTable = np.zeros([self.nStates, self.nActions], dtype=np.float)
    self.policy.reset()    
  nExperiments = 100
  nEpisodes = 10000
  
  # Agent
  gamma = 1.0
  
  # Environment
  env = Blackjack()

  #startState_dealerHand = [env.LABEL_ACE, min(env.VAL_FACECARDS, np.random.choice(env.deck))]
  startState_dealerHand = [env.LABEL_ACE, env.LABEL_ACE]
  startState_playerHand = [env.LABEL_ACE, env.LABEL_ACE, env.LABEL_ACE]
  startState = env.setHands(startState_playerHand, startState_dealerHand)
  startState_groundTruth = -0.27726

  behaviour_policy = StochasticPolicy(env.nStates, env.nActions)
  
  mse_weighted = np.zeros(nEpisodes)
  mse_ordinary = np.zeros(nEpisodes)
  for idx_experiment in range(nExperiments):
    agent_ordinary = MonteCarloOffPolicyPrediction(env.nStates, env.nActions, gamma, doUseWeightedIS=False)
    agent_weighted = MonteCarloOffPolicyPrediction(env.nStates, env.nActions, gamma, doUseWeightedIS=True)
    for i in range(env.nStatesPlayerSum-1, -1, -1):
      for j in range(env.nStatesDealerShowing):
        for k in [env.USABLE_ACE_YES, env.USABLE_ACE_NO]:
          idx_state = env.getLinearIndex(env.minPlayerSum+i, env.minDealerShowing+j, k)
          if(env.minPlayerSum+i<20):
            actionProb = np.zeros(env.nActions)
            actionProb[env.ACTION_HIT] = 1.0
            agent_ordinary.policy.update(idx_state, actionProb)
            agent_weighted.policy.update(idx_state, actionProb)
Example #13
0
class nStepQSigma(nStepTDControlAgent):
    def __init__(self,
                 nStates,
                 nActions,
                 alpha,
                 gamma,
                 n,
                 sigma,
                 policyUpdateMethod="esoft",
                 epsilon=0.1,
                 tieBreakingMethod="arbitrary",
                 valueInit="zeros"):
        super().__init__(nStates,
                         nActions,
                         alpha,
                         gamma,
                         n,
                         valueInit=valueInit)
        self.name = "n-step Q-sigma"
        self.sigma = sigma
        self.policy = StochasticPolicy(
            self.nStates,
            self.nActions,
            policyUpdateMethod=policyUpdateMethod,
            epsilon=epsilon,
            tieBreakingMethod=tieBreakingMethod)  # TODO

    def sweepBuffer(self, tau_start, tau_stop, t, T, behaviour_policy):
        for tau in range(tau_start, tau_stop):
            state = self.bufferExperience[tau]['state']
            action = self.bufferExperience[tau]['action']
            if ((t + 1) < T):
                G = self.actionValueTable[self.bufferExperience[t +
                                                                1]['state'],
                                          self.bufferExperience[t +
                                                                1]['action']]
            for k in range(t + 1, tau, -1):
                sweeping_state = self.bufferExperience[k]['state']
                sweeping_action = self.bufferExperience[k]['action']
                sweeping_reward = self.bufferExperience[k]['reward']
                if (k == T):
                    G = sweeping_reward
                else:
                    sigma = self.sigma
                    probActions = np.array(
                        self.policy.getProbability(sweeping_state))
                    p = probActions[sweeping_action]
                    b = behaviour_policy.getProbability(
                        sweeping_state, sweeping_action)
                    W = p / b
                    V = np.dot(probActions,
                               self.actionValueTable[sweeping_state, :])
                    G = sweeping_reward + self.gamma * (
                        sigma * W +
                        (1.0 - sigma) * p) * (G - self.actionValueTable[
                            sweeping_state, sweeping_action]) + self.gamma * V
            td_error = G - self.actionValueTable[state, action]
            self.actionValueTable[state, action] = self.actionValueTable[
                state, action] + self.alpha * td_error
            self.policy.update(state, self.actionValueTable[state, :])

    def selectAction(self, state, actionsAvailable=None):
        return self.policy.sampleAction(state, actionsAvailable)
Example #14
0
  # Environment
  sizeX = 4
  sizeY = 4
  defaultReward = -1.0
  terminalStates= [(0,0), (3,3)]
  
  # Agent
  gamma = 0.9
  thresh_convergence = 1e-30
  n = 5
  alpha_TDnOP = 0.001
  alpha_TDnPD = 0.001
 
  env = DeterministicGridWorld(sizeX, sizeY, defaultReward=defaultReward, terminalStates=terminalStates)
  # Behaviour policy is a simple stochastic policy with equiprobable actions
  behaviour_policy = StochasticPolicy(env.nStates, env.nActions)
  # Load target policy q table
  # We will use the optimal policy learned via VI as target policy
  # These are the values learned in chapter04/03_GridWorld_2_VI.py
  with open('gridworld_2_qtable.npy', 'rb') as f:
    targetPolicy_qTable = np.load(f)  
  target_policy = StochasticPolicy(env.nStates, env.nActions)
  for s in range(env.nStates):
    target_policy.update(s, targetPolicy_qTable[s,:])
  # A policy evaluation agent will provide the ground truth
  agent_PE = PolicyEvaluation(env.nStates, env.nActions, gamma, thresh_convergence, env.computeExpectedValue)
  
  env.printEnv()
  
  # Policy evaluation for reference
  for e in range(nEpisodes):
Example #15
0
                                 for y in range(13, 30)])
    else:
        sys.exit("ERROR: trackID not recognized")

    env = RaceTrack(sizeX,
                    sizeY,
                    startStates=startStates,
                    terminalStates=terminalStates,
                    impassableStates=outOfTrackStates,
                    defaultReward=defaultReward,
                    crashReward=outOfTrackReward,
                    finishReward=finishReward,
                    p_actionFail=p_actionFail)
    agent = MonteCarloOffPolicyControl(env.nStates, env.nActions, gamma)
    behaviour_policy = StochasticPolicy(env.nStates,
                                        env.nActions,
                                        policyUpdateMethod="esoft",
                                        epsilon=epsilon)

    for e in range(nEpochs):

        if (e % 1000 == 0):
            print("Epoch : ", e)

        experiences = [{}]
        state = env.reset()
        done = False
        while not done:

            action = behaviour_policy.sampleAction(state,
                                                   env.getAvailableActions())
Example #16
0
        agent_nStepSARSA = nStepSARSA(env.nStates,
                                      env.nActions,
                                      alpha_nStepSARSA,
                                      gamma_nStepSARSA,
                                      n_nStepSARSA,
                                      epsilon=epsilon_nStepSARSA)
        print("running:", agent_nStepSARSA.getName())
        cum_reward_nStepSARSA, nStepsPerEpisode_nStepSARSA = runExperiment(
            nEpisodes, env, agent_nStepSARSA)

        agent_nStepTB = nStepTreeBackup(env.nStates, env.nActions,
                                        alpha_nStepTB, gamma_nStepTB,
                                        n_nStepTB)
        print("running:", agent_nStepTB.getName())
        policy_behaviour = StochasticPolicy(env.nStates,
                                            env.nActions,
                                            policyUpdateMethod="esoft",
                                            epsilon=epsilon_behaviourPolicy)
        cum_reward_nStepTB, nStepsPerEpisode_nStepTB = runExperiment(
            nEpisodes, env, agent_nStepTB, policy_behaviour,
            doUpdateBehaviourPolicy)

        agent_nStepQSigma = nStepQSigma(env.nStates, env.nActions,
                                        alpha_nStepQSigma, gamma_nStepQSigma,
                                        n_nStepQSigma, sigma_nStepQSigma)
        print("running:", agent_nStepQSigma.getName())
        policy_behaviour = StochasticPolicy(env.nStates,
                                            env.nActions,
                                            policyUpdateMethod="esoft",
                                            epsilon=epsilon_behaviourPolicy)
        cum_reward_nStepQSigma, nStepsPerEpisode_nStepQSigma = runExperiment(
            nEpisodes, env, agent_nStepQSigma, policy_behaviour,
    defaultReward = -1.0
    terminalStates = [(0, 0), (3, 3)]

    # Agent
    gamma = 1.0
    thresh_convergence = 1e-30
    n = 5
    alpha_TDn = 0.01
    alpha_TD = 0.01
    alpha_sumTDError = 0.01

    env = DeterministicGridWorld(sizeX,
                                 sizeY,
                                 defaultReward=defaultReward,
                                 terminalStates=terminalStates)
    policy = StochasticPolicy(env.nStates, env.nActions)
    agent_PE = PolicyEvaluation(env.nStates, env.nActions, gamma,
                                thresh_convergence, env.computeExpectedValue)

    # TD agent to validate the TDn implementation
    agent_TD = TDPrediction(env.nStates, alpha_TD, gamma)
    agent_TDn = nStepTDPrediction(env.nStates, alpha_TDn, gamma, n)

    env.printEnv()

    # Policy evaluation for reference
    for e in range(nEpisodes):
        deltaMax, isConverged = agent_PE.evaluate(policy)

        print("Episode : ", e, " Delta: ", deltaMax)
from IRL.agents.MonteCarlo import MonteCarloPrediction
from IRL.utils.Policies import StochasticPolicy

if __name__=="__main__":

	nEpisodes = 100000

	# Environment
	maxCapital = 100
	prob_heads = 0.4

	# Agent
	gamma = 1.0

	env = CoinFlipGame(maxCapital, prob_heads)
	policy = StochasticPolicy(env.nStates, env.nActions)
	agent = MonteCarloPrediction(env.nStates, gamma, doUseAllVisits=False)
	
	#env.printEnv()
	
	for e in range(nEpisodes):
	
		if(e%1000==0):
			print("Episode : ", e)
			
		experiences = [{}]
		state = env.reset()
		done = False	
		while not done:
		
			action = policy.sampleAction(state, env.getAvailableActions())
Example #19
0
  alpha_DP = 0.01 
  gamma_DP = 0.99
  thresh_convergence = 1e-10
  
  alpha_QL = 0.01
  gamma_QL = 0.99

  alpha_GTD = 0.005 
  beta_GTD = 0.05
  gamma_GTD = 0.99
  
  alpha_ETD = 0.03
  gamma_ETD = 0.99
  
  env = BairdsCounterExample()
  behaviour_policy = StochasticPolicy(env.nStates, env.nActions)
  behaviour_policy.actionProbabilityTable[:,env.ACTION_IDX_DASHED] = 6.0/7.0
  behaviour_policy.actionProbabilityTable[:,env.ACTION_IDX_SOLID] = 1.0/7.0

  target_policy = StochasticPolicy(env.nStates, env.nActions)
  target_policy.actionProbabilityTable[:,:] = 0.0
  target_policy.actionProbabilityTable[:,env.ACTION_IDX_SOLID] = 1.0
  
  stateEncodingMatrix = np.zeros([env.nStates, nParams])
  for i in range(env.nStates-1):
    stateEncodingMatrix[i,i] = 2
    stateEncodingMatrix[i,7] = 1
  stateEncodingMatrix[6,6] = 1
  stateEncodingMatrix[6,7] = 2
  approximationFunctionArgs = {'af':linearTransform, 'afd':dLinearTransform, 'ftf':FixedStateEncoding, 'stateEncodingMatrix':stateEncodingMatrix}
Example #20
0
from mpl_toolkits.mplot3d import Axes3D

from IRL.environments.Gambler import Blackjack
from IRL.agents.MonteCarlo import MonteCarloOffPolicyPrediction
from IRL.utils.Policies import StochasticPolicy

if __name__ == "__main__":

    nEpisodes = 500000

    # Agent
    gamma = 1.0

    env = Blackjack()
    agent = MonteCarloOffPolicyPrediction(env.nStates, env.nActions, gamma)
    policy_behaviour = StochasticPolicy(env.nStates, env.nActions)
    for i in range(env.nStatesPlayerSum - 1, -1, -1):
        for j in range(env.nStatesDealerShowing):
            for k in [env.USABLE_ACE_YES, env.USABLE_ACE_NO]:
                idx_state = env.getLinearIndex(env.minPlayerSum + i,
                                               env.minDealerShowing + j, k)
                if (env.minPlayerSum + i < 20):
                    actionProb = np.zeros(env.nActions)
                    actionProb[env.ACTION_HIT] = 1.0
                    agent.policy.update(idx_state, actionProb)
                else:
                    actionProb = np.zeros(env.nActions)
                    actionProb[env.ACTION_STICK] = 1.0
                    agent.policy.update(idx_state, actionProb)

    #env.printEnv()