Exemple #1
0
    'USE': ord("f"),
}

DoomInputFilter = InputFilter(is_a_reference_filter=True)
DoomInputFilter.add_observation_filter(
    'observation', 'rescaling',
    ObservationRescaleToSizeFilter(
        ImageObservationSpace(np.array([60, 76, 3]), high=255)))
DoomInputFilter.add_observation_filter('observation', 'to_grayscale',
                                       ObservationRGBToYFilter())
DoomInputFilter.add_observation_filter('observation', 'to_uint8',
                                       ObservationToUInt8Filter(0, 255))
DoomInputFilter.add_observation_filter('observation', 'stacking',
                                       ObservationStackingFilter(3))

DoomOutputFilter = OutputFilter(is_a_reference_filter=True)
DoomOutputFilter.add_action_filter('to_discrete', FullDiscreteActionSpaceMap())


class DoomEnvironmentParameters(EnvironmentParameters):
    def __init__(self, level=None):
        super().__init__(level=level)
        self.default_input_filter = DoomInputFilter
        self.default_output_filter = DoomOutputFilter
        self.cameras = [DoomEnvironment.CameraTypes.OBSERVATION]

    @property
    def path(self):
        return 'rl_coach.environments.doom_environment:DoomEnvironment'

Exemple #2
0
StarcraftInputFilter.add_observation_filter('screen', 'to_uint8',
                                            ObservationToUInt8Filter(0, 255))

StarcraftInputFilter.add_observation_filter('minimap', 'move_axis',
                                            ObservationMoveAxisFilter(0, -1))
StarcraftInputFilter.add_observation_filter(
    'minimap', 'rescaling',
    ObservationRescaleToSizeFilter(
        PlanarMapsObservationSpace(np.array([64, 64, 1]),
                                   low=0,
                                   high=255,
                                   channels_axis=-1)))
StarcraftInputFilter.add_observation_filter('minimap', 'to_uint8',
                                            ObservationToUInt8Filter(0, 255))

StarcraftNormalizingOutputFilter = OutputFilter(is_a_reference_filter=True)
StarcraftNormalizingOutputFilter.add_action_filter(
    'normalization',
    LinearBoxToBoxMap(input_space_low=-SCREEN_SIZE / 2,
                      input_space_high=SCREEN_SIZE / 2 - 1))


class StarCraft2EnvironmentParameters(EnvironmentParameters):
    def __init__(self, level=None):
        super().__init__(level=level)
        self.screen_size = 84
        self.minimap_size = 64
        self.feature_minimap_maps_to_use = range(7)
        self.feature_screen_maps_to_use = range(17)
        self.observation_type = StarcraftObservationType.Features
        self.disable_fog = False
agent_params.network_wrappers['main'].learning_rate = 0.0001
agent_params.network_wrappers['main'].input_embedders_parameters = {
    "screen": InputEmbedderParameters(input_rescaling={'image': 3.0})
}
agent_params.network_wrappers['main'].heads_parameters = [
    DuelingQHeadParameters()
]
agent_params.memory.max_size = (MemoryGranularity.Transitions, 1000000)
# slave_agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(10000)
agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.1, 1000000)
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
agent_params.output_filter = \
    OutputFilter(
        action_filters=OrderedDict([
            ('discretization', BoxDiscretization(num_bins_per_dimension=4, force_int_bins=True))
        ]),
        is_a_reference_filter=False
    )

###############
# Environment #
###############

env_params = StarCraft2EnvironmentParameters(level='CollectMineralShards')
env_params.feature_screen_maps_to_use = [5]
env_params.feature_minimap_maps_to_use = [5]

########
# Test #
########
preset_validation_params = PresetValidationParameters()
Exemple #4
0
schedule_params.evaluation_steps = EnvironmentEpisodes(1)
schedule_params.heatup_steps = EnvironmentSteps(1000)

#########
# Agent #
#########
agent_params = DDQNAgentParameters()
agent_params.network_wrappers['main'].learning_rate = 0.00025
agent_params.network_wrappers['main'].heads_parameters = [DuelingQHeadParameters()]
agent_params.network_wrappers['main'].middleware_parameters.scheme = MiddlewareScheme.Empty
agent_params.network_wrappers['main'].rescale_gradient_from_head_by_factor = [1/math.sqrt(2), 1/math.sqrt(2)]
agent_params.network_wrappers['main'].clip_gradients = 10
agent_params.algorithm.num_consecutive_playing_steps = EnvironmentSteps(4)
agent_params.network_wrappers['main'].input_embedders_parameters['forward_camera'] = \
    agent_params.network_wrappers['main'].input_embedders_parameters.pop('observation')
agent_params.output_filter = OutputFilter()
agent_params.output_filter.add_action_filter('discretization', BoxDiscretization(5))

###############
# Environment #
###############
env_params = CarlaEnvironmentParameters()
env_params.level = 'town1'

vis_params = VisualizationParameters()
vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]
vis_params.dump_mp4 = False

graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params,
                                    schedule_params=schedule_params, vis_params=vis_params)
agent_params.network_wrappers['actor'].input_embedders_parameters[
    'observation'].scheme = [Dense(400)]
agent_params.network_wrappers['actor'].middleware_parameters.scheme = [
    Dense(300)
]
agent_params.network_wrappers['critic'].input_embedders_parameters[
    'observation'].scheme = [Dense(400)]
agent_params.network_wrappers['critic'].middleware_parameters.scheme = [
    Dense(300)
]
agent_params.network_wrappers['critic'].input_embedders_parameters[
    'action'].scheme = EmbedderScheme.Empty
agent_params.output_filter = \
    OutputFilter(
        action_filters=OrderedDict([
            ('discretization', BoxDiscretization(num_bins_per_dimension=int(1e6)))
        ]),
        is_a_reference_filter=False
    )

###############
# Environment #
###############
env_params = GymVectorEnvironment(level=SingleLevelSelection(mujoco_v2))

########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.test = True
preset_validation_params.min_reward_threshold = 500
preset_validation_params.max_episodes_to_achieve_reward = 1000