Esempio n. 1
0
def EMD_between_two_models_on_board(model1_name,
                                    input_plains_num_1,
                                    i1,
                                    model2_name,
                                    input_plains_num_2,
                                    i2,
                                    board1,
                                    board2,
                                    width=6,
                                    height=6,
                                    use_gpu=True):
    model_file_1 = f'/home/lirontyomkin/AlphaZero_Gomoku/models/{model1_name}/current_policy_{i1}.model'
    policy_1 = PolicyValueNet(width,
                              height,
                              model_file=model_file_1,
                              input_plains_num=input_plains_num_1,
                              use_gpu=use_gpu)

    model_file_2 = f'/home/lirontyomkin/AlphaZero_Gomoku/models/{model2_name}/current_policy_{i2}.model'
    policy_2 = PolicyValueNet(width,
                              height,
                              model_file=model_file_2,
                              input_plains_num=input_plains_num_2,
                              use_gpu=use_gpu)

    board_current_state1 = board1.current_state(last_move=True,
                                                is_random_last_turn=False)
    board_current_state2 = board2.current_state(last_move=True,
                                                is_random_last_turn=False)

    acts_policy1, probas_policy1 = zip(*policy_1.policy_value_fn(board1)[0])
    acts_policy2, probas_policy2 = zip(*policy_2.policy_value_fn(board2)[0])

    dist_matrix = generate_matrix_dist_metric(width)

    distance = emd(np.asarray(probas_policy1, dtype='float64'),
                   np.asarray(probas_policy2, dtype='float64'), dist_matrix)

    return distance
def threshold_cutoff_policy(board,
                            model_name,
                            input_plains_num,
                            model_iteration,
                            open_path_threshold,
                            opponent_weight,
                            rounding=-1,
                            cutoff_threshold=0.05,
                            model_file=None,
                            is_random_last_turn=False,
                            board_name=" "):

    # print(f"{model_name}_{model_iteration}, {open_path_threshold}, {opponent_weight}, {cutoff_threshold}")

    if is_random_last_turn:
        model_namee = model_name + "_random"
    else:
        model_namee = model_name

    width, height = board.width, board.height

    if model_file is None:
        model_file = f'/home/lirontyomkin/AlphaZero_Gomoku/models/{model_name}/current_policy_{model_iteration}.model'

    policy = PolicyValueNet(width,
                            height,
                            model_file=model_file,
                            input_plains_num=input_plains_num)
    board.set_is_random_last_turn(is_random_last_turn=is_random_last_turn,
                                  player=board.get_current_player())

    if is_random_last_turn:
        board.set_random_seed(model_iteration)

    acts_policy, probas_policy = zip(*policy.policy_value_fn(board)[0])

    # AlphaZero gives some probability to locations that are not available for some reason
    if np.sum(probas_policy) != 0:
        probas_policy = probas_policy / np.sum(probas_policy)

    move_probs_policy = np.zeros(width * height)
    move_probs_policy[list(acts_policy)] = probas_policy
    move_probs_policy = move_probs_policy.reshape(width, height)
    move_probs_policy = np.flipud(move_probs_policy)

    if cutoff_threshold < 1:
        move_probs_policy[move_probs_policy < cutoff_threshold] = 0
        heatmap_save_path = f"/home/lirontyomkin/AlphaZero_Gomoku/models_heatmaps/cutoff_threshold_{cutoff_threshold}/{model_namee}/iteration_{model_iteration}/"

    elif isinstance(cutoff_threshold, int):
        move_probs_policy = keep_k_squares(move_probs_policy, cutoff_threshold,
                                           board.height, board.width)
        heatmap_save_path = f"/home/lirontyomkin/AlphaZero_Gomoku/models_heatmaps/keep_{cutoff_threshold}_squares/{model_namee}/iteration_{model_iteration}/"

    move_probs_policy = normalize_matrix(move_probs_policy, board, rounding)

    # if not os.path.exists(heatmap_save_path):
    #     os.makedirs(heatmap_save_path)

    # #make sure you save once:
    # if open_path_threshold == -1 and opponent_weight == 0:
    #     save_trimmed_policy_heatmap(move_probs_policy, model_name, board, board_name, heatmap_save_path)

    return move_probs_policy