def TrainOneRound(self, afterstate_num, alpha=.1): """ Q learning following Sutton and Barto 6.5 Input: afterstate: the afterstate of target_policy to start trainng with Note that the opponent makes a move first, then the target policy. """ afterstate = State(from_base10=afterstate_num) while not afterstate.is_terminal(): beforestate_num = self.random_policy.move( afterstate.get_num()) # opponent makes a move beforestate = State(from_base10=beforestate_num) if beforestate.is_terminal(): r = beforestate.get_reward() self.target_policy.v_dict[afterstate.get_num()] += alpha * ( r - self.target_policy.v_dict[afterstate.get_num()]) break else: s_primes = beforestate.legal_afterstates() candidates = [] for s_prime in s_primes: r = State(from_base10=s_prime).get_reward() q = self.target_policy.v_dict[s_prime] candidates.append(r + q) if beforestate.turn == 1: self.target_policy.v_dict[ afterstate.get_num()] += alpha * ( max(candidates) - self.target_policy.v_dict[afterstate.get_num()]) else: self.target_policy.v_dict[ afterstate.get_num()] += alpha * ( min(candidates) - self.target_policy.v_dict[afterstate.get_num()]) afterstate_num = self.random_policy.move(beforestate_num) afterstate = State(from_base10=afterstate_num)
def test_legal_afterstates(): # full board, no legal afterstate state = State(board=[[2, 2, 2], [1, 1, 1], [1, 2, 2]], turn=1) assert not state.legal_afterstates() # one legal afterstate state = State(board=[[2, 2, 2], [1, 1, 1], [1, 0, 2]], turn=1) assert state.legal_afterstates() == [ State([[2, 2, 2], [1, 1, 1], [1, 1, 2]], turn=2).get_num() ] # 3 legal afterstates state = State(board=[[2, 2, 2], [1, 1, 1], [0, 0, 0]], turn=2) temp = state.legal_afterstates() assert len(temp) == 3 num1 = State(board=[[2, 2, 2], [1, 1, 1], [2, 0, 0]]).get_num() num2 = State(board=[[2, 2, 2], [1, 1, 1], [0, 2, 0]]).get_num() num3 = State(board=[[2, 2, 2], [1, 1, 1], [0, 0, 2]]).get_num() assert set(temp) == set([num1, num2, num3])
def ValueIteration(self, theta=0.01): t = time.time() while True: delta = 0 for num in range(int('1' + '0' * 9, 3), int('2' * 10, 3) + 1): v = self.policy_1.v_dict[num] state = State(from_base10=num) if state.is_terminal(): self.policy_1.v_dict[num] = state.get_reward() else: opponent_afterstate = State( from_base10=self.policy_2.move_dict[num]) if opponent_afterstate.is_terminal(): self.policy_1.v_dict[ num] = opponent_afterstate.get_reward() else: s_prime_choices = opponent_afterstate.legal_afterstates( ) if state.turn == 2: vi_update = max([ self.policy_1.v_dict[x] for x in s_prime_choices ]) else: vi_update = min([ self.policy_1.v_dict[x] for x in s_prime_choices ]) self.policy_1.v_dict[num] = vi_update delta = max(delta, np.abs(v - self.policy_1.v_dict[num])) self.i_epoch += 1 if delta < theta: print('Value function has converged!') print("Trained %i epochs so far." % self.i_epoch) self.policy_ever_changed = self.policy_1.be_greedy() pickle.dump((self.policy_1, self.i_epoch), open(self.write_path, "wb")) break if time.time() - t > 10: t = time.time() print("Trained %i epochs so far." % self.i_epoch) self.policy_ever_changed = self.policy_1.be_greedy() pickle.dump((self.policy_1, self.i_epoch), open(self.write_path, "wb"))
def PolicyImprovement(self): """ Policy Improvement following Sutton Barto 4.3 Against rush opponent, with afterstates """ self.policy_stable = True for num in range(int('1' + '0' * 9, 3), int('2' * 10, 3) + 1): state = State(from_base10=num) if not state.is_terminal(): old_action_num = self.policy_1.move_dict[num] # get the best afterstates afterstate_nums = state.legal_afterstates() afterstate_values = [ self.policy_1.v_dict[x] for x in afterstate_nums ] best = np.argmax( afterstate_values) if state.turn == 1 else np.argmin( afterstate_values) self.policy_1.move_dict[num] = afterstate_nums[best] if old_action_num != self.policy_1.move_dict[num]: self.policy_stable = False self.policy_ever_changed = True