def _generate_replay_buffer(self, rb_cls): stack_count = 4 shape = (15, 15, stack_count) single_shape = (15, 15, 1) observation_spec = array_spec.ArraySpec(shape, np.int32, 'obs') time_step_spec = ts.time_step_spec(observation_spec) action_spec = policy_step.PolicyStep(array_spec.BoundedArraySpec( shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')) self._trajectory_spec = trajectory.from_transition( time_step_spec, action_spec, time_step_spec) self._capacity = 32 self._replay_buffer = rb_cls( data_spec=self._trajectory_spec, capacity=self._capacity) # Generate N frames: the value of pixels is the frame index. # The observations will be generated by stacking K frames out of those N, # generating some redundancies between the observations. single_frames = [] frame_count = 100 for k in range(frame_count): single_frames.append(np.full(single_shape, k, dtype=np.int32)) # Add stack of frames to the replay buffer. time_steps = [] for k in range(len(single_frames) - stack_count + 1): observation = np.concatenate(single_frames[k:k + stack_count], axis=-1) time_steps.append(ts.transition(observation, reward=0.0)) self._transition_count = len(time_steps) - 1 dummy_action = policy_step.PolicyStep(np.int32(0)) for k in range(self._transition_count): self._replay_buffer.add_batch(nest_utils.batch_nested_array( trajectory.from_transition( time_steps[k], dummy_action, time_steps[k + 1])))
def _collect_step(self, time_step, metric_observers, train=False): """Run a single step (or 2 steps on life loss) in the environment.""" if train: policy = self._collect_policy else: policy = self._eval_policy with self._action_timer: action_step = policy.action(time_step) with self._step_timer: next_time_step = self._env.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) if next_time_step.is_last() and not self.game_over(): traj = traj._replace(discount=np.array([1.0], dtype=np.float32)) if train: self._store_to_rb(traj) # When AtariPreprocessing.terminal_on_life_loss is True, we receive LAST # time_steps when lives are lost but the game is not over.In this mode, the # replay buffer and agent's policy must see the life loss as a LAST step # and the subsequent step as a FIRST step. However, we do not want to # actually terminate the episode and metrics should be computed as if all # steps were MID steps, since life loss is not actually a terminal event # (it is mostly a trick to make it easier to propagate rewards backwards by # shortening episode durations from the agent's perspective). if next_time_step.is_last() and not self.game_over(): # Update metrics as if this is a mid-episode step. next_time_step = ts.transition(next_time_step.observation, next_time_step.reward) self._observe( metric_observers, trajectory.from_transition(time_step, action_step, next_time_step)) # Produce the next step as if this is the first step of an episode and # store to RB as such. The next_time_step will be a MID time step. reward = time_step.reward time_step = ts.restart(next_time_step.observation) with self._action_timer: action_step = policy.action(time_step) with self._step_timer: next_time_step = self._env.step(action_step.action) if train: self._store_to_rb( trajectory.from_transition(time_step, action_step, next_time_step)) # Update metrics as if this is a mid-episode step. time_step = ts.transition(time_step.observation, reward) traj = trajectory.from_transition(time_step, action_step, next_time_step) self._observe(metric_observers, traj) return next_time_step
def loop_body(counter, time_step, policy_state): """Runs a step in environment. While loop will call multiple times. Args: counter: Episode counters per batch index. Shape [batch_size]. time_step: TimeStep tuple with elements shape [batch_size, ...]. policy_state: Poicy state tensor shape [batch_size, policy_state_dim]. Pass empty tuple for non-recurrent policies. Returns: loop_vars for next iteration of tf.while_loop. """ action_step = self._policy.action(time_step, policy_state) policy_state = action_step.state next_time_step = self._env.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) observer_ops = [observer(traj) for observer in self._observers] with tf.control_dependencies([tf.group(observer_ops)]): time_step, next_time_step, policy_state = nest.map_structure( tf.identity, (time_step, next_time_step, policy_state)) # While loop counter is only incremented for episode reset episodes. counter += tf.cast(traj.is_boundary(), dtype=tf.int32) return [counter, next_time_step, policy_state]
def run(self, time_step, policy_state=()): """Run policy in environment given initial time_step and policy_state. Args: time_step: The initial time_step. policy_state: The initial policy_state. Returns: A tuple (final time_step, final policy_state). """ num_steps = 0 num_episodes = 0 while num_steps < self._max_steps and num_episodes < self._max_episodes: action_step = self.policy.action(time_step, policy_state) next_time_step = self.env.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) for observer in self.observers: observer(traj) num_episodes += np.sum(traj.is_last()) num_steps += np.sum(~traj.is_boundary()) time_step = next_time_step policy_state = action_step.state return time_step, policy_state
def _setup_specs(self): self._policy_step_spec = policy_step.PolicyStep( action=self._action_spec, state=self._policy_state_spec, info=self._info_spec) self._trajectory_spec = trajectory.from_transition( self._time_step_spec, self._policy_step_spec, self._time_step_spec)
def _fill_replay_buffer(self): # Generate N frames: the value of pixels is the frame index. # The observations will be generated by stacking K frames out of those N, # generating some redundancies between the observations. single_frames = [] frame_count = 100 for k in range(frame_count): single_frames.append(np.full(self._single_shape, k, dtype=np.int32)) # Add stack of frames to the replay buffer. time_steps = [] for k in range(len(single_frames) - self._stack_count + 1): observation = np.concatenate(single_frames[k:k + self._stack_count], axis=-1) time_steps.append(ts.transition(observation, reward=0.0)) self._transition_count = len(time_steps) - 1 dummy_action = policy_step.PolicyStep(np.int32(0)) for k in range(self._transition_count): self._replay_buffer.add_batch( nest_utils.batch_nested_array( trajectory.from_transition(time_steps[k], dummy_action, time_steps[k + 1])))
def collect_step(environment, policy): time_step = environment.current_time_step() action_step = policy.action(time_step) next_time_step = environment.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) return traj
def collect_step(environment, policy): time_step = environment.current_time_step() action_step = policy.action(time_step) next_time_step = environment.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) # Add trajectory to the replay buffer replay_buffer.add_batch(traj)
def collect_step(env, time_step, py_policy, replay_buffer): """Steps the environment and collects experience into the replay buffer.""" action_step = py_policy.action(time_step) next_time_step = env.step(action_step.action) if not time_step.is_last(): traj = trajectory.from_transition(time_step, action_step, next_time_step) replay_buffer.add_batch(traj) return next_time_step
def collect_step(environment, policy): time_step = environment.current_time_step() action_step = policy.action(time_step) next_time_step = environment.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) # Add trajectory to the replay buffer replay_buffer.add_batch(traj)
def make_replay_buffer(tf_env): """Default replay buffer factory.""" time_step_spec = tf_env.time_step_spec() action_step_spec = policy_step.PolicyStep( tf_env.action_spec(), (), tensor_spec.TensorSpec((), tf.int32)) trajectory_spec = trajectory.from_transition(time_step_spec, action_step_spec, time_step_spec) return tf_uniform_replay_buffer.TFUniformReplayBuffer(trajectory_spec, batch_size=1)
def _make_ppo_trajectory_spec(self, action_distribution_params_spec): # Make policy_step_spec with action_spec, empty tuple for policy_state, and # (act_log_prob_spec, value_pred_spec, action_distribution_params_spec) for # info. policy_step_spec = policy_step.PolicyStep( action=self.action_spec(), state=self._policy.policy_state_spec(), info=action_distribution_params_spec) trajectory_spec = trajectory.from_transition(self.time_step_spec(), policy_step_spec, self.time_step_spec()) return trajectory_spec
def _initial_collect(self): """Collect initial experience before training begins.""" logging.info('Collecting initial experience...') time_step_spec = ts.time_step_spec(self._env.observation_spec()) random_policy = random_py_policy.RandomPyPolicy( time_step_spec, self._env.action_spec()) time_step = self._env.reset() while self._replay_buffer.size < self._initial_collect_steps: if self.game_over(): time_step = self._env.reset() action_step = random_policy.action(time_step) next_time_step = self._env.step(action_step.action) self._replay_buffer.add_batch(trajectory.from_transition( time_step, action_step, next_time_step)) time_step = next_time_step logging.info('Done.')
def _create_replay_buffer(self, rb_cls): self._stack_count = 4 self._single_shape = (15, 15, 1) shape = (15, 15, self._stack_count) observation_spec = array_spec.ArraySpec(shape, np.int32, 'obs') time_step_spec = ts.time_step_spec(observation_spec) action_spec = policy_step.PolicyStep( array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')) self._trajectory_spec = trajectory.from_transition( time_step_spec, action_spec, time_step_spec) self._capacity = 32 self._replay_buffer = rb_cls(data_spec=self._trajectory_spec, capacity=self._capacity)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) trajectory_spec = trajectory.from_transition( time_step=tf_env.time_step_spec(), action_step=policy_step.PolicyStep(action=tf_env.action_spec()), next_time_step=tf_env.time_step_spec()) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=trajectory_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) q_net = q_network.QNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy() collect_policy = tf_agent.collect_policy() collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) global_step = tf.train.get_or_create_global_step() initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. tf.logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.' % initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) time_step = None policy_state = () timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience, train_step_counter=global_step) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step.numpy(), train_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) return train_loss
def metric_observer(time_step, action, next_time_step, policy_state): action_step = policy_step.PolicyStep(action, policy_state, ()) traj = trajectory.from_transition(time_step, action_step, next_time_step) return metric(traj)
def __init__( self, root_dir, env_name, num_iterations=200, max_episode_frames=108000, # ALE frames terminal_on_life_loss=False, conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), fc_layer_params=(512, ), # Params for collect initial_collect_steps=80000, # ALE frames epsilon_greedy=0.01, epsilon_decay_period=1000000, # ALE frames replay_buffer_capacity=1000000, # Params for train train_steps_per_iteration=1000000, # ALE frames update_period=16, # ALE frames target_update_tau=1.0, target_update_period=32000, # ALE frames batch_size=32, learning_rate=2.5e-4, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval do_eval=True, eval_steps_per_iteration=500000, # ALE frames eval_epsilon_greedy=0.001, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple Atari train and eval for DQN. Args: root_dir: Directory to write log files to. env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0). num_iterations: Number of train/eval iterations to run. max_episode_frames: Maximum length of a single episode, in ALE frames. terminal_on_life_loss: Whether to simulate an episode termination when a life is lost. conv_layer_params: Params for convolutional layers of QNetwork. fc_layer_params: Params for fully connected layers of QNetwork. initial_collect_steps: Number of frames to ALE frames to process before beginning to train. Since this is in ALE frames, there will be initial_collect_steps/4 items in the RB when training starts. epsilon_greedy: Final epsilon value to decay to for training. epsilon_decay_period: Period over which to decay epsilon, from 1.0 to epsilon_greedy (defined above). replay_buffer_capacity: Maximum number of items to store in the RB. train_steps_per_iteration: Number of ALE frames to run through for each iteration of training. update_period: Run a train operation every update_period ALE frames. target_update_tau: Coeffecient for soft target network updates (1.0 == hard updates). target_update_period: Period, in ALE frames, to copy the live network to the target network. batch_size: Number of frames to include in each training batch. learning_rate: RMS optimizer learning rate. gamma: Discount for future rewards. reward_scale_factor: Scaling factor for rewards. gradient_clipping: Norm length to clip gradients. do_eval: If True, run an eval every iteration. If False, skip eval. eval_steps_per_iteration: Number of ALE frames to run through for each iteration of training. eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 == totally greedy policy). log_interval: Log stats to the terminal every log_interval training steps. summary_interval: Write TF summaries every summary_interval training steps. summaries_flush_secs: Flush summaries to disk every summaries_flush_secs seconds. debug_summaries: If True, write additional summaries for debugging (see dqn_agent for which summaries are written). summarize_grads_and_vars: Include gradients in summaries. eval_metrics_callback: A callback function that takes (metric_dict, global_step) as parameters. Called after every eval with the results of the evaluation. """ self._update_period = update_period / ATARI_FRAME_SKIP self._train_steps_per_iteration = (train_steps_per_iteration / ATARI_FRAME_SKIP) self._do_eval = do_eval self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP self._eval_epsilon_greedy = eval_epsilon_greedy self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP self._summary_interval = summary_interval self._num_iterations = num_iterations self._log_interval = log_interval self._eval_metrics_callback = eval_metrics_callback with gin.unlock_config(): gin.bind_parameter('AtariPreprocessing.terminal_on_life_loss', terminal_on_life_loss) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() self._train_summary_writer = train_summary_writer self._eval_summary_writer = None if self._do_eval: eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_summary_writer = eval_summary_writer self._eval_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(lambda: tf.math.equal( self._global_step % self._summary_interval, 0)): self._env = suite_atari.load( env_name, max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) self._env = batched_py_environment.BatchedPyEnvironment( [self._env]) observation_spec = tensor_spec.from_spec( self._env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.from_spec(self._env.action_spec()) with tf.device('/cpu:0'): epsilon = tf.compat.v1.train.polynomial_decay( 1.0, self._global_step, epsilon_decay_period / ATARI_FRAME_SKIP / self._update_period, end_learning_rate=epsilon_greedy) with tf.device('/gpu:0'): optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) q_net = AtariQNetwork(observation_spec, action_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon, target_update_tau=target_update_tau, target_update_period=(target_update_period / ATARI_FRAME_SKIP / self._update_period), td_errors_loss_fn=dqn_agent.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=self._global_step) self._collect_policy = py_tf_policy.PyTFPolicy( tf_agent.collect_policy) if self._do_eval: self._eval_policy = py_tf_policy.PyTFPolicy( epsilon_greedy_policy.EpsilonGreedyPolicy( policy=tf_agent.policy, epsilon=self._eval_epsilon_greedy)) py_observation_spec = self._env.observation_spec() py_time_step_spec = ts.time_step_spec(py_observation_spec) py_action_spec = policy_step.PolicyStep( self._env.action_spec()) data_spec = trajectory.from_transition(py_time_step_spec, py_action_spec, py_time_step_spec) self._replay_buffer = ( py_hashed_replay_buffer.PyHashedReplayBuffer( data_spec=data_spec, capacity=replay_buffer_capacity)) with tf.device('/cpu:0'): ds = self._replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(4) # TODO(b/123242430): Add prefetch_to_device back here in order to # improve performance once errors are resolved. self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds) experience = self._ds_itr.get_next() with tf.device('/gpu:0'): self._train_op = tf_agent.train(experience) self._env_steps_metric = py_metrics.EnvironmentSteps() self._step_metrics = [ py_metrics.NumberOfEpisodes(), self._env_steps_metric, ] self._train_metrics = self._step_metrics + [ py_metrics.AverageReturnMetric(buffer_size=10), py_metrics.AverageEpisodeLengthMetric(buffer_size=10), ] # The _train_phase_metrics average over an entire train iteration, # rather than the rolling average of the last 10 episodes. self._train_phase_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._iteration_metric = py_metrics.CounterMetric( name='Iteration') # Summaries written from python should run every time they are # generated. with tf.compat.v2.summary.record_if(True): self._steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') self._steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=self._steps_per_second_ph) for metric in self._train_metrics: metric.tf_summaries(step_metrics=self._step_metrics) for metric in self._train_phase_metrics: metric.tf_summaries( step_metrics=(self._iteration_metric, )) self._iteration_metric.tf_summaries() if self._do_eval: with eval_summary_writer.as_default(): for metric in self._eval_metrics: metric.tf_summaries( step_metrics=(self._iteration_metric, )) self._train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=self._global_step, optimizer=optimizer, metrics=metric_utils.MetricsGroup( self._train_metrics + self._train_phase_metrics + [self._iteration_metric], 'train_metrics')) self._policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=self._global_step) self._rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._init_agent_op = tf_agent.initialize()