def test_blackjack(self): expect_usable_a = [ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], ] iterator = Sarsa(BlackjackEnv()) iterator.run(30000000, learning_rate=1) print(iterator.env.states.states) # blackjack_environment = BlackjackEnvironment() # blackjack_environment.monte_carlo_es() self.assertEqual(expect_usable_a, BlackjackGetter(iterator.env).get())
def test_blackjack(self): expect_usable_a = [ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], ] iterator = OffNStepSarsa(BlackjackEnv(), 3) iterator.run(3000000, learning_rate=0.5, epsilon=0.3) print(iterator.env.states.states) iterator.show_one_episode() # blackjack_environment = BlackjackEnvironment() # blackjack_environment.monte_carlo_es() self.assertEqual(expect_usable_a, BlackjackGetter(iterator.env).get())
def main(): env = BlackjackEnv() V_10k = mc_prediction(sample_policy, env, num_episodes=10000) plotting.plot_value_function(V_10k, title="10,000 Steps") V_500k = mc_prediction(sample_policy, env, num_episodes=500000) plotting.plot_value_function(V_500k, title="500,000 Steps")
def getEnv(domain): if domain == "Blackjack": return BlackjackEnv() elif domain == "Gridworld": return GridworldEnv() elif domain == "CliffWalking": return CliffWalkingEnv() elif domain == "WindyGridworld": return WindyGridworldEnv() else: try: return gym.make(domain) except: assert False, "Domain must be a valid (and installed) Gym environment"
def main(): env = BlackjackEnv() Q, policy = mc_control_epsilon_greedy(env, num_episodes=500000, epsilon=0.1) # For plotting: Create value function from action-value function # by picking the best action at each state V = defaultdict(float) for state, actions in Q.items(): action_value = np.max(actions) V[state] = action_value plotting.plot_value_function(V, title="Optimal Value Function")
np.dstack([X, Y])) def plot_surface(X, Y, Z, title): fig = plt.figure(figsize=(20, 10)) ax = fig.add_subplot(111, projection='3d') surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=matplotlib.cm.coolwarm, vmin=-1.0, vmax=1.0) ax.set_xlabel('Player Sum') ax.set_ylabel('Dealer Showing') ax.set_zlabel('Value') ax.set_title(title) ax.view_init(ax.elev, -120) fig.colorbar(surf) plt.show() plot_surface(X, Y, Z_noace, "{} (No Usable Ace)".format(title)) plot_surface(X, Y, Z_ace, "{} (Usable Ace)".format(title)) if __name__ == '__main__': env = BlackjackEnv() steps = 500000 Q, policy = epsilon_monte_carlo(env, steps) V = get_value_function(Q, policy) plot_value_function(V)
import sys import os import gym RF_REPO = "{}/distr/reinforcement-learning".format(os.environ["HOME"]) if RF_REPO not in sys.path: sys.path.append(RF_REPO) from lib.envs.blackjack import BlackjackEnv from collections import defaultdict import numpy as np from gym import spaces from lib import plotting env = BlackjackEnv() def reverse_array(a): assert len(a.shape) == 1 return np.fliplr([a])[0] def sample_policy(observation): """ A policy that sticks if the player score is >= 20 and hits otherwise. """ score, dealer_score, usable_ace = observation return 0 if score >= 20 else 1 policy = sample_policy discount_factor = 1.0
from collections import defaultdict import numpy as np from lib import plotting from lib.envs.blackjack import BlackjackEnv from mdp.algorithms.double_q_learning import DoubleQLearning from mdp.algorithms.mc_offline import McOfflinePolicy from mdp.algorithms.mc_online import McOnline from mdp.algorithms.q_learning import QLearning if __name__ == '__main__': iterator = QLearning(BlackjackEnv()) iterator.run(500000) print(iterator.env.states.states) Q = iterator.env.to_v() V = defaultdict(float) for state, actions in Q.items(): action_value = np.max(actions) V[state] = action_value plotting.plot_value_function(V, title="Optimal Value Function")
from collections import defaultdict import numpy as np import sys print(sys.path) from mdp.algorithms.double_q_learning import DoubleQLearning from mdp.algorithms.mc_offline import McOfflinePolicy from mdp.algorithms.mc_online import McOnline from mdp.algorithms.off_n_step_sarsa import OffNStepSarsa from lib import plotting from lib.envs.blackjack import BlackjackEnv from mdp.algorithms.q_learning import QLearning if __name__ == '__main__': iterator = OffNStepSarsa(BlackjackEnv(), 3) iterator.run(500000, epsilon=0.3, learning_rate=0.5) print(iterator.env.states.states) Q = iterator.env.to_v() V = defaultdict(float) for state, actions in Q.items(): action_value = np.max(actions) V[state] = action_value plotting.plot_value_function(V, title="Optimal Value Function")
def setUp(self): self.sample_policy = lambda observation: 0 if observation[ 0] >= 20 else 1 self.env = BlackjackEnv(test=True)
import numpy as np import sys if "../" not in sys.path: sys.path.append("../") from lib.envs.blackjack import BlackjackEnv env=BlackjackEnv() def print_observation(observation): score, dealer_score, usable_ace = observation print("Player Score: {} (Usable Ace: {}), Dealer Score: {}".format( score, usable_ace, dealer_score)) def strategy(observation): score, dealer_score, usable_ace = observation # Stick (action 0) if the score is > 20, hit (action 1) otherwise return 0 if score >= 20 else 1 for i_episode in range(20): observation = env.reset() for t in range(100): print_observation(observation) action = strategy(observation) print("Taking action: {}".format( ["Stick", "Hit"][action])) observation, reward, done, _ = env.step(action) if done: print_observation(observation) print("Game end. Reward: {}\n".format(float(reward))) break
def test_q_values(self): np.random.seed(0) env = BlackjackEnv(test=True) expected_q_values = { (14, 10, False): [-0.55667244, -0.55666552], (20, 10, False): [0.85307258, 0.], (19, 10, False): [-0.03771662, -0.03913207], (15, 3, True): [-0.08365019, -0.06177606], (17, 3, True): [-0.09180328, -0.02758621], (18, 7, False): [0.80474543, 0.], (21, 9, True): [1.87136564, 0.], (15, 10, False): [-0.57202554, -0.57083906], (13, 5, False): [-0.19646018, -0.19606004], (17, 7, False): [-0.21999044, -0.22351798], (21, 4, True): [1.74716981, 0.], (21, 4, False): [1.75901639, 0.], (20, 6, False): [1.39322034, 0.], (16, 5, False): [-0.22599418, -0.22754491], (20, 2, False): [1.30929919, 0.], (17, 10, False): [-0.53484487, -0.53445418], (18, 10, False): [-0.36599388, -0.36597938], (18, 10, True): [-0.17896389, -0.17784711], (13, 10, False): [-0.50639033, -0.50642118], (12, 3, False): [-0.27648931, -0.27795976], (21, 3, False): [1.77389985, 0.], (20, 1, False): [0.27903732, 0.], (19, 6, False): [0.93563102, -0.00103093], (20, 3, False): [1.35117278, 0.], (16, 6, False): [-0.27862595, -0.27877121], (16, 8, False): [-0.49311164, -0.49056604], (15, 6, False): [-0.18892508, -0.19052002], (13, 4, False): [-0.22846782, -0.2326228], (16, 4, False): [-0.27368421, -0.27437859], (15, 2, False): [-0.37379753, -0.37822878], (12, 10, False): [-0.49994324, -0.49988649], (13, 9, True): [-0.30493274, -0.2962963], (13, 9, False): [-0.43541762, -0.43526171], (14, 9, False): [-0.49571106, -0.49582947], (16, 9, False): [-0.5288868, -0.52892562], (18, 5, False): [0.39844886, -0.00187705], (19, 3, False): [0.79941719, 0.], (20, 7, True): [1.56769596, 0.], (20, 4, True): [1.31073446, 0.], (16, 3, False): [-0.31508805, -0.31433998], (15, 4, True): [-0.19305019, -0.05042017], (20, 7, False): [1.57104011e+00, -1.36147039e-03], (17, 6, False): [0.02804642, -0.01353311], (12, 2, True): [-0.14, 0.10714286], (12, 8, False): [-0.3288807, -0.32779317], (16, 10, False): [-0.57658694, -0.57661009], (13, 5, True): [-0.05882353, 0.36078431], (16, 2, False): [-0.36167076, -0.36104513], (13, 2, True): [-0.04918033, 0.27705628], (13, 2, False): [-0.30961183, -0.31088561], (17, 8, False): [-0.43448276, -0.43632869], (20, 4, False): [1.32866667, 0.], (21, 8, True): [1.89107981, 0.], (14, 6, True): [-0.25104603, -0.26771654], (14, 1, False): [-0.68097015, -0.68043088], (18, 8, False): [0.2268431, 0.], (19, 4, False): [0.83059548, -0.00294985], (14, 2, False): [-0.30662983, -0.30350554], (18, 4, False): [0.35192308, -0.00196175], (21, 10, False): [1.787005, 0.], (12, 9, False): [-0.42573821, -0.42732049], (17, 9, False): [-0.48977695, -0.48971784], (18, 9, False): [-0.29426189, -0.2952381], (20, 5, False): [1.36117768, 0.], (19, 10, True): [-0.02996255, -0.05882353], (20, 10, True): [0.89829352, 0.], (15, 8, False): [-0.4364667, -0.43577982], (16, 3, True): [-0.19157088, -0.17161716], (16, 2, True): [-0.22900763, -0.24334601], (20, 9, False): [1.52642706, 0.], (15, 4, False): [-0.31670481, -0.31813953], (18, 6, False): [0.59704433, 0.], (16, 9, True): [-0.16494845, -0.14915254], (18, 2, False): [0.26942482, -0.00291971], (19, 2, True): [0.84726225, 0.], (21, 7, False): [1.85587045, 0.], (12, 6, False): [-0.21556886, -0.21290323], (13, 6, False): [-0.17260274, -0.17280453], (12, 4, True): [-0.01550388, 0.16666667], (21, 1, False): [1.30333592, 0.], (14, 8, False): [-0.39443155, -0.3945157], (20, 8, False): [1.57485637, 0.], (14, 5, False): [-0.22222222, -0.2221231], (15, 5, False): [-0.19739292, -0.20194535], (21, 9, False): [1.89735365, 0.], (16, 1, True): [-0.65562914, -0.66420664], (17, 5, False): [-0.07304181, -0.07662464], (21, 5, False): [1.77480315, 0.], (13, 7, False): [-0.41295547, -0.41235241], (16, 10, True): [-0.36783734, -0.36730946], (18, 3, False): [0.27807487, -0.00099108], (16, 8, True): [-0.02013423, 0.07604563], (16, 1, False): [-0.72951208, -0.72945522], (19, 2, False): [0.79284963, -0.00289995], (14, 10, True): [-0.40317776, -0.38410596], (12, 2, False): [-0.30944774, -0.30908269], (13, 3, False): [-0.31129477, -0.31302801], (14, 4, False): [-0.23059256, -0.23224044], (19, 8, False): [1.25283391, -0.00195027], (18, 1, True): [-0.51465798, -0.50666667], (16, 4, True): [-0.00746269, 0.1294964], (18, 1, False): [-0.48914616, -0.48893167], (21, 5, True): [1.76222435, 0.], (19, 9, False): [0.61055777, 0.], (21, 10, True): [1.78874539, 0.], (17, 1, False): [-0.67257509, -0.67368421], (15, 9, False): [-0.48830683, -0.4886313], (13, 7, True): [-0.08695652, 0.09022556], (20, 5, True): [1.44038929, 0.], (12, 7, True): [-0.3442623, -0.31007752], (12, 7, False): [-0.3150022, -0.31488203], (15, 3, False): [-0.33017975, -0.33074463], (14, 5, True): [-0.07659574, 0.09302326], (13, 1, False): [-0.6660542, -0.66544622], (12, 1, False): [-0.67491166, -0.66696468], (15, 7, False): [-0.40413112, -0.40112729], (21, 2, True): [1.76015109, 0.], (15, 1, False): [-0.7369403, -0.73708069], (21, 3, True): [1.81576448, 0.], (16, 6, True): [-0.07352941, 0.25559105], (14, 7, True): [-0.24590164, -0.07017544], (16, 7, False): [-0.43923445, -0.43661972], (13, 8, False): [-0.39911894, -0.39716312], (14, 3, False): [-0.2665424, -0.26531552], (17, 3, False): [-0.14574518, -0.14522059], (17, 2, False): [-0.26778243, -0.27407762], (18, 9, True): [-0.19016393, -0.22960725], (21, 8, False): [1.86946011, 0.], (12, 4, False): [-0.23551229, -0.23282783], (17, 4, False): [-0.14487117, -0.14627011], (17, 6, True): [-0.03870968, 0.16828479], (21, 2, False): [1.79160187, 0.], (14, 1, True): [-0.5785124, -0.58039216], (19, 7, False): [1.20190275, 0.], (19, 1, False): [-0.19358074, -0.1960396], (19, 5, False): [0.8641115, 0.], (14, 7, False): [-0.43006834, -0.43024894], (19, 9, True): [0.48, 0.], (21, 7, True): [1.8575152, 0.], (17, 10, True): [-0.35404255, -0.34064081], (17, 8, True): [-0.02749141, 0.23255814], (14, 6, False): [-0.17974453, -0.18083671], (21, 6, True): [1.83101045, 0.], (15, 10, True): [-0.35166994, -0.35151515], (14, 8, True): [-0.2578125, -0.03305785], (21, 1, True): [1.24786325, 0.], (19, 1, True): [-0.14044944, -0.15384615], (12, 9, True): [-0.24299065, -0.24778761], (17, 9, True): [-0.26058632, -0.26751592], (16, 5, True): [-0.07560137, -0.10332103], (12, 5, False): [-0.2344519, -0.22790489], (17, 7, True): [-0.02076125, 0.10996564], (19, 7, True): [1.31855956, 0.], (14, 2, True): [-0.06837607, 0.08510638], (15, 2, True): [-0.0661157, 0.01459854], (20, 1, True): [0.32608696, -0.01041667], (18, 5, True): [0.38235294, -0.01246106], (19, 5, True): [0.83888889, 0.], (14, 9, True): [-0.06896552, 0.14096916], (13, 6, True): [-0.10569106, 0.22110553], (21, 6, False): [1.78449612, 0.], (17, 4, True): [0.02572347, -0.09427609], (15, 7, True): [-0.048, 0.09677419], (15, 1, True): [-0.66666667, -0.625], (17, 1, True): [-0.5483871, -0.53061224], (15, 5, True): [-0.11673152, -0.09266409], (18, 6, True): [-0.00645161, 0.15584416], (18, 8, True): [0.30083565, 0.], (20, 8, True): [1.62871287, 0.], (13, 10, True): [-0.47659574, -0.48], (20, 6, True): [1.49431818, 0.], (15, 6, True): [-0.14516129, -0.10355987], (20, 2, True): [-0.00544959, 0.53107345], (20, 9, True): [-0.00527704, 0.10989011], (13, 8, True): [-0.02564103, 0.37383178], (13, 4, True): [-0.09338521, 0.18333333], (13, 1, True): [-0.51428571, -0.46280992], (12, 3, True): [-0.01680672, 0.15873016], (13, 3, True): [-0.22900763, -0.23423423], (17, 5, True): [-0.04137931, 0.27272727], (15, 9, True): [-0.3153527, -0.3153527], (18, 2, True): [0.24342105, -0.01384083], (14, 3, True): [-0.00873362, 0.13278008], (12, 6, True): [-0.08928571, 0.17391304], (18, 4, True): [0.28382838, 0.], (16, 7, True): [-0.10691824, 0.12456747], (15, 8, True): [-0.19920319, -0.08658009], (12, 8, True): [-0.14141414, 0.03539823], (14, 4, True): [-0.07582938, 0.03524229], (19, 4, True): [0.86980609, 0.], (19, 3, True): [0.81632653, -0.0247678], (19, 6, True): [0.95468278, -0.01223242], (18, 7, True): [1.03135889, -0.0125], (17, 2, True): [-0.16891892, -0.1384083], (20, 3, True): [1.36778116, 0.], (18, 3, True): [0.26911315, 0.], (19, 8, True): [1.23646724, 0.], (12, 5, True): [-0.01652893, 0.31858407], (12, 1, True): [-0.32653061, -0.30252101], (12, 10, True): [-0.47413793, -0.41322314] } random_policy = create_random_policy(env.action_space.n) Q, _ = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy) self.assert_float_dict_almost_equal(expected_q_values, Q, decimal=2)