def q3_learn_on(gamma, alpha, ep_num, ep_len, eps, pond: FishPond, Q_asterisk):
    """
    creates episodes for Q - learning with action choosing using the learned Q
    :param ep_num: the number of episodes to create
    :param ep_len: the length of each episode to create
    :param eps: epsilon rate for the greedy policy
    :param Q_asterisk: optimal Q
    """
    err, Q = q3_learn_off(gamma, alpha, 1, ep_len, eps, pond, Q_asterisk)
    for e in range(ep_num - 1):
        pond.reset()
        print(e)
        for j in range(ep_len):
            x, y = pond.current_state
            n_a = np.argmax(Q[x, y, :])
            action = actions[n_a]
            reached_end = pond.perform_action(action)
            if reached_end:
                r_s_a = 0
            else:
                r_s_a = -1
            x_tag, y_tag = pond.current_state
            # temp_Q[x, y, n_a] = Q[x, y, n_a] + alpha * (r_s_a + gamma * np.max(Q[x, y, :]) - Q[x, y, n_a])
            Q[x, y, n_a] = Q[x, y, n_a] + alpha * (
                r_s_a + gamma * np.max(Q[x_tag, y_tag, :]) - Q[x, y, n_a])
            if reached_end:
                break
        err = np.append(err, get_MSE(Q, Q_asterisk))
    return err, Q
def q0_sample_policy(pond: FishPond):
    # Sample code for running a policy and plotting the trajectory
    for i in range(30):
        action = the_right_policy()
        reached_end = pond.perform_action(action)
        pond.plot()
        if reached_end:
            break
    print('Done')
    plt.savefig('Q0_' + pond_name + '.png')
    plt.show()
def q3_play_a_game(pond: FishPond, pi, section):
    path_len = np.abs(pond.start_state[0] -
                      pond.end_state[0]) + np.abs(pond.start_state[1] -
                                                  pond.end_state[1])
    pond.reset()
    for i in range(3 * path_len):
        action = pi[pond.current_state[0], pond.current_state[1]]
        action = actions[action]
        reached_end = pond.perform_action(action)
        pond.plot()
        if reached_end:
            break
    print('Done')
    plt.savefig('Q3_' + pond_name + '_' + section + '_game.png')
    plt.show()
def q1_greedy_policy(pond: FishPond):
    """given a fishpond, runs a game with respect to the greedy policy as implemented in the function
    the_greedy_policy"""
    path_len = np.abs(pond.start_state[0] -
                      pond.end_state[0]) + np.abs(pond.start_state[1] -
                                                  pond.end_state[1])
    # outer loop for number of episodes
    for e in range(1):
        pond.reset()
        # inner loop for an episode
        for i in range(3 * path_len):
            action = the_greedy_policy(pond.current_state, pond.end_state)
            reached_end = pond.perform_action(action)
            pond.plot()
            if reached_end:
                break
    print('Done')
    plt.savefig('Q1_' + pond_name + '.png')
    plt.show()
def q2_greedy_policy(pond: FishPond, gamma):
    """given a fishpond, runs a single game with the policy computed with the policy iteration procedure implemented
    in the function q2_learn_q_phi"""
    path_len = np.abs(pond.start_state[0] -
                      pond.end_state[0]) + np.abs(pond.start_state[1] -
                                                  pond.end_state[1])
    pi, _ = q2_learn_q_phi(pond, gamma)
    pi = pi.astype(np.int64)
    pond.reset()
    for i in range(3 * path_len):
        action = pi[pond.current_state[0], pond.current_state[1]]
        action = actions[action]
        reached_end = pond.perform_action(action)
        pond.plot()
        if reached_end:
            break
    print('Done')
    plt.savefig('Q2_' + pond_name + '.png')
    plt.show()
def q3_learn_off(gamma, alpha, ep_num, ep_len, eps, pond: FishPond,
                 Q_asterisk):
    """
    creates episodes for Q - learning with action choosing with greedy policy
    :param ep_num: the number of episodes to create
    :param ep_len: the length of each episode to create
    :param eps: epsilon rate for the greedy policy
    """
    Q = np.zeros((pond.pond_size[0], pond.pond_size[1], 4))
    # Q = np.ones((pond.pond_size[0], pond.pond_size[1], 4)) * (-1)
    # Q[pond.end_state[0], pond.end_state[1]] = 0
    err = np.empty(0)
    for e in range(ep_num):
        pond.reset()
        print(e)
        i = 0
        for j in range(ep_len):
            # while not reached_end:
            x, y = pond.current_state
            a1 = the_greedy_policy(pond.current_state, pond.end_state)
            a2 = np.random.choice(actions)
            action = np.random.choice([a1, a2], 1, p=[1 - eps, eps])[0]
            n_a = numed_actions[action]
            reached_end = pond.perform_action(action)
            if reached_end:
                r_s_a = 0
            else:
                r_s_a = -1
            x_tag, y_tag = pond.current_state
            Q[x, y, n_a] = Q[x, y, n_a] + alpha * (
                r_s_a + gamma * np.max(Q[x_tag, y_tag, :]) - Q[x, y, n_a])
            i += 1
            if reached_end:
                break
        err = np.append(err, get_MSE(Q, Q_asterisk))
    return err, Q