def test_deterministic_dataset_from_heap_sampler_remover(self): uniform_sampler_min_heap_remover_table = reverb.Table( name=self._table_name, sampler=reverb.selectors.MaxHeap(), remover=reverb.selectors.MinHeap(), max_size=100, max_times_sampled=0, rate_limiter=reverb.rate_limiters.MinSize(1)) server = reverb.Server([uniform_sampler_min_heap_remover_table]) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, local_server=server, sequence_length=None) replay.as_dataset(single_deterministic_pass=True) server.stop()
def test_capacity_set(self): table_name = 'test_table' capacity = 100 uniform_table = reverb.Table( table_name, max_size=capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(3)) server = reverb.Server([uniform_table]) data_spec = tensor_spec.TensorSpec((), tf.float32) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, local_server=server, sequence_length=None) self.assertEqual(capacity, replay.capacity) server.stop()
def test_experimental_distribute_datasets_from_function(self, strategy): sequence_length = 3 batch_size = 10 self._insert_random_data(self._env, num_steps=sequence_length, sequence_length=sequence_length) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, sequence_length=sequence_length, local_server=self._server) num_replicas = strategy.num_replicas_in_sync with strategy.scope(): dataset = strategy.experimental_distribute_datasets_from_function( lambda _: replay.as_dataset(batch_size // num_replicas)) iterator = iter(dataset) @common.function() def train_step(): with strategy.scope(): sample, _ = next(iterator) _, step = sample.observation loss = strategy.run(lambda x: tf.reduce_mean(x, axis=-1), args=(step, )) return strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=0) # Test running eagerly for _ in range(5): with strategy.scope(): loss = train_step() self.assertEqual(batch_size, loss) # Test with wrapping into a tf.function train_step_fn = common.function(train_step) for _ in range(5): with strategy.scope(): loss = train_step_fn() self.assertEqual(batch_size, loss)
def test_dataset_with_preprocess(self): def validate_data_observer(traj): if not array_spec.check_arrays_nest(traj, self._array_data_spec): raise ValueError('Trajectory incompatible with array_data_spec') def preprocess(traj): episode, step = traj.observation return traj.replace(observation=(episode, step + 1)) # Observe 10 steps from the env. This isn't the num_steps we're testing. self._insert_random_data( self._env, num_steps=10, additional_observers=[validate_data_observer], sequence_length=4) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, local_server=self._server, sequence_length=4) dataset = replay.as_dataset(num_steps=2) for sample, _ in dataset.take(5): episode, step = sample.observation self.assertEqual(episode[0], episode[1]) self.assertEqual(step[0] + 1, step[1]) # From even to odd steps self.assertEqual(0, step[0].numpy() % 2) self.assertEqual(1, step[1].numpy() % 2) dataset = replay.as_dataset( num_steps=2, sample_batch_size=1, sequence_preprocess_fn=preprocess) for sample, _ in dataset.take(5): episode, step = sample.observation self.assertEqual(episode[0, 0], episode[0, 1]) self.assertEqual(step[0, 0] + 1, step[0, 1]) # Makes sure the preprocess has happened. # From odd to even steps self.assertEqual(1, step[0, 0].numpy() % 2) self.assertEqual(0, step[0, 1].numpy() % 2)
def test_single_episode_dataset(self): sequence_length = 3 self._insert_random_data(self._env, num_steps=sequence_length, sequence_length=sequence_length) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, sequence_length=None, local_server=self._server) # Make sure observations are off by 1 given we are counting transitions in # the env observations. dataset = replay.as_dataset() for sample, _ in dataset.take(5): episode, step = sample.observation self.assertEqual((sequence_length, ), episode.shape) self.assertEqual((sequence_length, ), step.shape) self.assertAllEqual([0] * sequence_length, episode - episode[:1]) self.assertAllEqual(list(range(sequence_length)), step - step[:1])
def test_uniform_table(self): table_name = 'test_uniform_table' queue_table = reverb.Table( table_name, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=1000, rate_limiter=reverb.rate_limiters.MinSize(3)) reverb_server = reverb.Server([queue_table]) data_spec = tensor_spec.TensorSpec((), dtype=tf.int64) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, local_server=reverb_server, sequence_length=1, dataset_buffer_size=1) with replay.py_client.trajectory_writer( num_keep_alive_refs=1) as writer: for i in range(3): writer.append(i) trajectory = writer.history[-1:] writer.create_item(table_name, trajectory=trajectory, priority=1) dataset = replay.as_dataset(sample_batch_size=1, num_steps=None, num_parallel_calls=1) iterator = iter(dataset) counts = [0] * 3 for i in range(1000): item_0 = next(iterator)[0].numpy() # This is a matrix shaped 1x1. counts[int(item_0)] += 1 # Comparing against 200 to avoid flakyness self.assertGreater(counts[0], 200) self.assertGreater(counts[1], 200) self.assertGreater(counts[2], 200)
def test_dataset_samples_sequential(self, sequence_length): def validate_data_observer(traj): if not array_spec.check_arrays_nest(traj, self._array_data_spec): raise ValueError('Trajectory incompatible with array_data_spec') # Observe 20 steps from the env. This isn't the num_steps we're testing. self._insert_random_data( self._env, num_steps=20, additional_observers=[validate_data_observer], sequence_length=sequence_length or 4) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, local_server=self._server, sequence_length=sequence_length) # Make sure observations belong to the same episode and their step are off # by 1. for sample, _ in replay.as_dataset(num_steps=2).take(100): episode, step = sample.observation self.assertEqual(episode[0], episode[1]) self.assertEqual(step[0] + 1, step[1])
def test_prioritized_table_max_sample(self): table_name = 'test_prioritized_table' table = reverb.Table(table_name, sampler=reverb.selectors.Prioritized(1.0), remover=reverb.selectors.Fifo(), max_times_sampled=10, rate_limiter=reverb.rate_limiters.MinSize(1), max_size=3) reverb_server = reverb.Server([table]) data_spec = tensor_spec.TensorSpec((), dtype=tf.int64) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, sequence_length=1, local_server=reverb_server, dataset_buffer_size=1) with replay.py_client.trajectory_writer(1) as writer: for i in range(3): writer.append(i) writer.create_item(table_name, trajectory=writer.history[-1:], priority=i) dataset = replay.as_dataset(sample_batch_size=3, num_parallel_calls=3) self.assertTrue(table.can_sample(3)) iterator = iter(dataset) counts = [0] * 3 for i in range(10): item_0 = next(iterator)[0].numpy() # This is a matrix shaped 1x3. for item in item_0: counts[int(item)] += 1 self.assertFalse(table.can_sample(3)) # Same number of counts due to limit on max_times_sampled self.assertEqual(counts[0], 10) # priority 0 self.assertEqual(counts[1], 10) # priority 1 self.assertEqual(counts[2], 10) # priority 2
def test_prioritized_table(self): table_name = 'test_prioritized_table' queue_table = reverb.Table( table_name, sampler=reverb.selectors.Prioritized(1.0), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=3) reverb_server = reverb.Server([queue_table]) data_spec = tensor_spec.TensorSpec((), dtype=tf.int64) replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, table_name, sequence_length=1, local_server=reverb_server, dataset_buffer_size=1) with replay.py_client.writer(max_sequence_length=1) as writer: for i in range(3): writer.append(i) writer.create_item(table=table_name, num_timesteps=1, priority=i) dataset = replay.as_dataset(sample_batch_size=1, num_steps=None, num_parallel_calls=None) iterator = iter(dataset) counts = [0] * 3 for i in range(1000): item_0 = next(iterator)[0].numpy() # This is a matrix shaped 1x1. counts[int(item_0)] += 1 self.assertEqual(counts[0], 0) # priority 0 self.assertGreater(counts[1], 250) # priority 1 self.assertGreater(counts[2], 600) # priority 2
def test_batched_episodes_dataset(self, sequence_length): # Observe batch_size * sequence_length steps to have at least 3 episodes batch_size = 3 env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length) self._insert_random_data(env, num_steps=batch_size * sequence_length, sequence_length=sequence_length) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, sequence_length=None, local_server=self._server) dataset = replay.as_dataset(batch_size) for sample, _ in dataset.take(5): episode, step = sample.observation self.assertEqual((batch_size, sequence_length), episode.shape) self.assertEqual((batch_size, sequence_length), step.shape) for n in range(sequence_length): # All elements in the same batch should belong to the same episode. self.assertAllEqual(episode[:, 0], episode[:, n]) # All elements in the same batch should have consecutive steps. self.assertAllEqual(step[:, 0] + n, step[:, n])
def train( root_dir: Text, environment_name: Text, strategy: tf.distribute.Strategy, replay_buffer_server_address: Text, variable_container_server_address: Text, suite_load_fn: Callable[[Text], py_environment.PyEnvironment] = suite_mujoco.load, # Training params learning_rate: float = 3e-4, batch_size: int = 256, num_iterations: int = 2000000, learner_iterations_per_call: int = 1) -> None: """Trains a DQN agent.""" # Get the specs from the environment. logging.info('Training SAC with learning rate: %f', learning_rate) env = suite_load_fn(environment_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(env)) # Create the agent. with strategy.scope(): train_step = train_utils.create_train_step() agent = _create_agent( train_step=train_step, observation_tensor_spec=observation_tensor_spec, action_tensor_spec=action_tensor_spec, time_step_tensor_spec=time_step_tensor_spec, learning_rate=learning_rate) # Create the policy saver which saves the initial model now, then it # periodically checkpoints the policy weigths. saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) save_model_trigger = triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=1000) # Create the variable container. variables = { reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.push(variables) # Create the replay buffer. reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=reverb_replay_buffer.DEFAULT_TABLE, server_address=replay_buffer_server_address) # Initialize the dataset. def experience_dataset_fn(): with strategy.scope(): return reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) # Create the learner. learning_triggers = [ save_model_trigger, triggers.StepPerSecondLogTrigger(train_step, interval=1000) ] sac_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) # Run the training loop. while train_step.numpy() < num_iterations: sac_learner.run(iterations=learner_iterations_per_call) variable_container.push(variables)
def train_eval( root_dir, env_name='CartPole-v0', # Training params initial_collect_steps=1000, num_iterations=100000, fc_layer_params=(100, ), # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec()) action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec()) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 # Define a helper function to create Dense layers configured with the right # activation and kernel initializer. def dense_layer(num_units): return tf.keras.layers.Dense( num_units, activation=tf.keras.activations.relu, kernel_initializer=tf.keras.initializers.VarianceScaling( scale=2.0, mode='fan_in', distribution='truncated_normal')) # QNetwork consists of a sequence of Dense layers followed by a dense layer # with `num_actions` units to generate one q_value per available action as # it's output. dense_layers = [dense_layer(num_units) for num_units in fc_layer_params] q_values_layer = tf.keras.layers.Dense( num_actions, activation=None, kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03, maxval=0.03), bias_initializer=tf.keras.initializers.Constant(-0.2)) q_net = sequential.Sequential(dense_layers + [q_values_layer]) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def get_reverb_buffer(data_spec, sequence_length=None, table_name='uniform_table', table=None, reverb_server_address=None, port=None, replay_capacity=1000, min_size_limiter_size=1): """Returns an instance of Reverb replay buffer and observer to add items. Either creates a local reverb server or uses a remote reverb server at reverb_sever_address (if set). If reverb_server_address is None, creates a local server with a uniform table underneath. Args: data_spec: spec of the data elements to be stored in the replay buffer sequence_length: integer specifying sequence_lenghts used to write to the given table. table_name: Name of the table to create. table: Optional table for the backing local server. If None, automatically creates a uniform sampling table. reverb_server_address: Address of the remote reverb server, if None a local server is created. port: Port to launch the server in. replay_capacity: Optinal (for default uniform sampling table only, i.e if table=None) capacity of the uniform sampling table for the local replay server. min_size_limiter_size: Optional (for default uniform sampling table only, i.e if table=None) minimum number of items required in the RB before sampling can begin, used for local server only. Returns: Reverb replay buffer instance Note: the if local server is created, it is not returned. It can be retrieved by calling local_server() on the returned replay buffer. """ table_signature = tensor_spec.add_outer_dim(data_spec, sequence_length) if reverb_server_address is None: if table is None: table = _create_uniform_table( table_name, table_signature, table_capacity=replay_capacity, min_size_limiter_size=min_size_limiter_size) reverb_server = reverb.Server([table], port=port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) else: reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( data_spec, sequence_length=sequence_length, table_name=table_name, server_address=reverb_server_address) return reverb_replay
def train_agent(iterations, modeldir, logdir, policydir): """Train and convert the model using TF Agents.""" # TODO: add code to instantiate the training and evaluation environments # TODO: add code to create a reinforcement learning agent that is going to be trained tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy tf_policy_saver = policy_saver.PolicySaver(collect_policy) # Use reverb as replay buffer replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) table = reverb.Table( REPLAY_BUFFER_TABLE_NAME, max_size=REPLAY_BUFFER_CAPACITY, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature, ) # specify signature here for validation at insertion time reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, sequence_length=None, table_name=REPLAY_BUFFER_TABLE_NAME, local_server=reverb_server, ) replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver( replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME, REPLAY_BUFFER_CAPACITY) # Optimize by wrapping some of the code in a graph using TF function. tf_agent.train = common.function(tf_agent.train) # Evaluate the agent's policy once before training. avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy, NUM_EVAL_EPISODES) summary_writer = tf.summary.create_file_writer(logdir) for i in range(iterations): # TODO: add code to collect game episodes and train the agent logger = tf.get_logger() if i % EVAL_INTERVAL == 0: avg_return, avg_episode_length = compute_avg_return_and_steps( eval_env, eval_policy, NUM_EVAL_EPISODES) with summary_writer.as_default(): tf.summary.scalar("Average return", avg_return, step=i) tf.summary.scalar("Average episode length", avg_episode_length, step=i) summary_writer.flush() logger.info( "iteration = {0}: Average Return = {1}, Average Episode Length = {2}" .format(i, avg_return, avg_episode_length)) summary_writer.close() tf_policy_saver.save(policydir)
def test_size_empty(self): replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, local_server=self._server, sequence_length=None) self.assertEqual(replay.num_frames(), 0)
def train_eval( root_dir, env_name='Pong-v0', # Training params update_frequency=4, # Number of collect steps per policy update initial_collect_steps=50000, # 50k collect steps num_iterations=50000000, # 50M collect steps # Taken from Rainbow as it's not specified in Mnih,15. max_episode_frames_collect=50000, # env frames observed by the agent max_episode_frames_eval=108000, # env frames observed by the agent # Agent params epsilon_greedy=0.1, epsilon_decay_period=250000, # 1M collect steps / update_frequency batch_size=32, learning_rate=0.00025, n_step_update=1, gamma=0.99, target_update_tau=1.0, target_update_period=2500, # 10k collect steps / update_frequency reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=250000, eval_interval=1000, eval_episodes=30, debug_summaries=True): """Trains and evaluates DQN.""" collect_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_collect, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) eval_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_eval, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 epsilon = tf.compat.v1.train.polynomial_decay( 1.0, train_step, epsilon_decay_period, end_learning_rate=epsilon_greedy) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=create_q_network(num_actions), epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.95, epsilon=0.01, centered=True), td_errors_loss_fn=common.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step, debug_summaries=debug_summaries) table_name = 'uniform_table' table = reverb.Table( table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=update_frequency, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train( root_dir, strategy, replay_buffer_server_address, variable_container_server_address, create_agent_fn, create_env_fn, # Training params learning_rate=3e-4, batch_size=256, num_iterations=32000, learner_iterations_per_call=100): """Trains a DQN agent.""" # Get the specs from the environment. logging.info('Training SAC with learning rate: %f', learning_rate) env = create_env_fn() observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(env)) # Create the agent. with strategy.scope(): train_step = train_utils.create_train_step() agent = create_agent_fn(train_step, observation_tensor_spec, action_tensor_spec, time_step_tensor_spec, learning_rate) agent.initialize() # Create the policy saver which saves the initial model now, then it # periodically checkpoints the policy weigths. saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) save_model_trigger = triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=1000) # Create the variable container. variables = { reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.push(variables) # Create the replay buffer. reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=reverb_replay_buffer.DEFAULT_TABLE, server_address=replay_buffer_server_address) # Initialize the dataset. def experience_dataset_fn(): with strategy.scope(): return reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) # Create the learner. learning_triggers = [ save_model_trigger, triggers.StepPerSecondLogTrigger(train_step, interval=1000) ] sac_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) # Run the training loop. # TODO(b/162440911) change the loop use train_step to handle preemptions for _ in range(num_iterations): sac_learner.run(iterations=learner_iterations_per_call) variable_container.push(variables)
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params num_iterations=1600, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), learning_rate=3e-4, collect_sequence_length=2048, minibatch_size=64, num_epochs=10, # Agent params importance_ratio_clipping=0.2, lambda_value=0.95, discount_factor=0.99, entropy_regularization=0., value_pred_loss_coef=0.5, use_gae=True, use_td_lambda_return=True, gradient_clipping=0.5, value_clipping=None, # Replay params reverb_port=None, replay_capacity=10000, # Others policy_save_interval=5000, summary_interval=1000, eval_interval=10000, eval_episodes=100, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates PPO (Importance Ratio Clipping). Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. env_name: Name for the Mujoco environment to load. num_iterations: The number of iterations to perform collection and training. actor_fc_layers: List of fully_connected parameters for the actor network, where each item is the number of units in the layer. value_fc_layers: : List of fully_connected parameters for the value network, where each item is the number of units in the layer. learning_rate: Learning rate used on the Adam optimizer. collect_sequence_length: Number of steps to take in each collect run. minibatch_size: Number of elements in each mini batch. If `None`, the entire collected sequence will be treated as one batch. num_epochs: Number of iterations to repeat over all collected data per data collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool and 3 for Atari. importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For more detail, see explanation at the top of the doc. lambda_value: Lambda parameter for TD-lambda computation. discount_factor: Discount factor for return computation. Default to `0.99` which is the value used for all environments from (Schulman, 2017). entropy_regularization: Coefficient for entropy regularization loss term. Default to `0.0` because no entropy bonus was used in (Schulman, 2017). value_pred_loss_coef: Multiplier for value prediction loss to balance with policy gradient loss. Default to `0.5`, which was used for all environments in the OpenAI baseline implementation. This parameters is irrelevant unless you are sharing part of actor_net and value_net. In that case, you would want to tune this coeeficient, whose value depends on the network architecture of your choice. use_gae: If True (default False), uses generalized advantage estimation for computing per-timestep advantage. Else, just subtracts value predictions from empirical return. use_td_lambda_return: If True (default False), uses td_lambda_return for training value function; here: `td_lambda_return = gae_advantage + value_predictions`. `use_gae` must be set to `True` as well to enable TD -lambda returns. If `use_td_lambda_return` is set to True while `use_gae` is False, the empirical return will be used and a warning will be logged. gradient_clipping: Norm length to clip gradients. value_clipping: Difference between new and old value predictions are clipped to this threshold. Value clipping could be helpful when training very deep networks. Default: no clipping. reverb_port: Port for reverb server, if None, use a randomly chosen unused port. replay_capacity: The maximum number of elements for the replay buffer. Items will be wasted if this is smalled than collect_sequence_length. policy_save_interval: How often, in train_steps, the policy will be saved. summary_interval: How often to write data into Tensorboard. eval_interval: How often to run evaluation, in train_steps. eval_episodes: Number of episodes to evaluate over. debug_summaries: Boolean for whether to gather debug summaries. summarize_grads_and_vars: If true, gradient summaries will be written. """ collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) num_environments = 1 observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) # TODO(b/172267869): Remove this conversion once TensorNormalizer stops # converting float64 inputs to float32. observation_tensor_spec = tf.TensorSpec( dtype=tf.float32, shape=observation_tensor_spec.shape) train_step = train_utils.create_train_step() actor_net_builder = ppo_actor_network.PPOActorNetwork() actor_net = actor_net_builder.create_sequential_actor_net( actor_fc_layers, action_tensor_spec) value_net = value_network.ValueNetwork( observation_tensor_spec, fc_layer_params=value_fc_layers, kernel_initializer=tf.keras.initializers.Orthogonal()) current_iteration = tf.Variable(0, dtype=tf.int64) def learning_rate_fn(): # Linearly decay the learning rate. return learning_rate * (1 - current_iteration / num_iterations) agent = ppo_clip_agent.PPOClipAgent( time_step_tensor_spec, action_tensor_spec, optimizer=tf.keras.optimizers.Adam( learning_rate=learning_rate_fn, epsilon=1e-5), actor_net=actor_net, value_net=value_net, importance_ratio_clipping=importance_ratio_clipping, lambda_value=lambda_value, discount_factor=discount_factor, entropy_regularization=entropy_regularization, value_pred_loss_coef=value_pred_loss_coef, # This is a legacy argument for the number of times we repeat the data # inside of the train function, incompatible with mini batch learning. # We set the epoch number from the replay buffer and tf.Data instead. num_epochs=1, use_gae=use_gae, use_td_lambda_return=use_td_lambda_return, gradient_clipping=gradient_clipping, value_clipping=value_clipping, # TODO(b/150244758): Default compute_value_and_advantage_in_train to False # after Reverb open source. compute_value_and_advantage_in_train=False, # Skips updating normalizers in the agent, as it's handled in the learner. update_normalizers_in_train=False, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() reverb_server = reverb.Server( [ reverb.Table( # Replay buffer storing experience for training. name='training_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ), reverb.Table( # Replay buffer storing experience for normalization. name='normalization_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ) ], port=reverb_port) # Create the replay buffer. reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='training_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='normalization_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) rb_observer = reverb_utils.ReverbTrajectorySequenceObserver( reverb_replay_train.py_client, ['training_table', 'normalization_table'], sequence_length=collect_sequence_length, stride_length=collect_sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) collect_env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={ triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric }), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] def training_dataset_fn(): return reverb_replay_train.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) def normalization_dataset_fn(): return reverb_replay_normalization.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) agent_learner = ppo_learner.PPOLearner( root_dir, train_step, agent, experience_dataset_fn=training_dataset_fn, normalization_dataset_fn=normalization_dataset_fn, num_samples=1, num_epochs=num_epochs, minibatch_size=minibatch_size, shuffle_buffer_size=collect_sequence_length, triggers=learning_triggers) tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_sequence_length, observers=[rb_observer], metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric], reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), summary_interval=summary_interval) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( agent.policy, use_tf_function=True) if eval_interval: logging.info('Intial evaluation.') eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), episodes_per_run=eval_episodes) eval_actor.run_and_log() logging.info('Training on %s', env_name) last_eval_step = 0 for i in range(num_iterations): collect_actor.run() rb_observer.flush() agent_learner.run() reverb_replay_train.clear() reverb_replay_normalization.clear() current_iteration.assign_add(1) # Eval only if `eval_interval` has been set. Then, eval if the current train # step is equal or greater than the `last_eval_step` + `eval_interval` or if # this is the last iteration. This logic exists because agent_learner.run() # does not return after every train step. if (eval_interval and (agent_learner.train_step_numpy >= eval_interval + last_eval_step or i == num_iterations - 1)): logging.info('Evaluating.') eval_actor.run_and_log() last_eval_step = agent_learner.train_step_numpy rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name, # Training params train_sequence_length, initial_collect_steps=1000, collect_steps_per_iteration=1, num_iterations=100000, # RNN params. q_network_fn=q_lstm_network, # defaults to q_lstm_network. # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 q_net = q_network_fn(num_actions=num_actions) sequence_length = train_sequence_length + 1 agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, # n-step updates aren't supported with RNNs yet. n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=sequence_length, stride_length=1, pad_end_of_episodes=True) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_steps_per_iteration, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
table_name = 'uniform_table' replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature) table = reverb.Table(table_name, max_size=replay_buffer_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature) reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, table_name=table_name, sequence_length=None, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddEpisodeObserver(replay_buffer.py_client, table_name, replay_buffer_capacity) def collect_episode(environment, policy, num_episodes): driver = py_driver.PyDriver(environment, py_tf_eager_policy.PyTFEagerPolicy( policy, use_tf_function=True), [rb_observer], max_episodes=num_episodes)
def train_agent(iterations, modeldir, logdir, policydir): """Train and convert the model using TF Agents.""" train_py_env = planestrike_py_environment.PlaneStrikePyEnvironment( board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2) eval_py_env = planestrike_py_environment.PlaneStrikePyEnvironment( board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) # Alternatively you could use ActorDistributionNetwork as actor_net actor_net = tfa.networks.Sequential( [ tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE], [BOARD_SIZE**2]), tf.keras.layers.Dense(FC_LAYER_PARAMS, activation='relu'), tf.keras.layers.Dense(BOARD_SIZE**2), tf.keras.layers.Lambda( lambda t: tfp.distributions.Categorical(logits=t)), ], input_spec=train_py_env.observation_spec()) optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) train_step_counter = tf.Variable(0) tf_agent = reinforce_agent.ReinforceAgent( train_env.time_step_spec(), train_env.action_spec(), actor_network=actor_net, optimizer=optimizer, normalize_returns=True, train_step_counter=train_step_counter) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy tf_policy_saver = policy_saver.PolicySaver(collect_policy) # Use reverb as replay buffer replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) table = reverb.Table( REPLAY_BUFFER_TABLE_NAME, max_size=REPLAY_BUFFER_CAPACITY, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature ) # specify signature here for validation at insertion time reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, sequence_length=None, table_name=REPLAY_BUFFER_TABLE_NAME, local_server=reverb_server) replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver( replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME, REPLAY_BUFFER_CAPACITY) # Optimize by wrapping some of the code in a graph using TF function. tf_agent.train = common.function(tf_agent.train) # Evaluate the agent's policy once before training. avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy, NUM_EVAL_EPISODES) summary_writer = tf.summary.create_file_writer(logdir) for i in range(iterations): # Collect a few episodes using collect_policy and save to the replay buffer. collect_episode(train_py_env, collect_policy, COLLECT_EPISODES_PER_ITERATION, replay_buffer_observer) # Use data from the buffer and update the agent's network. iterator = iter(replay_buffer.as_dataset(sample_batch_size=1)) trajectories, _ = next(iterator) tf_agent.train(experience=trajectories) replay_buffer.clear() logger = tf.get_logger() if i % EVAL_INTERVAL == 0: avg_return, avg_episode_length = compute_avg_return_and_steps( eval_env, eval_policy, NUM_EVAL_EPISODES) with summary_writer.as_default(): tf.summary.scalar('Average return', avg_return, step=i) tf.summary.scalar('Average episode length', avg_episode_length, step=i) summary_writer.flush() logger.info( 'iteration = {0}: Average Return = {1}, Average Episode Length = {2}' .format(i, avg_return, avg_episode_length)) summary_writer.close() tf_policy_saver.save(policydir) # Convert to tflite model converter = tf.lite.TFLiteConverter.from_saved_model( policydir, signature_keys=['action']) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] tflite_policy = converter.convert() with open(os.path.join(modeldir, 'planestrike_tf_agents.tflite'), 'wb') as f: f.write(tflite_policy)
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, strategy: tf.distribute.Strategy, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=10000, replay_buffer_save_interval=100000, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) _, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) actor_net = create_sequential_actor_network( actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec) critic_net = create_sequential_critic_network( obs_fc_layer_units=critic_obs_fc_layers, action_fc_layer_units=critic_action_fc_layers, joint_fc_layer_units=critic_joint_fc_layers) with strategy.scope(): train_step = train_utils.create_train_step() agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR, learner.REPLAY_BUFFER_CHECKPOINT_DIR) reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer( path=reverb_checkpoint_dir) reverb_server = reverb.Server([table], port=reverb_port, checkpointer=reverb_checkpointer) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.ReverbCheckpointTrigger( train_step, interval=replay_buffer_save_interval, reverb_client=reverb_replay.py_client), # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption. triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()