def get_unfiltered_action_space( self, output_action_space: AttentionActionSpace) -> DiscreteActionSpace: if isinstance(self.num_bins_per_dimension, int): self.num_bins_per_dimension = [self.num_bins_per_dimension ] * output_action_space.shape[0] # create a discrete to linspace map to ease the extraction of attention actions discrete_to_box = BoxDiscretization( [n + 1 for n in self.num_bins_per_dimension], self.force_int_bins) discrete_to_box.get_unfiltered_action_space( BoxActionSpace(output_action_space.shape, output_action_space.low, output_action_space.high), ) rows, cols = self.num_bins_per_dimension start_ind = [ i * (cols + 1) + j for i in range(rows + 1) if i < rows for j in range(cols + 1) if j < cols ] end_ind = [i + cols + 2 for i in start_ind] self.target_actions = [ np.array([ discrete_to_box.target_actions[start], discrete_to_box.target_actions[end] ]) for start, end in zip(start_ind, end_ind) ] return super().get_unfiltered_action_space(output_action_space)
def test_filter(): filter = BoxDiscretization(9) # passing an output space that is wrong with pytest.raises(ValueError): filter.validate_output_action_space(DiscreteActionSpace(10)) # 1 dimensional box output_space = BoxActionSpace(1, 5, 15) input_space = filter.get_unfiltered_action_space(output_space) assert filter.target_actions == [[5.], [6.25], [7.5], [8.75], [10.], [11.25], [12.5], [13.75], [15.]] assert input_space.actions == list(range(9)) action = 2 result = filter.filter(action) assert result == [7.5] assert output_space.contains(result) # 2 dimensional box filter = BoxDiscretization(3) output_space = BoxActionSpace(2, 5, 15) input_space = filter.get_unfiltered_action_space(output_space) assert filter.target_actions == [[5., 5.], [5., 10.], [5., 15.], [10., 5.], [10., 10.], [10., 15.], [15., 5.], [15., 10.], [15., 15.]] assert input_space.actions == list(range(9)) action = 2 result = filter.filter(action) assert result == [5., 15.] assert output_space.contains(result)
agent_params.network_wrappers['main'].learning_rate = 0.0001 agent_params.network_wrappers['main'].input_embedders_parameters = { "screen": InputEmbedderParameters(input_rescaling={'image': 3.0}) } agent_params.network_wrappers['main'].heads_parameters = [ DuelingQHeadParameters() ] agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000) # slave_agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000) agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.1, 1000000) agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4) agent_params.output_filter = \ OutputFilter( action_filters=OrderedDict([ ('discretization', BoxDiscretization(num_bins_per_dimension=4, force_int_bins=True)) ]), is_a_reference_filter=False ) ############### # Environment # ############### env_params = StarCraft2EnvironmentParameters(level='CollectMineralShards') env_params.feature_screen_maps_to_use = [5] env_params.feature_minimap_maps_to_use = [5] ######## # Test # ########
schedule_params.evaluation_steps = EnvironmentEpisodes(1) schedule_params.heatup_steps = EnvironmentSteps(1000) ######### # Agent # ######### agent_params = DDQNAgentParameters() agent_params.network_wrappers['main'].learning_rate = 0.00025 agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()] agent_params.network_wrappers['main'].middleware_parameters.scheme = MiddlewareScheme.Empty agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2), 1/math.sqrt(2)] agent_params.network_wrappers['main'].clip_gradients = 10 agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4) agent_params.network_wrappers['main'].input_embedders_parameters['forward_camera'] = \ agent_params.network_wrappers['main'].input_embedders_parameters.pop('observation') agent_params.output_filter = OutputFilter() agent_params.output_filter.add_action_filter('discretization', BoxDiscretization(5)) ############### # Environment # ############### env_params = CarlaEnvironmentParameters() env_params.level = 'town1' vis_params = VisualizationParameters() vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()] vis_params.dump_mp4 = False graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params, vis_params=vis_params)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20) schedule_params.evaluation_steps = EnvironmentEpisodes(1) schedule_params.heatup_steps = EnvironmentSteps(1000) ######### # Agent # ######### agent_params = DDQNAgentParameters() agent_params.network_wrappers['main'].learning_rate = 0.00025 agent_params.network_wrappers['main'].heads_parameters = \ [DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))] agent_params.network_wrappers[ 'main'].middleware_parameters.scheme = MiddlewareScheme.Empty agent_params.network_wrappers['main'].clip_gradients = 10 agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4) agent_params.network_wrappers['main'].input_embedders_parameters['forward_camera'] = \ agent_params.network_wrappers['main'].input_embedders_parameters.pop('observation') agent_params.output_filter = OutputFilter() agent_params.output_filter.add_action_filter('discretization', BoxDiscretization(5)) ############### # Environment # ############### env_params = CarlaEnvironmentParameters() graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, schedule_params=schedule_params, vis_params=VisualizationParameters())