コード例 #1
0
        def __init__(self, worker_id, config, env_config, buffer_config):
            config_attr(self, config)
            cpu_affinity(f'Worker_{worker_id}')
            self._id = worker_id

            self._n_envvecs = env_config.pop('n_envvecs')
            env_config.pop('n_workers', None)
            self._envvecs = [
                create_env(env_config, force_envvec=True)
                for _ in range(self._n_envvecs)
            ]

            collect_fn = pkg.import_module('agent', config=config,
                                           place=-1).collect
            self._collect = functools.partial(collect_fn,
                                              env=None,
                                              step=None,
                                              reset=None)

            buffer_config['force_envvec'] = True
            self._buffs = {
                eid: create_local_buffer(buffer_config)
                for eid in range(self._n_envvecs)
            }

            self._obs = {
                eid: e.output().obs
                for eid, e in enumerate(self._envvecs)
            }
            self._info = collections.defaultdict(list)
コード例 #2
0
ファイル: actor.py プロジェクト: xlnwel/d2rl
        def __init__(self, model_fn, replay, config, model_config, env_config,
                     replay_config):
            config_actor('Learner', config)

            env = create_env(env_config)

            model = model_fn(config=model_config, env=env)

            am = pkg.import_module('agent', config=config, place=-1)
            data_format = am.get_data_format(env=env,
                                             replay_config=replay_config,
                                             agent_config=config,
                                             model=model)
            dataset = create_dataset(replay,
                                     env,
                                     data_format=data_format,
                                     use_ray=True)

            super().__init__(
                name='Learner',
                config=config,
                models=model,
                dataset=dataset,
                env=env,
            )
コード例 #3
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def main(env_config, model_config, agent_config, buffer_config, train=train):
    silence_tf_logs()
    configure_gpu()
    configure_precision(agent_config['precision'])

    create_model, Agent = pkg.import_agent(config=agent_config)
    Buffer = pkg.import_module('buffer', config=agent_config).Buffer

    use_ray = env_config.get('n_workers', 1) > 1
    if use_ray:
        import ray
        from utility.ray_setup import sigint_shutdown_ray
        ray.init()
        sigint_shutdown_ray()

    env = create_env(env_config, force_envvec=True)
    eval_env_config = env_config.copy()
    if 'num_levels' in eval_env_config:
        eval_env_config['num_levels'] = 0
    if 'seed' in eval_env_config:
        eval_env_config['seed'] += 1000
    eval_env_config['n_workers'] = 1
    for k in list(eval_env_config.keys()):
        # pop reward hacks
        if 'reward' in k:
            eval_env_config.pop(k)
    eval_env = create_env(eval_env_config, force_envvec=True)

    def sigint_handler(sig, frame):
        signal.signal(sig, signal.SIG_IGN)
        env.close()
        eval_env.close()
        sys.exit(0)

    signal.signal(signal.SIGINT, sigint_handler)

    models = create_model(model_config, env)

    buffer_config['n_envs'] = env.n_envs
    buffer_config['state_keys'] = models.state_keys
    buffer = Buffer(buffer_config)

    agent = Agent(config=agent_config, models=models, dataset=buffer, env=env)

    agent.save_config(
        dict(env=env_config,
             model=model_config,
             agent=agent_config,
             buffer=buffer_config))

    train(agent, env, eval_env, buffer)

    if use_ray:
        env.close()
        eval_env.close()
        ray.shutdown()
コード例 #4
0
ファイル: actor.py プロジェクト: xlnwel/d2rl
        def __init__(self, *, worker_id, config, model_config, env_config,
                     buffer_config, model_fn, buffer_fn):
            config_actor(f'Worker_{worker_id}', config)
            self._id = worker_id

            self.env = create_env(env_config)

            buffer_config['n_envs'] = self.env.n_envs
            if 'seqlen' not in buffer_config:
                buffer_config['seqlen'] = self.env.max_episode_steps
            self.buffer = buffer_fn(buffer_config)

            models = model_fn(config=model_config, env=self.env)

            super().__init__(name=f'Worker_{worker_id}',
                             config=config,
                             models=models,
                             dataset=self.buffer,
                             env=self.env)

            # setup runner
            import importlib
            em = importlib.import_module(
                f'env.{env_config["name"].split("_")[0]}')
            info_func = em.info_func if hasattr(em, 'info_func') else None
            self._run_mode = getattr(self, '_run_mode', RunMode.NSTEPS)
            assert self._run_mode in [RunMode.NSTEPS, RunMode.TRAJ]
            self.runner = Runner(self.env,
                                 self,
                                 nsteps=self.SYNC_PERIOD
                                 if self._run_mode == RunMode.NSTEPS else None,
                                 run_mode=self._run_mode,
                                 record_envs=getattr(self, '_record_envs',
                                                     None),
                                 info_func=info_func)

            # worker side prioritization
            self._worker_side_prioritization = getattr(
                self, '_worker_side_prioritization', False)
            self._return_stats = self._worker_side_prioritization \
                or buffer_config.get('max_steps', 0) > buffer_config.get('n_steps', 1)

            # setups self._collect using <collect> function from the algorithm module
            collect_fn = pkg.import_module('agent',
                                           algo=self._algorithm,
                                           place=-1).collect
            self._collect = functools.partial(collect_fn, self.buffer)

            # the names of network modules that should be in sync with the learner
            if not hasattr(self, '_pull_names'):
                self._pull_names = [
                    k for k in self.model.keys() if 'target' not in k
                ]

            # used for recording worker side info
            self._info = collections.defaultdict(list)
コード例 #5
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def main(env_config, model_config, agent_config, replay_config):
    silence_tf_logs()
    configure_gpu()
    configure_precision(agent_config.get('precision', 32))

    use_ray = env_config.get('n_workers', 1) > 1
    if use_ray:
        import ray
        from utility.ray_setup import sigint_shutdown_ray
        ray.init()
        sigint_shutdown_ray()

    env = create_env(env_config)
    eval_env_config = env_config.copy()
    eval_env_config['n_workers'] = 1
    eval_env_config['n_envs'] = 1
    reward_key = [k for k in eval_env_config.keys() if 'reward' in k]
    [eval_env_config.pop(k) for k in reward_key]
    eval_env = create_env(eval_env_config, force_envvec=True)

    agent_config['N_UPDATES'] *= env_config['n_workers'] * env_config['n_envs']
    create_model, Agent = pkg.import_agent(config=agent_config)
    models = create_model(model_config, env)

    n_workers = env_config.get('n_workers', 1)
    n_envs = env_config.get('n_envs', 1)
    replay_config['n_envs'] = n_workers * n_envs
    replay_config['seqlen'] = env.max_episode_steps
    if getattr(models, 'state_keys', ()):
        replay_config['state_keys'] = list(models.state_keys)
    replay = create_replay(replay_config)
    replay.load_data()

    am = pkg.import_module('agent', config=agent_config)
    data_format = am.get_data_format(env=env,
                                     replay_config=replay_config,
                                     agent_config=agent_config,
                                     model=models)
    dataset = create_dataset(replay, env, data_format=data_format)

    agent = Agent(config=agent_config, models=models, dataset=dataset, env=env)

    agent.save_config(
        dict(env=env_config,
             model=model_config,
             agent=agent_config,
             replay=replay_config))

    train(agent, env, eval_env, replay)

    if use_ray:
        ray.shutdown()
コード例 #6
0
def main(env_config,
         model_config,
         agent_config,
         replay_config,
         n,
         record=False,
         size=(128, 128),
         video_len=1000,
         fps=30,
         save=False):
    silence_tf_logs()
    configure_gpu()
    configure_precision(agent_config.get('precision', 32))

    use_ray = env_config.get('n_workers', 0) > 1
    if use_ray:
        import ray
        ray.init()
        sigint_shutdown_ray()

    algo_name = agent_config['algorithm']
    env_name = env_config['name']

    try:
        make_env = pkg.import_module('env', algo_name, place=-1).make_env
    except:
        make_env = None
    env_config.pop('reward_clip', False)
    env = create_env(env_config, env_fn=make_env)
    create_model, Agent = pkg.import_agent(config=agent_config)
    models = create_model(model_config, env)

    agent = Agent(config=agent_config, models=models, dataset=None, env=env)

    if n < env.n_envs:
        n = env.n_envs
    scores, epslens, video = evaluate(env,
                                      agent,
                                      n,
                                      record=record,
                                      size=size,
                                      video_len=video_len)
    pwc(f'After running {n} episodes',
        f'Score: {np.mean(scores):.3g}\tEpslen: {np.mean(epslens):.3g}',
        color='cyan')

    if record:
        save_video(f'{algo_name}-{env_name}', video, fps=fps)
    if use_ray:
        ray.shutdown()
コード例 #7
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def main(env_config, model_config, agent_config, replay_config):
    silence_tf_logs()
    configure_gpu()
    configure_precision(agent_config['precision'])

    use_ray = env_config.get('n_workers', 0) > 1
    if use_ray:
        import ray
        ray.init()
        sigint_shutdown_ray()

    env = create_env(env_config, make_env, force_envvec=True)
    eval_env_config = env_config.copy()
    eval_env_config['n_envs'] = 1
    eval_env_config['n_workers'] = 1
    eval_env = create_env(eval_env_config, make_env)

    replay_config['dir'] = agent_config['root_dir'].replace('logs', 'data')
    replay = create_replay(replay_config)
    replay.load_data()
    dtype = global_policy().compute_dtype
    data_format = pkg.import_module(
        'agent', config=agent_config).get_data_format(
            env=env,
            batch_size=agent_config['batch_size'],
            sample_size=agent_config['sample_size'],
            dtype=dtype)
    process = functools.partial(process_with_env,
                                env=env,
                                obs_range=[-.5, .5],
                                one_hot_action=True,
                                dtype=dtype)
    dataset = Dataset(replay, data_format, process)

    create_model, Agent = pkg.import_agent(config=agent_config)
    models = create_model(model_config, env)

    agent = Agent(config=agent_config, models=models, dataset=dataset, env=env)

    agent.save_config(
        dict(env=env_config,
             model=model_config,
             agent=agent_config,
             replay=replay_config))

    train(agent, env, eval_env, replay)
コード例 #8
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def main(env_config, model_config, agent_config, buffer_config):
    silence_tf_logs()
    configure_gpu()
    configure_precision(agent_config['precision'])

    create_model, Agent = pkg.import_agent(config=agent_config)
    Buffer = pkg.import_module('buffer', config=agent_config).Buffer

    use_ray = env_config.get('n_workers', 1) > 1
    if use_ray:
        import ray
        from utility.ray_setup import sigint_shutdown_ray
        ray.init()
        sigint_shutdown_ray()

    env = create_env(env_config, force_envvec=True)
    eval_env_config = env_config.copy()
    eval_env_config['seed'] += 1000
    eval_env_config['n_workers'] = 1
    eval_env_config['n_envs'] = 1
    for k in list(eval_env_config.keys()):
        # pop reward hacks
        if 'reward' in k:
            eval_env_config.pop(k)
    eval_env = create_env(eval_env_config, force_envvec=True)

    models = create_model(model_config, env)

    buffer_config['n_envs'] = env.n_envs
    buffer = Buffer(buffer_config)

    agent = Agent(config=agent_config, models=models, dataset=buffer, env=env)

    agent.save_config(
        dict(env=env_config,
             model=model_config,
             agent=agent_config,
             buffer=buffer_config))

    train(agent, env, eval_env, buffer)

    if use_ray:
        import ray
        ray.shutdown()
コード例 #9
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def train(agent, env, eval_env, replay):
    collect_fn = pkg.import_module('agent', algo=agent.name).collect
    collect = functools.partial(collect_fn, replay)

    _, step = replay.count_episodes()
    step = max(agent.env_step, step)

    runner = Runner(env, agent, step=step)

    def random_actor(*args, **kwargs):
        prev_action = random_actor.prev_action
        random_actor.prev_action = action = env.random_action()
        return action, {'prev_action': prev_action}
    random_actor.prev_action = np.zeros_like(env.random_action()) \
        if isinstance(env.random_action(), np.ndarray) else 0
    while not replay.good_to_learn():
        step = runner.run(action_selector=random_actor, step_fn=collect)

    to_log = Every(agent.LOG_PERIOD)
    to_eval = Every(agent.EVAL_PERIOD)
    print('Training starts...')
    while step < int(agent.MAX_STEPS):
        start_step = step
        start_t = time.time()
        agent.learn_log(step)
        step = runner.run(step_fn=collect, nsteps=agent.TRAIN_PERIOD)
        duration = time.time() - start_t
        agent.store(fps=(step - start_step) / duration,
                    tps=agent.N_UPDATES / duration)

        if to_eval(step):
            with TempStore(agent.get_states, agent.reset_states):
                score, epslen, video = evaluate(eval_env,
                                                agent,
                                                record=agent.RECORD,
                                                size=(64, 64))
                if agent.RECORD:
                    video_summary(f'{agent.name}/sim', video, step=step)
                agent.store(eval_score=score, eval_epslen=epslen)

        if to_log(step):
            agent.log(step)
            agent.save()
