def __init__(self, name: str, schedule_params: ScheduleParameters, vis_params: VisualizationParameters = VisualizationParameters()): self.sess = None self.level_managers = [] # type: List[LevelManager] self.top_level_manager = None self.environments = [] self.set_schedule_params(schedule_params) self.visualization_parameters = vis_params self.name = name self.task_parameters = None self._phase = self.phase = RunPhase.UNDEFINED self.preset_validation_params = PresetValidationParameters() self.reset_required = False # timers self.graph_creation_time = None self.last_checkpoint_saving_time = time.time() # counters self.total_steps_counters = { RunPhase.HEATUP: TotalStepsCounter(), RunPhase.TRAIN: TotalStepsCounter(), RunPhase.TEST: TotalStepsCounter() } self.checkpoint_id = 0 self.checkpoint_saver = None self.checkpoint_state_updater = None self.graph_logger = Logger() self.data_store = None self.is_batch_rl = False self.time_metric = TimeTypes.EpisodeNumber
def __init__(self, name: str, schedule_params: ScheduleParameters, vis_params: VisualizationParameters = VisualizationParameters()): self.sess = None self.level_managers = [] self.top_level_manager = None self.environments = [] self.heatup_steps = schedule_params.heatup_steps self.evaluation_steps = schedule_params.evaluation_steps self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods self.improve_steps = schedule_params.improve_steps self.visualization_parameters = vis_params self.name = name self.task_parameters = None self._phase = self.phase = RunPhase.UNDEFINED self.preset_validation_params = PresetValidationParameters() # timers self.graph_initialization_time = time.time() self.heatup_start_time = None self.training_start_time = None self.last_evaluation_start_time = None self.last_checkpoint_saving_time = time.time() # counters self.total_steps_counters = { RunPhase.HEATUP: TotalStepsCounter(), RunPhase.TRAIN: TotalStepsCounter(), RunPhase.TEST: TotalStepsCounter() } self.checkpoint_id = 0 self.checkpoint_saver = None self.graph_logger = Logger()
def __init__( self, agents_params: List[AgentParameters], env_params: EnvironmentParameters, schedule_params: ScheduleParameters, vis_params: VisualizationParameters = VisualizationParameters(), preset_validation_params: PresetValidationParameters = PresetValidationParameters()): self.sess = {agent_params.name: None for agent_params in agents_params} self.level_managers = [] # type: List[MultiAgentLevelManager] self.top_level_manager = None self.environments = [] self.set_schedule_params(schedule_params) self.visualization_parameters = vis_params self.name = 'multi_agent_graph' self.task_parameters = None self._phase = self.phase = RunPhase.UNDEFINED self.preset_validation_params = preset_validation_params self.reset_required = False self.num_checkpoints_to_keep = 4 # TODO: make this a parameter # timers self.graph_creation_time = None self.last_checkpoint_saving_time = time.time() # counters self.total_steps_counters = { RunPhase.HEATUP: TotalStepsCounter(), RunPhase.TRAIN: TotalStepsCounter(), RunPhase.TEST: TotalStepsCounter() } self.checkpoint_id = 0 self.checkpoint_saver = { agent_params.name: None for agent_params in agents_params } self.checkpoint_state_updater = None self.graph_logger = Logger() self.data_store = None self.is_batch_rl = False self.time_metric = TimeTypes.EpisodeNumber self.env_params = env_params self.agents_params = agents_params self.agent_params = agents_params[0] # ...(find a better way)... for agent_index, agent_params in enumerate(agents_params): if len(agents_params) == 1: agent_params.name = "agent" else: agent_params.name = "agent_{}".format(agent_index) agent_params.visualization = copy.copy(vis_params) if agent_params.input_filter is None: agent_params.input_filter = copy.copy( env_params.default_input_filter()) if agent_params.output_filter is None: agent_params.output_filter = copy.copy( env_params.default_output_filter())
def test_add_total_steps_counter(): counter = TotalStepsCounter() steps = counter + EnvironmentSteps(10) assert steps.num_steps == 10
def test_total_steps_counter_less_than(): counter = TotalStepsCounter() steps = counter + EnvironmentSteps(0) assert not (counter < steps) steps = counter + EnvironmentSteps(1) assert counter < steps
def test_add_total_steps_counter_non_zero(): counter = TotalStepsCounter() counter[EnvironmentSteps] += 10 steps = counter + EnvironmentSteps(10) assert steps.num_steps == 20