def _create_graph( self, task_parameters: TaskParameters ) -> Tuple[List[MultiAgentLevelManager], List[Environment]]: # environment loading self.env_params.seed = task_parameters.seed self.env_params.experiment_path = task_parameters.experiment_path env = short_dynamic_import(self.env_params.path)( **self.env_params.__dict__, visualization_parameters=self.visualization_parameters) # agent loading agents = OrderedDict() for agent_params in self.agents_params: agent_params.task_parameters = copy.copy(task_parameters) agent = short_dynamic_import(agent_params.path)(agent_params) agents[agent_params.name] = agent screen.log_title("Created agent: {}".format(agent_params.name)) if hasattr(self, 'memory_backend_params') and \ self.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER): agent.memory.memory_backend = deepracer_memory.DeepRacerRolloutBackEnd( self.memory_backend_params, agent_params.algorithm.num_consecutive_playing_steps, agent_params.name) # set level manager level_manager = MultiAgentLevelManager( agents=agents, environment=env, name="main_level", done_condition=self.done_condition) return [level_manager], [env]
def rollout_worker(graph_manager, checkpoint_dir, data_store, num_workers, memory_backend_params): """ wait for first checkpoint then perform rollouts using the model """ wait_for_checkpoint(checkpoint_dir, data_store) task_parameters = TaskParameters() task_parameters.__dict__['checkpoint_restore_dir'] = checkpoint_dir graph_manager.create_graph(task_parameters) graph_manager.reset_internal_state() for level in graph_manager.level_managers: for agent in level.agents.values(): agent.memory.memory_backend = deepracer_memory.DeepRacerRolloutBackEnd( memory_backend_params, graph_manager.agent_params.algorithm. num_consecutive_playing_steps) try: with graph_manager.phase_context(RunPhase.TRAIN): last_checkpoint = 0 act_steps = math.ceil( (graph_manager.agent_params.algorithm. num_consecutive_playing_steps.num_steps) / num_workers) for i in range( int(graph_manager.improve_steps.num_steps / act_steps)): if should_stop(checkpoint_dir): break try: # This will only work for DeepRacerRacetrackEnv enviroments graph_manager.top_level_manager.environment.env.env.set_allow_servo_step_signals( True) except Exception as ex: utils.json_format_logger( "Method not defined in enviroment class: {}".format( ex), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) if type(graph_manager.agent_params.algorithm. num_consecutive_playing_steps) == EnvironmentSteps: graph_manager.act( EnvironmentSteps(num_steps=act_steps), wait_for_full_episodes=graph_manager.agent_params. algorithm.act_for_full_episodes) elif type( graph_manager.agent_params.algorithm. num_consecutive_playing_steps) == EnvironmentEpisodes: graph_manager.act(EnvironmentEpisodes(num_steps=act_steps)) try: # This will only work for DeepRacerRacetrackEnv enviroments graph_manager.top_level_manager.environment.env.env.set_allow_servo_step_signals( False) graph_manager.top_level_manager.environment.env.env.stop_car( ) except Exception as ex: utils.json_format_logger( "Method not defined in enviroment class: {}".format( ex), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500)) if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC: data_store.load_from_store( expected_checkpoint_number=last_checkpoint + 1) last_checkpoint = get_latest_checkpoint(checkpoint_dir) graph_manager.restore_checkpoint() if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.ASYNC: new_checkpoint = get_latest_checkpoint(checkpoint_dir) if new_checkpoint > last_checkpoint: graph_manager.restore_checkpoint() last_checkpoint = new_checkpoint except Exception as ex: utils.json_format_logger( "An error occured during simulation: {}".format(ex), **utils.build_system_error_dict( utils.SIMAPP_SIMULATION_WORKER_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_500))