plt.close(fig1) else: plt.show(fig1) # Plot the episode reward over time fig2 = plt.figure(figsize=(10,5)) rewards_smoothed = pd.Series(stats.episode_rewards).rolling(smoothing_window, min_periods=smoothing_window).mean() plt.plot(rewards_smoothed) plt.xlabel("Episode") plt.ylabel("Episode Reward (Smoothed)") plt.title("Episode Reward over Time (Smoothed over window size {})".format(smoothing_window)) if noshow: plt.close(fig2) else: plt.show(fig2) return fig1, fig2 if __name__ == "__main__": env = BlackjackEnv() random_policy = create_random_policy(env.nA) Q, _ = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy) V = defaultdict(float) for state, action_values in Q.items(): action_value = np.max(action_values) V[state] = action_value fig, axarr = plt.subplots(2, subplot_kw={'projection': '3d'}) plot_value_function(V, axarr, title="500.000 Steps") plt.show()
def test_q_values(self): np.random.seed(0) env = BlackjackEnv(test=True) expected_q_values = { (18, 10, False): [-0.23533037475345092, -0.65069513406157], (20, 6, False): [0.6990585070611964, -0.8814504881450475], (19, 9, False): [0.23174294060370004, -0.74], (12, 9, False): [-0.5431985294117646, -0.29656419529837275], (17, 8, False): [-0.4034582132564843, -0.4707282246549266], (20, 9, True): [0.7628571428571427, 0.09944751381215464], (17, 4, False): [-0.12105751391465681, -0.5326237852845899], (13, 4, False): [-0.2312764955252001, -0.29011786038077975], (17, 1, False): [-0.6282051282051277, -0.6655389076848715], (13, 3, False): [-0.26743075453677173, -0.2716210343328985], (16, 6, False): [-0.10835322195704067, -0.4610136452241714], (20, 2, False): [0.6376912378303203, -0.8531152105812742], (20, 9, False): [0.7585848074921976, -0.8680203045685262], (21, 4, False): [0.8698830409356734, -1.0], (16, 1, False): [-0.7971721111652841, -0.6776007497656986], (15, 10, False): [-0.5712454852615625, -0.5446418205038894], (14, 5, False): [-0.13802816901408452, -0.3296193129062211], (14, 9, False): [-0.565416285452881, -0.4146797568957452], (13, 6, True): [-0.2627118644067797, 0.2666666666666668], (18, 10, True): [-0.1964996022275259, -0.1846153846153846], (18, 5, False): [0.2162293488824098, -0.6132542037586541], (21, 8, True): [0.922656960873521, 0.19824561403508784], (18, 3, False): [0.16515944788196105, -0.6347826086956525], (17, 3, False): [-0.13083213083213088, -0.5655934646804432], (20, 3, False): [0.6458835687220116, -0.8849165815457946], (19, 10, False): [-0.015288999378495947, -0.7419515847267275], (18, 1, False): [-0.36386386386386316, -0.6866096866096864], (14, 1, False): [-0.7548566142460688, -0.5900226757369622], (18, 9, False): [-0.2109337203676826, -0.5946843853820601], (14, 3, False): [-0.265155020823693, -0.3708165997322623], (16, 5, False): [-0.20906567992599415, -0.3809971777986831], (21, 5, True): [0.891840607210625, 0.3579545454545456], (20, 10, False): [0.4393263157894733, -0.8543532020124465], (15, 9, False): [-0.5498449268941082, -0.526289434151226], (16, 7, False): [-0.4891454965357966, -0.40610104861773155], (17, 10, False): [-0.47040094339622607, -0.6040100250626568], (19, 8, False): [0.5706627680311898, -0.7149829184968268], (17, 9, False): [-0.4056378404204498, -0.545539033457249], (13, 7, False): [-0.4789838337182441, -0.2375162831089877], (17, 5, False): [-0.02606177606177609, -0.5548141086749286], (13, 10, False): [-0.5757291788684543, -0.45740615868734547], (14, 4, False): [-0.21376146788990846, -0.31591448931116406], (12, 10, False): [-0.5766489421880333, -0.4186740077999534], (14, 10, False): [-0.5757015821688396, -0.5145094426531549], (16, 10, False): [-0.5621724630776569, -0.5581314477802275], (15, 2, False): [-0.3027101515847493, -0.4245939675174012], (16, 3, False): [-0.2693409742120339, -0.46125797629899834], (16, 2, False): [-0.2568509057129591, -0.4964166268514091], (16, 4, False): [-0.21896383186705776, -0.48490566037735827], (21, 2, True): [0.8876739562624256, 0.32078853046595013], (19, 1, False): [-0.1036889332003989, -0.7973231357552581], (19, 4, False): [0.4097258147956546, -0.7285067873303166], (21, 8, False): [0.9245723172628312, -1.0], (13, 8, False): [-0.524061810154526, -0.3537757437070942], (16, 8, False): [-0.5093139482053626, -0.4488497307880566], (15, 1, False): [-0.7456647398843936, -0.5942549371633743], (12, 5, False): [-0.14081996434937621, -0.1990846681922192], (15, 4, False): [-0.23978201634877452, -0.38447319778188516], (13, 6, False): [-0.15426997245179058, -0.18335619570187484], (21, 10, False): [0.8933922397233937, -1.0], (19, 3, False): [0.4270570418980312, -0.7368961973278526], (21, 1, False): [0.6351865955826352, -1.0], (20, 1, False): [0.15662650602409645, -0.8789297658862895], (20, 5, False): [0.6695088676671217, -0.8277925531914894], (18, 7, False): [0.3904576436222005, -0.5871954132823696], (20, 7, False): [0.7727583846680343, -0.8603089321692415], (21, 6, False): [0.8976631748589844, -1.0], (21, 10, True): [0.8835616438356135, 0.061950993989828944], (13, 10, True): [-0.5957219251336902, -0.03703703703703699], (12, 4, False): [-0.19615912208504802, -0.23269316659222872], (20, 4, False): [0.6774526678141133, -0.8362611866092168], (15, 10, True): [-0.5518913676042678, -0.19183673469387752], (15, 3, False): [-0.2550790067720088, -0.41751152073732706], (18, 2, False): [0.10782442748091597, -0.6102418207681367], (18, 4, False): [0.16543937162493852, -0.5975723622782455], (21, 4, True): [0.9018181818181826, 0.28173913043478266], (12, 3, False): [-0.2545931758530182, -0.26359447004608316], (13, 5, False): [-0.1932692307692308, -0.2750572082379863], (13, 1, False): [-0.7839059674502722, -0.5131052865393172], (16, 9, False): [-0.5247614720581559, -0.5062761506276147], (17, 9, True): [-0.42248062015503873, -0.18709677419354848], (21, 5, False): [0.8986852281515859, -1.0], (14, 2, False): [-0.27932960893854786, -0.39484777517564384], (18, 6, True): [0.22683706070287538, 0.19620253164556958], (15, 8, False): [-0.49580615097856545, -0.431539187913126], (15, 5, False): [-0.15183246073298437, -0.35169300225733613], (21, 9, False): [0.9377076411960139, -1.0], (12, 1, False): [-0.7730684326710832, -0.4712245781047172], (15, 6, False): [-0.15188679245283032, -0.353219696969697], (12, 8, False): [-0.515682656826568, -0.31906799809795494], (21, 7, False): [0.9144951140065153, -1.0], (21, 2, False): [0.8893344025661581, -1.0], (18, 6, False): [0.27338826951042144, -0.621114948199309], (20, 8, True): [0.7645259938837922, 0.15555555555555556], (12, 8, True): [-0.4285714285714286, 0.411764705882353], (12, 6, False): [-0.13891362422083722, -0.1832167832167831], (19, 6, False): [0.47623713865752093, -0.723625557206537], (19, 2, False): [0.37239979705733056, -0.7525150905432583], (19, 7, True): [0.6631578947368428, 0.25595238095238093], (20, 8, False): [0.7776617954070979, -0.8442622950819666], (17, 6, False): [0.04052165812761995, -0.5332403533240367], (14, 3, True): [-0.14716981132075468, -0.009433962264150898], (19, 7, False): [0.6104999999999992, -0.7406483790523696], (21, 3, False): [0.8782201405152218, -1.0], (16, 1, True): [-0.7762237762237766, -0.3028169014084505], (21, 1, True): [0.6742909423604755, -0.0976095617529882], (12, 7, False): [-0.44228157537347185, -0.1881818181818183], (15, 7, False): [-0.4949026876737715, -0.3318603623508622], (12, 10, True): [-0.5485327313769756, -0.1541666666666666], (17, 7, False): [-0.0702936928261917, -0.4908235294117645], (14, 6, False): [-0.11781076066790352, -0.28493150684931506], (16, 8, True): [-0.43772241992882555, -0.10967741935483877], (15, 6, True): [-0.15942028985507253, 0.17361111111111116], (14, 7, False): [-0.46685210941121913, -0.2968897266729508], (12, 2, False): [-0.30228471001757445, -0.26256458431188295], (17, 2, False): [-0.20066256507335523, -0.5998142989786454], (13, 2, False): [-0.2899628252788114, -0.35277516462841], (19, 5, False): [0.4169215086646282, -0.7076845806127565], (19, 8, True): [0.5273224043715847, 0.1807228915662651], (19, 1, True): [-0.10169491525423732, -0.2840236686390534], (20, 4, True): [0.626903553299492, 0.19170984455958548], (13, 4, True): [-0.2096774193548386, 0.2905982905982906], (17, 10, True): [-0.4914145543744891, -0.3162393162393163], (20, 1, True): [0.09164420485175204, -0.13089005235602094], (14, 4, True): [-0.11740890688259109, 0.22321428571428564], (13, 3, True): [-0.25345622119815664, 0.05434782608695648], (20, 6, True): [0.6878612716763011, 0.2857142857142856], (12, 5, True): [-0.19999999999999996, 0.2727272727272728], (19, 10, True): [0.008559201141226814, -0.17101449275362335], (16, 5, True): [-0.1184210526315789, 0.1486486486486487], (18, 8, False): [0.10637254901960781, -0.6134939759036155], (14, 8, False): [-0.5028546332894172, -0.3722763096893826], (14, 7, True): [-0.4942528735632186, -0.06153846153846153], (19, 9, True): [0.25867507886435326, -0.10447761194029857], (16, 9, True): [-0.5245283018867921, -0.14383561643835613], (21, 9, True): [0.9414455626715449, 0.10370370370370366], (13, 9, True): [-0.4615384615384616, 0.19148936170212763], (12, 6, True): [-0.20312499999999994, 0.1864406779661018], (21, 7, True): [0.9118457300275489, 0.252808988764045], (19, 5, True): [0.5297805642633239, 0.054216867469879554], (18, 1, True): [-0.36176470588235293, -0.42592592592592615], (21, 3, True): [0.8816169393647741, 0.2359767891682784], (15, 2, True): [-0.2845528455284551, -0.04065040650406501], (20, 3, True): [0.7316384180790965, 0.14942528735632185], (18, 7, True): [0.43181818181818166, 0.11695906432748544], (15, 7, True): [-0.47985347985347976, 0.06896551724137931], (12, 4, True): [-0.10091743119266058, 0.18181818181818177], (18, 8, True): [0.05014749262536869, 0.1079136690647482], (17, 2, True): [-0.1891891891891892, -0.1259842519685039], (17, 3, True): [-0.0899280575539568, 0.043209876543209895], (16, 10, True): [-0.58287795992714, -0.27560521415269995], (20, 10, True): [0.42847173761339813, -0.02462380300957595], (16, 2, True): [-0.362549800796813, -0.07575757575757579], (13, 9, False): [-0.50587211831231, -0.38563829787234005], (14, 1, True): [-0.8295964125560541, -0.1869158878504673], (18, 9, True): [-0.13504823151125409, -0.11764705882352944], (20, 5, True): [0.6820652173913053, 0.17708333333333337], (15, 5, True): [-0.25196850393700787, 0.027777777777777794], (20, 7, True): [0.7968337730870713, 0.1851851851851852], (16, 7, True): [-0.5053003533568905, -0.05673758865248227], (13, 7, True): [-0.4891774891774891, -0.017241379310344848], (12, 7, True): [-0.5419847328244272, 0.37333333333333335], (14, 10, True): [-0.547877591312932, -0.07954545454545457], (16, 3, True): [-0.19999999999999996, 0.027586206896551748], (15, 8, True): [-0.5502008032128513, -0.07913669064748201], (20, 2, True): [0.6198979591836736, 0.2848101265822785], (19, 6, True): [0.5433526011560694, 0.2336956521739131], (21, 6, True): [0.8909090909090909, 0.29304029304029283], (14, 9, True): [-0.5, -0.07272727272727275], (19, 4, True): [0.38855421686747016, 0.2530864197530865], (18, 2, True): [0.12871287128712872, 0.10457516339869276], (14, 2, True): [-0.29059829059829034, 0.025641025641025612], (15, 4, True): [-0.19215686274509802, -0.06086956521739138], (18, 4, True): [0.16279069767441848, 0.08284023668639053], (13, 1, True): [-0.7543859649122808, -0.3719008264462809], (18, 3, True): [0.08433734939759038, 0.20394736842105263], (16, 6, True): [-0.2666666666666665, -0.014598540145985398], (19, 3, True): [0.38601823708206706, 0.03428571428571428], (15, 9, True): [-0.6296296296296297, -0.12403100775193798], (13, 5, True): [-0.31225296442687717, 0.06060606060606062], (15, 1, True): [-0.7534246575342467, -0.47368421052631565], (17, 6, True): [-0.03859649122807019, 0.23333333333333334], (14, 6, True): [-0.146341463414634, 0.18584070796460178], (12, 1, True): [-0.7723577235772359, -0.2857142857142858], (15, 3, True): [-0.21481481481481488, 0.08088235294117646], (18, 5, True): [0.23262839879154074, 0.02366863905325443], (14, 8, True): [-0.561181434599156, -0.25833333333333325], (13, 8, True): [-0.6306306306306306, 0.11678832116788318], (13, 2, True): [-0.33333333333333326, 0.14999999999999988], (17, 5, True): [-0.043771043771043766, 0.04411764705882353], (12, 9, True): [-0.5238095238095237, -0.08333333333333337], (17, 1, True): [-0.6195652173913049, -0.3984962406015038], (12, 3, True): [-0.2982456140350876, -0.0888888888888889], (16, 4, True): [-0.1357142857142857, -0.04216867469879521], (19, 2, True): [0.457865168539326, 0.1381578947368422], (17, 4, True): [-0.20312500000000003, -0.07142857142857147], (17, 8, True): [-0.43050847457627134, -0.12592592592592594], (17, 7, True): [-0.11224489795918369, 0.12142857142857143], (12, 2, True): [-0.30097087378640786, 0.028169014084507], (14, 5, True): [-0.23320158102766794, 0.19587628865979384] } random_policy = create_random_policy(2) Q, _ = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy) self.assert_float_dict_almost_equal(expected_q_values, Q, decimal=2)
求解问题: 评估该策略的好坏。 求解过程 使用庄家显示的牌面值、玩家当前牌面总分值来确定一个二维状态空间,区分手中有无A分别处理。统计每一牌局下决定状态的庄家和玩家牌面的状态数据,同时计算其最终收获。通过模拟多次牌局,计算每一个状态下的平均值,得到如下图示。 最终结果 无论玩家手中是否有A牌,该策略在绝大多数情况下各状态价值都较低,只有在玩家拿到21分时状态价值有一个明显的提升。当前牌的分数(12 - 21),低于12时,你可以安全的再叫牌,所以没意义。 """ import numpy as np import time from blackjack import BlackjackEnv # 定义 21点 的环境 env = BlackjackEnv() # 显示一个observation的信息 def displayObservation(observation): # 将observation进行分解 score, dealerScore, usableAce = observation print("玩家分数为: {}(是否有可用的Ace: {}), 庄家的分数为: {}".format( score, usableAce, dealerScore)) # 显示当前observation对应的 策略 def policy(observation): # 将observation进行分解 score, dealerScore, usableSce = observation
returnn = self._visited_states_returns[key] new_val = \ state.value_pi + self.alpha * (returnn - state.value_pi) state.value_pi = new_val def remember(self, obs, reward): if obs not in self._visited_states_returns: self._visited_states_returns[obs] = 0 for obs in self._visited_states_returns: self._visited_states_returns[obs] += reward if __name__ == '__main__': env = BlackjackEnv() agent = BlackjackAgent() total_ep = 100000 for episode in range(total_ep): if episode % 1000 == 0: print('episode: {0} / {1}', episode, total_ep) # print() # print(' == EPISODE {0} START =='.format(episode)) obs = env.reset() agent.reset() # print('HAND: ', end=''); print(env.player_hand) # print('STATE pp={0}, ha={1}, dc={2}'.format(obs[0], obs[1], obs[2]))
def setUp(self): self.sample_policy = lambda observation: 0 if observation[0] >= 20 else 1 self.env = BlackjackEnv(test=True)