def get_unfiltered_action_space(
            self,
            output_action_space: AttentionActionSpace) -> DiscreteActionSpace:
        if isinstance(self.num_bins_per_dimension, int):
            self.num_bins_per_dimension = [self.num_bins_per_dimension
                                           ] * output_action_space.shape[0]

        # create a discrete to linspace map to ease the extraction of attention actions
        discrete_to_box = BoxDiscretization(
            [n + 1 for n in self.num_bins_per_dimension], self.force_int_bins)
        discrete_to_box.get_unfiltered_action_space(
            BoxActionSpace(output_action_space.shape, output_action_space.low,
                           output_action_space.high), )

        rows, cols = self.num_bins_per_dimension
        start_ind = [
            i * (cols + 1) + j for i in range(rows + 1) if i < rows
            for j in range(cols + 1) if j < cols
        ]
        end_ind = [i + cols + 2 for i in start_ind]
        self.target_actions = [
            np.array([
                discrete_to_box.target_actions[start],
                discrete_to_box.target_actions[end]
            ]) for start, end in zip(start_ind, end_ind)
        ]

        return super().get_unfiltered_action_space(output_action_space)
def test_filter():
    filter = BoxDiscretization(9)

    # passing an output space that is wrong
    with pytest.raises(ValueError):
        filter.validate_output_action_space(DiscreteActionSpace(10))

    # 1 dimensional box
    output_space = BoxActionSpace(1, 5, 15)
    input_space = filter.get_unfiltered_action_space(output_space)

    assert filter.target_actions == [[5.], [6.25], [7.5], [8.75], [10.],
                                     [11.25], [12.5], [13.75], [15.]]
    assert input_space.actions == list(range(9))

    action = 2

    result = filter.filter(action)
    assert result == [7.5]
    assert output_space.contains(result)

    # 2 dimensional box
    filter = BoxDiscretization(3)
    output_space = BoxActionSpace(2, 5, 15)
    input_space = filter.get_unfiltered_action_space(output_space)

    assert filter.target_actions == [[5., 5.], [5., 10.], [5., 15.], [10., 5.],
                                     [10., 10.], [10., 15.], [15., 5.],
                                     [15., 10.], [15., 15.]]
    assert input_space.actions == list(range(9))

    action = 2

    result = filter.filter(action)
    assert result == [5., 15.]
    assert output_space.contains(result)
agent_params.network_wrappers['main'].learning_rate = 0.0001
agent_params.network_wrappers['main'].input_embedders_parameters = {
    "screen": InputEmbedderParameters(input_rescaling={'image': 3.0})
}
agent_params.network_wrappers['main'].heads_parameters = [
    DuelingQHeadParameters()
]
agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000)
# slave_agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.1, 1000000)
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
agent_params.output_filter = \
    OutputFilter(
        action_filters=OrderedDict([
            ('discretization', BoxDiscretization(num_bins_per_dimension=4, force_int_bins=True))
        ]),
        is_a_reference_filter=False
    )

###############
# Environment #
###############

env_params = StarCraft2EnvironmentParameters(level='CollectMineralShards')
env_params.feature_screen_maps_to_use = [5]
env_params.feature_minimap_maps_to_use = [5]

########
# Test #
########
Ejemplo n.º 4
0
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(1000)

#########
# Agent #
#########
agent_params = DDQNAgentParameters()
agent_params.network_wrappers['main'].learning_rate = 0.00025
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
agent_params.network_wrappers['main'].middleware_parameters.scheme = MiddlewareScheme.Empty
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2), 1/math.sqrt(2)]
agent_params.network_wrappers['main'].clip_gradients = 10
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
agent_params.network_wrappers['main'].input_embedders_parameters['forward_camera'] = \
    agent_params.network_wrappers['main'].input_embedders_parameters.pop('observation')
agent_params.output_filter = OutputFilter()
agent_params.output_filter.add_action_filter('discretization', BoxDiscretization(5))

###############
# Environment #
###############
env_params = CarlaEnvironmentParameters()
env_params.level = 'town1'

vis_params = VisualizationParameters()
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
vis_params.dump_mp4 = False

graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
                                    schedule_params=schedule_params, vis_params=vis_params)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(1000)

#########
# Agent #
#########
agent_params = DDQNAgentParameters()
agent_params.network_wrappers['main'].learning_rate = 0.00025
agent_params.network_wrappers['main'].heads_parameters = \
    [DuelingQHeadParameters(rescale_gradient_from_head_by_factor=1/math.sqrt(2))]
agent_params.network_wrappers[
    'main'].middleware_parameters.scheme = MiddlewareScheme.Empty
agent_params.network_wrappers['main'].clip_gradients = 10
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
agent_params.network_wrappers['main'].input_embedders_parameters['forward_camera'] = \
    agent_params.network_wrappers['main'].input_embedders_parameters.pop('observation')
agent_params.output_filter = OutputFilter()
agent_params.output_filter.add_action_filter('discretization',
                                             BoxDiscretization(5))

###############
# Environment #
###############
env_params = CarlaEnvironmentParameters()

graph_manager = BasicRLGraphManager(agent_params=agent_params,
                                    env_params=env_params,
                                    schedule_params=schedule_params,
                                    vis_params=VisualizationParameters())