def _build_learner_with_strategy(self, create_agent_and_dataset_fn, strategy, sample_batch_size=2): if strategy is None: # Get default strategy if None provided. strategy = tf.distribute.get_strategy() with strategy.scope(): tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load('CartPole-v0')) train_step = train_utils.create_train_step() agent, dataset, dataset_fn, _ = create_agent_and_dataset_fn( tf_env.time_step_spec().observation, tf_env.action_spec(), tf_env.time_step_spec(), train_step, sample_batch_size) root_dir = os.path.join(self.create_tempdir().full_path, 'learner') test_learner = learner.Learner(root_dir=root_dir, train_step=train_step, agent=agent, experience_dataset_fn=dataset_fn) variables = agent.collect_policy.variables() return test_learner, dataset, variables, train_step
print(f" -- LEARNERS ({now()}) -- ") saved_model_dir = os.path.join(tempdir, learner.POLICY_SAVED_MODEL_DIR) # Triggers to save the agent's policy checkpoints. learning_triggers = [ triggers.PolicySavedModelTrigger(saved_model_dir, tf_agent, train_step, interval=HyperParms.policy_save_interval), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(tempdir, train_step, tf_agent, experience_dataset_fn, triggers=learning_triggers) print(f" -- METRICS AND EVALUATION ({now()}) -- ") def get_eval_metrics(): eval_actor.run() results = {} for metric in eval_actor.metrics: results[metric.name] = metric.result() return results metrics = get_eval_metrics()
def __init__(self, root_dir, train_step, agent, max_num_sequences=None, minibatch_size=None, shuffle_buffer_size=None, after_train_strategy_step_fn=None, triggers=None, checkpoint_interval=100000, summary_interval=1000, use_kwargs_in_agent_train=False, strategy=None): """Initializes a PPOLearner instance. Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. train_step: a scalar tf.int64 `tf.Variable` which will keep track of the number of train steps. This is used for artifacts created like summaries, or outputs in the root_dir. agent: `tf_agent.TFAgent` instance to train with. max_num_sequences: The max number of sequences to read from the input dataset in `run`. Defaults to None, in which case `run` will terminate when reach the end of the dataset (for instance when the rate limiter times out). minibatch_size: The minibatch size. The dataset used for training is shaped [minibatch_size, 1, ...]. shuffle_buffer_size: The buffer size for shuffling the trajectories before splitting them into mini batches. Only required when mini batch learning is enabled (minibatch_size is set). Otherwise it is ignored. Commonly set to a number 1-3x the episode length of your environment. after_train_strategy_step_fn: (Optional) callable of the form `fn(sample, loss)` which can be used for example to update priorities in a replay buffer where sample is pulled from the `experience_iterator` and loss is a `LossInfo` named tuple returned from the agent. This is called after every train step. It runs using `strategy.run(...)`. triggers: List of callables of the form `trigger(train_step)`. After every `run` call every trigger is called with the current `train_step` value as an np scalar. checkpoint_interval: Number of train steps in between checkpoints. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. Set to -1 to disable. This only takes care of the checkpointing the training process. Policies must be explicitly exported through triggers summary_interval: Number of train steps in between summaries. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. use_kwargs_in_agent_train: If True the experience from the replay buffer is passed into the agent as kwargs. This requires samples from the RB to be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This is useful if you have an agent with a custom argspec. strategy: (Optional) `tf.distribute.Strategy` to use during training. """ if minibatch_size is not None and shuffle_buffer_size is None: raise ValueError( 'shuffle_buffer_size must be provided if minibatch_size is not None.' ) if agent.update_normalizers_in_train: raise ValueError( 'agent.update_normalizers_in_train should be set to False when ' 'PPOLearner is used.' ) self._agent = agent self._max_num_sequences = max_num_sequences self._minibatch_size = minibatch_size self._shuffle_buffer_size = shuffle_buffer_size self._generic_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn=None, after_train_strategy_step_fn=after_train_strategy_step_fn, triggers=triggers, checkpoint_interval=checkpoint_interval, summary_interval=summary_interval, use_kwargs_in_agent_train=use_kwargs_in_agent_train, strategy=strategy)
def train( root_dir, strategy, replay_buffer_server_address, variable_container_server_address, create_agent_fn, create_env_fn, # Training params learning_rate=3e-4, batch_size=256, num_iterations=32000, learner_iterations_per_call=100): """Trains a DQN agent.""" # Get the specs from the environment. logging.info('Training SAC with learning rate: %f', learning_rate) env = create_env_fn() observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(env)) # Create the agent. with strategy.scope(): train_step = train_utils.create_train_step() agent = create_agent_fn(train_step, observation_tensor_spec, action_tensor_spec, time_step_tensor_spec, learning_rate) agent.initialize() # Create the policy saver which saves the initial model now, then it # periodically checkpoints the policy weigths. saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) save_model_trigger = triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=1000) # Create the variable container. variables = { reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.push(variables) # Create the replay buffer. reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=reverb_replay_buffer.DEFAULT_TABLE, server_address=replay_buffer_server_address) # Initialize the dataset. def experience_dataset_fn(): with strategy.scope(): return reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) # Create the learner. learning_triggers = [ save_model_trigger, triggers.StepPerSecondLogTrigger(train_step, interval=1000) ] sac_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) # Run the training loop. # TODO(b/162440911) change the loop use train_step to handle preemptions for _ in range(num_iterations): sac_learner.run(iterations=learner_iterations_per_call) variable_container.push(variables)
def __init__(self, root_dir: Text, train_step: tf.Variable, agent: ppo_agent.PPOAgent, experience_dataset_fn: Callable[..., tf.data.Dataset], normalization_dataset_fn: Callable[..., tf.data.Dataset], num_batches: int, num_epochs: int = 1, minibatch_size: Optional[int] = None, shuffle_buffer_size: Optional[int] = None, after_train_strategy_step_fn: Optional[Callable[ [types.NestedTensor, tf_agent.LossInfo], None]] = None, triggers: Callable[..., None] = None, checkpoint_interval: int = 100000, summary_interval: int = 1000, use_kwargs_in_agent_train: bool = False, strategy: Optional[tf.distribute.Strategy] = None): """Initializes a PPOLearner instance. ```python agent = ppo_agent.PPOAgent(..., compute_value_and_advantage_in_train=False, # Skips updating normalizers in the agent, as it's handled in the learner. update_normalizers_in_train=False) # train_replay_buffer and normalization_replay_buffer point to two Reverb # tables that are synchronized. Sampling is done in a FIFO fashion. def experience_dataset_fn(): return train_replay_buffer.as_dataset(sample_batch_size, sequence_preprocess_fn=agent.preprocess_sequence) def normalization_dataset_fn(): return normalization_replay_buffer.as_dataset(sample_batch_size, sequence_preprocess_fn=agent.preprocess_sequence) learner = PPOLearner(..., agent, experience_dataset_fn, normalization_dataset_fn) learner.run() ``` Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. train_step: a scalar tf.int64 `tf.Variable` which will keep track of the number of train steps. This is used for artifacts created like summaries, or outputs in the root_dir. agent: `ppo_agent.PPOAgent` instance to train with. Note that update_normalizers_in_train should be set to `False`, otherwise a ValueError will be raised. We do not update normalizers in the agent again because we already update it in the learner. When mini batching is enabled, compute_value_and_advantage_in_train should be set to False, and preprocessing should be done as part of the data pipeline as part of `replay_buffer.as_dataset`. experience_dataset_fn: a function that will create an instance of a tf.data.Dataset used to sample experience for training. Each element in the dataset is a (Trajectory, SampleInfo) pair. normalization_dataset_fn: a function that will create an instance of a tf.data.Dataset used for normalization. This dataset is often from a separate reverb table that is synchronized with the table used in experience_dataset_fn. Each element in the dataset is a (Trajectory, SampleInfo) pair. num_batches: The number of batches to sample for training and normalization. If fewer than this amount of batches exists in the dataset, the learner will wait for more data to be added, or until the reverb timeout is reached. num_epochs: The number of iterations to go through the same sequences. minibatch_size: The minibatch size. The dataset used for training is shaped `[minibatch_size, 1, ...]`. If None, full sequences will be fed into the agent. Please set this parameter to None for RNN networks which requires full sequences. shuffle_buffer_size: The buffer size for shuffling the trajectories before splitting them into mini batches. Only required when mini batch learning is enabled (minibatch_size is set). Otherwise it is ignored. Commonly set to a number 1-3x the episode length of your environment. after_train_strategy_step_fn: (Optional) callable of the form `fn(sample, loss)` which can be used for example to update priorities in a replay buffer where sample is pulled from the `experience_iterator` and loss is a `LossInfo` named tuple returned from the agent. This is called after every train step. It runs using `strategy.run(...)`. triggers: List of callables of the form `trigger(train_step)`. After every `run` call every trigger is called with the current `train_step` value as an np scalar. checkpoint_interval: Number of train steps in between checkpoints. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. Set to -1 to disable. This only takes care of the checkpointing the training process. Policies must be explicitly exported through triggers summary_interval: Number of train steps in between summaries. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. use_kwargs_in_agent_train: If True the experience from the replay buffer is passed into the agent as kwargs. This requires samples from the RB to be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This is useful if you have an agent with a custom argspec. strategy: (Optional) `tf.distribute.Strategy` to use during training. Raises: ValueError:mini batching is enabled, but shuffle_buffer_size isn't provided. ValueError: minibatch_size is passed in for RNN networks. RNNs require full sequences. ValueError:mini batching is enabled, but agent._compute_value_and_advantage_in_train is set to `True`. ValueError: agent.update_normalizers_in_train or is set to `True`. The learner already updates the normalizers, so no need to update again in the agent. """ if minibatch_size and shuffle_buffer_size is None: raise ValueError( 'shuffle_buffer_size must be provided if minibatch_size is not None.' ) if minibatch_size and (agent._actor_net.state_spec or agent._value_net.state_spec): raise ValueError( 'minibatch_size must be set to None for RNN networks.') if minibatch_size and agent._compute_value_and_advantage_in_train: raise ValueError( 'agent.compute_value_and_advantage_in_train should be set to False ' 'when mini batching is used.') if agent.update_normalizers_in_train: raise ValueError( 'agent.update_normalizers_in_train should be set to False when ' 'PPOLearner is used.') strategy = strategy or tf.distribute.get_strategy() self._agent = agent self._minibatch_size = minibatch_size self._shuffle_buffer_size = shuffle_buffer_size self._num_epochs = num_epochs self._experience_dataset_fn = experience_dataset_fn self._normalization_dataset_fn = normalization_dataset_fn self._num_batches = num_batches self._generic_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn=None, after_train_strategy_step_fn=after_train_strategy_step_fn, triggers=triggers, checkpoint_interval=checkpoint_interval, summary_interval=summary_interval, use_kwargs_in_agent_train=use_kwargs_in_agent_train, strategy=strategy) self.num_replicas = strategy.num_replicas_in_sync self._create_datasets(strategy) self.num_frames_for_training = tf.Variable(0, dtype=tf.int32)
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name, # Training params train_sequence_length, initial_collect_steps=1000, collect_steps_per_iteration=1, num_iterations=100000, # RNN params. q_network_fn=q_lstm_network, # defaults to q_lstm_network. # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 q_net = q_network_fn(num_actions=num_actions) sequence_length = train_sequence_length + 1 agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, # n-step updates aren't supported with RNNs yet. n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbTrajectorySequenceObserver( reverb_replay.py_client, table_name, sequence_length=sequence_length, stride_length=1) dataset = reverb_replay.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_steps_per_iteration, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name='Pong-v0', # Training params update_frequency=4, # Number of collect steps per policy update initial_collect_steps=50000, # 50k collect steps num_iterations=50000000, # 50M collect steps # Taken from Rainbow as it's not specified in Mnih,15. max_episode_frames_collect=50000, # env frames observed by the agent max_episode_frames_eval=108000, # env frames observed by the agent # Agent params epsilon_greedy=0.1, epsilon_decay_period=250000, # 1M collect steps / update_frequency batch_size=32, learning_rate=0.00025, n_step_update=1, gamma=0.99, target_update_tau=1.0, target_update_period=2500, # 10k collect steps / update_frequency reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=250000, eval_interval=1000, eval_episodes=30, debug_summaries=True): """Trains and evaluates DQN.""" collect_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_collect, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) eval_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_eval, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 epsilon = tf.compat.v1.train.polynomial_decay( 1.0, train_step, epsilon_decay_period, end_learning_rate=epsilon_greedy) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=create_q_network(num_actions), epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.95, epsilon=0.01, centered=True), td_errors_loss_fn=common.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step, debug_summaries=debug_summaries) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=update_frequency, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()