Exemple #1
0
    def tick_clocks(self, session):
        '''Tick all the clock in body_space, and check its own done_space to see if clock should be reset to next episode'''
        from slm_lab.experiment import analysis
        # TODO simplify below

        env_dones = []
        body_end_sessions = []
        for env in self.env_space.envs:
            done = env.done or env.clock.get('t') > env.max_timestep
            env_dones.append(done)
            if done:
                epi = env.clock.get('epi')
                save_this_epi = 'save_epi_frequency' in env.env_spec and (
                    epi % env.env_spec['save_epi_frequency']) == 0
                for body in env.nanflat_body_e:
                    self.body_done_log(body)
                    if epi > 0 and save_this_epi:
                        body.agent.algorithm.save(epi=epi)
                env.clock.tick('epi')
            else:
                env.clock.tick('t')
            env_end_session = env.clock.get('epi') > env.max_episode
            body_end_sessions.append(env_end_session)

        env_early_stops = []
        if any(env_dones) and self.clock.get('epi') > analysis.MA_WINDOW:
            session_mdp_data, session_data = analysis.get_session_data(session)
            for aeb in session_data:
                aeb_df = session_data[aeb]
                util.downcast_float32(aeb_df)
                body = self.body_space.data[aeb]
                env_epi = body.env.clock.get('epi')
                if env_epi > max(analysis.MA_WINDOW, body.env.max_episode / 2):
                    aeb_fitness_sr = analysis.calc_aeb_fitness_sr(
                        aeb_df, body.env.name)
                    strength = aeb_fitness_sr['strength']
                    # TODO properly trigger early stop
                    # env_early_stop = strength < analysis.NOISE_WINDOW
                    env_early_stop = False
                else:
                    env_early_stop = False
                env_early_stops.append(env_early_stop)
        else:
            env_early_stops.append(False)
        end_session = all(body_end_sessions) or all(env_early_stops)
        return end_session
Exemple #2
0
    def tick_clocks(self, session):
        '''Tick all the clock in body_space, and check its own done_space to see if clock should be reset to next episode'''
        from slm_lab.experiment import analysis
        # TODO simplify below

        env_dones = []
        body_end_sessions = []
        for env in self.env_space.envs:
            done = env.done or env.clock.get('t') > env.max_timestep
            env_dones.append(done)
            if done:
                epi = env.clock.get('epi')
                save_this_epi = 'save_epi_frequency' in env.env_spec and (epi % env.env_spec['save_epi_frequency']) == 0
                for body in env.nanflat_body_e:
                    self.body_done_log(body)
                    if epi > 0 and save_this_epi:
                        body.agent.algorithm.save(epi=epi)
                env.clock.tick('epi')
            else:
                env.clock.tick('t')
            env_end_session = env.clock.get('epi') > env.max_episode
            body_end_sessions.append(env_end_session)

        env_early_stops = []
        if any(env_dones) and self.clock.get('epi') > analysis.MA_WINDOW:
            session_mdp_data, session_data = analysis.get_session_data(session)
            for aeb in session_data:
                aeb_df = session_data[aeb]
                util.downcast_float32(aeb_df)
                body = self.body_space.data[aeb]
                env_epi = body.env.clock.get('epi')
                if env_epi > max(analysis.MA_WINDOW, body.env.max_episode / 2):
                    aeb_fitness_sr = analysis.calc_aeb_fitness_sr(aeb_df, body.env.name)
                    strength = aeb_fitness_sr['strength']
                    # TODO properly trigger early stop
                    # env_early_stop = strength < analysis.NOISE_WINDOW
                    env_early_stop = False
                else:
                    env_early_stop = False
                env_early_stops.append(env_early_stop)
        else:
            env_early_stops.append(False)
        end_session = all(body_end_sessions) or all(env_early_stops)
        return end_session
Exemple #3
0
    def tick_clocks(self, session):
        '''Tick all the clock in body_space, and check its own done_space to see if clock should be reset to next episode'''
        from slm_lab.experiment import analysis

        done_space = self.data_spaces['done']
        env_dones = []
        body_end_sessions = []
        for env in self.env_space.envs:
            clock = env.clock
            done = env.done or clock.get('t') > env.max_timestep
            env_dones.append(done)
            if done:
                done_space.data[:, env.e, :] = 1.
                done_space.swap_data[env.e, :, :] = 1.
                msg = f'Done: trial {self.info_space.get("trial")} session {self.info_space.get("session")} env {env.e} epi {clock.get("epi")}, t {clock.get("t")}'
                logger.info(msg)
                clock.tick('epi')
            else:
                clock.tick('t')
            env_end_session = clock.get('epi') > env.max_episode
            body_end_sessions.append(env_end_session)

        env_early_stops = []
        if any(env_dones) and self.clock.get('epi') > analysis.MA_WINDOW:
            session_mdp_data, session_data = analysis.get_session_data(session)
            for aeb in session_data:
                aeb_df = session_data[aeb]
                util.downcast_float32(aeb_df)
                body = self.body_space.data[aeb]
                env_epi = body.env.clock.get('epi')
                if env_epi > max(analysis.MA_WINDOW, body.env.max_episode / 2):
                    aeb_fitness_sr = analysis.calc_aeb_fitness_sr(
                        aeb_df, body.env.name)
                    strength = aeb_fitness_sr['strength']
                    env_early_stop = strength < analysis.NOISE_WINDOW
                else:
                    env_early_stop = False
                env_early_stops.append(env_early_stop)
        else:
            env_early_stops.append(False)
        end_session = all(body_end_sessions) or all(env_early_stops)
        return end_session