def test_with_dynamic_step_driver(self): env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(), policy.policy_step_spec, tf_env.time_step_spec()) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec) driver = dynamic_step_driver.DynamicStepDriver( tf_env, policy, observers=[common.function(tfrecord_observer)], num_steps=10) self.evaluate(tf.compat.v1.global_variables_initializer()) time_step = self.evaluate(tf_env.reset()) initial_policy_state = policy.get_initial_state(batch_size=1) self.evaluate( common.function(driver.run)(time_step, initial_policy_state)) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def testMultiStepReplayBufferObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state = policy.get_initial_state(1) replay_buffer = make_replay_buffer(env) driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=6, observers=[replay_buffer.add_batch]) run_driver = driver.run(policy_state=policy_state) rb_gather_all = replay_buffer.gather_all() self.evaluate(tf.global_variables_initializer()) self.evaluate(run_driver) trajectories = self.evaluate(rb_gather_all) self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1]]) self.assertAllEqual(trajectories.observation, [[0, 1, 3, 0, 1, 3, 0, 1]]) self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2]]) self.assertAllEqual(trajectories.policy_info, [[2, 4, 2, 2, 4, 2, 2, 4]]) self.assertAllEqual(trajectories.next_step_type, [[1, 2, 0, 1, 2, 0, 1, 2]]) self.assertAllEqual(trajectories.reward, [[1., 1., 0., 1., 1., 0., 1., 1.]]) self.assertAllEqual(trajectories.discount, [[1., 0., 1, 1, 0, 1., 1., 0.]])
def testOneStepUpdatesObservers(self): if tf.executing_eagerly(): self.skipTest('b/123880556') env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state_ph = tensor_spec.to_nest_placeholder( policy.policy_state_spec, default=0, name_scope='policy_state_ph', outer_dims=(1, )) num_episodes_observer = driver_test_utils.NumEpisodesObserver() driver = dynamic_step_driver.DynamicStepDriver( env, policy, observers=[num_episodes_observer]) run_driver = driver.run(policy_state=policy_state_ph) with self.session() as session: session.run(tf.compat.v1.global_variables_initializer()) _, policy_state = session.run(run_driver) for _ in range(4): _, policy_state = session.run( run_driver, feed_dict={policy_state_ph: policy_state}) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 2)
def test_parallel_envs(self): env_num = 5 ctors = [ lambda: suite_socialbot.load('SocialBot-CartPole-v0', wrap_with_process=False) ] * env_num self._env = parallel_py_environment.ParallelPyEnvironment( env_constructors=ctors, start_serially=False) tf_env = tf_py_environment.TFPyEnvironment(self._env) self.assertTrue(tf_env.batched) self.assertEqual(tf_env.batch_size, env_num) random_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_capacity = 100 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( random_policy.trajectory_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) steps = 100 step_driver = dynamic_step_driver.DynamicStepDriver( tf_env, random_policy, observers=[replay_buffer.add_batch], num_steps=steps) step_driver.run = common.function(step_driver.run) step_driver.run() self.assertIsNotNone(replay_buffer.get_next())
def create_collect_driver(train_env, agent, replay_buffer, collect_steps): return dynamic_step_driver.DynamicStepDriver( train_env, agent.collect_policy, observers=[replay_buffer.add_batch], num_steps=collect_steps, )
def testTwoObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state = policy.get_initial_state(1) num_episodes_observer0 = driver_test_utils.NumEpisodesObserver( variable_scope='observer0') num_episodes_observer1 = driver_test_utils.NumEpisodesObserver( variable_scope='observer1') num_steps_transition_observer = ( driver_test_utils.NumStepsTransitionObserver()) driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=5, observers=[num_episodes_observer0, num_episodes_observer1], transition_observers=[num_steps_transition_observer], ) run_driver = driver.run(policy_state=policy_state) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 2) self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 2) self.assertEqual( self.evaluate(num_steps_transition_observer.num_steps), 5)
def create_driver(self): # a driver that simulates N steps in an environment self._collect_driver = dynamic_step_driver.DynamicStepDriver( self._train_env, # a policy that can be used to collect data from the environment self._agent.collect_policy, # a list of observers that are updated after every step in the environment observers=[self._replay_buffer_observer] + self._train_metrics, # the number of steps simulated - N steps num_steps=param.DRIVER_STEPS)
def get_collection_driver(self): """Sets the collection driver for tf-agents. """ self._collection_driver = [] for agent in self._agent: self._collection_driver.append(dynamic_step_driver.DynamicStepDriver( env=self._runtime, policy=agent._agent.collect_policy, # this is the agents policy observers=[agent._replay_buffer.add_batch], num_steps = 1 ))
def test_validate_mask(self): env = tf_py_environment.TFPyEnvironment(self.env) policy = random_tf_policy.RandomTFPolicy( time_step_spec=env.time_step_spec(), action_spec=env.action_spec(), observation_and_action_constraint_splitter=GameEnv.obs_and_mask_splitter) driver = dynamic_step_driver.DynamicStepDriver(env, policy, num_steps=1) for i in range(10): time_step, _ = driver.run() action_step = policy.action(time_step) print(utils.get_action(action_step.action.numpy()[0], 3))
def test_metric_results_equal(self): python_metrics, tensorflow_metrics = self._build_metrics() observers = python_metrics + tensorflow_metrics driver = dynamic_step_driver.DynamicStepDriver( self._tf_env, self._policy, observers=observers, num_steps=1000) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(driver.run()) for python_metric, tensorflow_metric in zip(python_metrics, tensorflow_metrics): python_result = self.evaluate(python_metric.result()) tensorflow_result = self.evaluate(tensorflow_metric.result()) self.assertEqual(python_result, tensorflow_result)
def collect(tf_env, tf_policy, output_dir, checkpoint=None, num_iterations=500000, episodes_per_file=500, summary_interval=1000): """A simple train and eval for SAC.""" if not os.path.isdir(output_dir): logger.info('Making output directory %s...', output_dir) os.makedirs(output_dir) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Make the replay buffer. replay_buffer = tfrecord_replay_buffer.TFRecordReplayBuffer( data_spec=tf_policy.trajectory_spec, experiment_id='exp', file_prefix=os.path.join(output_dir, 'data'), episodes_per_file=episodes_per_file) replay_observer = [replay_buffer.add_batch] collect_policy = tf_policy collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=1).run() with tf.compat.v1.Session() as sess: # Initialize training. try: common.initialize_uninitialized_variables(sess) except Exception: pass # Restore checkpoint. if checkpoint is not None: if os.path.isdir(checkpoint): train_dir = os.path.join(checkpoint, 'train') checkpoint_path = tf.train.latest_checkpoint(train_dir) else: checkpoint_path = checkpoint restorer = tf.train.Saver(name='restorer') restorer.restore(sess, checkpoint_path) collect_call = sess.make_callable(collect_op) for _ in range(num_iterations): collect_call()
def testMultiStepUpdatesObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer = driver_test_utils.NumEpisodesObserver() driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=5, observers=[num_episodes_observer]) run_driver = driver.run() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 2)
def _driver(self, cfg): """return a driver""" observers = [self.replay.add_batch] + self.observers if cfg["type"] == "episode": return dynamic_episode_driver.DynamicEpisodeDriver( self.env, self.agent.collect_policy, observers=observers, num_episodes=cfg["length"]) elif cfg["type"] == "step": return dynamic_step_driver.DynamicStepDriver( self.env, self.agent.collect_policy, observers=observers, num_steps=cfg["length"]) else: raise ValueError("Unknown type of driver! Input is {}".format( cfg["type"]))
def test_dmlab_env_run(self, scene): ctor = lambda: suite_dmlab.load(scene=scene, gym_env_wrappers= [wrappers.FrameResize], wrap_with_process=False) self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4) env = tf_py_environment.TFPyEnvironment(self._env) self.assertEqual((84, 84, 3), env.observation_spec().shape) random_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) driver = dynamic_step_driver.DynamicStepDriver(env=env, policy=random_policy, observers=None, num_steps=10) driver.run(maximum_iterations=10)
def testOneStepReplayBufferObservers(self): if tf.executing_eagerly(): self.skipTest('b/123880556') env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state_ph = tensor_spec.to_nest_placeholder( policy.policy_state_spec, default=0, name_scope='policy_state_ph', outer_dims=(1, )) replay_buffer = driver_test_utils.make_replay_buffer(policy) driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=1, observers=[replay_buffer.add_batch]) run_driver = driver.run(policy_state=policy_state_ph) rb_gather_all = replay_buffer.gather_all() with self.session() as session: session.run(tf.compat.v1.global_variables_initializer()) _, policy_state = session.run(run_driver) for _ in range(5): _, policy_state = session.run( run_driver, feed_dict={policy_state_ph: policy_state}) trajectories = self.evaluate(rb_gather_all) self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1]]) self.assertAllEqual(trajectories.observation, [[0, 1, 3, 0, 1, 3, 0, 1]]) self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2]]) self.assertAllEqual(trajectories.policy_info, [[2, 4, 2, 2, 4, 2, 2, 4]]) self.assertAllEqual(trajectories.next_step_type, [[1, 2, 0, 1, 2, 0, 1, 2]]) self.assertAllEqual(trajectories.reward, [[1., 1., 0., 1., 1., 0., 1., 1.]]) self.assertAllEqual(trajectories.discount, [[1., 0., 1, 1, 0, 1., 1., 0.]])
def test_mario_env(self): ctor = lambda: suite_mario.load( 'SuperMarioBros-Nes', 'Level1-1', wrap_with_process=False) self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4) env = tf_py_environment.TFPyEnvironment(self._env) self.assertEqual(np.uint8, env.observation_spec().dtype) self.assertEqual((84, 84, 4), env.observation_spec().shape) random_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) metrics = [ AverageReturnMetric(batch_size=4), AverageEpisodeLengthMetric(batch_size=4), EnvironmentSteps(), NumberOfEpisodes() ] driver = dynamic_step_driver.DynamicStepDriver(env, random_policy, metrics, 10000) driver.run(maximum_iterations=10000)
def init_replay_buffer(tf_env, data_spec, train_metrics): """Creates and initializes a replay buffer.""" replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=data_spec, batch_size=tf_env.batch_size, max_length=FLAGS.replay_buffer_size) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with' ' a random policy.', FLAGS.initial_collect_steps) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=FLAGS.initial_collect_steps).run() return replay_buffer
def data_collection(agnt, env, policy): replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agnt.collect_data_spec, batch_size=env.batch_size, max_length=replay_buffer_max_length) collect_driver = dynamic_step_driver.DynamicStepDriver( env, policy, observers=[replay_buffer.add_batch], num_steps=collect_steps_per_iteration) # Initial data collection collect_driver.run() # Dataset generates trajectories with shape [BxTx...] where # T = n_step_update + 1. dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) return iterator, collect_driver, replay_buffer
def populate_buffer(env, replay_buffer, policy, spec, num_steps, batch_size): prepare = get_prepare(spec) def add_to_replay_buffer(transition): time_step, _, next_time_step = transition time_step = prepare(time_step) next_time_step = prepare(next_time_step) action_step = policy.action(time_step) traj = trajectory.from_transition(time_step, action_step, next_time_step) traj_batched = tf.nest.map_structure( lambda t: tf.stack([t] * batch_size), traj) replay_buffer.add_batch(traj_batched) observers = [add_to_replay_buffer] driver = dynamic_step_driver.DynamicStepDriver( env, policy, transition_observers=observers, num_steps=num_steps) # Initial driver.run will reset the environment and initialize the policy. driver.run()
def test_metric_results_equal_with_batched_env(self): env_ctor = lambda: random_py_environment.RandomPyEnvironment( # pylint: disable=g-long-lambda self._time_step_spec.observation, self._action_spec) batch_size = 5 env = batched_py_environment.BatchedPyEnvironment( [env_ctor() for _ in range(batch_size)]) tf_env = tf_py_environment.TFPyEnvironment(env) python_metrics, tensorflow_metrics = self._build_metrics( batch_size=batch_size) observers = python_metrics + tensorflow_metrics driver = dynamic_step_driver.DynamicStepDriver( tf_env, self._policy, observers=observers, num_steps=1000) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(driver.run()) for python_metric, tensorflow_metric in zip(python_metrics, tensorflow_metrics): python_result = self.evaluate(python_metric.result()) tensorflow_result = self.evaluate(tensorflow_metric.result()) self.assertEqual(python_result, tensorflow_result)
def testBanditEnvironment(self): def _context_sampling_fn(): return np.array([[5, -5], [2, -2]]) reward_fns = [ environment_utilities.LinearNormalReward(theta, sigma=0.0) for theta in ([1, 0], [0, 1]) ] batch_size = 2 py_env = sspe.StationaryStochasticPyEnvironment(_context_sampling_fn, reward_fns, batch_size=batch_size) env = tf_py_environment.TFPyEnvironment(py_env) policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(), env.action_spec()) steps_per_loop = 4 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=policy.trajectory_spec, batch_size=batch_size, max_length=steps_per_loop) driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=steps_per_loop * batch_size, observers=[replay_buffer.add_batch]) run_driver = driver.run() rb_gather_all = replay_buffer.gather_all() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) trajectories = self.evaluate(rb_gather_all) self.assertAllEqual(trajectories.step_type, [[0, 0, 0, 0], [0, 0, 0, 0]]) self.assertAllEqual(trajectories.next_step_type, [[2, 2, 2, 2], [2, 2, 2, 2]])
def setup(self): self.train_env = tf_py_environment.TFPyEnvironment( self.create_env_train()) self.eval_env = tf_py_environment.TFPyEnvironment(self.create_env()) self.observation_spec = self.train_env.observation_spec() print('obs:', self.observation_spec) self.action_spec = self.train_env.action_spec() print('action:', self.action_spec) self.nature_cnn = self.pre_processing_natureCnn() self.critic_net = self.critic(self.nature_cnn) self.actor_net = self.actor(self.nature_cnn) self.global_step = tf.compat.v1.train.get_or_create_global_step() self.tf_agent = self.sac_agent() self.tf_agent.initialize() self.replay_buffer = self.create_buffer() self.replay_observer = [self.replay_buffer.add_batch] iniyial_collect_policy = random_tf_policy.RandomTFPolicy( self.train_env.time_step_spec(), self.train_env.action_spec()) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( self.train_env, iniyial_collect_policy, observers=self.replay_observer, num_steps=self.initial_collect_steps) if (not self.resume_training): print('------- Filling Buffer -------') _ = initial_collect_driver.run() print('------- END -------') self.step_metrics, self.train_metrics = self.define_metrics()
def data_generation(self): # set up random policy initial_collect_policy = random_tf_policy.RandomTFPolicy( self._train_env.time_step_spec(), self._train_env.action_spec()) # set up a driver that with random policy to collect data init_driver = dynamic_step_driver.DynamicStepDriver( self._train_env, # a random policy that can be used to collect data from the environment initial_collect_policy, # a list of observers that are updated after every step in the environment observers=[ self._replay_buffer_observer, Progress_viz(param.DATASET_STEPS) ], # the number of steps in the dataset num_steps=param.DATASET_STEPS) # recording the sequence of state transitions and results in observers final_time_step, final_policy_state = init_driver.run() # Verify collected trajectories (optional) if self._visual_flag: trajectories, buffer_info = self._replay_buffer.get_next( sample_batch_size=2, num_steps=10) time_steps, action_steps, next_time_steps = trajectory.to_transition( trajectories) print("trajectories._fields", trajectories._fields) print("time_steps.observation.shape = ", time_steps.observation.shape) # Create Dataset from Replay Buffer self._dataset = self._replay_buffer.as_dataset( sample_batch_size=param.DATASET_BATCH, num_steps=param.DATASET_BUFFER_STEP, num_parallel_calls=param.DATASET_PARALLEL).prefetch( param.DATASET_PREFETCH)
def train_eval( root_dir, env_name='HalfCheetah-v1', env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400,), critic_action_fc_layers=None, critic_joint_fc_layers=(300,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_py_env = env_load_fn(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() collect_policy = tf_agent.collect_policy() initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, unused_info = iterator.get_next() train_op = tf_agent.train( experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
traj = trajectory.from_transition(time_step, action_step, next_time_step) # Add trajectory to the replay buffer replay_buffer.add_batch(traj) #for _ in range(1000): # collect_step(train_env, tf_agent.collect_policy) dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) driver = dynamic_step_driver.DynamicStepDriver(train_env, collect_policy, observers=replay_observer + train_metrics, num_steps=1) iterator = iter(dataset) print(compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)) tf_agent.train = common.function(tf_agent.train) tf_agent.train_step_counter.assign(0) final_time_step, policy_state = driver.run() for i in range(1000): final_time_step, _ = driver.run(final_time_step, policy_state)
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=train_tf_env.batch_size, max_length=1000000) replay_buffer_observer = replay_buffer.add_batch train_metrics = [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric() ] collect_driver = dynamic_step_driver.DynamicStepDriver( train_tf_env, agent.collect_policy, observers=[replay_buffer_observer] + train_metrics, num_steps=update_period) class ShowProgress: def __init__(self, total): self.counter = 0 self.total = total def __call__(self, trajectory): if not trajectory.is_boundary(): self.counter += 1 if self.counter % 100 == 0: print("\r{}/{}".format(self.counter, self.total), end="")
train_step_counter=train_step_counter) agent.initialize() print("Agent ready") collect_policy = agent.collect_policy print("Policies ready") replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=train_env.batch_size, max_length=buffer_length) print("Replay buffer ready") collect_driver = dynamic_step_driver.DynamicStepDriver( train_env, agent.collect_policy, observers=[replay_buffer.add_batch], num_steps=steps_iteration) collect_driver.run() print("Initial steps collected") dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) print("Dataset ready") agent.train_step_counter.assign(0) avgBeforeTraining = average_reward_return(eval_env, agent.policy, 1) best_avg.append(avgBeforeTraining) best_policy.append(agent.policy)
def train_eval( root_dir, environment_name="broken_reacher", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, real_initial_collect_steps=10000, collect_steps_per_iteration=1, real_collect_interval=10, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, classifier_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, train_on_real=False, delta_r_warmup=0, random_seed=0, checkpoint_dir=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, "train") eval_dir = os.path.join(root_dir, "eval") train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) if environment_name == "broken_reacher": get_env_fn = darc_envs.get_broken_reacher_env elif environment_name == "half_cheetah_obstacle": get_env_fn = darc_envs.get_half_cheetah_direction_env elif environment_name == "inverted_pendulum": get_env_fn = darc_envs.get_inverted_pendulum_env elif environment_name.startswith("broken_joint"): base_name = environment_name.split("broken_joint_")[1] get_env_fn = functools.partial(darc_envs.get_broken_joint_env, env_name=base_name) elif environment_name.startswith("falling"): base_name = environment_name.split("falling_")[1] get_env_fn = functools.partial(darc_envs.get_falling_env, env_name=base_name) else: raise NotImplementedError("Unknown environment: %s" % environment_name) eval_name_list = ["sim", "real"] eval_env_list = [get_env_fn(mode) for mode in eval_name_list] eval_metrics_list = [] for name in eval_name_list: eval_metrics_list.append([ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, name="AverageReturn_%s" % name), ]) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env_real = get_env_fn("real") if train_on_real: tf_env = get_env_fn("real") else: tf_env = get_env_fn("sim") time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork), ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer="glorot_uniform", last_kernel_initializer="glorot_uniform", ) classifier = classifiers.build_classifier(observation_spec, action_spec) tf_agent = darc_agent.DarcAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, classifier=classifier, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), classifier_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=classifier_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) replay_observer = [replay_buffer.add_batch] real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) real_replay_observer = [real_replay_buffer.add_batch] sim_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnSim", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthSim", ), ] real_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnReal", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthReal", ), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup( sim_train_metrics + real_train_metrics, "train_metrics"), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "policy"), policy=eval_policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "replay_buffer"), max_to_keep=1, replay_buffer=(replay_buffer, real_replay_buffer), ) if checkpoint_dir is not None: checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) assert checkpoint_path is not None train_checkpointer._load_status = train_checkpointer._checkpoint.restore( # pylint: disable=protected-access checkpoint_path) train_checkpointer._load_status.initialize_or_restore() # pylint: disable=protected-access else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + sim_train_metrics, num_steps=initial_collect_steps, ) real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, initial_collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=real_initial_collect_steps, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + sim_train_metrics, num_steps=collect_steps_per_iteration, ) real_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=collect_steps_per_iteration, ) config_str = gin.operative_config_str() logging.info(config_str) with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"), "w") as f: f.write(config_str) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) real_initial_collect_driver.run = common.function( real_initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) real_collect_driver.run = common.function(real_collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if replay_buffer.num_frames() == 0: logging.info( "Initializing replay buffer by collecting experience for %d steps with " "a random policy.", initial_collect_steps, ) initial_collect_driver.run() real_initial_collect_driver.run() for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) time_step = None real_time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = (replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) real_dataset = (real_replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) real_iterator = iter(real_dataset) def train_step(): experience, _ = next(iterator) real_experience, _ = next(real_iterator) return tf_agent.train(experience, real_experience=real_experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) assert not policy_state # We expect policy_state == (). if (global_step.numpy() % real_collect_interval == 0 and global_step.numpy() >= delta_r_warmup): real_time_step, policy_state = real_collect_driver.run( time_step=real_time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info("step = %d, loss = %f", global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) tf.compat.v2.summary.scalar(name="global_steps_per_sec", data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in sim_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=sim_train_metrics[:2]) for train_metric in real_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=real_train_metrics[:2]) if global_step_val % eval_interval == 0: for eval_name, eval_env, eval_metrics in zip( eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def DDPG_Bipedal(root_dir): # Setting up directories for results root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train' + '/' + str(run_id)) eval_dir = os.path.join(root_dir, 'eval' + '/' + str(run_id)) vid_dir = os.path.join(root_dir, 'vid' + '/' + str(run_id)) # Set up Summary writer for training and evaluation train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ # Metric to record average return tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), # Metric to record average episode length tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] #Create global step global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Load Environment with different wrappers tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) # Define Actor Network actorNN = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=(400, 300), ) # Define Critic Network NN_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) criticNN = critic_network.CriticNetwork( NN_input_specs, observation_fc_layer_params=(400, ), action_fc_layer_params=None, joint_fc_layer_params=(300, ), ) # Define & initialize DDPG Agent agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actorNN, critic_network=criticNN, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=gamma, train_step_counter=global_step) agent.initialize() # Determine which train metrics to display with summary writer train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Set policies for evaluation, initial collection eval_policy = agent.policy # Actor policy collect_policy = agent.collect_policy # Actor policy with OUNoise # Set up replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) # Define driver for initial replay buffer filling initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, # Initializes with random Parameters observers=[replay_buffer.add_batch], num_steps=initial_collect_steps) # Define collect driver for collect steps per iteration collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) agent.train = common.function(agent.train) # Make 1000 random steps in tf_env and save in Replay Buffer logging.info( 'Initializing replay buffer by collecting experience for 1000 steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() # Computes Evaluation Metrics results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset outputs steps in batches of 64 dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=64, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next( iterator) #Get experience from dataset (replay buffer) return agent.train(experience) #Train agent on that experience if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() # Get start time # Collect data for replay buffer time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) # Train on experience for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='iterations_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) if results['AverageReturn'].numpy() >= 230.0: video_score = create_video(video_dir=vid_dir, env_name="BipedalWalker-v2", vid_policy=eval_policy, video_id=global_step.numpy()) return train_loss
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss