Пример #1
0
def check_score():
    print(args)
    p_size = cpu_num
    print("use cpu num:{}".format(p_size))

    loss_history = []
    check_deck_id = int(
        args.check_deck_id) if args.check_deck_id is not None else None
    cuda_flg = args.cuda == "True"
    #node_num = int(args.node_num)
    #net = New_Dual_Net(node_num)
    model_name = args.model_name
    existed_output_list = os.listdir(path="./Battle_Result")
    existed_output_list = [
        f for f in existed_output_list
        if os.path.isfile(os.path.join("./Battle_Result", f))
    ]
    result_name = "{}:{}".format(model_name.split(".")[0], args.deck_list)
    same_name_count = len(
        [1 for cell in existed_output_list if result_name in cell])
    print("same_name_count:", same_name_count)
    result_name += "_{:0>3}".format(same_name_count + 1)
    PATH = 'model/' + model_name
    model_dict = torch.load(PATH)
    n_size = model_dict["final_layer.weight"].size()[1]
    net = New_Dual_Net(n_size, hidden_num=args.hidden_num[0])
    net.load_state_dict(model_dict)
    opponent_net = None
    if args.opponent_model_name is not None:
        # opponent_net = New_Dual_Net(node_num)
        o_model_name = args.opponent_model_name
        PATH = 'model/' + o_model_name
        model_dict = torch.load(PATH)
        n_size = model_dict["final_layer.weight"].size()[1]
        opponent_net = New_Dual_Net(n_size, hidden_num=args.hidden_num[1])
        opponent_net.load_state_dict(model_dict)

    if torch.cuda.is_available() and cuda_flg:
        net = net.cuda()
        opponent_net = opponent_net.cuda(
        ) if opponent_net is not None else None
        print("cuda is available.")
    #net.zero_grad()
    deck_sampling_type = False
    if args.deck is not None:
        deck_sampling_type = True
    G = Game()
    net.cpu()
    t3 = datetime.datetime.now()
    if args.greedy_mode is not None:
        p1 = Player(9, True, policy=Dual_NN_GreedyPolicy(origin_model=net))
    else:
        p1 = Player(9,
                    True,
                    policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                        origin_model=net,
                        cuda=cuda_flg,
                        iteration=args.step_iter),
                    mulligan=Min_cost_mulligan_policy())
    #p1 = Player(9, True, policy=AggroPolicy())
    p1.name = "Alice"
    if fixed_opponent is not None:
        if fixed_opponent == "Aggro":
            p2 = Player(9,
                        False,
                        policy=AggroPolicy(),
                        mulligan=Min_cost_mulligan_policy())
        elif fixed_opponent == "OM":
            p2 = Player(9, False, policy=Opponent_Modeling_ISMCTSPolicy())
        elif fixed_opponent == "NR_OM":
            p2 = Player(9,
                        False,
                        policy=Non_Rollout_OM_ISMCTSPolicy(iteration=200),
                        mulligan=Min_cost_mulligan_policy())
        elif fixed_opponent == "ExItGreedy":
            tmp = opponent_net if opponent_net is not None else net
            p2 = Player(9,
                        False,
                        policy=Dual_NN_GreedyPolicy(origin_model=tmp))
        elif fixed_opponent == "Greedy":
            p2 = Player(9,
                        False,
                        policy=New_GreedyPolicy(),
                        mulligan=Simple_mulligan_policy())
        elif fixed_opponent == "Random":
            p2 = Player(9,
                        False,
                        policy=RandomPolicy(),
                        mulligan=Simple_mulligan_policy())
    else:
        assert opponent_net is not None
        p2 = Player(9,
                    False,
                    policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                        origin_model=opponent_net, cuda=cuda_flg),
                    mulligan=Min_cost_mulligan_policy())
    # p2 = Player(9, False, policy=RandomPolicy(), mulligan=Min_cost_mulligan_policy())
    p2.name = "Bob"
    Battle_Result = {}
    deck_list = tuple(map(int, args.deck_list.split(",")))
    print(deck_list)
    test_deck_list = deck_list  # (0,1,4,10,13)
    test_deck_list = tuple(itertools.product(test_deck_list, test_deck_list))
    test_episode_len = evaluate_num  #100
    episode_num = evaluate_num
    match_num = len(test_deck_list)
    manager = Manager()
    shared_array = manager.Array("i",
                                 [0 for _ in range(3 * len(test_deck_list))])
    iter_data = [(p1, p2, shared_array, episode_num, p_id, test_deck_list)
                 for p_id in range(p_size)]
    freeze_support()
    p1_name = p1.policy.name.replace("origin", args.model_name)
    if args.opponent_model_name is not None:
        p2_name = p2.policy.name.replace("origin", args.opponent_model_name)
    else:
        p2_name = p2.policy.name.replace("origin", args.model_name)
    print(p1_name)
    print(p2_name)
    pool = Pool(p_size, initializer=tqdm.set_lock,
                initargs=(RLock(), ))  # 最大プロセス数:8
    memory = pool.map(multi_battle, iter_data)
    pool.close()  # add this.
    pool.terminate()  # add this.
    print("\n" * (match_num + 1))
    memory = list(memory)
    min_WR = 1.0
    Battle_Result = {(deck_id[0], deck_id[1]): \
                         tuple(shared_array[3*index+1:3*index+3]) for index, deck_id in enumerate(test_deck_list)}
    print(shared_array)
    txt_dict = {}
    for key in sorted(list((Battle_Result.keys()))):
        cell = "{}:WR:{:.2%},first_WR:{:.2%}"\
              .format(key,Battle_Result[key][0]/test_episode_len,2*Battle_Result[key][1]/test_episode_len)
        print(cell)
        txt_dict[key] = cell
    print(Battle_Result)
    #     result_name = "{}:{}_{}".format(model_name.split(".")[0],args.deck_list,)
    #     result_name = model_name.split(".")[0] + ":" + args.deck_list + ""
    deck_num = len(deck_list)
    # os.makedirs("Battle_Result", exist_ok=True)
    with open("Battle_Result/" + result_name, "w") as f:
        writer = csv.writer(f, delimiter='\t', lineterminator='\n')
        row = ["{} vs {}".format(p1_name, p2_name)]
        deck_names = [deck_id_2_name[deck_list[i]] for i in range(deck_num)]
        row = row + deck_names
        writer.writerow(row)
        for i in deck_list:
            row = [deck_id_2_name[i]]
            for j in deck_list:
                row.append(Battle_Result[(i, j)])
            writer.writerow(row)
        for key in list(txt_dict.keys()):
            writer.writerow([txt_dict[key]])
