Пример #1
0
    plt.close(fig1)
  else:
    plt.show(fig1)

  # Plot the episode reward over time
  fig2 = plt.figure(figsize=(10,5))
  rewards_smoothed = pd.Series(stats.episode_rewards).rolling(smoothing_window, min_periods=smoothing_window).mean()
  plt.plot(rewards_smoothed)
  plt.xlabel("Episode")
  plt.ylabel("Episode Reward (Smoothed)")
  plt.title("Episode Reward over Time (Smoothed over window size {})".format(smoothing_window))
  if noshow:
    plt.close(fig2)
  else:
    plt.show(fig2)

  return fig1, fig2

if __name__ == "__main__":
  env = BlackjackEnv()
  random_policy = create_random_policy(env.nA)
  Q, _ = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy)

  V = defaultdict(float)
  for state, action_values in Q.items():
    action_value = np.max(action_values)
    V[state] = action_value

  fig, axarr = plt.subplots(2, subplot_kw={'projection': '3d'})
  plot_value_function(V, axarr, title="500.000 Steps")
  plt.show()
    def test_q_values(self):
        np.random.seed(0)
        env = BlackjackEnv(test=True)
        expected_q_values = {
            (18, 10, False): [-0.23533037475345092, -0.65069513406157],
            (20, 6, False): [0.6990585070611964, -0.8814504881450475],
            (19, 9, False): [0.23174294060370004, -0.74],
            (12, 9, False): [-0.5431985294117646, -0.29656419529837275],
            (17, 8, False): [-0.4034582132564843, -0.4707282246549266],
            (20, 9, True): [0.7628571428571427, 0.09944751381215464],
            (17, 4, False): [-0.12105751391465681, -0.5326237852845899],
            (13, 4, False): [-0.2312764955252001, -0.29011786038077975],
            (17, 1, False): [-0.6282051282051277, -0.6655389076848715],
            (13, 3, False): [-0.26743075453677173, -0.2716210343328985],
            (16, 6, False): [-0.10835322195704067, -0.4610136452241714],
            (20, 2, False): [0.6376912378303203, -0.8531152105812742],
            (20, 9, False): [0.7585848074921976, -0.8680203045685262],
            (21, 4, False): [0.8698830409356734, -1.0],
            (16, 1, False): [-0.7971721111652841, -0.6776007497656986],
            (15, 10, False): [-0.5712454852615625, -0.5446418205038894],
            (14, 5, False): [-0.13802816901408452, -0.3296193129062211],
            (14, 9, False): [-0.565416285452881, -0.4146797568957452],
            (13, 6, True): [-0.2627118644067797, 0.2666666666666668],
            (18, 10, True): [-0.1964996022275259, -0.1846153846153846],
            (18, 5, False): [0.2162293488824098, -0.6132542037586541],
            (21, 8, True): [0.922656960873521, 0.19824561403508784],
            (18, 3, False): [0.16515944788196105, -0.6347826086956525],
            (17, 3, False): [-0.13083213083213088, -0.5655934646804432],
            (20, 3, False): [0.6458835687220116, -0.8849165815457946],
            (19, 10, False): [-0.015288999378495947, -0.7419515847267275],
            (18, 1, False): [-0.36386386386386316, -0.6866096866096864],
            (14, 1, False): [-0.7548566142460688, -0.5900226757369622],
            (18, 9, False): [-0.2109337203676826, -0.5946843853820601],
            (14, 3, False): [-0.265155020823693, -0.3708165997322623],
            (16, 5, False): [-0.20906567992599415, -0.3809971777986831],
            (21, 5, True): [0.891840607210625, 0.3579545454545456],
            (20, 10, False): [0.4393263157894733, -0.8543532020124465],
            (15, 9, False): [-0.5498449268941082, -0.526289434151226],
            (16, 7, False): [-0.4891454965357966, -0.40610104861773155],
            (17, 10, False): [-0.47040094339622607, -0.6040100250626568],
            (19, 8, False): [0.5706627680311898, -0.7149829184968268],
            (17, 9, False): [-0.4056378404204498, -0.545539033457249],
            (13, 7, False): [-0.4789838337182441, -0.2375162831089877],
            (17, 5, False): [-0.02606177606177609, -0.5548141086749286],
            (13, 10, False): [-0.5757291788684543, -0.45740615868734547],
            (14, 4, False): [-0.21376146788990846, -0.31591448931116406],
            (12, 10, False): [-0.5766489421880333, -0.4186740077999534],
            (14, 10, False): [-0.5757015821688396, -0.5145094426531549],
            (16, 10, False): [-0.5621724630776569, -0.5581314477802275],
            (15, 2, False): [-0.3027101515847493, -0.4245939675174012],
            (16, 3, False): [-0.2693409742120339, -0.46125797629899834],
            (16, 2, False): [-0.2568509057129591, -0.4964166268514091],
            (16, 4, False): [-0.21896383186705776, -0.48490566037735827],
            (21, 2, True): [0.8876739562624256, 0.32078853046595013],
            (19, 1, False): [-0.1036889332003989, -0.7973231357552581],
            (19, 4, False): [0.4097258147956546, -0.7285067873303166],
            (21, 8, False): [0.9245723172628312, -1.0],
            (13, 8, False): [-0.524061810154526, -0.3537757437070942],
            (16, 8, False): [-0.5093139482053626, -0.4488497307880566],
            (15, 1, False): [-0.7456647398843936, -0.5942549371633743],
            (12, 5, False): [-0.14081996434937621, -0.1990846681922192],
            (15, 4, False): [-0.23978201634877452, -0.38447319778188516],
            (13, 6, False): [-0.15426997245179058, -0.18335619570187484],
            (21, 10, False): [0.8933922397233937, -1.0],
            (19, 3, False): [0.4270570418980312, -0.7368961973278526],
            (21, 1, False): [0.6351865955826352, -1.0],
            (20, 1, False): [0.15662650602409645, -0.8789297658862895],
            (20, 5, False): [0.6695088676671217, -0.8277925531914894],
            (18, 7, False): [0.3904576436222005, -0.5871954132823696],
            (20, 7, False): [0.7727583846680343, -0.8603089321692415],
            (21, 6, False): [0.8976631748589844, -1.0],
            (21, 10, True): [0.8835616438356135, 0.061950993989828944],
            (13, 10, True): [-0.5957219251336902, -0.03703703703703699],
            (12, 4, False): [-0.19615912208504802, -0.23269316659222872],
            (20, 4, False): [0.6774526678141133, -0.8362611866092168],
            (15, 10, True): [-0.5518913676042678, -0.19183673469387752],
            (15, 3, False): [-0.2550790067720088, -0.41751152073732706],
            (18, 2, False): [0.10782442748091597, -0.6102418207681367],
            (18, 4, False): [0.16543937162493852, -0.5975723622782455],
            (21, 4, True): [0.9018181818181826, 0.28173913043478266],
            (12, 3, False): [-0.2545931758530182, -0.26359447004608316],
            (13, 5, False): [-0.1932692307692308, -0.2750572082379863],
            (13, 1, False): [-0.7839059674502722, -0.5131052865393172],
            (16, 9, False): [-0.5247614720581559, -0.5062761506276147],
            (17, 9, True): [-0.42248062015503873, -0.18709677419354848],
            (21, 5, False): [0.8986852281515859, -1.0],
            (14, 2, False): [-0.27932960893854786, -0.39484777517564384],
            (18, 6, True): [0.22683706070287538, 0.19620253164556958],
            (15, 8, False): [-0.49580615097856545, -0.431539187913126],
            (15, 5, False): [-0.15183246073298437, -0.35169300225733613],
            (21, 9, False): [0.9377076411960139, -1.0],
            (12, 1, False): [-0.7730684326710832, -0.4712245781047172],
            (15, 6, False): [-0.15188679245283032, -0.353219696969697],
            (12, 8, False): [-0.515682656826568, -0.31906799809795494],
            (21, 7, False): [0.9144951140065153, -1.0],
            (21, 2, False): [0.8893344025661581, -1.0],
            (18, 6, False): [0.27338826951042144, -0.621114948199309],
            (20, 8, True): [0.7645259938837922, 0.15555555555555556],
            (12, 8, True): [-0.4285714285714286, 0.411764705882353],
            (12, 6, False): [-0.13891362422083722, -0.1832167832167831],
            (19, 6, False): [0.47623713865752093, -0.723625557206537],
            (19, 2, False): [0.37239979705733056, -0.7525150905432583],
            (19, 7, True): [0.6631578947368428, 0.25595238095238093],
            (20, 8, False): [0.7776617954070979, -0.8442622950819666],
            (17, 6, False): [0.04052165812761995, -0.5332403533240367],
            (14, 3, True): [-0.14716981132075468, -0.009433962264150898],
            (19, 7, False): [0.6104999999999992, -0.7406483790523696],
            (21, 3, False): [0.8782201405152218, -1.0],
            (16, 1, True): [-0.7762237762237766, -0.3028169014084505],
            (21, 1, True): [0.6742909423604755, -0.0976095617529882],
            (12, 7, False): [-0.44228157537347185, -0.1881818181818183],
            (15, 7, False): [-0.4949026876737715, -0.3318603623508622],
            (12, 10, True): [-0.5485327313769756, -0.1541666666666666],
            (17, 7, False): [-0.0702936928261917, -0.4908235294117645],
            (14, 6, False): [-0.11781076066790352, -0.28493150684931506],
            (16, 8, True): [-0.43772241992882555, -0.10967741935483877],
            (15, 6, True): [-0.15942028985507253, 0.17361111111111116],
            (14, 7, False): [-0.46685210941121913, -0.2968897266729508],
            (12, 2, False): [-0.30228471001757445, -0.26256458431188295],
            (17, 2, False): [-0.20066256507335523, -0.5998142989786454],
            (13, 2, False): [-0.2899628252788114, -0.35277516462841],
            (19, 5, False): [0.4169215086646282, -0.7076845806127565],
            (19, 8, True): [0.5273224043715847, 0.1807228915662651],
            (19, 1, True): [-0.10169491525423732, -0.2840236686390534],
            (20, 4, True): [0.626903553299492, 0.19170984455958548],
            (13, 4, True): [-0.2096774193548386, 0.2905982905982906],
            (17, 10, True): [-0.4914145543744891, -0.3162393162393163],
            (20, 1, True): [0.09164420485175204, -0.13089005235602094],
            (14, 4, True): [-0.11740890688259109, 0.22321428571428564],
            (13, 3, True): [-0.25345622119815664, 0.05434782608695648],
            (20, 6, True): [0.6878612716763011, 0.2857142857142856],
            (12, 5, True): [-0.19999999999999996, 0.2727272727272728],
            (19, 10, True): [0.008559201141226814, -0.17101449275362335],
            (16, 5, True): [-0.1184210526315789, 0.1486486486486487],
            (18, 8, False): [0.10637254901960781, -0.6134939759036155],
            (14, 8, False): [-0.5028546332894172, -0.3722763096893826],
            (14, 7, True): [-0.4942528735632186, -0.06153846153846153],
            (19, 9, True): [0.25867507886435326, -0.10447761194029857],
            (16, 9, True): [-0.5245283018867921, -0.14383561643835613],
            (21, 9, True): [0.9414455626715449, 0.10370370370370366],
            (13, 9, True): [-0.4615384615384616, 0.19148936170212763],
            (12, 6, True): [-0.20312499999999994, 0.1864406779661018],
            (21, 7, True): [0.9118457300275489, 0.252808988764045],
            (19, 5, True): [0.5297805642633239, 0.054216867469879554],
            (18, 1, True): [-0.36176470588235293, -0.42592592592592615],
            (21, 3, True): [0.8816169393647741, 0.2359767891682784],
            (15, 2, True): [-0.2845528455284551, -0.04065040650406501],
            (20, 3, True): [0.7316384180790965, 0.14942528735632185],
            (18, 7, True): [0.43181818181818166, 0.11695906432748544],
            (15, 7, True): [-0.47985347985347976, 0.06896551724137931],
            (12, 4, True): [-0.10091743119266058, 0.18181818181818177],
            (18, 8, True): [0.05014749262536869, 0.1079136690647482],
            (17, 2, True): [-0.1891891891891892, -0.1259842519685039],
            (17, 3, True): [-0.0899280575539568, 0.043209876543209895],
            (16, 10, True): [-0.58287795992714, -0.27560521415269995],
            (20, 10, True): [0.42847173761339813, -0.02462380300957595],
            (16, 2, True): [-0.362549800796813, -0.07575757575757579],
            (13, 9, False): [-0.50587211831231, -0.38563829787234005],
            (14, 1, True): [-0.8295964125560541, -0.1869158878504673],
            (18, 9, True): [-0.13504823151125409, -0.11764705882352944],
            (20, 5, True): [0.6820652173913053, 0.17708333333333337],
            (15, 5, True): [-0.25196850393700787, 0.027777777777777794],
            (20, 7, True): [0.7968337730870713, 0.1851851851851852],
            (16, 7, True): [-0.5053003533568905, -0.05673758865248227],
            (13, 7, True): [-0.4891774891774891, -0.017241379310344848],
            (12, 7, True): [-0.5419847328244272, 0.37333333333333335],
            (14, 10, True): [-0.547877591312932, -0.07954545454545457],
            (16, 3, True): [-0.19999999999999996, 0.027586206896551748],
            (15, 8, True): [-0.5502008032128513, -0.07913669064748201],
            (20, 2, True): [0.6198979591836736, 0.2848101265822785],
            (19, 6, True): [0.5433526011560694, 0.2336956521739131],
            (21, 6, True): [0.8909090909090909, 0.29304029304029283],
            (14, 9, True): [-0.5, -0.07272727272727275],
            (19, 4, True): [0.38855421686747016, 0.2530864197530865],
            (18, 2, True): [0.12871287128712872, 0.10457516339869276],
            (14, 2, True): [-0.29059829059829034, 0.025641025641025612],
            (15, 4, True): [-0.19215686274509802, -0.06086956521739138],
            (18, 4, True): [0.16279069767441848, 0.08284023668639053],
            (13, 1, True): [-0.7543859649122808, -0.3719008264462809],
            (18, 3, True): [0.08433734939759038, 0.20394736842105263],
            (16, 6, True): [-0.2666666666666665, -0.014598540145985398],
            (19, 3, True): [0.38601823708206706, 0.03428571428571428],
            (15, 9, True): [-0.6296296296296297, -0.12403100775193798],
            (13, 5, True): [-0.31225296442687717, 0.06060606060606062],
            (15, 1, True): [-0.7534246575342467, -0.47368421052631565],
            (17, 6, True): [-0.03859649122807019, 0.23333333333333334],
            (14, 6, True): [-0.146341463414634, 0.18584070796460178],
            (12, 1, True): [-0.7723577235772359, -0.2857142857142858],
            (15, 3, True): [-0.21481481481481488, 0.08088235294117646],
            (18, 5, True): [0.23262839879154074, 0.02366863905325443],
            (14, 8, True): [-0.561181434599156, -0.25833333333333325],
            (13, 8, True): [-0.6306306306306306, 0.11678832116788318],
            (13, 2, True): [-0.33333333333333326, 0.14999999999999988],
            (17, 5, True): [-0.043771043771043766, 0.04411764705882353],
            (12, 9, True): [-0.5238095238095237, -0.08333333333333337],
            (17, 1, True): [-0.6195652173913049, -0.3984962406015038],
            (12, 3, True): [-0.2982456140350876, -0.0888888888888889],
            (16, 4, True): [-0.1357142857142857, -0.04216867469879521],
            (19, 2, True): [0.457865168539326, 0.1381578947368422],
            (17, 4, True): [-0.20312500000000003, -0.07142857142857147],
            (17, 8, True): [-0.43050847457627134, -0.12592592592592594],
            (17, 7, True): [-0.11224489795918369, 0.12142857142857143],
            (12, 2, True): [-0.30097087378640786, 0.028169014084507],
            (14, 5, True): [-0.23320158102766794, 0.19587628865979384]
        }

        random_policy = create_random_policy(2)
        Q, _ = mc_control_importance_sampling(env,
                                              num_episodes=500000,
                                              behavior_policy=random_policy)
        self.assert_float_dict_almost_equal(expected_q_values, Q, decimal=2)
