示例#1
0
    def get_policy_graph_values(self) -> common.Graph2DValues:
        policy: TabularPolicy = self.algorithm.target_policy

        x_list: list[int] = []
        y_list: list[float] = []
        for s, state in enumerate(self.environment.states):
            if not state.is_terminal:
                x_list.append(state.capital)
                action: Action = policy.get_action(s)  # type: ignore
                y_list.append(float(action.stake))
                # print(state.capital, v[state])
        x_values = np.array(x_list, dtype=int)
        y_values = np.array(y_list, dtype=float)

        g: common.Graph2DValues = common.Graph2DValues()
        g.x_series = common.Series(title=g.x_label, values=x_values)
        g.graph_series = [common.Series(title=g.y_label, values=y_values)]
        g.title = "Policy"
        g.x_label = "Capital"
        g.y_label = "Stake"
        g.x_min = 0.0
        g.x_max = 100.0
        g.y_min = 0.0
        g.y_max = None
        g.has_grid = True
        g.has_legend = False
        return g
示例#2
0
 def create(self) -> Comparison:
     return Comparison(
         environment_parameters=EnvironmentParameters(),
         comparison_settings=Settings(),
         breakdown_parameters=common.BreakdownParameters(
             breakdown_type=common.BreakdownType.RETURN_BY_EPISODE),
         settings_list=[
             # Settings(algorithm_parameters=common.AlgorithmParameters(
             #     algorithm_type=common.AlgorithmType.EXPECTED_SARSA,
             #     alpha=0.9
             # )),
             # Settings(algorithm_parameters=common.AlgorithmParameters(
             #     algorithm_type=common.AlgorithmType.VQ,
             #     alpha=0.2
             # )),
             Settings(algorithm_parameters=common.AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.TABULAR_Q_LEARNING,
                 alpha=0.5)),
             Settings(algorithm_parameters=common.AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.TABULAR_SARSA,
                 alpha=0.5)),
         ],
         # settings_list_multiprocessing=common.ParallelContextType.SPAWN,
         graph2d_values=common.Graph2DValues(
             has_grid=True,
             has_legend=True,
             moving_average_window_size=19,
             y_min=-100,
             y_max=0,
         ),
         grid_view_parameters=common.GridViewParameters(show_demo=True,
                                                        show_q=True))
示例#3
0
 def create(self) -> Comparison:
     return Comparison(
         environment_parameters=EnvironmentParameters(),
         comparison_settings=common.Settings(
             runs=10,
             # runs_multiprocessing=common.ParallelContextType.FORK_GLOBAL,
             training_episodes=100,
             # display_every_step=True,
         ),
         breakdown_parameters=common.BreakdownAlgorithmByAlpha(
             breakdown_type=common.BreakdownType.RETURN_BY_ALPHA,
             alpha_min=0.1,
             alpha_max=1.0,
             alpha_step=0.1,
             algorithm_type_list=[
                 common.AlgorithmType.TABULAR_EXPECTED_SARSA,
                 common.AlgorithmType.TABULAR_VQ,
                 common.AlgorithmType.TABULAR_Q_LEARNING,
                 common.AlgorithmType.TABULAR_SARSA
             ],
         ),
         settings_list_multiprocessing=common.ParallelContextType.
         FORK_GLOBAL,
         graph2d_values=common.Graph2DValues(
             has_grid=True,
             has_legend=True,
             y_min=-140,
             y_max=0,
         ),
         grid_view_parameters=common.GridViewParameters(show_policy=True,
                                                        show_q=True))
示例#4
0
 def create(self) -> Comparison:
     # TODO: Make work, once multiprocessing
     return Comparison(
         environment_parameters=EnvironmentParameters(),
         comparison_settings=common.Settings(
             runs=1,
             training_episodes=100_000,
         ),
         breakdown_parameters=common.BreakdownAlgorithmByAlpha(
             breakdown_type=common.BreakdownType.RETURN_BY_ALPHA,
             alpha_min=0.1,
             alpha_max=1.0,
             alpha_step=0.05,
             algorithm_type_list=[
                 common.AlgorithmType.TABULAR_EXPECTED_SARSA,
                 # common.AlgorithmType.VQ,
                 common.AlgorithmType.TABULAR_Q_LEARNING,
                 common.AlgorithmType.TABULAR_SARSA
             ],
         ),
         settings_list_multiprocessing=common.ParallelContextType.SPAWN,
         graph2d_values=common.Graph2DValues(
             has_grid=True,
             has_legend=True,
             y_min=-140,
             y_max=0,
         ),
     )
 def create(self):
     return Comparison(
         # environment_parameters=EnvironmentParameters(),
         comparison_settings=Settings(),
         settings_list=[
             Settings(
                 algorithm_parameters=common.AlgorithmParameters(
                     theta=0.00001,  # accuracy of policy_evaluation
                     algorithm_type=common.AlgorithmType.
                     DP_VALUE_ITERATION_V,
                     verbose=True), ),
         ],
         graph2d_values=common.Graph2DValues(),
     )
