def _test_batch_equivalence_case(self, stateful, state_tuple): """ Test that doing things in batches is consistent, given the model parameters. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple) unbatched_rollouts = [] for env_fn in env_fns: batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): unbatched_rollouts.extend(trunc_roller.rollouts()) batched_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): batched_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(self, unbatched_rollouts, batched_rollouts, ordered=False)
def test_dummy_equiv_dtype(dtype): """ Test that batched_gym_env() gives something equivalent to a synchronous environment. """ def make_fn(seed): """ Get an environment constructor with a seed. """ return lambda: SimpleEnv(seed, SHAPE, dtype) fns = [make_fn(i) for i in range(SUB_BATCH_SIZE * NUM_SUB_BATCHES)] real = batched_gym_env(fns, num_sub_batches=NUM_SUB_BATCHES) dummy = batched_gym_env(fns, num_sub_batches=NUM_SUB_BATCHES, sync=True) try: _assert_resets_equal(dummy, real) np.random.seed(1337) for _ in range(NUM_STEPS): joint_shape = (SUB_BATCH_SIZE, ) + SHAPE actions = np.array(np.random.randint(0, 0x100, size=joint_shape), dtype=dtype) _assert_steps_equal(actions, dummy, real) finally: dummy.close() real.close()
def test_batched_stack(concat): """ Test that BatchedFrameStack is equivalent to a regular batched FrameStackEnv. """ envs = [ lambda idx=i: SimpleEnv(idx + 2, (3, 2, 5), 'float32') for i in range(6) ] env1 = BatchedFrameStack(batched_gym_env(envs, num_sub_batches=3, sync=True), concat=concat) env2 = batched_gym_env( [lambda env=e: FrameStackEnv(env(), concat=concat) for e in envs], num_sub_batches=3, sync=True) for j in range(50): for i in range(3): if j == 0 or (j + i) % 17 == 0: env1.reset_start(sub_batch=i) env2.reset_start(sub_batch=i) obs1 = env1.reset_wait(sub_batch=i) obs2 = env2.reset_wait(sub_batch=i) assert np.allclose(obs1, obs2) actions = [env1.action_space.sample() for _ in range(2)] env1.step_start(actions, sub_batch=i) env2.step_start(actions, sub_batch=i) obs1, rews1, dones1, _ = env1.step_wait(sub_batch=i) obs2, rews2, dones2, _ = env2.step_wait(sub_batch=i) assert np.allclose(obs1, obs2) assert np.array(rews1 == rews2).all() assert np.array(dones1 == dones2).all()
def test_async_creation_exit(): """ Test that an exception is forwarded when the environment constructor exits. """ try: batched_gym_env([lambda: sys.exit(1)] * 4) except RuntimeError: return pytest.fail('should have gotten exception')
def test_async_creation_exception(): """ Test that an exception is forwarded when the environment constructor fails. """ try: def raiser(): raise ValueError('hello world') batched_gym_env([raiser] * 4) except RuntimeError: return pytest.fail('should have gotten exception')
def test_truncation(stateful, state_tuple): """ Test sequence truncation for TruncatedRoller with a batch of one environment. """ def env_fn(): return SimpleEnv(7, (5, 3), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps // 2 + 1) actual1 = trunc_roller.rollouts() assert actual1[-1].trunc_end actual2 = trunc_roller.rollouts() expected1, expected2 = _artificial_truncation(expected, len(actual1) - 1, actual1[-1].num_steps) assert len(actual2) == len(expected2) + 1 actual2 = actual2[:-1] _compare_rollout_batch(actual1, expected1) _compare_rollout_batch(actual2, expected2)
def _test_truncation_case(self, stateful, state_tuple): """ Test rollout truncation and continuation for a specific set of model parameters. """ env_fn = lambda: SimpleEnv(7, (5, 3), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps // 2 + 1) actual1 = trunc_roller.rollouts() self.assertTrue(actual1[-1].trunc_end) actual2 = trunc_roller.rollouts() expected1, expected2 = _artificial_truncation(expected, len(actual1) - 1, actual1[-1].num_steps) self.assertEqual(len(actual2), len(expected2) + 1) actual2 = actual2[:-1] _compare_rollout_batch(self, actual1, expected1) _compare_rollout_batch(self, actual2, expected2)
def test_output_consistency(): """ Test that outputs from stepping are consistent with the batched model outputs. """ with tf.Graph().as_default(): with tf.Session() as sess: env = batched_gym_env([DummyEnv] * 16, sync=True) model = ReActFF(sess, *gym_spaces(env), input_scale=1.0, input_dtype=tf.float32, base=lambda x: tf.layers.dense(x, 12), actor=MatMul, critic=lambda x: MatMul(x, 1)) sess.run(tf.global_variables_initializer()) roller = TruncatedRoller(env, model, 8) for _ in range(10): rollouts = roller.rollouts() actor_out, critic_out = model.batch_outputs() info = next(model.batches(rollouts)) actor_out, critic_out = sess.run(model.batch_outputs(), feed_dict=info['feed_dict']) idxs = enumerate( zip(info['rollout_idxs'], info['timestep_idxs'])) for i, (rollout_idx, timestep_idx) in idxs: outs = rollouts[rollout_idx].model_outs[timestep_idx] assert np.allclose(actor_out[i], outs['action_params'][0]) assert np.allclose(critic_out[i], outs['values'][0])
def _test_batch_equivalence_case(self, stateful, state_tuple, **roller_kwargs): """ Test BasicRoller equivalence when using a batch of environments. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple) batched_env = batched_gym_env([env_fn] * 21, num_sub_batches=7, sync=True) ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs) actual = ep_roller.rollouts() total_steps = sum([r.num_steps for r in actual]) self.assertTrue(len(actual) >= ep_roller.min_episodes) self.assertTrue(total_steps >= ep_roller.min_steps) if 'min_steps' not in roller_kwargs: num_eps = ep_roller.min_episodes + batched_env.num_envs - 1 self.assertTrue(len(actual) == num_eps) basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual)) expected = basic_roller.rollouts() _compare_rollout_batch(self, actual, expected)
def test_ep_batches(stateful, state_tuple, limits): """ Test that EpisodeRoller is equivalent to a BasicRoller when run on a batch of envs. """ def env_fn(): return SimpleEnv(3, (4, 5), 'uint8') model = SimpleModel((4, 5), stateful=stateful, state_tuple=state_tuple) batched_env = batched_gym_env([env_fn] * 21, num_sub_batches=7, sync=True) ep_roller = EpisodeRoller(batched_env, model, **limits) actual = ep_roller.rollouts() total_steps = sum([r.num_steps for r in actual]) assert len(actual) >= ep_roller.min_episodes assert total_steps >= ep_roller.min_steps if 'min_steps' not in limits: num_eps = ep_roller.min_episodes + batched_env.num_envs - 1 assert len(actual) == num_eps basic_roller = BasicRoller(env_fn(), model, min_episodes=len(actual)) expected = basic_roller.rollouts() _compare_rollout_batch(actual, expected)
def test_mixed_batch(): """ Test a batch with a bunch of different environments. """ env_fns = [ lambda s=seed: SimpleEnv(s, (1, 2, 3), 'float32') for seed in [3, 3, 3, 3, 3, 3] ] #[5, 8, 1, 9, 3, 2]] make_agent = lambda: SimpleModel((1, 2, 3), stateful=True) for num_sub in [1, 2, 3]: batched_player = BatchedPlayer( batched_gym_env(env_fns, num_sub_batches=num_sub), make_agent(), 3) expected_eps = [] for player in [ BasicPlayer(env_fn(), make_agent(), 3) for env_fn in env_fns ]: transes = [t for _ in range(50) for t in player.play()] expected_eps.extend(_separate_episodes(transes)) actual_transes = [t for _ in range(50) for t in batched_player.play()] actual_eps = _separate_episodes(actual_transes) assert len(expected_eps) == len(actual_eps) for episode in expected_eps: found = False for i, actual in enumerate(actual_eps): if _episodes_equivalent(episode, actual): del actual_eps[i] found = True break assert found
def main(): """ Entry-point for the program. """ args = _parse_args() env = batched_gym_env([partial(make_single_env, args.game)] * args.workers) # Using BatchedFrameStack with concat=False is more # memory efficient than other stacking options. env = BatchedFrameStack(env, num_images=4, concat=False) with tf.Session() as sess: def make_net(name): return NatureQNetwork(sess, env.action_space.n, gym_space_vectorizer(env.observation_space), name, dueling=True) dqn = DQN(make_net('online'), make_net('target')) player = BatchedPlayer(env, EpsGreedyQNetwork(dqn.online_net, args.epsilon)) optimize = dqn.optimize(learning_rate=args.lr) sess.run(tf.global_variables_initializer()) reward_hist = [] total_steps = 0 def _handle_ep(steps, rew): nonlocal total_steps total_steps += steps reward_hist.append(rew) if len(reward_hist) == REWARD_HISTORY: print('%d steps: mean=%f' % (total_steps, sum(reward_hist) / len(reward_hist))) reward_hist.clear() dqn.train(num_steps=int(1e7), player=player, replay_buffer=UniformReplayBuffer(args.buffer_size), optimize_op=optimize, target_interval=args.target_interval, batch_size=args.batch_size, min_buffer_size=args.min_buffer_size, handle_ep=_handle_ep) env.close()
def test_single_batch(): """ Test BatchedPlayer when the batch size is 1. """ make_env = lambda: SimpleEnv(9, (1, 2, 3), 'float32') make_agent = lambda: SimpleModel((1, 2, 3), stateful=True) basic_player = BasicPlayer(make_env(), make_agent(), 3) batched_player = BatchedPlayer(batched_gym_env([make_env]), make_agent(), 3) for _ in range(50): transes1 = basic_player.play() transes2 = batched_player.play() assert len(transes1) == len(transes2) for trans1, trans2 in zip(transes1, transes2): assert _transitions_equal(trans1, trans2)
def test_trunc_batches(stateful, state_tuple): """ Test that TruncatedRoller produces the same result for batches as it does for individual environments. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=stateful, state_tuple=state_tuple) unbatched_rollouts = [] for env_fn in env_fns: batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): unbatched_rollouts.extend(trunc_roller.rollouts()) batched_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): batched_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(unbatched_rollouts, batched_rollouts, ordered=False)
def make_env(args): maze_data = ("A.......\n" + "wwwwwww.\n" + "wwx.www.\n" + "www.www.\n" + "www.www.\n" + "www.www.\n" + "www.www.\n" + "........") maze = parse_2d_maze(maze_data) def _make_env(): return TimeLimit(HorizonEnv(maze, sparse_rew=True, horizon=2), max_episode_steps=args.max_timesteps) return batched_gym_env([_make_env] * args.num_envs, sync=True)
def test_batched_env_rollouts(benchmark): """ Benchmark rollouts with a batched environment and a regular truncated roller. """ env = batched_gym_env([lambda: gym.make('Pong-v0')] * 8) try: agent = ActorCritic(gym_space_distribution(env.action_space), gym_space_distribution(env.observation_space)) agent.actor = _testing_ensemble() agent.critic = _testing_ensemble(num_outs=1) roller = TruncatedRoller(env, agent, 128) with agent.frozen(): benchmark(roller.rollouts) finally: env.close()
def main(): """Run DQN until the environment throws an exception.""" env_fns, env_names = create_envs() env = BatchedFrameStack(batched_gym_env(env_fns), num_images=4, concat=False) config = tf.ConfigProto() config.gpu_options.allow_growth = True # pylint: disable=E1101 with tf.Session(config=config) as sess: dqn = DQN(*rainbow_models(sess, env.action_space.n, gym_space_vectorizer(env.observation_space), min_val=-200, max_val=200)) player = NStepPlayer(BatchedPlayer(env, dqn.online_net), 3) optimize = dqn.optimize(learning_rate=1e-4) # Use ADAM sess.run(tf.global_variables_initializer()) reward_hist = [] total_steps = 0 def _handle_ep(steps, rew, env_rewards): nonlocal total_steps total_steps += steps reward_hist.append(rew) if total_steps % 1 == 0: print('%d episodes, %d steps: mean of last 100 episodes=%f' % (len(reward_hist), total_steps, sum(reward_hist[-100:]) / len(reward_hist[-100:]))) dqn.train( num_steps= 2000000000, # Make sure an exception arrives before we stop. player=player, replay_buffer=PrioritizedReplayBuffer(500000, 0.5, 0.4, epsilon=0.1), optimize_op=optimize, train_interval=1, target_interval=8192, batch_size=32, min_buffer_size=20000, handle_ep=_handle_ep, num_envs=len(env_fns), save_interval=10, )
def test_ep_basic_equivalence(stateful, state_tuple, limits): """ Test that EpisodeRoller is equivalent to a BasicRoller when run on a single environment. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, **limits) expected = basic_roller.rollouts() batched_env = batched_gym_env([env_fn], sync=True) ep_roller = EpisodeRoller(batched_env, model, **limits) actual = ep_roller.rollouts() _compare_rollout_batch(actual, expected)
def _test_basic_equivalence_case(self, stateful, state_tuple): """ Test BasicRoller equivalence for a specific set of model settings. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps) actual = trunc_roller.rollouts() _compare_rollout_batch(self, actual, expected)
def _test_basic_equivalence_case(self, stateful, state_tuple, **roller_kwargs): """ Test BasicRoller equivalence for a single env in a specific case. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, **roller_kwargs) expected = basic_roller.rollouts() batched_env = batched_gym_env([env_fn], sync=True) ep_roller = EpisodeRoller(batched_env, model, **roller_kwargs) actual = ep_roller.rollouts() _compare_rollout_batch(self, actual, expected)
def test_trunc_basic_equivalence(stateful, state_tuple): """ Test that TruncatedRoller is equivalent to BasicRoller for batches of one environment when the episodes end cleanly. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() model = SimpleModel(env.action_space.low.shape, stateful=stateful, state_tuple=state_tuple) basic_roller = BasicRoller(env, model, min_episodes=5) expected = basic_roller.rollouts() total_timesteps = sum([x.num_steps for x in expected]) batched_env = batched_gym_env([env_fn], sync=True) trunc_roller = TruncatedRoller(batched_env, model, total_timesteps) actual = trunc_roller.rollouts() _compare_rollout_batch(actual, expected)
def test_multiple_batches(self): """ Make sure calling rollouts multiple times works. """ env_fn = lambda: SimpleEnv(3, (4, 5), 'uint8') env = env_fn() try: model = SimpleModel(env.action_space.low.shape) finally: env.close() batched_env = batched_gym_env([env_fn], sync=True) try: ep_roller = EpisodeRoller(batched_env, model, min_episodes=5, min_steps=7) first = ep_roller.rollouts() for _ in range(3): _compare_rollout_batch(self, first, ep_roller.rollouts()) finally: batched_env.close()
def learn_pong(): """Train an agent.""" env = batched_gym_env([make_single_env] * NUM_WORKERS) try: agent = ActorCritic(gym_space_distribution(env.action_space), gym_space_vectorizer(env.observation_space)) with tf.Session() as sess: a2c = A2C(sess, agent, target_kl=TARGET_KL) roller = TruncatedRoller(env, agent, HORIZON) total_steps = 0 rewards = [] print("Training... Don't expect progress for ~400K steps.") while True: with agent.frozen(): rollouts = roller.rollouts() for rollout in rollouts: total_steps += rollout.num_steps if not rollout.trunc_end: rewards.append(rollout.total_reward) agent.actor.extend( a2c.policy_update(rollouts, POLICY_STEP, NUM_STEPS, min_leaf=MIN_LEAF, feature_frac=FEATURE_FRAC)) agent.critic.extend( a2c.value_update(rollouts, VALUE_STEP, NUM_STEPS, min_leaf=MIN_LEAF, feature_frac=FEATURE_FRAC)) if rewards: print( '%d steps: mean=%f' % (total_steps, sum(rewards[-10:]) / len(rewards[-10:]))) else: print('%d steps: no episodes complete yet' % total_steps) finally: env.close()
def test_trunc_drop_states(): """ Test TruncatedRoller with drop_states=True. """ env_fns = [ lambda seed=x: SimpleEnv(seed, (5, 3), 'uint8') for x in range(15) ] model = SimpleModel((5, 3), stateful=True, state_tuple=True) expected_rollouts = [] batched_env = batched_gym_env(env_fns, num_sub_batches=3, sync=True) trunc_roller = TruncatedRoller(batched_env, model, 17) for _ in range(3): expected_rollouts.extend(trunc_roller.rollouts()) for rollout in expected_rollouts: for model_out in rollout.model_outs: model_out['states'] = None actual_rollouts = [] trunc_roller = TruncatedRoller(batched_env, model, 17, drop_states=True) for _ in range(3): actual_rollouts.extend(trunc_roller.rollouts()) _compare_rollout_batch(actual_rollouts, expected_rollouts)
# Get our envs before we import tensorflow, incase they need their own tf instance from multienv import getEnvFns from anyrl.envs import batched_gym_env env_fns = getEnvFns( bk2dir='data/record/' ) env = batched_gym_env( env_fns ) import sys from rainbow import train train( env, output_dir='/tmp/', num_steps=100000 )
def getBatchedEnv(bk2dir=None): env_fns = getEnvFns(bk2dir=bk2dir) return batched_gym_env(env_fns)
def __init__(self, FLAGS): self.FLAGS = FLAGS self.is_training_ph = tf.placeholder(tf.bool, []) with rl_util.Timer('building_envs', self.FLAGS): fns = [make_fn(i, FLAGS) for i in range(FLAGS['num_envs'])] if self.FLAGS['num_envs'] == 1: self.env = fns[0]() else: self.env = batched_gym_env(fns) self.action_dist = rl_util.convert_to_dict_dist( self.env.action_space.spaces) self.obs_vectorizer = rl_util.convert_to_dict_dist( self.env.observation_space.spaces) # TF INIT STUFF # TODO: make these passable as cmd line params self.global_itr = tf.get_variable('global_itr', initializer=tf.constant( 1, dtype=tf.int32), trainable=False) self.inc_global = tf.assign_add(self.global_itr, tf.constant(1, dtype=tf.int32)) # sarsd phs in_batch_shape = (None, ) + self.obs_vectorizer.out_shape self.sarsd_phs = {} self.sarsd_phs['s'] = tf.placeholder(tf.float32, shape=in_batch_shape, name='s_ph') self.sarsd_phs['a'] = tf.placeholder( tf.float32, (None, ) + self.action_dist.out_shape, name='a_ph') # actions that were taken self.sarsd_phs['s_next'] = tf.placeholder(tf.float32, shape=in_batch_shape, name='s_next_ph') self.sarsd_phs['r'] = tf.placeholder(tf.float32, shape=(None, ), name='r_ph') self.sarsd_phs['d'] = tf.placeholder(tf.float32, shape=(None, ), name='d_ph') self.embed_phs = { key: tf.placeholder(tf.float32, shape=(None, self.FLAGS['embed_shape']), name='{}_ph'.format(key)) for key in self.FLAGS['embeds'] } for key in ['a', 'r', 'd']: self.embed_phs[key] = self.sarsd_phs[key] # Pre-compute these transforms so we don't have to do it all the time # sarsd vals self.sarsd_vals = rl_util.sarsd_to_vals(self.sarsd_phs, self.obs_vectorizer, self.FLAGS) # ALGO SETUP Encoder = VAE if 'vae' in self.FLAGS['cnn_gn'] else DYN EName = 'VAE' if 'vae' in self.FLAGS['cnn_gn'] else 'DYN' if self.FLAGS['goal_dyn'] != '': self.goal_model = DYN(self.sarsd_vals, self.sarsd_phs, self.action_dist, self.obs_vectorizer, FLAGS, conv='cnn' in self.FLAGS['goal_dyn'], name='GoalDYN', compute_grad=False).model else: self.goal_model = None if self.FLAGS['aac']: self.value_encoder = Encoder(self.sarsd_vals, self.sarsd_phs, self.action_dist, self.obs_vectorizer, FLAGS, conv=False, name='Value' + EName) if self.FLAGS['value_goal']: self.goal_model = self.value_encoder.model else: self.value_encoder = None self.encoder = Encoder(self.sarsd_vals, self.sarsd_phs, self.action_dist, self.obs_vectorizer, FLAGS, conv='cnn' in self.FLAGS['cnn_gn'], goal_model=self.goal_model, is_training_ph=self.is_training_ph) Agent = {'scripted': Scripted, 'sac': SAC}[self.FLAGS['agent']] self.agent = Agent( sas_vals=self.sarsd_vals, sas_phs=self.sarsd_phs, embed_phs=self.embed_phs, action_dist=self.action_dist, obs_vectorizer=self.obs_vectorizer, FLAGS=FLAGS, dyn=self.encoder if self.FLAGS['share_dyn'] else None, value_dyn=self.value_encoder if self.FLAGS['aac'] else None, is_training_ph=self.is_training_ph) if self.FLAGS['agent'] == 'sac': self.scripted_agent = Scripted( sas_vals=self.sarsd_vals, sas_phs=self.sarsd_phs, embed_phs=self.embed_phs, action_dist=self.action_dist, obs_vectorizer=self.obs_vectorizer, FLAGS=FLAGS, dyn=self.encoder if self.FLAGS['share_dyn'] else None) self.eval_vals = {} if self.FLAGS['grad_summaries']: self.eval_vals['summary'] = tf.summary.merge( self.encoder.grad_summaries) else: self.eval_vals['summary'] = self.encoder.eval_vals.pop('summary') #self.eval_vals['summary'] = tf.no_op() self.eval_vals.update(self.encoder.eval_vals) if self.FLAGS['aac']: self.eval_vals.update( prefix_vals('value', self.value_encoder.eval_vals)) self.sess = get_session() self.sess.run(tf.global_variables_initializer()) self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10, keep_checkpoint_every_n_hours=0.5) self.train_writer = tf.summary.FileWriter( os.path.join(self.FLAGS['log_path'], 'tb'), self.sess.graph) # Load pre-trained variables from a path maybe. if self.FLAGS['load_path'] != '': def get_vars_to_load(path): load_names = [ name for name, _ in tf.contrib.framework.list_variables(path) ] vars = [ var for var in tf.global_variables() if var.name[:-2] in load_names and 'Adam' not in var.name ] return vars path = self.FLAGS['load_path'] #import ipdb; ipdb.set_trace() loader = tf.train.Saver(var_list=get_vars_to_load(path)) loader.restore(self.sess, path) print() print('Loaded trained variables from', path) print() # SUMMARY STUFF self.pred_plot_ph = tf.placeholder(tf.uint8) self.plot_summ = tf.summary.image('mdn_contour', self.pred_plot_ph, max_outputs=3) self.value_plot_summ = tf.summary.image('value_mdn_contour', self.pred_plot_ph, max_outputs=3) self.gif_paths_ph = tf.placeholder(tf.string, shape=(None, ), name='gif_path_ph') self.gif_summ = rl_util.gif_summary('rollout', self.gif_paths_ph, max_outputs=3) #self.sess.graph.finalize() # Make separate process for gif because it takes a long time self.gif_rollout_queue = Queue(maxsize=3) self.gif_path_queue = Queue(maxsize=3) # TODO: what is passed in a python process? if self.FLAGS['run_rl_optim']: self.gif_proc = Process( target=gif_plotter, daemon=True, args=(self.obs_vectorizer, self.action_dist, self.FLAGS, self.gif_rollout_queue, self.gif_path_queue)) self.gif_proc.start() self.mean_summ = defaultdict(lambda: None) self.mean_summ_phs = defaultdict(lambda: None) numvars() if self.FLAGS['threading']: # multi-thread for a moderate speed-up self.rollout_queue = queue.Queue(maxsize=3) self.rollout_thread = Thread(target=self.rollout_maker, daemon=True)
def main(): """ Entry-point for the program. """ args = _parse_args() # batched env = creates gym env, not sure what batched means # make_single_env = GrayscaleEnv > DownsampleEnv # GrayscaleEnv = turns RGB into grayscale # DownsampleEnv = down samples observation by N times where N is the specified variable (e.g. 2x smaller) env = batched_gym_env([partial(make_single_env, args.game)] * args.workers) env_test = make_single_env(args.game) #make_single_env(args.game) print('OBSSSS', env_test.observation_space) #env = CustomWrapper(args.game) # Using BatchedFrameStack with concat=False is more # memory efficient than other stacking options. env = BatchedFrameStack(env, num_images=4, concat=False) with tf.Session() as sess: def make_net(name): return rainbow_models(sess, env.action_space.n, gym_space_vectorizer(env.observation_space), min_val=-200, max_val=200) dqn = DQN(*rainbow_models(sess, env.action_space.n, gym_space_vectorizer(env.observation_space), min_val=-200, max_val=200)) player = BatchedPlayer(env, EpsGreedyQNetwork(dqn.online_net, args.epsilon)) optimize = dqn.optimize(learning_rate=args.lr) sess.run(tf.global_variables_initializer()) reward_hist = [] total_steps = 0 def _handle_ep(steps, rew): nonlocal total_steps total_steps += steps reward_hist.append(rew) if len(reward_hist) == REWARD_HISTORY: print('%d steps: mean=%f' % (total_steps, sum(reward_hist) / len(reward_hist))) reward_hist.clear() dqn.train(num_steps=int(1e7), player=player, replay_buffer=UniformReplayBuffer(args.buffer_size), optimize_op=optimize, target_interval=args.target_interval, batch_size=args.batch_size, min_buffer_size=args.min_buffer_size, handle_ep=_handle_ep) env.close()