def create(self) -> Comparison: return Comparison( environment_parameters=EnvironmentParameters(), comparison_settings=Settings(), breakdown_parameters=common.BreakdownParameters( breakdown_type=common.BreakdownType.RETURN_BY_EPISODE), settings_list=[ # Settings(algorithm_parameters=common.AlgorithmParameters( # algorithm_type=common.AlgorithmType.EXPECTED_SARSA, # alpha=0.9 # )), # Settings(algorithm_parameters=common.AlgorithmParameters( # algorithm_type=common.AlgorithmType.VQ, # alpha=0.2 # )), Settings(algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType.TABULAR_Q_LEARNING, alpha=0.5)), Settings(algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType.TABULAR_SARSA, alpha=0.5)), ], # settings_list_multiprocessing=common.ParallelContextType.SPAWN, graph2d_values=common.Graph2DValues( has_grid=True, has_legend=True, moving_average_window_size=19, y_min=-100, y_max=0, ), grid_view_parameters=common.GridViewParameters(show_demo=True, show_q=True))
def __init__(self): super().__init__() self._max_cars: int = 20 # problem statement = 20 self._environment_parameters = EnvironmentParameters( max_cars=self._max_cars, extra_rules=True, # change this for extra rules in book as per challenge ) self._comparison_settings = common.Settings( gamma=0.9, policy_parameters=common.PolicyParameters( policy_type=common.PolicyType.TABULAR_DETERMINISTIC, ), algorithm_parameters=common.AlgorithmParameters( theta=0.1 # accuracy of policy_evaluation ), display_every_step=True, ) self._graph3d_values = common.Graph3DValues( x_label="Cars at 1st location", y_label="Cars at 2nd location", z_label="V(s)", x_min=0, x_max=self._max_cars, y_min=0, y_max=self._max_cars, ) self._grid_view_parameters = common.GridViewParameters( grid_view_type=common.GridViewType.JACKS, show_result=True, show_policy=True, )
def create(self): return Comparison( # environment_parameters=EnvironmentParameters(), comparison_settings=Settings(), settings_list=[ Settings( algorithm_parameters=common.AlgorithmParameters( theta=0.00001, # accuracy of policy_evaluation algorithm_type=common.AlgorithmType. DP_VALUE_ITERATION_V, verbose=True), ), ], graph2d_values=common.Graph2DValues(), )
def create(self) -> Comparison: return Comparison( environment_parameters=EnvironmentParameters(), comparison_settings=Settings(), settings_list=[ Settings( algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType.MC_PREDICTION_Q, first_visit=True, verbose=True, derive_v_from_q_as_final_step=True), training_episodes=100_000, ), ], graph3d_values=self._graph3d_values, grid_view_parameters=self._grid_view_parameters, )
def create(self): # TODO: Problem with the first step not learning and crashing? # Try grids.TRACK_1 for example (3rd position crash) return Comparison( environment_parameters=EnvironmentParameters( grid=grids.TRACK_3, extra_reward_for_failure=-100.0, # 0.0 in problem statement skid_probability=0.1, ), comparison_settings=Settings(), breakdown_parameters=common.BreakdownParameters( breakdown_type=common.BreakdownType.RETURN_BY_EPISODE ), settings_list=[ # Settings(algorithm_parameters=common.AlgorithmParameters( # algorithm_type=common.AlgorithmType.EXPECTED_SARSA, # alpha=0.9 # )), # Settings(algorithm_parameters=common.AlgorithmParameters( # algorithm_type=common.AlgorithmType.VQ, # alpha=0.2 # )), # Settings(algorithm_parameters=common.AlgorithmParameters( # algorithm_type=common.AlgorithmType.Q_LEARNING, # alpha=0.5 # )), Settings(algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType.MC_CONTROL_OFF_POLICY, initial_q_value=-40.0, )), ], # settings_list_multiprocessing=common.ParallelContextType.SPAWN, graph2d_values=common.Graph2DValues( has_grid=True, has_legend=True, moving_average_window_size=101, y_min=-200, y_max=0 ), grid_view_parameters=common.GridViewParameters( grid_view_type=common.GridViewType.POSITION, show_demo=True, show_trail=True ) )
def create(self) -> Comparison: graph3d_values = self._graph3d_values grid_view_parameters = self._grid_view_parameters return Comparison( environment_parameters=self._environment_parameters, comparison_settings=Settings(), settings_list=[ Settings(algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType.DP_VALUE_ITERATION_V, verbose=True, theta=0.1 # accuracy of policy_evaluation )), ], graph3d_values=graph3d_values, grid_view_parameters=grid_view_parameters, )
def create(self) -> Comparison: graph3d_values = self._graph3d_values grid_view_parameters = self._grid_view_parameters return Comparison( environment_parameters=self._environment_parameters, comparison_settings=Settings(), settings_list=[ Settings( algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType.DP_POLICY_ITERATION_Q, verbose=True, derive_v_from_q_as_final_step=True, theta=0.1 # accuracy of policy_evaluation ), ), ], graph3d_values=graph3d_values, grid_view_parameters=self._grid_view_parameters, )
def create(self) -> Comparison: return Comparison( environment_parameters=EnvironmentParameters( random_wind=self._random_wind, ), comparison_settings=Settings(), breakdown_parameters=common.BreakdownParameters( breakdown_type=common.BreakdownType.EPISODE_BY_TIMESTEP, ), settings_list=[ Settings(algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType.TABULAR_SARSA, alpha=0.5, initial_q_value=0.0, )) ], graph2d_values=common.Graph2DValues( has_grid=True, has_legend=True, ), grid_view_parameters=common.GridViewParameters( show_demo=True, show_q=True, ))
def create(self) -> Comparison: # self._comparison_settings.training_episodes = 100_000 comparison = Comparison( environment_parameters=EnvironmentParameters(), comparison_settings=Settings(), settings_list=[ Settings( algorithm_parameters=common.AlgorithmParameters( algorithm_type=common.AlgorithmType. MC_CONTROL_ON_POLICY, first_visit=True, exploring_starts=True, derive_v_from_q_as_final_step=True, verbose=True, ), training_episodes=100_000, ), ], graph3d_values=self._graph3d_values, grid_view_parameters=self._grid_view_parameters, ) return comparison