def __init__(self, data_spec, batch_size=1, n_steps=2): self.data_spec = data_spec self.batch_size = batch_size self.n_steps = n_steps self.replay_buffer_capacity = 100000 self.name = 'PER' self.server = reverb.Server([ reverb.Table(name=self.name, max_size=self.replay_buffer_capacity, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) ]) self.buffer = reverb_replay_buffer.ReverbReplayBuffer( data_spec=self.data_spec, sequence_length=self.n_steps, table_name=self.name, local_server=self.server) self.writer = reverb_utils.ReverbAddTrajectoryObserver( self.buffer.py_client, table_name=self.name, sequence_length=self.n_steps, stride_length=1, priority=1, ) self.states = None
def get_reverb_buffer_and_observer(data_spec, sequence_length=None, table_name='uniform_table', table=None, reverb_server_address=None, port=None, replay_capacity=1000, min_size_limiter_size=1, stride_length=1): """Returns an instance of Reverb replay buffer and observer to add items. Either creates a local reverb server or uses a remote reverb server at reverb_sever_address (if set). If `reverb_server_address is None`, creates a local server with a uniform table underneath. Args: data_spec: spec of the data elements to be stored in the replay buffer sequence_length: integer specifying sequence_lenghts used to write to the given table. table_name: Name of the uniform table to create. table: Optional table for the backing local server. If None, automatically creates a uniform sampling table. reverb_server_address: Address of the remote reverb server, if None a local server is created. port: Port to launch the server in. replay_capacity: capacity of the local replay server, if using (i.e. if reverb_server_address is None). min_size_limiter_size: Minimum number of items required in the RB before sampling can begin, used for local server only. stride_length: Integer strides for the sliding window for overlapping sequences. Returns: A tuple consisting of: - reverb replay buffer instance - replay buffer observer Note: the if local server is created, it is not returned. It can be retrieved by calling local_server() on the returned replay buffer. """ reverb_replay = get_reverb_buffer( data_spec=data_spec, sequence_length=sequence_length, table_name=table_name, table=table, reverb_server_address=reverb_server_address, port=port, replay_capacity=replay_capacity, min_size_limiter_size=min_size_limiter_size) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=sequence_length, stride_length=stride_length) return reverb_replay, rb_observer
def collect(summary_dir: Text, environment_name: Text, collect_policy: py_tf_eager_policy.PyTFEagerPolicyBase, replay_buffer_server_address: Text, variable_container_server_address: Text, suite_load_fn: Callable[ [Text], py_environment.PyEnvironment] = suite_mujoco.load, initial_collect_steps: int = 10000, max_train_steps: int = 2000000) -> None: """Collects experience using a policy updated after every episode.""" # Create the environment. For now support only single environment collection. collect_env = suite_load_fn(environment_name) # Create the variable container. train_step = train_utils.create_train_step() variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.update(variables) # Create the replay buffer observer. rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb.Client(replay_buffer_server_address), table_name=reverb_replay_buffer.DEFAULT_TABLE, sequence_length=2, stride_length=1) random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() env_step_metric = py_metrics.EnvironmentSteps() collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=summary_dir, observers=[rb_observer, env_step_metric]) # Run the experience collection loop. while train_step.numpy() < max_train_steps: logging.info('Collecting with policy at step: %d', train_step.numpy()) collect_actor.run() variable_container.update(variables)
def _insert_random_data(self, env, num_steps, sequence_length=2, additional_observers=None): """Insert `num_step` random observations into Reverb server.""" observers = [] if additional_observers is None else additional_observers traj_obs = reverb_utils.ReverbAddTrajectoryObserver( self._py_client, self._table_name, sequence_length=sequence_length) observers.append(traj_obs) policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) driver = py_driver.PyDriver(env, policy, observers=observers, max_steps=num_steps) time_step = env.reset() driver.run(time_step) traj_obs.close()
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def _create_and_yield(client): yield reverb_utils.ReverbAddTrajectoryObserver(client, *args, **kwargs)
def train_eval( root_dir, env_name='CartPole-v0', # Training params initial_collect_steps=1000, num_iterations=100000, fc_layer_params=(100, ), # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec()) action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec()) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 # Define a helper function to create Dense layers configured with the right # activation and kernel initializer. def dense_layer(num_units): return tf.keras.layers.Dense( num_units, activation=tf.keras.activations.relu, kernel_initializer=tf.keras.initializers.VarianceScaling( scale=2.0, mode='fan_in', distribution='truncated_normal')) # QNetwork consists of a sequence of Dense layers followed by a dense layer # with `num_actions` units to generate one q_value per available action as # it's output. dense_layers = [dense_layer(num_units) for num_units in fc_layer_params] q_values_layer = tf.keras.layers.Dense( num_actions, activation=None, kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03, maxval=0.03), bias_initializer=tf.keras.initializers.Constant(-0.2)) q_net = sequential.Sequential(dense_layers + [q_values_layer]) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name, # Training params train_sequence_length, initial_collect_steps=1000, collect_steps_per_iteration=1, num_iterations=100000, # RNN params. q_network_fn=q_lstm_network, # defaults to q_lstm_network. # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 q_net = q_network_fn(num_actions=num_actions) sequence_length = train_sequence_length + 1 agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, # n-step updates aren't supported with RNNs yet. n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=sequence_length, stride_length=1, pad_end_of_episodes=True) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_steps_per_iteration, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name='Pong-v0', # Training params update_frequency=4, # Number of collect steps per policy update initial_collect_steps=50000, # 50k collect steps num_iterations=50000000, # 50M collect steps # Taken from Rainbow as it's not specified in Mnih,15. max_episode_frames_collect=50000, # env frames observed by the agent max_episode_frames_eval=108000, # env frames observed by the agent # Agent params epsilon_greedy=0.1, epsilon_decay_period=250000, # 1M collect steps / update_frequency batch_size=32, learning_rate=0.00025, n_step_update=1, gamma=0.99, target_update_tau=1.0, target_update_period=2500, # 10k collect steps / update_frequency reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=250000, eval_interval=1000, eval_episodes=30, debug_summaries=True): """Trains and evaluates DQN.""" collect_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_collect, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) eval_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_eval, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 epsilon = tf.compat.v1.train.polynomial_decay( 1.0, train_step, epsilon_decay_period, end_learning_rate=epsilon_greedy) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=create_q_network(num_actions), epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.95, epsilon=0.01, centered=True), td_errors_loss_fn=common.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step, debug_summaries=debug_summaries) table_name = 'uniform_table' table = reverb.Table( table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=update_frequency, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset print(f" -- POLICIES ({now()}) -- ") tf_eval_policy = tf_agent.policy eval_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_eval_policy, use_tf_function=True) tf_collect_policy = tf_agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) print(f" -- ACTORS ({now()}) -- ") rb_observer = reverb_utils.ReverbAddTrajectoryObserver(reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=HyperParms.initial_collect_steps, observers=[rb_observer]) initial_collect_actor.run() env_step_metric = py_metrics.EnvironmentSteps() collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1,
def collect(task, root_dir, replay_buffer_server_address, variable_container_server_address, create_env_fn, initial_collect_steps=10000, num_iterations=10000000): """Collects experience using a policy updated after every episode.""" # Create the environment. For now support only single environment collection. collect_env = create_env_fn() # Create the path for the serialized collect policy. collect_policy_saved_model_path = os.path.join( root_dir, learner.POLICY_SAVED_MODEL_DIR, learner.COLLECT_POLICY_SAVED_MODEL_DIR) saved_model_pb_path = os.path.join(collect_policy_saved_model_path, 'saved_model.pb') try: # Wait for the collect policy to be outputed by learner (timeout after 2 # days), then load it. train_utils.wait_for_file( saved_model_pb_path, sleep_time_secs=2, num_retries=86400) collect_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( collect_policy_saved_model_path, load_specs_from_pbtxt=True) except TimeoutError as e: # If the collect policy does not become available during the wait time of # the call `wait_for_file`, that probably means the learner is not running. logging.error('Could not get the file %s. Exiting.', saved_model_pb_path) raise e # Create the variable container. train_step = train_utils.create_train_step() variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.update(variables) # Create the replay buffer observer. rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb.Client(replay_buffer_server_address), table_name=reverb_replay_buffer.DEFAULT_TABLE, sequence_length=2, stride_length=1) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() env_step_metric = py_metrics.EnvironmentSteps() collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR, str(task)), observers=[rb_observer, env_step_metric]) # Run the experience collection loop. for _ in range(num_iterations): logging.info('Collecting with policy at step: %d', train_step.numpy()) collect_actor.run() variable_container.update(variables)
def train_eval( root_dir, strategy: tf.distribute.Strategy, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=10000, replay_buffer_save_interval=100000, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) _, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) actor_net = create_sequential_actor_network( actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec) critic_net = create_sequential_critic_network( obs_fc_layer_units=critic_obs_fc_layers, action_fc_layer_units=critic_action_fc_layers, joint_fc_layer_units=critic_joint_fc_layers) with strategy.scope(): train_step = train_utils.create_train_step() agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR, learner.REPLAY_BUFFER_CHECKPOINT_DIR) reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer( path=reverb_checkpoint_dir) reverb_server = reverb.Server([table], port=reverb_port, checkpointer=reverb_checkpointer) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.ReverbCheckpointTrigger( train_step, interval=replay_buffer_save_interval, reverb_client=reverb_replay.py_client), # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption. triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()