Пример #2
0
def run_main():
    import subprocess
    from torch.utils.tensorboard import SummaryWriter
    print(args)
    p_size = cpu_num
    print("use cpu num:{}".format(p_size))
    print("w_d:{}".format(weight_decay))
    std_th = args.th

    loss_history = []

    cuda_flg = args.cuda is not None
    node_num = args.node_num
    net = New_Dual_Net(node_num, rand=args.rand, hidden_num=args.hidden_num[0])
    print(next(net.parameters()).is_cuda)

    if args.model_name is not None:
        PATH = 'model/' + args.model_name
        net.load_state_dict(torch.load(PATH))
    if torch.cuda.is_available() and cuda_flg:
        net = net.cuda()
        print(next(net.parameters()).is_cuda)
    net.zero_grad()
    epoch_interval = args.save_interval
    G = Game()

    episode_len = args.episode_num
    batch_size = args.batch_size
    iteration = args.iteration_num
    epoch_num = args.epoch_num
    import datetime
    t1 = datetime.datetime.now()
    print(t1)
    #print(net)
    prev_net = copy.deepcopy(net)
    optimizer = optim.Adam(net.parameters(), weight_decay=weight_decay)

    date = "{}_{}_{}_{}".format(t1.month, t1.day, t1.hour, t1.minute)
    LOG_PATH = "{}episode_{}nodes_deckids{}_{}/".format(
        episode_len, node_num, args.fixed_deck_ids, date)
    writer = SummaryWriter(log_dir="./logs/" + LOG_PATH)
    TAG = "{}_{}_{}".format(episode_len, node_num, args.fixed_deck_ids)
    early_stopper = EarlyStopping(patience=args.loss_th, verbose=True)
    th = args.WR_th
    last_updated = 0
    reset_count = 0
    min_loss = 100
    loss_th = args.loss_th

    for epoch in range(epoch_num):

        net.cpu()
        prev_net.cpu()
        net.share_memory()

        print("epoch {}".format(epoch + 1))
        t3 = datetime.datetime.now()
        R = New_Dual_ReplayMemory(100000)
        test_R = New_Dual_ReplayMemory(100000)
        episode_len = args.episode_num
        if args.greedy_mode is not None:
            p1 = Player(9,
                        True,
                        policy=Dual_NN_GreedyPolicy(origin_model=net),
                        mulligan=Min_cost_mulligan_policy())
            p2 = Player(9,
                        False,
                        policy=Dual_NN_GreedyPolicy(origin_model=net),
                        mulligan=Min_cost_mulligan_policy())
        else:
            p1 = Player(9,
                        True,
                        policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                            origin_model=net,
                            cuda=False,
                            iteration=args.step_iter),
                        mulligan=Min_cost_mulligan_policy())
            p2 = Player(9,
                        False,
                        policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                            origin_model=net,
                            cuda=False,
                            iteration=args.step_iter),
                        mulligan=Min_cost_mulligan_policy())

        p1.name = "Alice"
        p2.name = "Bob"
        manager = Manager()
        shared_value = manager.Value("i", 0)
        #iter_data = [[p1, p2,shared_value,single_iter,i] for i in range(double_p_size)]
        iter_data = [[p1, p2, shared_value, episode_len, i]
                     for i in range(p_size)]
        freeze_support()
        with Pool(p_size, initializer=tqdm.set_lock,
                  initargs=(RLock(), )) as pool:
            memory = pool.map(multi_preparation, iter_data)
        print("\n" * (p_size + 1))
        del p1
        del p2
        del iter_data
        battle_data = [cell.pop(-1) for cell in memory]

        sum_of_choice = max(
            sum([cell["sum_of_choices"] for cell in battle_data]), 1)
        sum_of_code = max(sum([cell["sum_code"] for cell in battle_data]), 1)
        win_num = sum([cell["win_num"] for cell in battle_data])
        sum_end_turn = sum([cell["end_turn"] for cell in battle_data])
        # [[result_data, result_data,...], [result_data, result_data,...],...]
        # result_data: 1対戦のデータ
        origin_memories = list(itertools.chain.from_iterable(memory))
        print(type(memory), type(origin_memories),
              int(episode_len * args.data_rate), len(origin_memories))
        memories = list(
            itertools.chain.from_iterable(
                origin_memories[:int(episode_len * args.data_rate)]))
        test_memories = list(
            itertools.chain.from_iterable(
                origin_memories[int(episode_len * args.data_rate):]))
        follower_attack_num = 0
        all_able_to_follower_attack = 0
        memos = [memories, test_memories]
        for i in range(2):
            for data in memos[i]:
                after_state = {
                    "hand_ids": data[0]['hand_ids'],
                    "hand_card_costs": data[0]['hand_card_costs'],
                    "follower_card_ids": data[0]['follower_card_ids'],
                    "amulet_card_ids": data[0]['amulet_card_ids'],
                    "follower_stats": data[0]['follower_stats'],
                    "follower_abilities": data[0]['follower_abilities'],
                    "able_to_evo": data[0]['able_to_evo'],
                    "life_data": data[0]['life_data'],
                    "pp_data": data[0]['pp_data'],
                    "able_to_play": data[0]['able_to_play'],
                    "able_to_attack": data[0]['able_to_attack'],
                    "able_to_creature_attack":
                    data[0]['able_to_creature_attack'],
                    "deck_data": data[0]['deck_data']
                }
                before_state = data[2]
                hit_flg = int(1 in data[3]['able_to_choice'][10:35])
                all_able_to_follower_attack += hit_flg
                follower_attack_num += hit_flg * int(data[1] >= 10
                                                     and data[1] <= 34)
                if i == 0:
                    R.push(after_state, data[1], before_state, data[3],
                           data[4])
                else:
                    test_R.push(after_state, data[1], before_state, data[3],
                                data[4])

        print("win_rate:{:.3%}".format(win_num / episode_len))
        print("mean_of_num_of_choice:{:.3f}".format(sum_of_choice /
                                                    sum_of_code))
        print("follower_attack_ratio:{:.3%}".format(
            follower_attack_num / max(1, all_able_to_follower_attack)))
        print("mean end_turn:{:.3f}".format(sum_end_turn / episode_len))
        print("train_data_size:{}".format(len(R.memory)))
        print("test_data_size:{}".format(len(test_R.memory)))
        net.train()
        prev_net = copy.deepcopy(net)

        p, pai, z, states = None, None, None, None
        batch = len(
            R.memory) // batch_num if batch_num is not None else batch_size
        print("batch_size:{}".format(batch))
        pass_flg = False
        if args.multi_train is not None:
            if last_updated > args.max_update_interval - 3:
                net = New_Dual_Net(node_num,
                                   rand=args.rand,
                                   hidden_num=args.hidden_num[0])
                reset_count += 1
                print("reset_num:", reset_count)
            p_size = min(args.cpu_num, 3)
            if cuda_flg:
                torch.cuda.empty_cache()
                net = net.cuda()
            net.share_memory()
            net.train()
            net.zero_grad()
            all_data = R.sample(batch_size,
                                all=True,
                                cuda=cuda_flg,
                                multi=args.multi_sample_num)
            all_states, all_actions, all_rewards = all_data
            memory_len = all_actions.size()[0]
            all_data_ids = list(range(memory_len))
            train_ids = random.sample(all_data_ids, k=memory_len)
            test_data = test_R.sample(batch_size,
                                      all=True,
                                      cuda=cuda_flg,
                                      multi=args.multi_sample_num)
            test_states, test_actions, test_rewards = test_data
            test_memory_len = test_actions.size()[0]
            test_data_range = list(range(test_memory_len))
            test_ids = list(range(test_memory_len))
            min_loss = [0, 0.0, 100, 100, 100]
            best_train_data = [100, 100, 100]
            w_list = args.w_list
            epoch_list = args.epoch_list
            next_net = net  #[copy.deepcopy(net) for k in range(len(epoch_list))]
            #[copy.deepcopy(net) for k in range(len(w_list))]
            #iteration_num = int(memory_len//batch)*iteration #(int(memory_len * 0.85) // batch)*iteration
            weight_scale = 0
            freeze_support()
            print("pid:", os.getpid())
            #cmd = "pgrep --parent {} | xargs kill -9".format(int(os.getpid()))
            #proc = subprocess.call( cmd , shell=True)
            with Pool(p_size, initializer=tqdm.set_lock,
                      initargs=(RLock(), )) as pool:
                for epoch_scale in range(len(epoch_list)):
                    target_net = copy.deepcopy(net)
                    target_net.train()
                    target_net.share_memory()
                    #print("weight_decay:",w_list[weight_scale])
                    print("epoch_num:", epoch_list[epoch_scale])
                    iteration_num = int(
                        memory_len / batch) * epoch_list[epoch_scale]
                    iter_data = [[
                        target_net, all_data, batch,
                        int(iteration_num / p_size), train_ids, i,
                        w_list[weight_scale]
                    ] for i in range(p_size)]
                    torch.cuda.empty_cache()
                    if p_size == 1:
                        loss_data = [multi_train(iter_data[0])]
                    else:
                        freeze_support()
                        loss_data = pool.map(multi_train, iter_data)
                        # pool.terminate()  # add this.
                        # pool.close()  # add this.
                        print("\n" * p_size)
                    sum_of_loss = sum(map(lambda data: data[0], loss_data))
                    sum_of_MSE = sum(map(lambda data: data[1], loss_data))
                    sum_of_CEE = sum(map(lambda data: data[2], loss_data))
                    train_overall_loss = sum_of_loss / iteration_num
                    train_state_value_loss = sum_of_MSE / iteration_num
                    train_action_value_loss = sum_of_CEE / iteration_num
                    print("AVE | Over_All_Loss(train): {:.3f} | MSE: {:.3f} | CEE:{:.3f}" \
                          .format(train_overall_loss, train_state_value_loss, train_action_value_loss))
                    #all_states, all_actions, all_rewards = all_data
                    test_ids_len = len(test_ids)
                    #separate_num = test_ids_len
                    separate_num = test_ids_len // batch
                    states_keys = tuple(
                        test_states.keys())  #tuple(all_states.keys())
                    value_keys = tuple(test_states['values'].keys()
                                       )  #tuple(all_states['values'].keys())
                    normal_states_keys = tuple(
                        set(states_keys) -
                        {'values', 'detailed_action_codes', 'before_states'})

                    action_code_keys = tuple(
                        test_states['detailed_action_codes'].keys())

                    target_net.eval()
                    iteration_num = int(memory_len // batch)
                    partition = test_memory_len // p_size
                    iter_data = [[
                        target_net, test_data, batch,
                        test_data_range[i *
                                        partition:min(test_memory_len -
                                                      1, (i + 1) * partition)],
                        i
                    ] for i in range(p_size)]
                    freeze_support()
                    loss_data = pool.map(multi_eval, iter_data)
                    print("\n" * p_size)
                    sum_of_loss = sum(map(lambda data: data[0], loss_data))
                    sum_of_MSE = sum(map(lambda data: data[1], loss_data))
                    sum_of_CEE = sum(map(lambda data: data[2], loss_data))
                    test_overall_loss = sum_of_loss / p_size
                    test_state_value_loss = sum_of_MSE / p_size
                    test_action_value_loss = sum_of_CEE / p_size
                    pass_flg = test_overall_loss > loss_th

                    print("AVE | Over_All_Loss(test ): {:.3f} | MSE: {:.3f} | CEE:{:.3f}" \
                          .format(test_overall_loss, test_state_value_loss, test_action_value_loss))

                    target_epoch = epoch_list[epoch_scale]
                    print("debug1:", target_epoch)
                    writer.add_scalars(
                        TAG + "/" + 'Over_All_Loss', {
                            'train:' + str(target_epoch): train_overall_loss,
                            'test:' + str(target_epoch): test_overall_loss
                        }, epoch)
                    writer.add_scalars(
                        TAG + "/" + 'state_value_loss', {
                            'train:' + str(target_epoch):
                            train_state_value_loss,
                            'test:' + str(target_epoch): test_state_value_loss
                        }, epoch)
                    writer.add_scalars(
                        TAG + "/" + 'action_value_loss', {
                            'train:' + str(target_epoch):
                            train_action_value_loss,
                            'test:' + str(target_epoch): test_action_value_loss
                        }, epoch)
                    print("debug2:", target_epoch)
                    if min_loss[
                            2] > test_overall_loss and test_overall_loss > train_overall_loss:
                        next_net = target_net
                        min_loss = [
                            epoch_scale, epoch_list[epoch_scale],
                            test_overall_loss, test_state_value_loss,
                            test_action_value_loss
                        ]
                        print("current best:", min_loss)
                print("finish training")
                pool.terminate()  # add this.
                pool.close()  # add this.
                #print(cmd)
                #proc = subprocess.call( cmd , shell=True)
            print("\n" * p_size + "best_data:", min_loss)
            net = next_net  #copy.deepcopy(next_nets[min_loss[0]])
            #del next_net
            loss_history.append(sum_of_loss / iteration)
            p_size = cpu_num

        else:
            prev_optimizer = copy.deepcopy(optimizer)
            if cuda_flg:
                net = net.cuda()
                prev_net = prev_net.cuda()
                optimizer = optim.Adam(net.parameters(),
                                       weight_decay=weight_decay)
                optimizer.load_state_dict(prev_optimizer.state_dict())
                #optimizer = optimizer.cuda()

            current_net = copy.deepcopy(
                net).cuda() if cuda_flg else copy.deepcopy(net)
            all_data = R.sample(batch_size, all=True, cuda=cuda_flg)
            all_states, all_actions, all_rewards = all_data
            states_keys = list(all_states.keys())
            value_keys = list(all_states['values'].keys())
            normal_states_keys = tuple(
                set(states_keys) -
                {'values', 'detailed_action_codes', 'before_states'})
            action_code_keys = list(all_states['detailed_action_codes'].keys())
            memory_len = all_actions.size()[0]
            all_data_ids = list(range(memory_len))
            train_ids = random.sample(all_data_ids, k=int(memory_len * 0.8))
            test_ids = list(set(all_data_ids) - set(train_ids))
            train_num = iteration * len(train_ids)
            nan_count = 0

            for i in tqdm(range(train_num)):
                key = random.sample(train_ids, k=batch)
                states = {}
                states.update({
                    dict_key: torch.clone(all_states[dict_key][key])
                    for dict_key in normal_states_keys
                })
                states['values'] = {sub_key: torch.clone(all_states['values'][sub_key][key]) \
                                    for sub_key in value_keys}
                states['detailed_action_codes'] = {
                    sub_key: torch.clone(
                        all_states['detailed_action_codes'][sub_key][key])
                    for sub_key in action_code_keys
                }
                orig_before_states = all_states["before_states"]
                states['before_states'] = {
                    dict_key: torch.clone(orig_before_states[dict_key][key])
                    for dict_key in normal_states_keys
                }
                states['before_states']['values'] = {sub_key: torch.clone(orig_before_states['values'][sub_key][key]) \
                                                     for sub_key in value_keys}

                actions = all_actions[key]
                rewards = all_rewards[key]

                states['target'] = {'actions': actions, 'rewards': rewards}
                net.zero_grad()
                optimizer.zero_grad()
                with detect_anomaly():
                    p, v, loss = net(states, target=True)
                    if True not in torch.isnan(loss[0]):
                        loss[0].backward()
                        optimizer.step()
                        current_net = copy.deepcopy(net)
                        prev_optimizer = copy.deepcopy(optimizer)
                    else:
                        if nan_count < 5:
                            print("loss:{}".format(nan_count))
                            print(loss)
                        net = current_net
                        optimizer = optim.Adam(net.parameters(),
                                               weight_decay=weight_decay)
                        optimizer.load_state_dict(prev_optimizer.state_dict())
                        nan_count += 1
            print("nan_count:{}/{}".format(nan_count, train_num))
            train_ids_len = len(train_ids)
            separate_num = train_ids_len
            train_objective_loss = 0
            train_MSE = 0
            train_CEE = 0
            nan_batch_num = 0

            for i in tqdm(range(separate_num)):
                key = [train_ids[i]]
                #train_ids[2*i:2*i+2] if 2*i+2 < train_ids_len else train_ids[train_ids_len-2:train_ids_len]
                states = {}
                states.update({
                    dict_key: torch.clone(all_states[dict_key][key])
                    for dict_key in normal_states_keys
                })
                states['values'] = {sub_key: torch.clone(all_states['values'][sub_key][key]) \
                                    for sub_key in value_keys}
                states['detailed_action_codes'] = {
                    sub_key: torch.clone(
                        all_states['detailed_action_codes'][sub_key][key])
                    for sub_key in action_code_keys
                }
                orig_before_states = all_states["before_states"]
                states['before_states'] = {
                    dict_key: torch.clone(orig_before_states[dict_key][key])
                    for dict_key in normal_states_keys
                }
                states['before_states']['values'] = {sub_key: torch.clone(orig_before_states['values'][sub_key][key]) \
                                                     for sub_key in value_keys}

                actions = all_actions[key]
                rewards = all_rewards[key]
                states['target'] = {'actions': actions, 'rewards': rewards}
                del loss
                torch.cuda.empty_cache()
                _, _, loss = net(states, target=True)
                if True in torch.isnan(loss[0]):
                    if nan_batch_num < 5:
                        print("loss")
                        print(loss)
                    separate_num -= 1
                    nan_batch_num += 1
                    continue
                train_objective_loss += float(loss[0].item())
                train_MSE += float(loss[1].item())
                train_CEE += float(loss[2].item())
            separate_num = max(1, separate_num)
            #writer.add_scalar(LOG_PATH + "WIN_RATE", win_num / episode_len, epoch)
            print("nan_batch_ids:{}/{}".format(nan_batch_num, train_ids_len))
            print(train_MSE, separate_num)
            train_objective_loss /= separate_num
            train_MSE /= separate_num
            train_CEE /= separate_num
            print("AVE(train) | Over_All_Loss: {:.3f} | MSE: {:.3f} | CEE:{:.3f}" \
                  .format(train_objective_loss,train_MSE,train_CEE))
            test_ids_len = len(test_ids)
            batch_len = 512 if 512 < test_ids_len else 128
            separate_num = test_ids_len // batch_len
            #separate_num = test_ids_len
            test_objective_loss = 0
            test_MSE = 0
            test_CEE = 0
            for i in tqdm(range(separate_num)):
                #key = [test_ids[i]]#test_ids[i*batch_len:min(test_ids_len,(i+1)*batch_len)]
                key = test_ids[i * batch_len:min(test_ids_len, (i * 1) *
                                                 batch_len)]
                states = {}
                states.update({
                    dict_key: torch.clone(all_states[dict_key][key])
                    for dict_key in normal_states_keys
                })
                states['values'] = {sub_key: torch.clone(all_states['values'][sub_key][key]) \
                                    for sub_key in value_keys}
                states['detailed_action_codes'] = {
                    sub_key: torch.clone(
                        all_states['detailed_action_codes'][sub_key][key])
                    for sub_key in action_code_keys
                }
                orig_before_states = all_states["before_states"]
                states['before_states'] = {
                    dict_key: torch.clone(orig_before_states[dict_key][key])
                    for dict_key in normal_states_keys
                }
                states['before_states']['values'] = {sub_key: torch.clone(orig_before_states['values'][sub_key][key]) \
                                                     for sub_key in value_keys}
                actions = all_actions[key]
                rewards = all_rewards[key]
                states['target'] = {'actions': actions, 'rewards': rewards}
                del loss
                torch.cuda.empty_cache()
                p, v, loss = net(states, target=True)
                if True in torch.isnan(loss[0]):
                    separate_num -= 1
                    continue
                test_objective_loss += float(loss[0].item())
                test_MSE += float(loss[1].item())
                test_CEE += float(loss[2].item())
            print("")
            for batch_id in range(1):
                print("states:{}".format(batch_id))
                print("p:{}".format(p[batch_id]))
                print("pi:{}".format(actions[batch_id]))
                print("v:{} z:{}".format(v[batch_id], rewards[batch_id]))
            del p, v
            del actions
            del all_data
            del all_states
            del all_actions
            del all_rewards
            separate_num = max(1, separate_num)
            print(test_MSE, separate_num)
            test_objective_loss /= separate_num
            test_MSE /= separate_num
            test_CEE /= separate_num
            writer.add_scalars(LOG_PATH + 'Over_All_Loss', {
                'train': train_objective_loss,
                'test': test_objective_loss
            }, epoch)
            writer.add_scalars(LOG_PATH + 'MSE', {
                'train': train_MSE,
                'test': test_MSE
            }, epoch)
            writer.add_scalars(LOG_PATH + 'CEE', {
                'train': train_CEE,
                'test': test_CEE
            }, epoch)
            print("AVE | Over_All_Loss: {:.3f} | MSE: {:.3f} | CEE:{:.3f}" \
                  .format(test_objective_loss, test_MSE, test_CEE))

            loss_history.append(test_objective_loss)
            if early_stopper.validate(test_objective_loss): break

        print("evaluate step")
        del R
        del test_R
        net.cpu()
        prev_net.cpu()
        print("evaluate ready")
        if pass_flg:
            min_WR = 0
            WR = 0
            print("evaluation of this epoch is passed.")
        else:
            if args.greedy_mode is not None:
                p1 = Player(9,
                            True,
                            policy=Dual_NN_GreedyPolicy(origin_model=net),
                            mulligan=Min_cost_mulligan_policy())
                p2 = Player(9,
                            False,
                            policy=Dual_NN_GreedyPolicy(origin_model=net),
                            mulligan=Min_cost_mulligan_policy())
            else:
                p1 = Player(9,
                            True,
                            policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                                origin_model=net,
                                cuda=False,
                                iteration=args.step_iter),
                            mulligan=Min_cost_mulligan_policy())

                p2 = Player(9,
                            False,
                            policy=New_Dual_NN_Non_Rollout_OM_ISMCTSPolicy(
                                origin_model=prev_net,
                                cuda=False,
                                iteration=args.step_iter),
                            mulligan=Min_cost_mulligan_policy())

            p1.name = "Alice"
            p2.name = "Bob"
            test_deck_list = tuple(
                100, ) if deck_flg is None else deck_flg  # (0,1,4,10,13)
            test_deck_list = tuple(
                itertools.product(test_deck_list, test_deck_list))
            test_episode_len = evaluate_num  #100
            match_num = len(test_deck_list)

            manager = Manager()
            shared_array = manager.Array(
                "i", [0 for _ in range(3 * len(test_deck_list))])
            #iter_data = [(p1, p2,test_episode_len, p_id ,cell) for p_id,cell in enumerate(deck_pairs)]
            iter_data = [(p1, p2, shared_array, test_episode_len, p_id,
                          test_deck_list) for p_id in range(p_size)]

            freeze_support()
            with Pool(p_size, initializer=tqdm.set_lock,
                      initargs=(RLock(), )) as pool:
                _ = pool.map(multi_battle, iter_data)
            print("\n" * (match_num + 1))
            del iter_data
            del p1
            del p2
            match_num = len(test_deck_list)  #if deck_flg is None else p_size
            min_WR = 1.0
            Battle_Result = {(deck_id[0], deck_id[1]): \
                                 tuple(shared_array[3*index+1:3*index+3]) for index, deck_id in enumerate(test_deck_list)}
            #for memory_cell in memory:
            #    #Battle_Result[memory_cell[0]] = memory_cell[1]
            #    #min_WR = min(min_WR,memory_cell[1])
            print(shared_array)
            result_table = {}
            for key in sorted(list((Battle_Result.keys()))):
                cell_WR = Battle_Result[key][0] / test_episode_len
                cell_first_WR = 2 * Battle_Result[key][1] / test_episode_len
                print("{}:train_WR:{:.2%},first_WR:{:.2%}"\
                      .format(key,cell_WR,cell_first_WR))
                if key[::-1] not in result_table:
                    result_table[key] = cell_WR
                else:
                    result_table[key[::-1]] = (result_table[key[::-1]] +
                                               cell_WR) / 2
            print(result_table)
            min_WR = min(result_table.values())
            WR = sum(result_table.values()) / len(result_table.values())
            del result_table

        win_flg = False
        #WR=1.0
        writer.add_scalars(TAG + "/" + 'win_rate', {
            'mean': WR,
            'min': min_WR,
            'threthold': th
        }, epoch)
        if WR >= th or (len(deck_flg) > 1 and min_WR > 0.5):
            win_flg = True
            print("new_model win! WR:{:.1%} min:{:.1%}".format(WR, min_WR))
        else:
            del net
            net = None
            net = prev_net
            print("new_model lose... WR:{:.1%}".format(WR))
        torch.cuda.empty_cache()
        t4 = datetime.datetime.now()
        print(t4 - t3)
        # or (epoch_num > 4 and (epoch+1) % epoch_interval == 0 and epoch+1 < epoch_num)
        if win_flg and last_updated > 0:
            PATH = "model/{}_{}_{}in{}_{}_nodes.pth".format(
                t1.month, t1.day, epoch + 1, epoch_num, node_num)
            if torch.cuda.is_available() and cuda_flg:
                PATH = "model/{}_{}_{}in{}_{}_nodes_cuda.pth".format(
                    t1.month, t1.day, epoch + 1, epoch_num, node_num)
            torch.save(net.state_dict(), PATH)
            print("{} is saved.".format(PATH))
            last_updated = 0
        else:
            last_updated += 1
            print("last_updated:", last_updated)
            if last_updated > args.max_update_interval:
                print("update finished.")
                break
        if len(loss_history) > epoch_interval - 1:
            #UB = np.std(loss_history[-epoch_interval:-1])/(np.sqrt(2*epoch) + 1)
            UB = np.std(loss_history) / (np.sqrt(epoch) + 1)
            print("{:<2} std:{}".format(epoch, UB))
            if UB < std_th:
                break

    writer.close()
    #pool.terminate()
    #pool.close()
    print('Finished Training')

    PATH = "model/{}_{}_finished_{}_nodes.pth".format(t1.month, t1.day,
                                                      node_num)
    if torch.cuda.is_available() and cuda_flg:
        PATH = "model/{}_{}_finished_{}_nodes_cuda.pth".format(
            t1.month, t1.day, node_num)
    torch.save(net.state_dict(), PATH)
    print("{} is saved.".format(PATH))
    t2 = datetime.datetime.now()
    print(t2)
    print(t2 - t1)