Exemplo n.º 1
0
 def __init__(self):
     super().__init__(
         algorithm=SILAlgorithmParameters(),
         exploration=
         None,  #TODO this should be different for continuous (ContinuousEntropyExploration)
         #  and discrete (CategoricalExploration) action spaces. how to deal with that?
         memory=PrioritizedExperienceReplayParameters(),
         networks={"main": SILNetworkParameters()})
Exemplo n.º 2
0
    def __init__(self):
        super().__init__()
        self.algorithm = RainbowDQNAlgorithmParameters()

        # ParameterNoiseParameters is changing the network wrapper parameters. This line needs to be done first.
        self.network_wrappers = {"main": RainbowDQNNetworkParameters()}

        self.exploration = ParameterNoiseParameters(self)
        self.memory = PrioritizedExperienceReplayParameters()
Exemplo n.º 3
0
 def __init__(self):
     super().__init__()
     self.algorithm = RainbowDQNAlgorithmParameters()
     self.exploration = ParameterNoiseParameters(self)
     self.memory = PrioritizedExperienceReplayParameters()
     self.network_wrappers = {"main": RainbowDQNNetworkParameters()}
Exemplo n.º 4
0
from rl_coach.agents.ddqn_agent import DDQNAgentParameters
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4, atari_schedule
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.memories.non_episodic.prioritized_experience_replay import PrioritizedExperienceReplayParameters
from rl_coach.schedules import LinearSchedule

#########
# Agent #
#########
agent_params = DDQNAgentParameters()
agent_params.network_wrappers['main'].learning_rate = 0.00025 / 4
agent_params.memory = PrioritizedExperienceReplayParameters()
agent_params.memory.beta = LinearSchedule(
    0.4, 1, 12500000)  # 12.5M training iterations = 50M steps = 200M frames

###############
# Environment #
###############
env_params = Atari(level=SingleLevelSelection(atari_deterministic_v4))

########
# Test #
########
preset_validation_params = PresetValidationParameters()
preset_validation_params.trace_test_levels = [
    'breakout', 'pong', 'space_invaders'
]

graph_manager = BasicRLGraphManager(