求解问题:
    评估该策略的好坏。

求解过程
    使用庄家显示的牌面值、玩家当前牌面总分值来确定一个二维状态空间,区分手中有无A分别处理。统计每一牌局下决定状态的庄家和玩家牌面的状态数据,同时计算其最终收获。通过模拟多次牌局,计算每一个状态下的平均值,得到如下图示。

最终结果
    无论玩家手中是否有A牌,该策略在绝大多数情况下各状态价值都较低,只有在玩家拿到21分时状态价值有一个明显的提升。当前牌的分数(12 - 21),低于12时,你可以安全的再叫牌,所以没意义。

"""
import numpy as np
import time
from blackjack import BlackjackEnv

# 定义 21点 的环境
env = BlackjackEnv()


# 显示一个observation的信息
def displayObservation(observation):
    # 将observation进行分解
    score, dealerScore, usableAce = observation
    print("玩家分数为: {}(是否有可用的Ace: {}), 庄家的分数为: {}".format(
        score, usableAce, dealerScore))


# 显示当前observation对应的 策略
def policy(observation):
    # 将observation进行分解
    score, dealerScore, usableSce = observation
Пример #4
0
            returnn = self._visited_states_returns[key]
            new_val = \
                state.value_pi + self.alpha * (returnn - state.value_pi)
            state.value_pi = new_val

    def remember(self, obs, reward):
        if obs not in self._visited_states_returns:
            self._visited_states_returns[obs] = 0

        for obs in self._visited_states_returns:
            self._visited_states_returns[obs] += reward


if __name__ == '__main__':
    env = BlackjackEnv()
    agent = BlackjackAgent()

    total_ep = 100000
    for episode in range(total_ep):
        if episode % 1000 == 0:
            print('episode: {0} / {1}', episode, total_ep)
        # print()
        # print('  ==  EPISODE {0} START  =='.format(episode))

        obs = env.reset()
        agent.reset()

        # print('HAND:  ', end=''); print(env.player_hand)
        # print('STATE pp={0}, ha={1}, dc={2}'.format(obs[0], obs[1], obs[2]))
Пример #5
0
 def setUp(self):
   self.sample_policy = lambda observation: 0 if observation[0] >= 20 else 1
   self.env = BlackjackEnv(test=True)