Ejemplo n.º 1
0
    def pick_a_card(self):
        #确认桌上牌的数量和自己坐的位置相符
        assert (self.cards_on_table[0] + len(self.cards_on_table) -
                1) % 4 == self.place
        #utility datas
        suit = self.decide_suit()  #inherited from MrRandom
        cards_dict = MrGreed.gen_cards_dict(self.cards_list)
        #如果别无选择
        if cards_dict.get(suit) != None and len(cards_dict[suit]) == 1:
            choice = cards_dict[suit][0]
            if print_level >= 1:
                log("I have no choice but %s" % (choice))
            return choice

        if print_level >= 1:
            log("my turn: %s, %s" % (self.cards_on_table, self.cards_list))
        fmt_scores = MrGreed.gen_fmt_scores(
            self.scores
        )  #in absolute order, because self.scores is in absolute order
        #log("fmt scores: %s"%(fmt_scores))
        d_legal = {
            c: 0
            for c in MrGreed.gen_legal_choice(suit, cards_dict,
                                              self.cards_list)
        }  #dict of legal choice
        sce_gen = ScenarioGen(self.place,
                              self.history,
                              self.cards_on_table,
                              self.cards_list,
                              number=MrRandTree.N_SAMPLE,
                              METHOD1_PREFERENCE=100)
        for cards_list_list in sce_gen:
            cards_lists = [None, None, None, None]
            cards_lists[self.place] = copy.copy(self.cards_list)
            for i in range(3):
                cards_lists[(self.place + i + 1) % 4] = cards_list_list[i]
            if print_level >= 1:
                log("get scenario: %s" % (cards_lists))
            cards_on_table_copy = copy.copy(self.cards_on_table)
            gamestate = GameState(cards_lists, fmt_scores, cards_on_table_copy,
                                  self.place)
            searcher = mcts(iterationLimit=200, explorationConstant=100)
            searcher.search(initialState=gamestate)
            for action, node in searcher.root.children.items():
                if print_level >= 1:
                    log("%s: %s" % (action, node))
                d_legal[action] += node.totalReward / node.numVisits
        if print_level >= 1:
            log("d_legal: %s" % (d_legal))
            input("press any key to continue...")
        best_choice = MrGreed.pick_best_from_dlegal(d_legal)
        return best_choice
Ejemplo n.º 2
0
    def action(self):

        ## key used for importing/exporting from/to json
        format_list = []
        format_list.append(sorted(self.state["black"]))
        format_list.append(sorted(self.state["white"]))
        format_list.append(self.currentPlayer)

        format_key = str(format_list)
        '''
        ### Search action from model

        import json

        with open('MCTS/model.txt') as json_file:
            if json_file:
                model = json.load(json_file)
                if format_key in model:
                    # print("found this move in the model!!")
                    action = model[format_key]
                    if action[0] == "BOOM":
                        return ("BOOM", tuple(action[1]))
                    else:
                        return ("MOVE", action[1], tuple(action[2]), tuple(action[3]))
        '''

        currentState = MCTS(self.state, self.currentPlayer)

        agent = mcts(timeLimit=1000)
        action = agent.search(initialState=currentState)
        '''
        ### Export result

        import json

        with open('MCTS/model.txt') as json_file:
            if json_file:
                json_decoded = json.load(json_file)

        json_decoded[format_key] = action

        with open('MCTS/model.txt', 'w') as json_file:
            json.dump(json_decoded, json_file)
        '''

        return action
    def pick_a_card_complete_info(self):
        #确认桌上牌的数量和自己坐的位置相符
        assert (self.cards_on_table[0] + len(self.cards_on_table) -
                1) % 4 == self.place

        #initialize gamestate
        #assert self.cards_list==self.cards_remain[self.place]
        gamestate = GameState(self.cards_remain, self.scores,
                              self.cards_on_table, self.place)

        #mcts
        suit = self.decide_suit()
        cards_dict = MrGreed.gen_cards_dict(self.cards_list)
        legal_choice = MrGreed.gen_legal_choice(suit, cards_dict,
                                                self.cards_list)
        searchnum = self.mcts_b + self.mcts_k * len(legal_choice)
        searcher = mcts(iterationLimit=searchnum,
                        rolloutPolicy=self.pv_policy,
                        explorationConstant=MCTS_EXPL)
        searcher.search(initialState=gamestate)
        d_legal_temp = {
            action: node.totalReward / node.numVisits
            for action, node in searcher.root.children.items()
        }
        #save data for train
        if self.train_mode:
            value_max = max(d_legal_temp.values())
            target_p = torch.zeros(52)
            legal_mask = torch.zeros(52)
            for k, v in d_legal_temp.items():
                target_p[ORDER_DICT[k]] = math.exp(BETA * (v - value_max))
                legal_mask[ORDER_DICT[k]] = 1
            target_p /= target_p.sum()
            target_v = torch.tensor(value_max - gamestate.getReward())
            netin = MrZeroTreeSimple.prepare_ohs(self.cards_remain,
                                                 self.cards_on_table,
                                                 self.scores, self.place)
            self.train_datas.append((netin, target_p, target_v, legal_mask))
        best_choice = MrGreed.pick_best_from_dlegal(d_legal_temp)
        return best_choice
