def _create_eval_actor(eval_env: YGOEnvironment, eval_policy: PyTFEagerPolicy, train_step) -> actor.Actor: return actor.Actor(eval_env, eval_policy, train_step, episodes_per_run=_num_eval_episodes, metrics=actor.eval_metrics(_num_eval_episodes), summary_dir=os.path.join(tempdir, 'eval'))
def testEvalLocalPyActorRun(self): rb_port = portpicker.pick_unused_port(portserver_address='localhost') env, agent, train_step, replay_buffer, _ = ( self._build_components(rb_port)) temp_dir = self.create_tempdir().full_path tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) test_actor = actor.Actor( env, collect_policy, train_step, episodes_per_run=1, metrics=actor.eval_metrics(buffer_size=1), summary_dir=temp_dir, ) self.assertEqual(replay_buffer.num_frames(), 0) for _ in range(2): test_actor.run() self.assertEqual(replay_buffer.num_frames(), 0) self.assertGreater(test_actor._metrics[0].result(), 0)
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params num_iterations=1600, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), learning_rate=3e-4, collect_sequence_length=2048, minibatch_size=64, num_epochs=10, # Agent params importance_ratio_clipping=0.2, lambda_value=0.95, discount_factor=0.99, entropy_regularization=0., value_pred_loss_coef=0.5, use_gae=True, use_td_lambda_return=True, gradient_clipping=0.5, value_clipping=None, # Replay params reverb_port=None, replay_capacity=10000, # Others policy_save_interval=5000, summary_interval=1000, eval_interval=10000, eval_episodes=100, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates PPO (Importance Ratio Clipping). Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. env_name: Name for the Mujoco environment to load. num_iterations: The number of iterations to perform collection and training. actor_fc_layers: List of fully_connected parameters for the actor network, where each item is the number of units in the layer. value_fc_layers: : List of fully_connected parameters for the value network, where each item is the number of units in the layer. learning_rate: Learning rate used on the Adam optimizer. collect_sequence_length: Number of steps to take in each collect run. minibatch_size: Number of elements in each mini batch. If `None`, the entire collected sequence will be treated as one batch. num_epochs: Number of iterations to repeat over all collected data per data collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool and 3 for Atari. importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For more detail, see explanation at the top of the doc. lambda_value: Lambda parameter for TD-lambda computation. discount_factor: Discount factor for return computation. Default to `0.99` which is the value used for all environments from (Schulman, 2017). entropy_regularization: Coefficient for entropy regularization loss term. Default to `0.0` because no entropy bonus was used in (Schulman, 2017). value_pred_loss_coef: Multiplier for value prediction loss to balance with policy gradient loss. Default to `0.5`, which was used for all environments in the OpenAI baseline implementation. This parameters is irrelevant unless you are sharing part of actor_net and value_net. In that case, you would want to tune this coeeficient, whose value depends on the network architecture of your choice. use_gae: If True (default False), uses generalized advantage estimation for computing per-timestep advantage. Else, just subtracts value predictions from empirical return. use_td_lambda_return: If True (default False), uses td_lambda_return for training value function; here: `td_lambda_return = gae_advantage + value_predictions`. `use_gae` must be set to `True` as well to enable TD -lambda returns. If `use_td_lambda_return` is set to True while `use_gae` is False, the empirical return will be used and a warning will be logged. gradient_clipping: Norm length to clip gradients. value_clipping: Difference between new and old value predictions are clipped to this threshold. Value clipping could be helpful when training very deep networks. Default: no clipping. reverb_port: Port for reverb server, if None, use a randomly chosen unused port. replay_capacity: The maximum number of elements for the replay buffer. Items will be wasted if this is smalled than collect_sequence_length. policy_save_interval: How often, in train_steps, the policy will be saved. summary_interval: How often to write data into Tensorboard. eval_interval: How often to run evaluation, in train_steps. eval_episodes: Number of episodes to evaluate over. debug_summaries: Boolean for whether to gather debug summaries. summarize_grads_and_vars: If true, gradient summaries will be written. """ collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) num_environments = 1 observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) # TODO(b/172267869): Remove this conversion once TensorNormalizer stops # converting float64 inputs to float32. observation_tensor_spec = tf.TensorSpec( dtype=tf.float32, shape=observation_tensor_spec.shape) train_step = train_utils.create_train_step() actor_net_builder = ppo_actor_network.PPOActorNetwork() actor_net = actor_net_builder.create_sequential_actor_net( actor_fc_layers, action_tensor_spec) value_net = value_network.ValueNetwork( observation_tensor_spec, fc_layer_params=value_fc_layers, kernel_initializer=tf.keras.initializers.Orthogonal()) current_iteration = tf.Variable(0, dtype=tf.int64) def learning_rate_fn(): # Linearly decay the learning rate. return learning_rate * (1 - current_iteration / num_iterations) agent = ppo_clip_agent.PPOClipAgent( time_step_tensor_spec, action_tensor_spec, optimizer=tf.keras.optimizers.Adam( learning_rate=learning_rate_fn, epsilon=1e-5), actor_net=actor_net, value_net=value_net, importance_ratio_clipping=importance_ratio_clipping, lambda_value=lambda_value, discount_factor=discount_factor, entropy_regularization=entropy_regularization, value_pred_loss_coef=value_pred_loss_coef, # This is a legacy argument for the number of times we repeat the data # inside of the train function, incompatible with mini batch learning. # We set the epoch number from the replay buffer and tf.Data instead. num_epochs=1, use_gae=use_gae, use_td_lambda_return=use_td_lambda_return, gradient_clipping=gradient_clipping, value_clipping=value_clipping, # TODO(b/150244758): Default compute_value_and_advantage_in_train to False # after Reverb open source. compute_value_and_advantage_in_train=False, # Skips updating normalizers in the agent, as it's handled in the learner. update_normalizers_in_train=False, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() reverb_server = reverb.Server( [ reverb.Table( # Replay buffer storing experience for training. name='training_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ), reverb.Table( # Replay buffer storing experience for normalization. name='normalization_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ) ], port=reverb_port) # Create the replay buffer. reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='training_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='normalization_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) rb_observer = reverb_utils.ReverbTrajectorySequenceObserver( reverb_replay_train.py_client, ['training_table', 'normalization_table'], sequence_length=collect_sequence_length, stride_length=collect_sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) collect_env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={ triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric }), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] def training_dataset_fn(): return reverb_replay_train.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) def normalization_dataset_fn(): return reverb_replay_normalization.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) agent_learner = ppo_learner.PPOLearner( root_dir, train_step, agent, experience_dataset_fn=training_dataset_fn, normalization_dataset_fn=normalization_dataset_fn, num_samples=1, num_epochs=num_epochs, minibatch_size=minibatch_size, shuffle_buffer_size=collect_sequence_length, triggers=learning_triggers) tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_sequence_length, observers=[rb_observer], metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric], reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), summary_interval=summary_interval) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( agent.policy, use_tf_function=True) if eval_interval: logging.info('Intial evaluation.') eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), episodes_per_run=eval_episodes) eval_actor.run_and_log() logging.info('Training on %s', env_name) last_eval_step = 0 for i in range(num_iterations): collect_actor.run() rb_observer.flush() agent_learner.run() reverb_replay_train.clear() reverb_replay_normalization.clear() current_iteration.assign_add(1) # Eval only if `eval_interval` has been set. Then, eval if the current train # step is equal or greater than the `last_eval_step` + `eval_interval` or if # this is the last iteration. This logic exists because agent_learner.run() # does not return after every train step. if (eval_interval and (agent_learner.train_step_numpy >= eval_interval + last_eval_step or i == num_iterations - 1)): logging.info('Evaluating.') eval_actor.run_and_log() last_eval_step = agent_learner.train_step_numpy rb_observer.close() reverb_server.stop()
def evaluate(env_name, saved_model_dir, env_load_fn=env_utils.load_dm_env_for_eval, num_episodes=1, eval_log_dir=None, continuous=False, max_train_step=math.inf, seconds_between_checkpoint_polls=5, num_retries=100, log_measurements=lambda metrics, current_step: None): """Evaluates a checkpoint directory. Checkpoints for the saved model to evaluate are assumed to be at the same directory level as the saved_model dir. ie: * saved_model_dir: root_dir/policies/greedy_policy * checkpoints_dir: root_dir/checkpoints Args: env_name: Name of the environment to evaluate in. saved_model_dir: String path to the saved model directory. env_load_fn: Function to load the environment specified by env_name. num_episodes: Number or episodes to evaluate per checkpoint. eval_log_dir: Optional path to output summaries of the evaluations. If None a default directory relative to the saved_model_dir will be used. continuous: If True all the evaluation will keep polling for new checkpoints. max_train_step: Maximum train_step to evaluate. Once a train_step greater or equal to this is evaluated the evaluations will terminate. Should set to <= train_eval.num_iterations to ensure that eval terminates. seconds_between_checkpoint_polls: The amount of time in seconds to wait between polls to see if new checkpoints appear in the continuous setting. num_retries: Number of retries for reading checkpoints. log_measurements: Function to log measurements. Raises: IOError: on repeated failures to read checkpoints after all the retries. """ split = os.path.split(saved_model_dir) # Remove trailing slash if we have one. if not split[-1]: saved_model_dir = split[0] env = env_load_fn(env_name) # Load saved model. saved_model_path = os.path.join(saved_model_dir, 'saved_model.pb') while continuous and not tf.io.gfile.exists(saved_model_path): logging.info( 'Waiting on the first checkpoint to become available at: %s', saved_model_path) time.sleep(seconds_between_checkpoint_polls) for _ in range(num_retries): try: policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( saved_model_dir, load_specs_from_pbtxt=True) break except (tf.errors.OpError, tf.errors.DataLossError, IndexError, FileNotFoundError): logging.warning( 'Encountered an error while loading a policy. This can ' 'happen when reading a checkpoint before it is fully written. ' 'Retrying...') time.sleep(seconds_between_checkpoint_polls) else: logging.error('Failed to load a checkpoint after retrying: %s', saved_model_dir) if max_train_step and policy.get_train_step() > max_train_step: logging.info( 'Policy train_step (%d) > max_train_step (%d). No evaluations performed.', policy.get_train_step(), max_train_step) return # Assume saved_model dir is of the form: root_dir/policies/greedy_policy. This # requires going up two levels to get the root_dir. root_dir = os.path.dirname(os.path.dirname(saved_model_dir)) log_dir = eval_log_dir or os.path.join(root_dir, 'eval') # evaluated_file = os.path.join(log_dir, EVALUATED_STEPS_FILE) evaluated_checkpoints = set() train_step = tf.Variable(policy.get_train_step(), dtype=tf.int64) metrics = actor.eval_metrics(buffer_size=num_episodes) eval_actor = actor.Actor(env, policy, train_step, metrics=metrics, episodes_per_run=num_episodes, summary_dir=log_dir) checkpoint_list = _get_checkpoints_to_evaluate(evaluated_checkpoints, saved_model_dir) latest_eval_step = policy.get_train_step() while (checkpoint_list or continuous) and latest_eval_step < max_train_step: while not checkpoint_list and continuous: logging.info('Waiting on new checkpoints to become available.') time.sleep(seconds_between_checkpoint_polls) checkpoint_list = _get_checkpoints_to_evaluate( evaluated_checkpoints, saved_model_dir) checkpoint = checkpoint_list.pop() for _ in range(num_retries): try: policy.update_from_checkpoint(checkpoint) break except (tf.errors.OpError, IndexError): logging.warning( 'Encountered an error while evaluating a checkpoint. This can ' 'happen when reading a checkpoint before it is fully written. ' 'Retrying...') time.sleep(seconds_between_checkpoint_polls) else: # This seems to happen rarely. Just skip this checkpoint. logging.error('Failed to evaluate checkpoint after retrying: %s', checkpoint) continue logging.info('Evaluating:\n\tStep:%d\tcheckpoint: %s', policy.get_train_step(), checkpoint) eval_actor.train_step.assign(policy.get_train_step()) train_step = policy.get_train_step() if triggers.ENV_STEP_METADATA_KEY in policy.get_metadata(): env_step = policy.get_metadata()[ triggers.ENV_STEP_METADATA_KEY].numpy() eval_actor.training_env_step = env_step if latest_eval_step <= train_step: eval_actor.run_and_log() latest_eval_step = policy.get_train_step() else: logging.info( 'Skipping over train_step %d to avoid logging backwards in time.', train_step) evaluated_checkpoints.add(checkpoint)
def train_eval( root_dir, env_name='CartPole-v0', # Training params initial_collect_steps=1000, num_iterations=100000, fc_layer_params=(100, ), # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec()) action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec()) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 # Define a helper function to create Dense layers configured with the right # activation and kernel initializer. def dense_layer(num_units): return tf.keras.layers.Dense( num_units, activation=tf.keras.activations.relu, kernel_initializer=tf.keras.initializers.VarianceScaling( scale=2.0, mode='fan_in', distribution='truncated_normal')) # QNetwork consists of a sequence of Dense layers followed by a dense layer # with `num_actions` units to generate one q_value per available action as # it's output. dense_layers = [dense_layer(num_units) for num_units in fc_layer_params] q_values_layer = tf.keras.layers.Dense( num_actions, activation=None, kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03, maxval=0.03), bias_initializer=tf.keras.initializers.Constant(-0.2)) q_net = sequential.Sequential(dense_layers + [q_values_layer]) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name, # Training params train_sequence_length, initial_collect_steps=1000, collect_steps_per_iteration=1, num_iterations=100000, # RNN params. q_network_fn=q_lstm_network, # defaults to q_lstm_network. # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 q_net = q_network_fn(num_actions=num_actions) sequence_length = train_sequence_length + 1 agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, # n-step updates aren't supported with RNNs yet. n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=sequence_length, stride_length=1, pad_end_of_episodes=True) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_steps_per_iteration, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name='Pong-v0', # Training params update_frequency=4, # Number of collect steps per policy update initial_collect_steps=50000, # 50k collect steps num_iterations=50000000, # 50M collect steps # Taken from Rainbow as it's not specified in Mnih,15. max_episode_frames_collect=50000, # env frames observed by the agent max_episode_frames_eval=108000, # env frames observed by the agent # Agent params epsilon_greedy=0.1, epsilon_decay_period=250000, # 1M collect steps / update_frequency batch_size=32, learning_rate=0.00025, n_step_update=1, gamma=0.99, target_update_tau=1.0, target_update_period=2500, # 10k collect steps / update_frequency reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=250000, eval_interval=1000, eval_episodes=30, debug_summaries=True): """Trains and evaluates DQN.""" collect_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_collect, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) eval_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_eval, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 epsilon = tf.compat.v1.train.polynomial_decay( 1.0, train_step, epsilon_decay_period, end_learning_rate=epsilon_greedy) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=create_q_network(num_actions), epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.95, epsilon=0.01, centered=True), td_errors_loss_fn=common.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step, debug_summaries=debug_summaries) table_name = 'uniform_table' table = reverb.Table( table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=update_frequency, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, # Dataset params env_name, data_dir=None, load_pretrained=False, pretrained_model_dir=None, img_pad=4, frame_shape=(84, 84, 3), frame_stack=3, num_augmentations=2, # K and M in DrQ # Training params contrastive_loss_weight=1.0, contrastive_loss_temperature=0.5, image_encoder_representation=True, initial_collect_steps=1000, num_train_steps=3000000, actor_fc_layers=(1024, 1024), critic_joint_fc_layers=(1024, 1024), # Agent params batch_size=256, actor_learning_rate=1e-3, critic_learning_rate=1e-3, alpha_learning_rate=1e-3, encoder_learning_rate=1e-3, actor_update_freq=2, gamma=0.99, target_update_tau=0.01, target_update_period=2, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others checkpoint_interval=10000, policy_save_interval=5000, eval_interval=10000, summary_interval=250, debug_summaries=False, eval_episodes_per_run=10, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" collect_env = env_utils.load_dm_env_for_training(env_name, frame_shape, frame_stack=frame_stack) eval_env = env_utils.load_dm_env_for_eval(env_name, frame_shape, frame_stack=frame_stack) logging.info('Data directory: %s', data_dir) logging.info('Num train steps: %d', num_train_steps) logging.info('Contrastive loss coeff: %.2f', contrastive_loss_weight) logging.info('Contrastive loss temperature: %.4f', contrastive_loss_temperature) logging.info('load_pretrained: %s', 'yes' if load_pretrained else 'no') logging.info('encoder representation: %s', 'yes' if image_encoder_representation else 'no') load_episode_data = (contrastive_loss_weight > 0) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() image_encoder = networks.ImageEncoder(observation_tensor_spec) actor_net = model_utils.Actor( observation_tensor_spec, action_tensor_spec, image_encoder=image_encoder, fc_layers=actor_fc_layers, image_encoder_representation=image_encoder_representation) critic_net = networks.Critic((observation_tensor_spec, action_tensor_spec), image_encoder=image_encoder, joint_fc_layers=critic_joint_fc_layers) critic_net_2 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=image_encoder, joint_fc_layers=critic_joint_fc_layers) target_image_encoder = networks.ImageEncoder(observation_tensor_spec) target_critic_net_1 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=target_image_encoder) target_critic_net_2 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=target_image_encoder) agent = pse_drq_agent.DrQSacModifiedAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, critic_network_2=critic_net_2, target_critic_network=target_critic_net_1, target_critic_network_2=target_critic_net_2, actor_update_frequency=actor_update_freq, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), contrastive_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=encoder_learning_rate), contrastive_loss_weight=contrastive_loss_weight, contrastive_loss_temperature=contrastive_loss_temperature, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, use_log_alpha_in_alpha_loss=False, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step, num_augmentations=num_augmentations) agent.initialize() # Setup the replay buffer. reverb_replay, rb_observer = ( replay_buffer_utils.get_reverb_buffer_and_observer( agent.collect_data_spec, sequence_length=2, replay_capacity=replay_capacity, port=reverb_port)) # pylint: disable=g-long-lambda if num_augmentations == 0: image_aug = lambda traj, meta: (dict( experience=traj, augmented_obs=[], augmented_next_obs=[]), meta) else: image_aug = lambda traj, meta: pse_drq_agent.image_aug( traj, meta, img_pad, num_augmentations) augmented_dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).unbatch().map( image_aug, num_parallel_calls=3) augmented_iterator = iter(augmented_dataset) trajs = augmented_dataset.batch(batch_size).prefetch(50) if load_episode_data: # Load full episodes and zip them episodes = dataset_utils.load_episodes( os.path.join(data_dir, 'episodes2'), img_pad) episode_iterator = iter(episodes) dataset = tf.data.Dataset.zip((trajs, episodes)).prefetch(10) else: dataset = trajs experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) learning_triggers = [ triggers.PolicySavedModelTrigger(saved_model_dir, agent, train_step, interval=policy_save_interval), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] agent_learner = model_utils.Learner( root_dir, train_step, agent, experience_dataset_fn=experience_dataset_fn, triggers=learning_triggers, checkpoint_interval=checkpoint_interval, summary_interval=summary_interval, load_episode_data=load_episode_data, use_kwargs_in_agent_train=True, # Turn off the initialization of the optimizer variables since, the agent # expects different batching for the `training_data_spec` and # `train_argspec` which can't be handled in general by the initialization # logic in the learner. run_optimizer_variable_init=False) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. train_dir = os.path.join(root_dir, learner.TRAIN_DIR) # Code for loading pretrained policy. if load_pretrained: # Note that num_train_steps is same as the max_train_step we want to # load the pretrained policy for our experiments pretrained_policy = model_utils.load_pretrained_policy( pretrained_model_dir, num_train_steps) initial_collect_policy = pretrained_policy agent.policy.update_partial(pretrained_policy) agent.collect_policy.update_partial(pretrained_policy) logging.info('Restored pretrained policy.') else: initial_collect_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, initial_collect_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer], metrics=actor.collect_metrics(buffer_size=10), summary_dir=train_dir, summary_interval=summary_interval, name='CollectActor') # If restarting with train_step > 0, the replay buffer will be empty # except for random experience. Populate the buffer with some on-policy # experience. if load_pretrained or (agent_learner.train_step_numpy > 0): for _ in range(batch_size * 50): collect_actor.run() tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor(eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes_per_run, metrics=actor.eval_metrics(buffer_size=10), summary_dir=os.path.join(root_dir, 'eval'), summary_interval=-1, name='EvalTrainActor') if eval_interval: logging.info('Evaluating.') img_summary( next(augmented_iterator)[0], eval_actor.summary_writer, train_step) if load_episode_data: contrastive_img_summary(next(episode_iterator), agent, eval_actor.summary_writer, train_step) eval_actor.run_and_log() logging.info('Saving operative gin config file.') gin_path = os.path.join(train_dir, 'train_operative_gin_config.txt') with tf.io.gfile.GFile(gin_path, mode='w') as f: f.write(gin.operative_config_str()) logging.info('Training Staring at: %r', train_step.numpy()) while train_step < num_train_steps: collect_actor.run() agent_learner.run(iterations=1) if (not eval_interval) and (train_step % 10000 == 0): img_summary( next(augmented_iterator)[0], agent_learner.train_summary_writer, train_step) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') img_summary( next(augmented_iterator)[0], eval_actor.summary_writer, train_step) if load_episode_data: contrastive_img_summary(next(episode_iterator), agent, eval_actor.summary_writer, train_step) eval_actor.run_and_log()
def train_eval( root_dir, dataset_path, env_name, # Training params tpu=False, use_gpu=False, num_gradient_updates=1000000, actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256, 256), # Agent params batch_size=256, bc_steps=0, actor_learning_rate=3e-5, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, reward_scale_factor=1.0, cql_alpha_learning_rate=3e-4, cql_alpha=5.0, cql_tau=10.0, num_cql_samples=10, reward_noise_variance=0.0, include_critic_entropy_term=False, use_lagrange_cql_alpha=True, log_cql_alpha_clipping=None, softmax_temperature=1.0, # Data params reward_shift=0.0, action_clipping=None, use_trajectories=False, data_shuffle_buffer_size_per_record=1, data_shuffle_buffer_size=100, data_num_shards=1, data_block_length=10, data_parallel_reads=None, data_parallel_calls=10, data_prefetch=10, data_cycle_length=10, # Others policy_save_interval=10000, eval_interval=10000, summary_interval=1000, learner_iterations_per_call=1, eval_episodes=10, debug_summaries=False, summarize_grads_and_vars=False, seed=None): """Trains and evaluates CQL-SAC.""" logging.info('Training CQL-SAC on: %s', env_name) tf.random.set_seed(seed) np.random.seed(seed) # Load environment. env = load_d4rl(env_name) tf_env = tf_py_environment.TFPyEnvironment(env) strategy = strategy_utils.get_strategy(tpu, use_gpu) if not dataset_path.endswith('.tfrecord'): dataset_path = os.path.join(dataset_path, env_name, '%s*.tfrecord' % env_name) logging.info('Loading dataset from %s', dataset_path) dataset_paths = tf.io.gfile.glob(dataset_path) # Create dataset. with strategy.scope(): dataset = create_tf_record_dataset( dataset_paths, batch_size, shuffle_buffer_size_per_record=data_shuffle_buffer_size_per_record, shuffle_buffer_size=data_shuffle_buffer_size, num_shards=data_num_shards, cycle_length=data_cycle_length, block_length=data_block_length, num_parallel_reads=data_parallel_reads, num_parallel_calls=data_parallel_calls, num_prefetch=data_prefetch, strategy=strategy, reward_shift=reward_shift, action_clipping=action_clipping, use_trajectories=use_trajectories) # Create agent. time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() with strategy.scope(): train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = cql_sac_agent.CqlSacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), cql_alpha=cql_alpha, num_cql_samples=num_cql_samples, include_critic_entropy_term=include_critic_entropy_term, use_lagrange_cql_alpha=use_lagrange_cql_alpha, cql_alpha_learning_rate=cql_alpha_learning_rate, target_update_tau=5e-3, target_update_period=1, random_seed=seed, cql_tau=cql_tau, reward_noise_variance=reward_noise_variance, num_bc_steps=bc_steps, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=reward_scale_factor, gradient_clipping=None, log_cql_alpha_clipping=log_cql_alpha_clipping, softmax_temperature=softmax_temperature, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() # Create learner. saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) collect_env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger(saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={ triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric }), triggers.StepPerSecondLogTrigger(train_step, interval=100) ] cql_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn=lambda: dataset, triggers=learning_triggers, summary_interval=summary_interval, strategy=strategy) # Create actor for evaluation. tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor(env, eval_greedy_policy, train_step, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), episodes_per_run=eval_episodes) # Run. dummy_trajectory = trajectory.mid((), (), (), 0., 1.) num_learner_iterations = int(num_gradient_updates / learner_iterations_per_call) for _ in range(num_learner_iterations): # Mimic collecting environment steps since we loaded a static dataset. for _ in range(learner_iterations_per_call): collect_env_step_metric(dummy_trajectory) cql_learner.run(iterations=learner_iterations_per_call) if eval_interval and train_step.numpy() % eval_interval == 0: eval_actor.run_and_log()
env_step_metric = py_metrics.EnvironmentSteps() collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join(tempdir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) eval_actor = actor.Actor( eval_env, eval_policy, train_step, episodes_per_run=num_eval_episodes, metrics=actor.eval_metrics(num_eval_episodes), summary_dir=os.path.join(tempdir, 'eval'), ) saved_model_dir = os.path.join(tempdir, learner.POLICY_SAVED_MODEL_DIR) # Triggers to save the agent's policy checkpoints. learning_triggers = [ triggers.PolicySavedModelTrigger(saved_model_dir, tf_agent, train_step, interval=policy_save_interval), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(tempdir,
def train_eval( root_dir, strategy: tf.distribute.Strategy, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=10000, replay_buffer_save_interval=100000, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) _, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) actor_net = create_sequential_actor_network( actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec) critic_net = create_sequential_critic_network( obs_fc_layer_units=critic_obs_fc_layers, action_fc_layer_units=critic_action_fc_layers, joint_fc_layer_units=critic_joint_fc_layers) with strategy.scope(): train_step = train_utils.create_train_step() agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR, learner.REPLAY_BUFFER_CHECKPOINT_DIR) reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer( path=reverb_checkpoint_dir) reverb_server = reverb.Server([table], port=reverb_port, checkpointer=reverb_checkpointer) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.ReverbCheckpointTrigger( train_step, interval=replay_buffer_save_interval, reverb_client=reverb_replay.py_client), # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption. triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()