def test_epsilon_greedy_policy_sampler(self):

        epsilon: float = 0.1
        action_value_fcn_dict: Dict[int, Dict[str, Union[float, int]]] = dict()
        action_value_fcn_dict[0] = dict(a=1, b=3, c=-1)
        action_value_fcn_dict[1] = dict(a=3, b=1)

        logger.info(get_pretty_json_str(action_value_fcn_dict))

        epsilon_greedy_policy_sampler = EpsilonGreedyPolicySampler(
            epsilon, action_value_fcn_dict)

        N = 10000
        empirical_action_value_fcn_dict: Dict[int, Dict[str, float]] = dict()
        for state in epsilon_greedy_policy_sampler.get_all_states():
            logger.info(f"state: {state}")

            action_value_fcn_for_one_state: Dict[str, float] = defaultdict(int)
            for _ in range(N):
                action_value_fcn_for_one_state[
                    epsilon_greedy_policy_sampler.get_action(state)] += 1.0

            for action in action_value_fcn_for_one_state:
                action_value_fcn_for_one_state[action] /= N

            empirical_action_value_fcn_dict[
                state] = action_value_fcn_for_one_state

        logger.info(get_pretty_json_str(empirical_action_value_fcn_dict))
    def test_generate_app_navigation_graph(self):

        shortest_path_length = 10

        app_nav_graph: AppNavigationGraph = generate_app_navigation_graph(shortest_path_length)

        figure: Figure
        axis: Axes

        figure, axis = plt.subplots()
        nx.draw(app_nav_graph.directed_graph, axis=axis, pos=app_nav_graph.pos)
        nx.draw_networkx_labels(
            app_nav_graph.directed_graph, axis=axis, pos=app_nav_graph.pos
        )

        axis.axis([-0.1, 1.1, -0.1, 1.1])

        figure.show()

        json_obj: dict = app_navigation_graph_to_json_obj(app_nav_graph)

        logger.debug("json_obj from graph")
        logger.debug(get_pretty_json_str(json_obj))

        # state_transition_graph_json_file_path = os.path.join(
        #    TestAppNavGraph.test_data_directory,
        #    "large_deterministic_state_transition_graph.json",
        # )

        # with open(state_transition_graph_json_file_path, 'w') as fout:
        #    fout.write(get_pretty_json_str(json_obj) + "\n")

        self.assertEqual(True, True)
    def test_policy_with_simple_example(self):
        # test ProbabilisticPolicy

        state_action_probability_dict_dict = dict()

        state_action_probability_dict_dict[0] = dict(a=0.3, b=0.7)
        state_action_probability_dict_dict[1] = dict(a=0.8, b=0.2)

        policy = ProbabilisticPolicy(state_action_probability_dict_dict)

        logger.info(get_pretty_json_str(state_action_probability_dict_dict))

        N = 10000

        empirical_state_action_probability_dict_dict: Dict[int, Dict[
            str, float]] = dict()
        for state in state_action_probability_dict_dict:
            action_probability_dict = defaultdict(int)
            for _ in range(N):
                action_probability_dict[policy.get_action(state)] += 1

            for action in action_probability_dict:
                action_probability_dict[action] /= N

            empirical_state_action_probability_dict_dict[
                state] = action_probability_dict

        logger.info(
            get_pretty_json_str(empirical_state_action_probability_dict_dict))

        state: str
        action: str
        action_probability_dict: Dict[str, float]

        for state, action_probability_dict in empirical_state_action_probability_dict_dict.items(
        ):
            for action, probability in action_probability_dict.items():
                self.assertAlmostEqual(
                    state_action_probability_dict_dict[state][action],
                    probability, 1)
    def test_with_random_walk_environment(self) -> None:

        # RandomState(MT19937(SeedSequence(123456789)))
        numpy.random.seed(760104)

        num_nodes: int = TestOneStepTemporalDifferenceAlgorithm.num_nodes

        random_walk_environment: RandomWalkEnvironment = RandomWalkEnvironment(
            num_nodes)
        random_policy: EquallyProbableRandomPolicySampler = EquallyProbableRandomPolicySampler(
            ("left", "right"))

        gamma: float = 1.0

        if TestOneStepTemporalDifferenceAlgorithm.diminishing_step_size:

            def learning_rate_strategy(iter_num: int,
                                       episode_num: int) -> float:
                return 0.1 / (1.0 + iter_num * 0.001)
        else:
            learning_rate_strategy: float = 0.1

        one_step_temporal_difference_algorithm: OneStepTemporalDifferenceAlgorithm = (
            OneStepTemporalDifferenceAlgorithm(gamma, learning_rate_strategy,
                                               0.5))

        figure: Figure
        axis: Axes
        figure, axis = plt.subplots()

        max_num_episodes: int = 300
        max_num_transitions_per_episode: int = 100
        max_num_iters: int = max_num_episodes * 100

        total_num_episodes: int = 0
        for _ in range(5):
            one_step_temporal_difference_algorithm.predict(
                random_walk_environment,
                random_policy,
                max_num_episodes,
                max_num_iters,
                max_num_transitions_per_episode,
            )

            total_num_episodes += max_num_episodes

            state_value_fcn_dict: Dict[
                Any,
                float] = one_step_temporal_difference_algorithm.get_state_value_fcn_dict(
                )

            logger.debug(get_pretty_json_str(state_value_fcn_dict))

            random_walk_environment.draw_state_value_fcn_values(
                axis, state_value_fcn_dict, "o-", label=total_num_episodes)

        axis.legend()

        figure.show()

        logger.debug(
            get_pretty_json_str(
                one_step_temporal_difference_algorithm.state_value_fcn_dict))

        node: int
        err_list: List[float] = list()
        for node in range(1, num_nodes + 1):
            estimated_state_value_fcn_value: float = one_step_temporal_difference_algorithm.state_value_fcn_dict[
                node]
            true_state_value_fcn_value: float = float(node) / (num_nodes + 1)

            logger.debug(
                f"{estimated_state_value_fcn_value} ~ {true_state_value_fcn_value}"
            )

            err: float = estimated_state_value_fcn_value - true_state_value_fcn_value
            logging.debug(err)
            err_list.append(err)

            self.assertAlmostEqual(estimated_state_value_fcn_value,
                                   true_state_value_fcn_value, 1)

        logger.debug(err_list)
        max_abs_error: float = np.abs(err_list).max()
        logger.debug(max_abs_error)
        self.assertAlmostEqual(max_abs_error, 0.04792, 5)
    def test_neighborhood_mapping_with_simple_example(self) -> None:

        deterministic_directed_graph_environment_json_dict: Dict[str, Any]
        with open(TestNeighborhoodMapping.
                  simple_deterministic_state_transition_graph_json_file_path
                  ) as fin:
            deterministic_directed_graph_environment_json_dict = json.load(fin)

        deterministic_directed_graph_environment: DeterministicDirectedGraphEnvironment = (
            create_deterministic_directed_graph_environment_from_json_obj(
                deterministic_directed_graph_environment_json_dict))

        logger.debug(
            f"environment start state: {deterministic_directed_graph_environment.start_state}"
        )

        with open(TestNeighborhoodMapping.simple_action_sequence_json_file_path
                  ) as fin:
            action_sequence_json_obj: dict = json.load(fin)

        action_value_fcn_dict_from_action_sequence: Dict[Any, Dict[
            Any, float]] = (action_sequence_json_obj_to_action_value_fcn_dict(
                action_sequence_json_obj))

        logger.debug(
            get_pretty_json_str(deterministic_directed_graph_environment.
                                state_transition_graph_dict))
        logger.debug(
            get_pretty_json_str(deterministic_directed_graph_environment.
                                state_transition_reward_dict))
        logger.debug(get_pretty_json_str(action_sequence_json_obj))
        logger.debug(
            get_pretty_json_str(action_value_fcn_dict_from_action_sequence))

        gamma: float = 0.9
        learning_rate: float = 0.1
        epsilon: float = 0.1
        default_state_value_fcn_value: float = 0.0

        action_sequence_policy: EpsilonGreedyPolicySampler = EpsilonGreedyPolicySampler(
            epsilon, action_value_fcn_dict_from_action_sequence)

        one_step_temporal_difference_alg: OneStepTemporalDifferenceAlgorithm = OneStepTemporalDifferenceAlgorithm(
            gamma, learning_rate, default_state_value_fcn_value)

        one_step_temporal_difference_alg.predict(
            deterministic_directed_graph_environment,
            action_sequence_policy,
            100,
            10000,
            100,
            True,
            True,
            True,
        )

        figure: Figure
        axis: Axes

        figure, axis = plt.subplots()
        one_step_temporal_difference_alg.plot_value_fcn_history(axis)
        figure.show()

        logger.info(
            get_pretty_json_str(
                one_step_temporal_difference_alg.state_value_fcn_dict))

        default_action_value_fcn_value: float = np.array(
            list(one_step_temporal_difference_alg.state_value_fcn_dict.values(
            )), float).mean()

        one_step_q_learning_alg: OneStepQLearningAlgorithm = OneStepQLearningAlgorithm(
            gamma, learning_rate, 0.0, default_action_value_fcn_value)

        if TestNeighborhoodMapping.do_initialize:
            initialize_action_value_fcn_from_state_value_fcn(
                one_step_q_learning_alg.action_value_fcn_dict,
                one_step_temporal_difference_alg.state_value_fcn_dict,
            )

        one_step_q_learning_alg.learn(
            deterministic_directed_graph_environment,
            100,
            10000,
            100,
            10,
            True,
            False,
            False,
        )

        figure, axis = plt.subplots()
        one_step_q_learning_alg.plot_value_fcn_history(axis)
        figure.show()

        logger.info(
            get_pretty_json_str(one_step_q_learning_alg.action_value_fcn_dict))

        self.assertTrue(True)