Exemplo n.º 1
0
    def _test_batch_training(self,
                             gpu,
                             steps=5000,
                             load_model=False,
                             require_success=True):

        random_seed.set_random_seed(1)
        logging.basicConfig(level=logging.DEBUG)

        env, _ = self.make_vec_env_and_successful_return(test=False)
        test_env, successful_return = self.make_vec_env_and_successful_return(
            test=True)
        agent = self.make_agent(env, gpu)

        if load_model:
            print("Load agent from", self.agent_dirname)
            agent.load(self.agent_dirname)
            agent.replay_buffer.load(self.rbuf_filename)

        # Train
        train_agent_batch_with_evaluation(
            agent=agent,
            env=env,
            steps=steps,
            outdir=self.tmpdir,
            eval_interval=200,
            eval_n_steps=None,
            eval_n_episodes=5,
            successful_score=1,
            eval_env=test_env,
        )
        env.close()

        # Test
        n_test_runs = 5
        eval_returns, _ = batch_run_evaluation_episodes(
            test_env,
            agent,
            n_steps=None,
            n_episodes=n_test_runs,
        )
        test_env.close()
        n_succeeded = np.sum(np.asarray(eval_returns) >= successful_return)
        if require_success:
            assert n_succeeded == n_test_runs

        # Save
        agent.save(self.agent_dirname)
        agent.replay_buffer.save(self.rbuf_filename)
Exemplo n.º 2
0
    def run_func(process_idx):
        random_seed.set_random_seed(random_seeds[process_idx])

        env = make_env(process_idx, test=False)
        if evaluator is None:
            eval_env = env
        else:
            eval_env = make_env(process_idx, test=True)
        if make_agent is not None:
            local_agent = make_agent(process_idx)
            if use_shared_memory:
                for attr in agent.shared_attributes:
                    setattr(local_agent, attr, getattr(agent, attr))
        else:
            local_agent = agent
        local_agent.process_idx = process_idx

        def f():
            train_loop(
                process_idx=process_idx,
                counter=counter,
                episodes_counter=episodes_counter,
                agent=local_agent,
                env=env,
                steps=steps,
                outdir=outdir,
                max_episode_len=max_episode_len,
                evaluator=evaluator,
                successful_score=successful_score,
                stop_event=stop_event,
                exception_event=exception_event,
                eval_env=eval_env,
                global_step_hooks=global_step_hooks,
                logger=logger,
            )

        if profile:
            import cProfile

            cProfile.runctx("f()", globals(), locals(),
                            "profile-{}.out".format(os.getpid()))
        else:
            f()

        env.close()
        if eval_env is not env:
            eval_env.close()