コード例 #10
0
def train(agent, env, eval_env, buffer):
    collect_fn = pkg.import_module('agent', algo=agent.name).collect
    collect = functools.partial(collect_fn, buffer)

    step = agent.env_step
    runner = Runner(env, agent, step=step, nsteps=agent.N_STEPS)
    exp_buffer = get_expert_data(f'{buffer.DATA_PATH}-{env.name}')

    if step == 0 and agent.is_obs_normalized:
        print('Start to initialize running stats...')
        for _ in range(10):
            runner.run(action_selector=env.random_action, step_fn=collect)
            agent.update_obs_rms(np.concatenate(buffer['obs']))
            agent.update_reward_rms(buffer['reward'], buffer['discount'])
            buffer.reset()
        buffer.clear()
        agent.save(print_terminal_info=True)

    runner.step = step
    # print("Initial running stats:", *[f'{k:.4g}' for k in agent.get_running_stats() if k])
    to_log = Every(agent.LOG_PERIOD, agent.LOG_PERIOD)
    to_eval = Every(agent.EVAL_PERIOD)
    rt = Timer('run')
    tt = Timer('train')
    et = Timer('eval')
    lt = Timer('log')
    print('Training starts...')
    while step < agent.MAX_STEPS:
        start_env_step = agent.env_step
        agent.before_run(env)
        with rt:
            step = runner.run(step_fn=collect)
        agent.store(fps=(step - start_env_step) / rt.last())
        buffer.reshape_to_sample()
        agent.disc_learn_log(exp_buffer)
        buffer.compute_reward_with_func(agent.compute_reward)
        buffer.reshape_to_store()

        # NOTE: normalizing rewards here may introduce some inconsistency
        # if normalized rewards is fed as an input to the network.
        # One can reconcile this by moving normalization to collect
        # or feeding the network with unnormalized rewards.
        # The latter is adopted in our implementation.
        # However, the following line currently doesn't store
        # a copy of unnormalized rewards
        agent.update_reward_rms(buffer['reward'], buffer['discount'])
        buffer.update('reward',
                      agent.normalize_reward(buffer['reward']),
                      field='all')
        agent.record_last_env_output(runner.env_output)
        value = agent.compute_value()
        buffer.finish(value)

        start_train_step = agent.train_step
        with tt:
            agent.learn_log(step)
        agent.store(tps=(agent.train_step - start_train_step) / tt.last())
        buffer.reset()

        if to_eval(agent.train_step) or step > agent.MAX_STEPS:
            with TempStore(agent.get_states, agent.reset_states):
                with et:
                    eval_score, eval_epslen, video = evaluate(
                        eval_env,
                        agent,
                        n=agent.N_EVAL_EPISODES,
                        record=agent.RECORD,
                        size=(64, 64))
                if agent.RECORD:
                    video_summary(f'{agent.name}/sim', video, step=step)
                agent.store(eval_score=eval_score, eval_epslen=eval_epslen)

        if to_log(agent.train_step) and agent.contains_stats('score'):
            with lt:
                agent.store(
                    **{
                        'train_step': agent.train_step,
                        'time/run': rt.total(),
                        'time/train': tt.total(),
                        'time/eval': et.total(),
                        'time/log': lt.total(),
                        'time/run_mean': rt.average(),
                        'time/train_mean': tt.average(),
                        'time/eval_mean': et.average(),
                        'time/log_mean': lt.average(),
                    })
                agent.log(step)
                agent.save()
