Пример #1
0
    def test_run_async(self):
        counter = mp.Value('l', 0)

        def run_func(process_idx):
            for _ in range(1000):
                with counter.get_lock():
                    counter.value += 1
        async_.run_async(4, run_func)
        self.assertEqual(counter.value, 4000)
Пример #2
0
    def test_run_async_exit_code(self):
        def run_with_exit_code_0(process_idx):
            sys.exit(0)

        def run_with_exit_code_11(process_idx):
            os.kill(os.getpid(), signal.SIGSEGV)

        with warnings.catch_warnings(record=True) as w:
            async_.run_async(4, run_with_exit_code_0)
            # There should be no warnings
            assert len(w) == 0

        with warnings.catch_warnings(record=True) as w:
            async_.run_async(4, run_with_exit_code_11)
            # There should be 4 warnings
            assert len(w) == 4
Пример #3
0
    def test_run_async_exit_code(self):
        def run_with_exit_code_0(process_idx):
            sys.exit(0)

        def run_with_exit_code_11(process_idx):
            os.kill(os.getpid(), signal.SIGSEGV)

        with warnings.catch_warnings(record=True) as ws:
            async_.run_async(4, run_with_exit_code_0)
            # There should be no AbnormalExitWarning
            self.assertEqual(
                sum(1 if issubclass(w.category, async_.AbnormalExitWarning
                                    ) else 0 for w in ws), 0)

        with warnings.catch_warnings(record=True) as ws:
            async_.run_async(4, run_with_exit_code_11)
            # There should be 4 AbnormalExitWarning
            self.assertEqual(
                sum(1 if issubclass(w.category, async_.AbnormalExitWarning
                                    ) else 0 for w in ws), 4)
Пример #4
0
def train_agent_async(outdir, processes, make_env,
                      profile=False,
                      steps=8 * 10 ** 7,
                      eval_interval=10 ** 6,
                      eval_n_runs=10,
                      max_episode_len=None,
                      step_offset=0,
                      successful_score=None,
                      agent=None,
                      make_agent=None,
                      global_step_hooks=[],
                      save_best_so_far_agent=True,
                      logger=None,
                      ):
    """Train agent asynchronously using multiprocessing.

    Either `agent` or `make_agent` must be specified.

    Args:
        outdir (str): Path to the directory to output things.
        processes (int): Number of processes.
        make_env (callable): (process_idx, test) -> Environment.
        profile (bool): Profile if set True.
        steps (int): Number of global time steps for training.
        eval_interval (int): Interval of evaluation. If set to None, the agent
            will not be evaluated at all.
        eval_n_runs (int): Number of runs for each time of evaluation.
        max_episode_len (int): Maximum episode length.
        step_offset (int): Time step from which training starts.
        successful_score (float): Finish training if the mean score is greater
            or equal to this value if not None
        agent (Agent): Agent to train.
        make_agent (callable): (process_idx) -> Agent
        global_step_hooks (list): List of callable objects that accepts
            (env, agent, step) as arguments. They are called every global
            step. See chainerrl.experiments.hooks.
        save_best_so_far_agent (bool): If set to True, after each evaluation,
            if the score (= mean return of evaluation episodes) exceeds
            the best-so-far score, the current agent is saved.
        logger (logging.Logger): Logger used in this function.

    Returns:
        Trained agent.
    """

    logger = logger or logging.getLogger(__name__)

    # Prevent numpy from using multiple threads
    os.environ['OMP_NUM_THREADS'] = '1'

    counter = mp.Value('l', 0)
    episodes_counter = mp.Value('l', 0)
    training_done = mp.Value('b', False)  # bool

    if agent is None:
        assert make_agent is not None
        agent = make_agent(0)

    shared_objects = extract_shared_objects_from_agent(agent)
    set_shared_objects(agent, shared_objects)

    if eval_interval is None:
        evaluator = None
    else:
        evaluator = AsyncEvaluator(
            n_runs=eval_n_runs,
            eval_interval=eval_interval, outdir=outdir,
            max_episode_len=max_episode_len,
            step_offset=step_offset,
            save_best_so_far_agent=save_best_so_far_agent,
            logger=logger,
        )

    def run_func(process_idx):
        random_seed.set_random_seed(process_idx)

        env = make_env(process_idx, test=False)
        if evaluator is None:
            eval_env = env
        else:
            eval_env = make_env(process_idx, test=True)
        if make_agent is not None:
            local_agent = make_agent(process_idx)
            set_shared_objects(local_agent, shared_objects)
        else:
            local_agent = agent
        local_agent.process_idx = process_idx

        def f():
            train_loop(
                process_idx=process_idx,
                counter=counter,
                episodes_counter=episodes_counter,
                agent=local_agent,
                env=env,
                steps=steps,
                outdir=outdir,
                max_episode_len=max_episode_len,
                evaluator=evaluator,
                successful_score=successful_score,
                training_done=training_done,
                eval_env=eval_env,
                global_step_hooks=global_step_hooks,
                logger=logger)

        if profile:
            import cProfile
            cProfile.runctx('f()', globals(), locals(),
                            'profile-{}.out'.format(os.getpid()))
        else:
            f()

        env.close()
        if eval_env is not env:
            eval_env.close()

    async_.run_async(processes, run_func)

    return agent