Ejemplo n.º 4
0
import sys
sys.path.append('..')
from ConnectXEnv.ConnectX_Game import ConnectX_MCTS
from MCTS.mcts import mcts

# 定义MCTS类
my_mcts = mcts(ConnectX_MCTS(width=7,height=6,win_length=4))

# 加载模型
my_mcts.load_model('./temp_root.pkl')

# 自我对战学习
my_mcts.self_play(numEps=100, numMCTSSims=200, display=False)

# 保存
my_mcts.save_model('./temp_root.pkl')
Ejemplo n.º 5
0
    def pick_a_card(self):
        #确认桌上牌的数量和自己坐的位置相符
        assert (self.cards_on_table[0] + len(self.cards_on_table) -
                1) % 4 == self.place
        #utility datas
        suit = self.decide_suit()  #inherited from MrRandom
        cards_dict = MrGreed.gen_cards_dict(self.cards_list)
        #如果别无选择
        if cards_dict.get(suit) != None and len(cards_dict[suit]) == 1:
            choice = cards_dict[suit][0]
            if print_level >= 1:
                log("I have no choice but %s" % (choice))
            return choice
        if len(self.cards_list) == 1:
            return self.cards_list[0]
        if print_level >= 1:
            log("my turn: %s, %s, %s" %
                (self.cards_on_table, self.cards_list, self.scores))

        legal_choice = MrGreed.gen_legal_choice(suit, cards_dict,
                                                self.cards_list)
        #imp_cards=self.select_interact_cards(legal_choice)

        if self.sample_k >= 0:
            sce_num = self.sample_b + int(self.sample_k * len(self.cards_list))
            assert self.sample_b >= 0 and sce_num > 0
            sce_gen = ScenarioGen(self.place,
                                  self.history,
                                  self.cards_on_table,
                                  self.cards_list,
                                  number=sce_num)
            scenarios = [i for i in sce_gen]
        else:
            assert self.sample_k < 0 and self.sample_b < 0
            sce_gen = ImpScenarioGen(self.place,
                                     self.history,
                                     self.cards_on_table,
                                     self.cards_list,
                                     suit,
                                     level=-1 * self.sample_k,
                                     num_per_imp=-1 * self.sample_b)
            #imp_cards=imp_cards,num_per_imp=-1*self.sample_b)
            scenarios = sce_gen.get_scenarios()

        #cards_played,scores_stage,void_info_stage=self.public_info()
        scenarios_weight = []
        cards_lists_list = []
        for cll in scenarios:
            if print_level >= 3:
                log("analyzing: %s" % (cll))
            cards_lists = [None, None, None, None]
            cards_lists[self.place] = copy.copy(self.cards_list)
            for i in range(3):
                cards_lists[(self.place + i + 1) % 4] = cll[i]

            #scenarios_weight.append(1.0)
            #scenarios_weight.append(self.possi_rectify(cards_lists,suit,cards_played,scores_stage,void_info_stage))
            scenarios_weight.append(self.possi_rectify(cards_lists, suit))

            #scenarios_weight[-1]*=self.int_equ_class(cards_lists,suit)
            #scenarios_weight[-1]*=self.int_equ_class_li(cards_lists,suit)

            cards_lists_list.append(cards_lists)
        else:
            del scenarios
        if print_level >= 2:
            log("scenarios_weight: %s" %
                (["%.4f" % (i) for i in scenarios_weight], ))
        weight_sum = sum(scenarios_weight)
        scenarios_weight = [i / weight_sum for i in scenarios_weight]
        assert (sum(scenarios_weight) -
                1) < 1e-6, "scenario weight is %.8f: %s" % (
                    sum(scenarios_weight),
                    scenarios_weight,
                )

        #legal_choice=MrGreed.gen_legal_choice(suit,cards_dict,self.cards_list)
        d_legal = {c: 0 for c in legal_choice}
        searchnum = self.mcts_b + self.mcts_k * len(legal_choice)
        for i, cards_lists in enumerate(cards_lists_list):
            #initialize gamestate
            gamestate = GameState(cards_lists, self.scores,
                                  self.cards_on_table, self.place)
            #mcts
            if self.mcts_k >= 0:
                searcher = mcts(iterationLimit=searchnum,
                                rolloutPolicy=self.pv_policy,
                                explorationConstant=MCTS_EXPL)
                searcher.search(initialState=gamestate)
                for action, node in searcher.root.children.items():
                    d_legal[action] += scenarios_weight[
                        i] * node.totalReward / node.numVisits
            elif self.mcts_k == -1:
                input("not using")
                netin = MrZeroTree.prepare_ohs(cards_lists,
                                               self.cards_on_table,
                                               self.scores, self.place)
                with torch.no_grad():
                    p, _ = self.pv_net(netin.to(self.device))
                p_legal = [(c, p[ORDER_DICT[c]]) for c in legal_choice]
                p_legal.sort(key=lambda x: x[1], reverse=True)
                d_legal[p_legal[0][0]] += 1
            elif self.mcts_k == -2:
                input("not using")
                assert self.sample_b == 1 and self.sample_k == 0 and self.mcts_b == 0, "This is raw-policy mode"
                netin = MrZeroTree.prepare_ohs_post_rect(
                    cards_lists, self.cards_on_table, self.scores, self.place)
                with torch.no_grad():
                    p, _ = self.pv_net(netin.to(self.device))
                p_legal = [(c, p[ORDER_DICT[c]]) for c in legal_choice]
                p_legal.sort(key=lambda x: x[1], reverse=True)
                return p_legal[0][0]
            else:
                raise Exception("reserved")

        if print_level >= 2:
            log("d_legal: %s" %
                ({k: float("%.1f" % (v))
                  for k, v in d_legal.items()}))
            #time.sleep(5+10*random.random())

        best_choice = MrGreed.pick_best_from_dlegal(d_legal)
        """
        if len(legal_choice)>1:
            g=self.g_aux[self.place]
            g.cards_on_table=copy.copy(self.cards_on_table)
            g.history=copy.deepcopy(self.history)
            g.scores=copy.deepcopy(self.scores)
            g.cards_list=copy.deepcopy(self.cards_list)
            gc=g.pick_a_card()

            netin=MrZeroTree.prepare_ohs(cards_lists,self.cards_on_table,self.scores,self.place)
            with torch.no_grad():
                p,_=self.pv_net(netin.to(self.device))

            p_legal=[(c,p[ORDER_DICT[c]].item()) for c in legal_choice if c[0]==gc[0]]
            v_max=max((v for c,v in p_legal))
            p_legal=[(c,1+BETA_POST_RECT*(v-v_max)/BETA) for c,v in p_legal]
            p_legal.sort(key=lambda x:x[1],reverse=True)
            p_choice=(v for c,v in p_legal if c==gc).__next__()
            possi=max(p_choice,0.2)
            log("greed, %s, %s, %s, %.4f"%(gc,suit,gc==p_legal[0][0],possi),logfile="stat_sim.txt",fileonly=True)

            p_legal=[(c,p[ORDER_DICT[c]].item()) for c in legal_choice if c[0]==best_choice[0]]
            v_max=max((v for c,v in p_legal))
            p_legal=[(c,1+BETA_POST_RECT*(v-v_max)/BETA) for c,v in p_legal]
            p_legal.sort(key=lambda x:x[1],reverse=True)
            p_choice=(v for c,v in p_legal if c==best_choice).__next__()
            possi=max(p_choice,0.2)
            log("zerotree, %s, %s, %s, %.4f"%(best_choice,suit,best_choice==p_legal[0][0],possi),logfile="stat_sim.txt",fileonly=True)"""

        return best_choice
 def pick_a_card(self):
     #input("in pick a card")
     #确认桌上牌的数量和自己坐的位置相符
     assert (self.cards_on_table[0] + len(self.cards_on_table) -
             1) % 4 == self.place
     #utility datas
     suit = self.decide_suit()  #inherited from MrRandom
     cards_dict = MrGreed.gen_cards_dict(self.cards_list)
     #如果别无选择
     if cards_dict.get(suit) != None and len(cards_dict[suit]) == 1:
         choice = cards_dict[suit][0]
         if print_level >= 1:
             log("I have no choice but %s." % (choice))
         return choice
     if len(self.cards_list) == 1:
         if print_level >= 1:
             log("There is only one card left.")
         return self.cards_list[0]
     if print_level >= 1:
         log("my turn: %s, %s, %s" %
             (self.cards_on_table, self.cards_list, self.scores))
     #生成Scenario
     sce_num = self.sample_b + int(self.sample_k * len(self.cards_list))
     sce_gen = ScenarioGen(self.place,
                           self.history,
                           self.cards_on_table,
                           self.cards_list,
                           number=sce_num)
     cards_lists_list = []
     for cll in sce_gen:
         cards_lists = [None, None, None, None]
         cards_lists[self.place] = copy.copy(self.cards_list)
         for i in range(3):
             cards_lists[(self.place + i + 1) % 4] = cll[i]
         cards_lists_list.append(cards_lists)
     #MCTS并对Scenario平均
     legal_choice = MrGreed.gen_legal_choice(suit, cards_dict,
                                             self.cards_list)
     d_legal = {c: 0 for c in legal_choice}
     searchnum = self.mcts_b + self.mcts_k * len(legal_choice)
     for i, cards_lists in enumerate(cards_lists_list):
         #initialize gamestate
         gamestate = GameState(cards_lists, self.scores,
                               self.cards_on_table, self.place)
         #mcts
         if self.mcts_k >= 0:
             searcher = mcts(iterationLimit=searchnum,
                             rolloutPolicy=self.pv_policy,
                             explorationConstant=MCTS_EXPL)
             searcher.search(initialState=gamestate)
             for action, node in searcher.root.children.items():
                 d_legal[action] += (node.totalReward /
                                     node.numVisits) / len(cards_lists_list)
         elif self.mcts_k == -1:
             input("not using this mode")
             netin = MrZeroTreeSimple.prepare_ohs(cards_lists,
                                                  self.cards_on_table,
                                                  self.scores, self.place)
             with torch.no_grad():
                 p, _ = self.pv_net(netin.to(self.device))
             p_legal = [(c, p[ORDER_DICT[c]]) for c in legal_choice]
             p_legal.sort(key=lambda x: x[1], reverse=True)
             d_legal[p_legal[0][0]] += 1
         else:
             raise Exception("reserved")
     #挑选出最好的并返回
     #d_legal={k:v/ for k,v in d_legal.items()}
     best_choice = MrGreed.pick_best_from_dlegal(d_legal)
     return best_choice