示例#6
0
 def create(self):
     return Comparison(
         environment_parameters=EnvironmentParameters(),
         comparison_settings=Settings(),
         breakdown_parameters=common.BreakdownParameters(
             breakdown_type=common.BreakdownType.RMS_BY_EPISODE, ),
         settings_list=[
             Settings(algorithm_parameters=AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.TABULAR_TD_0,
                 alpha=0.05,
             )),
             Settings(algorithm_parameters=AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.TABULAR_TD_0,
                 alpha=0.1,
             )),
             Settings(algorithm_parameters=AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.TABULAR_TD_0,
                 alpha=0.15,
             )),
             Settings(algorithm_parameters=AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.MC_CONSTANT_ALPHA,
                 alpha=0.01,
             )),
             Settings(algorithm_parameters=AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.MC_CONSTANT_ALPHA,
                 alpha=0.02,
             )),
             Settings(algorithm_parameters=AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.MC_CONSTANT_ALPHA,
                 alpha=0.03,
             )),
             Settings(algorithm_parameters=AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.MC_CONSTANT_ALPHA,
                 alpha=0.04,
             )),
         ],
         settings_list_multiprocessing=common.ParallelContextType.
         FORK_GLOBAL,
         graph2d_values=common.Graph2DValues(has_grid=True,
                                             has_legend=True,
                                             y_min=0.0,
                                             y_max=0.25),
         grid_view_parameters=common.GridViewParameters(
             show_result=True,
             show_v=True,
         ),
     )
示例#7
0
 def create(self):
     # TODO: Problem with the first step not learning and crashing?
     #  Try grids.TRACK_1 for example (3rd position crash)
     return Comparison(
         environment_parameters=EnvironmentParameters(
             grid=grids.TRACK_3,
             extra_reward_for_failure=-100.0,  # 0.0 in problem statement
             skid_probability=0.1,
         ),
         comparison_settings=Settings(),
         breakdown_parameters=common.BreakdownParameters(
             breakdown_type=common.BreakdownType.RETURN_BY_EPISODE
         ),
         settings_list=[
             # Settings(algorithm_parameters=common.AlgorithmParameters(
             #     algorithm_type=common.AlgorithmType.EXPECTED_SARSA,
             #     alpha=0.9
             # )),
             # Settings(algorithm_parameters=common.AlgorithmParameters(
             #     algorithm_type=common.AlgorithmType.VQ,
             #     alpha=0.2
             # )),
             # Settings(algorithm_parameters=common.AlgorithmParameters(
             #     algorithm_type=common.AlgorithmType.Q_LEARNING,
             #     alpha=0.5
             # )),
             Settings(algorithm_parameters=common.AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.MC_CONTROL_OFF_POLICY,
                 initial_q_value=-40.0,
             )),
         ],
         # settings_list_multiprocessing=common.ParallelContextType.SPAWN,
         graph2d_values=common.Graph2DValues(
             has_grid=True,
             has_legend=True,
             moving_average_window_size=101,
             y_min=-200,
             y_max=0
         ),
         grid_view_parameters=common.GridViewParameters(
             grid_view_type=common.GridViewType.POSITION,
             show_demo=True,
             show_trail=True
         )
     )
示例#8
0
 def create(self) -> Comparison:
     return Comparison(
         environment_parameters=EnvironmentParameters(
             random_wind=self._random_wind, ),
         comparison_settings=Settings(),
         breakdown_parameters=common.BreakdownParameters(
             breakdown_type=common.BreakdownType.EPISODE_BY_TIMESTEP, ),
         settings_list=[
             Settings(algorithm_parameters=common.AlgorithmParameters(
                 algorithm_type=common.AlgorithmType.TABULAR_SARSA,
                 alpha=0.5,
                 initial_q_value=0.0,
             ))
         ],
         graph2d_values=common.Graph2DValues(
             has_grid=True,
             has_legend=True,
         ),
         grid_view_parameters=common.GridViewParameters(
             show_demo=True,
             show_q=True,
         ))
示例#9
0
    def get_state_graph_values(self) -> common.Graph2DValues:
        x_list: list[int] = []
        y_list: list[float] = []
        for s, state in enumerate(self.environment.states):
            if not state.is_terminal:
                x_list.append(state.capital)
                y_list.append(self.algorithm.V[s])
                # print(state.capital, v[state])
        x_values = np.array(x_list, dtype=int)
        y_values = np.array(y_list, dtype=float)

        g: common.Graph2DValues = common.Graph2DValues()
        g.x_series = common.Series(title=g.x_label, values=x_values)
        g.graph_series = [common.Series(title=g.y_label, values=y_values)]
        g.title = "V(s)"
        g.x_label = "Capital"
        g.y_label = "V(s)"
        g.x_min = 0.0
        g.x_max = 100.0
        g.y_min = 0.0
        g.y_max = 1.0
        g.has_grid = True
        g.has_legend = False
        return g