コード例 #1
0
    def test_blackjack(self):
        expect_usable_a = [
            [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        ]
        iterator = Sarsa(BlackjackEnv())
        iterator.run(30000000, learning_rate=1)
        print(iterator.env.states.states)

        # blackjack_environment = BlackjackEnvironment()
        # blackjack_environment.monte_carlo_es()
        self.assertEqual(expect_usable_a, BlackjackGetter(iterator.env).get())
コード例 #2
0
    def test_blackjack(self):
        expect_usable_a = [
            [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        ]
        iterator = OffNStepSarsa(BlackjackEnv(), 3)
        iterator.run(3000000, learning_rate=0.5, epsilon=0.3)
        print(iterator.env.states.states)
        iterator.show_one_episode()

        # blackjack_environment = BlackjackEnvironment()
        # blackjack_environment.monte_carlo_es()
        self.assertEqual(expect_usable_a, BlackjackGetter(iterator.env).get())
コード例 #3
0
def main():
    env = BlackjackEnv()

    V_10k = mc_prediction(sample_policy, env, num_episodes=10000)
    plotting.plot_value_function(V_10k, title="10,000 Steps")

    V_500k = mc_prediction(sample_policy, env, num_episodes=500000)
    plotting.plot_value_function(V_500k, title="500,000 Steps")
コード例 #4
0
def getEnv(domain):
    if domain == "Blackjack":
        return BlackjackEnv()
    elif domain == "Gridworld":
        return GridworldEnv()
    elif domain == "CliffWalking":
        return CliffWalkingEnv()
    elif domain == "WindyGridworld":
        return WindyGridworldEnv()
    else:
        try:
            return gym.make(domain)
        except:
            assert False, "Domain must be a valid (and installed) Gym environment"
コード例 #5
0
def main():
    env = BlackjackEnv()

    Q, policy = mc_control_epsilon_greedy(env,
                                          num_episodes=500000,
                                          epsilon=0.1)

    # For plotting: Create value function from action-value function
    # by picking the best action at each state
    V = defaultdict(float)
    for state, actions in Q.items():
        action_value = np.max(actions)
        V[state] = action_value
    plotting.plot_value_function(V, title="Optimal Value Function")
コード例 #6
0
                                np.dstack([X, Y]))

    def plot_surface(X, Y, Z, title):
        fig = plt.figure(figsize=(20, 10))
        ax = fig.add_subplot(111, projection='3d')
        surf = ax.plot_surface(X,
                               Y,
                               Z,
                               rstride=1,
                               cstride=1,
                               cmap=matplotlib.cm.coolwarm,
                               vmin=-1.0,
                               vmax=1.0)
        ax.set_xlabel('Player Sum')
        ax.set_ylabel('Dealer Showing')
        ax.set_zlabel('Value')
        ax.set_title(title)
        ax.view_init(ax.elev, -120)
        fig.colorbar(surf)
        plt.show()

    plot_surface(X, Y, Z_noace, "{} (No Usable Ace)".format(title))
    plot_surface(X, Y, Z_ace, "{} (Usable Ace)".format(title))


if __name__ == '__main__':
    env = BlackjackEnv()
    steps = 500000
    Q, policy = epsilon_monte_carlo(env, steps)
    V = get_value_function(Q, policy)
    plot_value_function(V)
コード例 #7
0
import sys
import os
import gym
RF_REPO = "{}/distr/reinforcement-learning".format(os.environ["HOME"])
if RF_REPO not in sys.path:
    sys.path.append(RF_REPO)

from lib.envs.blackjack import BlackjackEnv
from collections import defaultdict
import numpy as np
from gym import spaces
from lib import plotting

env = BlackjackEnv()


def reverse_array(a):
    assert len(a.shape) == 1
    return np.fliplr([a])[0]


def sample_policy(observation):
    """
    A policy that sticks if the player score is >= 20 and hits otherwise.
    """
    score, dealer_score, usable_ace = observation
    return 0 if score >= 20 else 1


policy = sample_policy
discount_factor = 1.0
コード例 #8
0
from collections import defaultdict

import numpy as np

from lib import plotting
from lib.envs.blackjack import BlackjackEnv
from mdp.algorithms.double_q_learning import DoubleQLearning
from mdp.algorithms.mc_offline import McOfflinePolicy
from mdp.algorithms.mc_online import McOnline
from mdp.algorithms.q_learning import QLearning

if __name__ == '__main__':
    iterator = QLearning(BlackjackEnv())
    iterator.run(500000)
    print(iterator.env.states.states)
    Q = iterator.env.to_v()
    V = defaultdict(float)
    for state, actions in Q.items():
        action_value = np.max(actions)
        V[state] = action_value
    plotting.plot_value_function(V, title="Optimal Value Function")
コード例 #9
0
from collections import defaultdict

import numpy as np
import sys
print(sys.path)
from mdp.algorithms.double_q_learning import DoubleQLearning
from mdp.algorithms.mc_offline import McOfflinePolicy
from mdp.algorithms.mc_online import McOnline
from mdp.algorithms.off_n_step_sarsa import OffNStepSarsa
from lib import plotting
from lib.envs.blackjack import BlackjackEnv
from mdp.algorithms.q_learning import QLearning

if __name__ == '__main__':
    iterator = OffNStepSarsa(BlackjackEnv(), 3)
    iterator.run(500000, epsilon=0.3, learning_rate=0.5)
    print(iterator.env.states.states)
    Q = iterator.env.to_v()
    V = defaultdict(float)
    for state, actions in Q.items():
        action_value = np.max(actions)
        V[state] = action_value
    plotting.plot_value_function(V, title="Optimal Value Function")
コード例 #10
0
 def setUp(self):
     self.sample_policy = lambda observation: 0 if observation[
         0] >= 20 else 1
     self.env = BlackjackEnv(test=True)
コード例 #11
0
ファイル: testenv.py プロジェクト: vicpang/intro_rf
import numpy as np 
import sys
if "../" not in sys.path:
	sys.path.append("../")
from lib.envs.blackjack import BlackjackEnv

env=BlackjackEnv() 



def print_observation(observation):
    score, dealer_score, usable_ace = observation
    print("Player Score: {} (Usable Ace: {}), Dealer Score: {}".format(
          score, usable_ace, dealer_score))

def strategy(observation):
    score, dealer_score, usable_ace = observation
    # Stick (action 0) if the score is > 20, hit (action 1) otherwise
    return 0 if score >= 20 else 1

for i_episode in range(20):
    observation = env.reset()
    for t in range(100):
        print_observation(observation)
        action = strategy(observation)
        print("Taking action: {}".format( ["Stick", "Hit"][action]))
        observation, reward, done, _ = env.step(action)
        if done:
            print_observation(observation)
            print("Game end. Reward: {}\n".format(float(reward)))
            break
コード例 #12
0
    def test_q_values(self):
        np.random.seed(0)
        env = BlackjackEnv(test=True)
        expected_q_values = {
            (14, 10, False): [-0.55667244, -0.55666552],
            (20, 10, False): [0.85307258, 0.],
            (19, 10, False): [-0.03771662, -0.03913207],
            (15, 3, True): [-0.08365019, -0.06177606],
            (17, 3, True): [-0.09180328, -0.02758621],
            (18, 7, False): [0.80474543, 0.],
            (21, 9, True): [1.87136564, 0.],
            (15, 10, False): [-0.57202554, -0.57083906],
            (13, 5, False): [-0.19646018, -0.19606004],
            (17, 7, False): [-0.21999044, -0.22351798],
            (21, 4, True): [1.74716981, 0.],
            (21, 4, False): [1.75901639, 0.],
            (20, 6, False): [1.39322034, 0.],
            (16, 5, False): [-0.22599418, -0.22754491],
            (20, 2, False): [1.30929919, 0.],
            (17, 10, False): [-0.53484487, -0.53445418],
            (18, 10, False): [-0.36599388, -0.36597938],
            (18, 10, True): [-0.17896389, -0.17784711],
            (13, 10, False): [-0.50639033, -0.50642118],
            (12, 3, False): [-0.27648931, -0.27795976],
            (21, 3, False): [1.77389985, 0.],
            (20, 1, False): [0.27903732, 0.],
            (19, 6, False): [0.93563102, -0.00103093],
            (20, 3, False): [1.35117278, 0.],
            (16, 6, False): [-0.27862595, -0.27877121],
            (16, 8, False): [-0.49311164, -0.49056604],
            (15, 6, False): [-0.18892508, -0.19052002],
            (13, 4, False): [-0.22846782, -0.2326228],
            (16, 4, False): [-0.27368421, -0.27437859],
            (15, 2, False): [-0.37379753, -0.37822878],
            (12, 10, False): [-0.49994324, -0.49988649],
            (13, 9, True): [-0.30493274, -0.2962963],
            (13, 9, False): [-0.43541762, -0.43526171],
            (14, 9, False): [-0.49571106, -0.49582947],
            (16, 9, False): [-0.5288868, -0.52892562],
            (18, 5, False): [0.39844886, -0.00187705],
            (19, 3, False): [0.79941719, 0.],
            (20, 7, True): [1.56769596, 0.],
            (20, 4, True): [1.31073446, 0.],
            (16, 3, False): [-0.31508805, -0.31433998],
            (15, 4, True): [-0.19305019, -0.05042017],
            (20, 7, False): [1.57104011e+00, -1.36147039e-03],
            (17, 6, False): [0.02804642, -0.01353311],
            (12, 2, True): [-0.14, 0.10714286],
            (12, 8, False): [-0.3288807, -0.32779317],
            (16, 10, False): [-0.57658694, -0.57661009],
            (13, 5, True): [-0.05882353, 0.36078431],
            (16, 2, False): [-0.36167076, -0.36104513],
            (13, 2, True): [-0.04918033, 0.27705628],
            (13, 2, False): [-0.30961183, -0.31088561],
            (17, 8, False): [-0.43448276, -0.43632869],
            (20, 4, False): [1.32866667, 0.],
            (21, 8, True): [1.89107981, 0.],
            (14, 6, True): [-0.25104603, -0.26771654],
            (14, 1, False): [-0.68097015, -0.68043088],
            (18, 8, False): [0.2268431, 0.],
            (19, 4, False): [0.83059548, -0.00294985],
            (14, 2, False): [-0.30662983, -0.30350554],
            (18, 4, False): [0.35192308, -0.00196175],
            (21, 10, False): [1.787005, 0.],
            (12, 9, False): [-0.42573821, -0.42732049],
            (17, 9, False): [-0.48977695, -0.48971784],
            (18, 9, False): [-0.29426189, -0.2952381],
            (20, 5, False): [1.36117768, 0.],
            (19, 10, True): [-0.02996255, -0.05882353],
            (20, 10, True): [0.89829352, 0.],
            (15, 8, False): [-0.4364667, -0.43577982],
            (16, 3, True): [-0.19157088, -0.17161716],
            (16, 2, True): [-0.22900763, -0.24334601],
            (20, 9, False): [1.52642706, 0.],
            (15, 4, False): [-0.31670481, -0.31813953],
            (18, 6, False): [0.59704433, 0.],
            (16, 9, True): [-0.16494845, -0.14915254],
            (18, 2, False): [0.26942482, -0.00291971],
            (19, 2, True): [0.84726225, 0.],
            (21, 7, False): [1.85587045, 0.],
            (12, 6, False): [-0.21556886, -0.21290323],
            (13, 6, False): [-0.17260274, -0.17280453],
            (12, 4, True): [-0.01550388, 0.16666667],
            (21, 1, False): [1.30333592, 0.],
            (14, 8, False): [-0.39443155, -0.3945157],
            (20, 8, False): [1.57485637, 0.],
            (14, 5, False): [-0.22222222, -0.2221231],
            (15, 5, False): [-0.19739292, -0.20194535],
            (21, 9, False): [1.89735365, 0.],
            (16, 1, True): [-0.65562914, -0.66420664],
            (17, 5, False): [-0.07304181, -0.07662464],
            (21, 5, False): [1.77480315, 0.],
            (13, 7, False): [-0.41295547, -0.41235241],
            (16, 10, True): [-0.36783734, -0.36730946],
            (18, 3, False): [0.27807487, -0.00099108],
            (16, 8, True): [-0.02013423, 0.07604563],
            (16, 1, False): [-0.72951208, -0.72945522],
            (19, 2, False): [0.79284963, -0.00289995],
            (14, 10, True): [-0.40317776, -0.38410596],
            (12, 2, False): [-0.30944774, -0.30908269],
            (13, 3, False): [-0.31129477, -0.31302801],
            (14, 4, False): [-0.23059256, -0.23224044],
            (19, 8, False): [1.25283391, -0.00195027],
            (18, 1, True): [-0.51465798, -0.50666667],
            (16, 4, True): [-0.00746269, 0.1294964],
            (18, 1, False): [-0.48914616, -0.48893167],
            (21, 5, True): [1.76222435, 0.],
            (19, 9, False): [0.61055777, 0.],
            (21, 10, True): [1.78874539, 0.],
            (17, 1, False): [-0.67257509, -0.67368421],
            (15, 9, False): [-0.48830683, -0.4886313],
            (13, 7, True): [-0.08695652, 0.09022556],
            (20, 5, True): [1.44038929, 0.],
            (12, 7, True): [-0.3442623, -0.31007752],
            (12, 7, False): [-0.3150022, -0.31488203],
            (15, 3, False): [-0.33017975, -0.33074463],
            (14, 5, True): [-0.07659574, 0.09302326],
            (13, 1, False): [-0.6660542, -0.66544622],
            (12, 1, False): [-0.67491166, -0.66696468],
            (15, 7, False): [-0.40413112, -0.40112729],
            (21, 2, True): [1.76015109, 0.],
            (15, 1, False): [-0.7369403, -0.73708069],
            (21, 3, True): [1.81576448, 0.],
            (16, 6, True): [-0.07352941, 0.25559105],
            (14, 7, True): [-0.24590164, -0.07017544],
            (16, 7, False): [-0.43923445, -0.43661972],
            (13, 8, False): [-0.39911894, -0.39716312],
            (14, 3, False): [-0.2665424, -0.26531552],
            (17, 3, False): [-0.14574518, -0.14522059],
            (17, 2, False): [-0.26778243, -0.27407762],
            (18, 9, True): [-0.19016393, -0.22960725],
            (21, 8, False): [1.86946011, 0.],
            (12, 4, False): [-0.23551229, -0.23282783],
            (17, 4, False): [-0.14487117, -0.14627011],
            (17, 6, True): [-0.03870968, 0.16828479],
            (21, 2, False): [1.79160187, 0.],
            (14, 1, True): [-0.5785124, -0.58039216],
            (19, 7, False): [1.20190275, 0.],
            (19, 1, False): [-0.19358074, -0.1960396],
            (19, 5, False): [0.8641115, 0.],
            (14, 7, False): [-0.43006834, -0.43024894],
            (19, 9, True): [0.48, 0.],
            (21, 7, True): [1.8575152, 0.],
            (17, 10, True): [-0.35404255, -0.34064081],
            (17, 8, True): [-0.02749141, 0.23255814],
            (14, 6, False): [-0.17974453, -0.18083671],
            (21, 6, True): [1.83101045, 0.],
            (15, 10, True): [-0.35166994, -0.35151515],
            (14, 8, True): [-0.2578125, -0.03305785],
            (21, 1, True): [1.24786325, 0.],
            (19, 1, True): [-0.14044944, -0.15384615],
            (12, 9, True): [-0.24299065, -0.24778761],
            (17, 9, True): [-0.26058632, -0.26751592],
            (16, 5, True): [-0.07560137, -0.10332103],
            (12, 5, False): [-0.2344519, -0.22790489],
            (17, 7, True): [-0.02076125, 0.10996564],
            (19, 7, True): [1.31855956, 0.],
            (14, 2, True): [-0.06837607, 0.08510638],
            (15, 2, True): [-0.0661157, 0.01459854],
            (20, 1, True): [0.32608696, -0.01041667],
            (18, 5, True): [0.38235294, -0.01246106],
            (19, 5, True): [0.83888889, 0.],
            (14, 9, True): [-0.06896552, 0.14096916],
            (13, 6, True): [-0.10569106, 0.22110553],
            (21, 6, False): [1.78449612, 0.],
            (17, 4, True): [0.02572347, -0.09427609],
            (15, 7, True): [-0.048, 0.09677419],
            (15, 1, True): [-0.66666667, -0.625],
            (17, 1, True): [-0.5483871, -0.53061224],
            (15, 5, True): [-0.11673152, -0.09266409],
            (18, 6, True): [-0.00645161, 0.15584416],
            (18, 8, True): [0.30083565, 0.],
            (20, 8, True): [1.62871287, 0.],
            (13, 10, True): [-0.47659574, -0.48],
            (20, 6, True): [1.49431818, 0.],
            (15, 6, True): [-0.14516129, -0.10355987],
            (20, 2, True): [-0.00544959, 0.53107345],
            (20, 9, True): [-0.00527704, 0.10989011],
            (13, 8, True): [-0.02564103, 0.37383178],
            (13, 4, True): [-0.09338521, 0.18333333],
            (13, 1, True): [-0.51428571, -0.46280992],
            (12, 3, True): [-0.01680672, 0.15873016],
            (13, 3, True): [-0.22900763, -0.23423423],
            (17, 5, True): [-0.04137931, 0.27272727],
            (15, 9, True): [-0.3153527, -0.3153527],
            (18, 2, True): [0.24342105, -0.01384083],
            (14, 3, True): [-0.00873362, 0.13278008],
            (12, 6, True): [-0.08928571, 0.17391304],
            (18, 4, True): [0.28382838, 0.],
            (16, 7, True): [-0.10691824, 0.12456747],
            (15, 8, True): [-0.19920319, -0.08658009],
            (12, 8, True): [-0.14141414, 0.03539823],
            (14, 4, True): [-0.07582938, 0.03524229],
            (19, 4, True): [0.86980609, 0.],
            (19, 3, True): [0.81632653, -0.0247678],
            (19, 6, True): [0.95468278, -0.01223242],
            (18, 7, True): [1.03135889, -0.0125],
            (17, 2, True): [-0.16891892, -0.1384083],
            (20, 3, True): [1.36778116, 0.],
            (18, 3, True): [0.26911315, 0.],
            (19, 8, True): [1.23646724, 0.],
            (12, 5, True): [-0.01652893, 0.31858407],
            (12, 1, True): [-0.32653061, -0.30252101],
            (12, 10, True): [-0.47413793, -0.41322314]
        }

        random_policy = create_random_policy(env.action_space.n)
        Q, _ = mc_control_importance_sampling(env,
                                              num_episodes=500000,
                                              behavior_policy=random_policy)
        self.assert_float_dict_almost_equal(expected_q_values, Q, decimal=2)