コード例 #11
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def train(agent, env, eval_env, replay):
    collect_fn = pkg.import_module('agent', algo=agent.name).collect
    collect = functools.partial(collect_fn, replay)

    em = pkg.import_module(env.name.split("_")[0], pkg='env')
    info_func = em.info_func if hasattr(em, 'info_func') else None

    env_step = agent.env_step
    runner = Runner(env,
                    agent,
                    step=env_step,
                    run_mode=RunMode.TRAJ,
                    info_func=info_func)
    agent.TRAIN_PERIOD = env.max_episode_steps
    while not replay.good_to_learn():
        env_step = runner.run(step_fn=collect)
        replay.finish_episodes()

    to_eval = Every(agent.EVAL_PERIOD)
    to_log = Every(agent.LOG_PERIOD, agent.LOG_PERIOD)
    to_eval = Every(agent.EVAL_PERIOD)
    to_record = Every(agent.EVAL_PERIOD * 10)
    rt = Timer('run')
    tt = Timer('train')
    # et = Timer('eval')
    lt = Timer('log')
    print('Training starts...')
    while env_step <= int(agent.MAX_STEPS):
        with rt:
            env_step = runner.run(step_fn=collect)
        replay.finish_episodes()
        assert np.all(runner.env_output.reset), \
            (runner.env_output.reset, env.info().get('score', 0), env.info().get('epslen', 0))
        with tt:
            agent.learn_log(env_step)

        # if to_eval(env_step):
        #     with TempStore(agent.get_states, agent.reset_states):
        #         with et:
        #             record = agent.RECORD and to_record(env_step)
        #             eval_score, eval_epslen, video = evaluate(
        #                 eval_env, agent, n=agent.N_EVAL_EPISODES,
        #                 record=agent.RECORD, size=(64, 64))
        #             if record:
        #                 video_summary(f'{agent.name}/sim', video, step=env_step)
        #             agent.store(
        #                 eval_score=eval_score,
        #                 eval_epslen=eval_epslen)

        if to_log(env_step):
            with lt:
                fps = rt.average() * agent.TRAIN_PERIOD
                tps = tt.average() * agent.N_UPDATES

                agent.store(
                    env_step=agent.env_step,
                    train_step=agent.train_step,
                    fps=fps,
                    tps=tps,
                )
                agent.store(
                    **{
                        'train_step': agent.train_step,
                        'time/run': rt.total(),
                        'time/train': tt.total(),
                        # 'time/eval': et.total(),
                        'time/log': lt.total(),
                        'time/run_mean': rt.average(),
                        'time/train_mean': tt.average(),
                        # 'time/eval_mean': et.average(),
                        'time/log_mean': lt.average(),
                    })
                agent.log(env_step)
                agent.save()
コード例 #12
0
def main(env_config, model_config, agent_config, replay_config):
    gpus = tf.config.list_physical_devices('GPU')
    ray.init(num_cpus=os.cpu_count(), num_gpus=len(gpus))

    sigint_shutdown_ray()

    default_agent_config.update(agent_config)
    agent_config = default_agent_config

    replay = create_replay_center(replay_config)

    model_fn, Agent = pkg.import_agent(config=agent_config)
    am = pkg.import_module('actor', config=agent_config)
    fm = pkg.import_module('func', config=agent_config)

    monitor = fm.create_monitor(config=agent_config)

    Worker = am.get_worker_class(Agent)
    workers = []
    for wid in range(agent_config['n_workers']):
        worker = fm.create_worker(Worker=Worker,
                                  worker_id=wid,
                                  model_fn=model_fn,
                                  config=agent_config,
                                  model_config=model_config,
                                  env_config=env_config,
                                  buffer_config=replay_config)
        worker.prefill_replay.remote(replay)
        workers.append(worker)

    Evaluator = am.get_evaluator_class(Agent)
    evaluator = fm.create_evaluator(Evaluator=Evaluator,
                                    model_fn=model_fn,
                                    config=agent_config,
                                    model_config=model_config,
                                    env_config=env_config)

    Learner = am.get_learner_class(Agent)
    learner = fm.create_learner(Learner=Learner,
                                model_fn=model_fn,
                                replay=replay,
                                config=agent_config,
                                model_config=model_config,
                                env_config=env_config,
                                replay_config=replay_config)

    learner.start_learning.remote()
    [w.run.remote(learner, replay, monitor) for w in workers]
    evaluator.run.remote(learner, monitor)

    elapsed_time = 0
    interval = 10
    while not ray.get(monitor.is_over.remote()):
        time.sleep(interval)
        elapsed_time += interval
        if elapsed_time % agent_config['LOG_PERIOD'] == 0:
            monitor.record_train_stats.remote(learner)

    ray.get(learner.save.remote())

    ray.shutdown()
