コード例 #1
0
    def __init__(self,
                 name: str,
                 schedule_params: ScheduleParameters,
                 vis_params: VisualizationParameters = VisualizationParameters()):
        self.sess = None
        self.level_managers = []  # type: List[LevelManager]
        self.top_level_manager = None
        self.environments = []
        self.set_schedule_params(schedule_params)
        self.visualization_parameters = vis_params
        self.name = name
        self.task_parameters = None
        self._phase = self.phase = RunPhase.UNDEFINED
        self.preset_validation_params = PresetValidationParameters()
        self.reset_required = False

        # timers
        self.graph_creation_time = None
        self.last_checkpoint_saving_time = time.time()

        # counters
        self.total_steps_counters = {
            RunPhase.HEATUP: TotalStepsCounter(),
            RunPhase.TRAIN: TotalStepsCounter(),
            RunPhase.TEST: TotalStepsCounter()
        }
        self.checkpoint_id = 0

        self.checkpoint_saver = None
        self.checkpoint_state_updater = None
        self.graph_logger = Logger()
        self.data_store = None
        self.is_batch_rl = False
        self.time_metric = TimeTypes.EpisodeNumber
コード例 #2
0
    def __init__(self,
                 name: str,
                 schedule_params: ScheduleParameters,
                 vis_params: VisualizationParameters = VisualizationParameters()):
        self.sess = None
        self.level_managers = []
        self.top_level_manager = None
        self.environments = []
        self.heatup_steps = schedule_params.heatup_steps
        self.evaluation_steps = schedule_params.evaluation_steps
        self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
        self.improve_steps = schedule_params.improve_steps
        self.visualization_parameters = vis_params
        self.name = name
        self.task_parameters = None
        self._phase = self.phase = RunPhase.UNDEFINED
        self.preset_validation_params = PresetValidationParameters()

        # timers
        self.graph_initialization_time = time.time()
        self.heatup_start_time = None
        self.training_start_time = None
        self.last_evaluation_start_time = None
        self.last_checkpoint_saving_time = time.time()

        # counters
        self.total_steps_counters = {
            RunPhase.HEATUP: TotalStepsCounter(),
            RunPhase.TRAIN: TotalStepsCounter(),
            RunPhase.TEST: TotalStepsCounter()
        }
        self.checkpoint_id = 0

        self.checkpoint_saver = None
        self.graph_logger = Logger()
コード例 #3
0
    def __init__(
        self,
        agents_params: List[AgentParameters],
        env_params: EnvironmentParameters,
        schedule_params: ScheduleParameters,
        vis_params: VisualizationParameters = VisualizationParameters(),
        preset_validation_params:
        PresetValidationParameters = PresetValidationParameters()):
        self.sess = {agent_params.name: None for agent_params in agents_params}
        self.level_managers = []  # type: List[MultiAgentLevelManager]
        self.top_level_manager = None
        self.environments = []
        self.set_schedule_params(schedule_params)
        self.visualization_parameters = vis_params
        self.name = 'multi_agent_graph'
        self.task_parameters = None
        self._phase = self.phase = RunPhase.UNDEFINED
        self.preset_validation_params = preset_validation_params
        self.reset_required = False
        self.num_checkpoints_to_keep = 4  # TODO: make this a parameter

        # timers
        self.graph_creation_time = None
        self.last_checkpoint_saving_time = time.time()

        # counters
        self.total_steps_counters = {
            RunPhase.HEATUP: TotalStepsCounter(),
            RunPhase.TRAIN: TotalStepsCounter(),
            RunPhase.TEST: TotalStepsCounter()
        }
        self.checkpoint_id = 0

        self.checkpoint_saver = {
            agent_params.name: None
            for agent_params in agents_params
        }
        self.checkpoint_state_updater = None
        self.graph_logger = Logger()
        self.data_store = None
        self.is_batch_rl = False
        self.time_metric = TimeTypes.EpisodeNumber

        self.env_params = env_params
        self.agents_params = agents_params
        self.agent_params = agents_params[0]  # ...(find a better way)...

        for agent_index, agent_params in enumerate(agents_params):
            if len(agents_params) == 1:
                agent_params.name = "agent"
            else:
                agent_params.name = "agent_{}".format(agent_index)
            agent_params.visualization = copy.copy(vis_params)
            if agent_params.input_filter is None:
                agent_params.input_filter = copy.copy(
                    env_params.default_input_filter())
            if agent_params.output_filter is None:
                agent_params.output_filter = copy.copy(
                    env_params.default_output_filter())
コード例 #4
0
def test_add_total_steps_counter():
    counter = TotalStepsCounter()
    steps = counter + EnvironmentSteps(10)
    assert steps.num_steps == 10
コード例 #5
0
def test_total_steps_counter_less_than():
    counter = TotalStepsCounter()
    steps = counter + EnvironmentSteps(0)
    assert not (counter < steps)
    steps = counter + EnvironmentSteps(1)
    assert counter < steps
コード例 #6
0
def test_add_total_steps_counter_non_zero():
    counter = TotalStepsCounter()
    counter[EnvironmentSteps] += 10
    steps = counter + EnvironmentSteps(10)
    assert steps.num_steps == 20