def testFakeMTRandomness(self): mutex = threading.Lock() @contextlib.contextmanager def fake_with_rng(rrr): from tartist.random import rng with mutex: backup = rng._rng rng._rng = rrr yield rrr with mutex: rng._rng = backup q = queue.Queue() def proc(): rng = tar.gen_rng() with fake_with_rng(rng): time.sleep(0.5) state = tar.get_rng().get_state() time.sleep(0.5) q.put(state) threads = [Thread(target=proc) for i in range(2)] map_exec(Thread.start, threads) map_exec(Thread.join, threads) v1, v2 = q.get(), q.get() self.assertFalse(not np.allclose(v1[1], v2[1]))
def main_inference_play_multithread(trainer): def runner(): func = trainer.env.make_func() func.compile({'theta': trainer.env.network.outputs['theta']}) player = make_player() score = _evaluate(player, func) mgr = trainer.runtime.get('summary_histories', None) if mgr is not None: mgr.put_async_scalar('inference/score', score) nr_players = get_env('ppo.inference.nr_plays') pool = [threading.Thread(target=runner) for _ in range(nr_players)] map_exec(threading.Thread.start, pool) map_exec(threading.Thread.join, pool)
def main_inference_play_multithread(trainer): def runner(): func = trainer.env.make_func() func.compile(trainer.env.network.outputs['q_argmax']) player = make_player() score = player.evaluate_one_episode( lambda state: func(state=state[np.newaxis])[0]) mgr = trainer.runtime.get('summary_histories', None) if mgr is not None: mgr.put_async_scalar('inference/score', score) nr_players = get_env('dqn.inference.nr_plays') pool = [threading.Thread(target=runner) for _ in range(nr_players)] map_exec(threading.Thread.start, pool) map_exec(threading.Thread.join, pool)
def testMTRandomness(self): q = queue.Queue() def proc(): rng = tar.gen_rng() with tar.with_rng(rng): time.sleep(0.5) state = tar.get_rng().get_state() time.sleep(0.5) q.put(state) threads = [Thread(target=proc) for i in range(2)] map_exec(Thread.start, threads) map_exec(Thread.join, threads) v1, v2 = q.get(), q.get() self.assertFalse(np.allclose(v1[1], v2[1]))
def _compute_e_var(rs, ret_variance): e_r = sum(rs) / len(rs) if ret_variance: rs_sqr = map_exec(lambda x: x**2, rs) var_r = sum(rs_sqr) / len(rs_sqr) - e_r**2 return e_r, var_r return e_r