コード例 #13
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def main(env_config, model_config, agent_config, replay_config):
    ray.init(num_cpus=os.cpu_count(), num_gpus=1)

    sigint_shutdown_ray()

    default_agent_config.update(agent_config)
    agent_config = default_agent_config

    replay = create_replay_center(replay_config)

    model_fn, Agent = pkg.import_agent(config=agent_config)
    am = pkg.import_module('actor', config=agent_config)
    fm = pkg.import_module('func', config=agent_config)

    # create the monitor
    monitor = fm.create_monitor(config=agent_config)

    # create workers
    Worker = am.get_worker_class()
    workers = []
    for wid in range(agent_config['n_workers']):
        worker = fm.create_worker(Worker=Worker,
                                  worker_id=wid,
                                  config=agent_config,
                                  env_config=env_config,
                                  buffer_config=replay_config)
        worker.set_handler.remote(replay=replay)
        worker.set_handler.remote(monitor=monitor)
        workers.append(worker)

    # create the learner
    Learner = am.get_learner_class(Agent)
    learner = fm.create_learner(Learner=Learner,
                                model_fn=model_fn,
                                replay=replay,
                                config=agent_config,
                                model_config=model_config,
                                env_config=env_config,
                                replay_config=replay_config)
    learner.start_learning.remote()

    # create the evaluator
    Evaluator = am.get_evaluator_class(Agent)
    evaluator = fm.create_evaluator(Evaluator=Evaluator,
                                    model_fn=model_fn,
                                    config=agent_config,
                                    model_config=model_config,
                                    env_config=env_config)
    evaluator.run.remote(learner, monitor)

    Actor = am.get_actor_class(Agent)
    actors = []
    na = agent_config['n_actors']
    nw = agent_config['n_workers']
    assert nw % na == 0, f"n_workers({nw}) is not divisible by n_actors({na})"
    wpa = nw // na
    for aid in range(agent_config['n_actors']):
        actor = fm.create_actor(Actor=Actor,
                                actor_id=aid,
                                model_fn=model_fn,
                                config=agent_config,
                                model_config=model_config,
                                env_config=env_config)
        actor.start.remote(workers[aid * wpa:(aid + 1) * wpa], learner,
                           monitor)
        actors.append(actor)

    elapsed_time = 0
    interval = 10
    # put the main thead into sleep
    # the monitor records training stats once in a while
    while not ray.get(monitor.is_over.remote()):
        time.sleep(interval)
        elapsed_time += interval
        if elapsed_time % agent_config['LOG_PERIOD'] == 0:
            monitor.record_train_stats.remote(learner)

    ray.get(learner.save.remote())

    ray.shutdown()
コード例 #14
0
ファイル: train.py プロジェクト: xlnwel/d2rl
def train(agent, env, eval_env, replay):
    collect_fn = pkg.import_module('agent', algo=agent.name).collect
    collect = functools.partial(collect_fn, replay)

    env_step = agent.env_step
    runner = Runner(env, agent, step=env_step, nsteps=agent.TRAIN_PERIOD)
    while not replay.good_to_learn():
        env_step = runner.run(
            # NOTE: random action below makes a huge difference for Mujoco tasks
            # by default, we don't use it as it's not a conventional practice.
            # action_selector=env.random_action,
            step_fn=collect)

    to_eval = Every(agent.EVAL_PERIOD)
    to_log = Every(agent.LOG_PERIOD, agent.LOG_PERIOD)
    to_eval = Every(agent.EVAL_PERIOD)
    to_record = Every(agent.EVAL_PERIOD * 10)
    rt = Timer('run')
    tt = Timer('train')
    et = Timer('eval')
    lt = Timer('log')
    print('Training starts...')
    while env_step <= int(agent.MAX_STEPS):
        with rt:
            env_step = runner.run(step_fn=collect)
        with tt:
            agent.learn_log(env_step)

        if to_eval(env_step):
            with TempStore(agent.get_states, agent.reset_states):
                with et:
                    record = agent.RECORD and to_record(env_step)
                    eval_score, eval_epslen, video = evaluate(
                        eval_env,
                        agent,
                        n=agent.N_EVAL_EPISODES,
                        record=agent.RECORD,
                        size=(64, 64))
                    if record:
                        video_summary(f'{agent.name}/sim',
                                      video,
                                      step=env_step)
                    agent.store(eval_score=eval_score, eval_epslen=eval_epslen)

        if to_log(env_step):
            with lt:
                fps = rt.average() * agent.TRAIN_PERIOD
                tps = tt.average() * agent.N_UPDATES

                agent.store(
                    env_step=agent.env_step,
                    train_step=agent.train_step,
                    fps=fps,
                    tps=tps,
                )
                agent.store(
                    **{
                        'train_step': agent.train_step,
                        'time/run': rt.total(),
                        'time/train': tt.total(),
                        'time/eval': et.total(),
                        'time/log': lt.total(),
                        'time/run_mean': rt.average(),
                        'time/train_mean': tt.average(),
                        'time/eval_mean': et.average(),
                        'time/log_mean': lt.average(),
                    })
                agent.log(env_step)
                agent.save()