Example #1
0
from rl_coach.agents.ddpg_agent import DDPGAgentParameters
from rl_coach.agents.dqn_agent import DQNAgentParameters
from rl_coach.graph_managers.graph_manager import SimpleSchedule, SimpleScheduleWithoutEvaluation
from rl_coach.core_types import EnvironmentSteps, TrainingSteps
from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.first_test import ControlSuiteEnvironmentParameters

agent_params = DQNAgentParameters()
# rename the input embedder key from 'observation' to 'measurements'
# agent_params.network_wrappers['main'].input_embedders_parameters['measurements'] = agent_params.network_wrappers['main'].input_embedders_parameters.pop('observation')
schedule_params = SimpleSchedule()
schedule_params.heatup_steps = EnvironmentSteps(10)
preset_validation_params = PresetValidationParameters()
# preset_validation_params.test = True
# preset_validation_params.min_reward_threshold = 20
# preset_validation_params.max_episodes_to_achieve_reward = 400

vis_params = VisualizationParameters(render=False)

env_params = ControlSuiteEnvironmentParameters()

graph_manager = BasicRLGraphManager(
    agent_params=agent_params,
    env_params=env_params,
    schedule_params=schedule_params,
    vis_params=VisualizationParameters(),
    preset_validation_params=preset_validation_params)
Example #2
0
from rl_coach.agents.human_agent import HumanAgentParameters
from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters
from rl_coach.base_parameters import VisualizationParameters, MiddlewareScheme, PresetValidationParameters
from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps
# from rl_coach.environments.environment import SingleLevelSelection
# from rl_coach.environments.gym_environment import Atari, atari_deterministic_v4
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import SimpleSchedule
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.environments.lab_environment import LabEnvironmentParameters, level_scripts
####################
# Graph Scheduling #
####################
schedule_params = SimpleSchedule()
# schedule_params.heatup_steps = EnvironmentSteps(10000)

#########
# Agent #
#########
agent_params = HumanAgentParameters()

###############
# Environment #
###############
env_params = LabEnvironmentParameters(
    level=SingleLevelSelection(level_scripts),
    human_control=True,
    width=100,
    height=100)

graph_manager = BasicRLGraphManager(agent_params=agent_params,
env_params.frame_skip = 5  #to make sure the gifs work without skipping steps

vis_params = VisualizationParameters()
vis_params.dump_gifs = True
#vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]

#experiment_name = "TSPEasy"
#experiment_name = logger.get_experiment_name(experiment_name)
#experiment_path = logger.get_experiment_path(experiment_name)

#task_params = TaskParameters(experiment_path=experiment_path)

####################
# Graph Scheduling #
####################

schedule_params = SimpleSchedule()
schedule_params.improve_steps = TrainingSteps(50000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
schedule_params.heatup_steps = EnvironmentSteps(1000)

graph_manager = BasicRLGraphManager(agent_params=agent_params,
                                    env_params=env_params,
                                    schedule_params=schedule_params,
                                    vis_params=vis_params)

#graph_manager = graph_manager.create_graph(task_parameters=task_params)

#graph_manager.improve()
vis_params = VisualizationParameters()
vis_params.dump_gifs=True
#vis_params.video_dump_methods = [SelectedPhaseOnlyDumpMethod(RunPhase.TEST), MaxDumpMethod()]

#experiment_name = "TSPEasy"
#experiment_name = logger.get_experiment_name(experiment_name)
#experiment_path = logger.get_experiment_path(experiment_name)

#task_params = TaskParameters(experiment_path=experiment_path)

####################
# Graph Scheduling #
####################

schedule_params=SimpleSchedule()
schedule_params.improve_steps = TrainingSteps(5000000)
schedule_params.steps_between_evaluation_periods = EnvironmentEpisodes(20)
schedule_params.evaluation_steps = EnvironmentEpisodes(5)
schedule_params.heatup_steps = EnvironmentSteps(1000)

graph_manager = BasicRLGraphManager(
    agent_params= agent_params,
    env_params=env_params,
    schedule_params=schedule_params,
    vis_params=vis_params
    )

#graph_manager = graph_manager.create_graph(task_parameters=task_params)

#graph_manager.improve()