def load_env(mode='default'): env = None if (mode == 'default'): env = suite_gym.load(config.ENV_NAME) elif (mode == 'nonskipping'): max_episode_steps = 27000 # <=> 108k ALE frames since 1 step = 4 frames env = suite_atari.load( config.ENV_NONSKIPPING_NAME, max_episode_steps=max_episode_steps, gym_env_wrappers=[AtariPreprocessing, FrameStack4]) return env
def _env(self, cfg): """create a tf_env""" if cfg["lib"] == "gym": return tf_py_environment.TFPyEnvironment( suite_gym.load(cfg["name"])) elif cfg["lib"] == "atari": return tf_py_environment.TFPyEnvironment( suite_atari.load(cfg["name"], max_episode_steps=50000, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)) else: raise NotImplementedError
# ---------------------Parser---------------------------------- parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'test'], default='train') # parser.add_argument('--env-name', type=str, default='BreakoutDeterministic-v4') # parser.add_argument('--weights', type=str, default=None) args = parser.parse_args() # -------------------Environment------------------------------- # env_name = 'BreakoutDeterministic-v4' env_name = 'Pong-v0' DEFAULT_ATARI_GYM_WRAPPERS = (atari_preprocessing.AtariPreprocessing, ) DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING = DEFAULT_ATARI_GYM_WRAPPERS + ( atari_wrappers.FrameStack4, ) env = suite_atari.load( env_name, gym_env_wrappers=DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) env.reset() print("!!!!!!!!!!!!!!!!!!!") print('Observation Spec:') print(env.time_step_spec().observation) print('Reward Spec:') print(env.time_step_spec().reward) print('Action Spec:') print(env.action_spec()) print("!!!!!!!!!!!!!!!!!!!") # train_py_env = suite_gym.load(env_name) # eval_py_env = suite_gym.load(env_name) # train_env = tf_py_environment.TFPyEnvironment(train_py_env) # eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
def __init__( self, root_dir, env_name, num_iterations=200, max_episode_frames=108000, # ALE frames terminal_on_life_loss=False, conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), fc_layer_params=(512, ), # Params for collect initial_collect_steps=80000, # ALE frames epsilon_greedy=0.01, epsilon_decay_period=1000000, # ALE frames replay_buffer_capacity=1000000, # Params for train train_steps_per_iteration=1000000, # ALE frames update_period=16, # ALE frames target_update_tau=1.0, target_update_period=32000, # ALE frames batch_size=32, learning_rate=2.5e-4, n_step_update=2, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval do_eval=True, eval_steps_per_iteration=500000, # ALE frames eval_epsilon_greedy=0.001, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None): """A simple Atari train and eval for DQN. Args: root_dir: Directory to write log files to. env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0). num_iterations: Number of train/eval iterations to run. max_episode_frames: Maximum length of a single episode, in ALE frames. terminal_on_life_loss: Whether to simulate an episode termination when a life is lost. conv_layer_params: Params for convolutional layers of QNetwork. fc_layer_params: Params for fully connected layers of QNetwork. initial_collect_steps: Number of frames to ALE frames to process before beginning to train. Since this is in ALE frames, there will be initial_collect_steps/4 items in the replay buffer when training starts. epsilon_greedy: Final epsilon value to decay to for training. epsilon_decay_period: Period over which to decay epsilon, from 1.0 to epsilon_greedy (defined above). replay_buffer_capacity: Maximum number of items to store in the replay buffer. train_steps_per_iteration: Number of ALE frames to run through for each iteration of training. update_period: Run a train operation every update_period ALE frames. target_update_tau: Coeffecient for soft target network updates (1.0 == hard updates). target_update_period: Period, in ALE frames, to copy the live network to the target network. batch_size: Number of frames to include in each training batch. learning_rate: RMS optimizer learning rate. n_step_update: The number of steps to consider when computing TD error and TD loss. Applies standard single-step updates when set to 1. gamma: Discount for future rewards. reward_scale_factor: Scaling factor for rewards. gradient_clipping: Norm length to clip gradients. do_eval: If True, run an eval every iteration. If False, skip eval. eval_steps_per_iteration: Number of ALE frames to run through for each iteration of evaluation. eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 == totally greedy policy). log_interval: Log stats to the terminal every log_interval training steps. summary_interval: Write TF summaries every summary_interval training steps. summaries_flush_secs: Flush summaries to disk every summaries_flush_secs seconds. debug_summaries: If True, write additional summaries for debugging (see dqn_agent for which summaries are written). summarize_grads_and_vars: Include gradients in summaries. eval_metrics_callback: A callback function that takes (metric_dict, global_step) as parameters. Called after every eval with the results of the evaluation. """ self._update_period = update_period / ATARI_FRAME_SKIP self._train_steps_per_iteration = (train_steps_per_iteration / ATARI_FRAME_SKIP) self._do_eval = do_eval self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP self._eval_epsilon_greedy = eval_epsilon_greedy self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP self._summary_interval = summary_interval self._num_iterations = num_iterations self._log_interval = log_interval self._eval_metrics_callback = eval_metrics_callback with gin.unlock_config(): gin.bind_parameter(('tf_agents.environments.atari_preprocessing.' 'AtariPreprocessing.terminal_on_life_loss'), terminal_on_life_loss) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() self._train_summary_writer = train_summary_writer self._eval_summary_writer = None if self._do_eval: self._eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(lambda: tf.math.equal( self._global_step % self._summary_interval, 0)): self._env = suite_atari.load( env_name, max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) self._env = batched_py_environment.BatchedPyEnvironment( [self._env]) observation_spec = tensor_spec.from_spec( self._env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.from_spec(self._env.action_spec()) with tf.device('/cpu:0'): epsilon = tf.compat.v1.train.polynomial_decay( 1.0, self._global_step, epsilon_decay_period / ATARI_FRAME_SKIP / self._update_period, end_learning_rate=epsilon_greedy) with tf.device('/gpu:0'): optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) categorical_q_net = AtariCategoricalQNetwork( observation_spec, action_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) agent = categorical_dqn_agent.CategoricalDqnAgent( time_step_spec, action_spec, categorical_q_network=categorical_q_net, optimizer=optimizer, epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=(target_update_period / ATARI_FRAME_SKIP / self._update_period), gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=self._global_step) self._collect_policy = py_tf_policy.PyTFPolicy( agent.collect_policy) if self._do_eval: self._eval_policy = py_tf_policy.PyTFPolicy( epsilon_greedy_policy.EpsilonGreedyPolicy( policy=agent.policy, epsilon=self._eval_epsilon_greedy)) py_observation_spec = self._env.observation_spec() py_time_step_spec = ts.time_step_spec(py_observation_spec) py_action_spec = policy_step.PolicyStep( self._env.action_spec()) data_spec = trajectory.from_transition(py_time_step_spec, py_action_spec, py_time_step_spec) self._replay_buffer = py_hashed_replay_buffer.PyHashedReplayBuffer( data_spec=data_spec, capacity=replay_buffer_capacity) with tf.device('/cpu:0'): ds = self._replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) ds = ds.apply( tf.data.experimental.prefetch_to_device('/gpu:0')) with tf.device('/gpu:0'): self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds) experience = self._ds_itr.get_next() self._train_op = agent.train(experience) self._env_steps_metric = py_metrics.EnvironmentSteps() self._step_metrics = [ py_metrics.NumberOfEpisodes(), self._env_steps_metric, ] self._train_metrics = self._step_metrics + [ py_metrics.AverageReturnMetric(buffer_size=10), py_metrics.AverageEpisodeLengthMetric(buffer_size=10), ] # The _train_phase_metrics average over an entire train iteration, # rather than the rolling average of the last 10 episodes. self._train_phase_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._iteration_metric = py_metrics.CounterMetric( name='Iteration') # Summaries written from python should run every time they are # generated. with tf.compat.v2.summary.record_if(True): self._steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') self._steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=self._steps_per_second_ph, step=self._global_step) for metric in self._train_metrics: metric.tf_summaries(train_step=self._global_step, step_metrics=self._step_metrics) for metric in self._train_phase_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._iteration_metric.tf_summaries( train_step=self._global_step) if self._do_eval: with self._eval_summary_writer.as_default(): for metric in self._eval_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=agent, global_step=self._global_step, optimizer=optimizer, metrics=metric_utils.MetricsGroup( self._train_metrics + self._train_phase_metrics + [self._iteration_metric], 'train_metrics')) self._policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=agent.policy, global_step=self._global_step) self._rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._init_agent_op = agent.initialize()
def _setup(self): self.env = suite_atari.load( self.environment_name, max_episode_steps=self.max_episode_steps, gym_env_wrappers=[AtariPreprocessing, FrameStack4])
initial_policy = 0 if len(sys.argv) == 2: initial_policy = int(sys.argv[1]) policy = None if initial_policy != 0: policy = load_policy_from_disk(initial_policy) logging.getLogger().setLevel(logging.INFO) max_episode_steps = 27000 environment_name = "BreakoutNoFrameskip-v4" env = suite_atari.load(environment_name, max_episode_steps=max_episode_steps, gym_env_wrappers=[AtariPreprocessing, FrameStack4]) tf.random.set_seed(42) np.random.seed(42) env.seed(42) tf_env = TFPyEnvironment(env) # for i in range(max_episode_steps): # timestep = tf_env.step(np.array(np.random.choice([0,1,2,3]), dtype= np.int32)) # if(timestep.is_last()): # print("game over", i) # break # tf_env.render(mode = "human") # time.sleep(0.2)
def train_eval( root_dir, env_name='Pong-v0', # Training params update_frequency=4, # Number of collect steps per policy update initial_collect_steps=50000, # 50k collect steps num_iterations=50000000, # 50M collect steps # Taken from Rainbow as it's not specified in Mnih,15. max_episode_frames_collect=50000, # env frames observed by the agent max_episode_frames_eval=108000, # env frames observed by the agent # Agent params epsilon_greedy=0.1, epsilon_decay_period=250000, # 1M collect steps / update_frequency batch_size=32, learning_rate=0.00025, n_step_update=1, gamma=0.99, target_update_tau=1.0, target_update_period=2500, # 10k collect steps / update_frequency reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=250000, eval_interval=1000, eval_episodes=30, debug_summaries=True): """Trains and evaluates DQN.""" collect_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_collect, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) eval_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_eval, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 epsilon = tf.compat.v1.train.polynomial_decay( 1.0, train_step, epsilon_decay_period, end_learning_rate=epsilon_greedy) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=create_q_network(num_actions), epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.95, epsilon=0.01, centered=True), td_errors_loss_fn=common.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step, debug_summaries=debug_summaries) table_name = 'uniform_table' table = reverb.Table( table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=update_frequency, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def testAtariObsSpec(self): env = suite_atari.load('Pong-v0') self.assertIsInstance(env, py_environment.Base) self.assertEqual(np.uint8, env.observation_spec().dtype) self.assertEqual((84, 84, 1), env.observation_spec().shape)
def testAtariActionSpec(self): env = suite_atari.load('Pong-v0') self.assertIsInstance(env, py_environment.Base) self.assertEqual(np.int64, env.action_spec().dtype) self.assertEqual((), env.action_spec().shape)
def testAtariEnvRegistered(self): env = suite_atari.load('Pong-v0') self.assertIsInstance(env, py_environment.Base) self.assertIsInstance(env, atari_wrappers.AtariTimeLimit)
video_interval = 10000 fc_layer_params = ( 256, 256, ) # -------------------Environment------------------------------- # env_name = 'Pong-ram-v0' # env_name = 'Breakout-ram-v0' # env_name = 'BreakoutNoFrameskip-v4' # env_name = 'CartPole-v0' env_name = suite_atari.game('Breakout', 'ram') # print(env_name) env = suite_atari.load(env_name, max_episode_steps=1000, gym_env_wrappers=[atari_wrappers.FireOnReset]) # print(env) # env.reset() # print('Observation Spec:') # print(env.time_step_spec().observation) # print('Action Spec:') # print(env.action_spec()) # # train_py_env = suite_gym.load(env_name) # # eval_py_env = suite_gym.load(env_name) train_py_env = suite_atari.load(env_name, max_episode_steps=10000, gym_env_wrappers=[atari_wrappers.FireOnReset]) eval_py_env = suite_atari.load(env_name, max_episode_steps=10000,
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(100, ), critic_obs_fc_layers=(64, ), critic_action_fc_layers=(100, ), critic_joint_fc_layers=(100, ), # Params for collect initial_collect_steps=50, collect_steps_per_iteration=1, replay_buffer_capacity=1000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=5000, # Params for summaries and logging train_checkpoint_interval=1000, policy_checkpoint_interval=2000, rb_checkpoint_interval=3000, log_interval=50, summary_interval=50, summaries_flush_secs=1000, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" global interrupted dir_key = root_dir root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): env_name = 'Pong-v0' # train_env = T4TFEnv(metrics_key=dir_key) eval_env = suite_atari.load(env_name) train_env = ActionLoggerWrapper(env=suite_atari.load(env_name)) train_env.reset() eval_env.reset() tf_env = tf_py_environment.TFPyEnvironment(train_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, # conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), conv_layer_params=[(8, (3, 3), 1)], # lstm_size=(40,), continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_conv_layer_params=[(8, (3, 3), 1)], # observation_conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), # lstm_size=(40,), observation_fc_layer_params=None, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) print(actor_net) print(critic_net) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): if interrupted: train_env.interrupted = True start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): if interrupted: train_env.interrupted = True experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: # actor_net.save(os.path.join(root_dir, 'actor'), save_format='tf') # critic_net.save(os.path.join(root_dir, 'critic'), save_format='tf') logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) obs_shape = time_step.observation.shape print(obs_shape) tf.compat.v2.summary.image( name='input_image', data=np.reshape( time_step.observation, (1, obs_shape[1], obs_shape[2], obs_shape[3])), step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() print('global step: %d' % global_step_val) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def main(_): # Environment env_name = "Breakout-v4" train_num_parallel_environments = 5 max_steps_per_episode = 1000 # Replay buffer replay_buffer_capacity = 50000 init_replay_buffer = 500 # Driver collect_steps_per_iteration = 1 * train_num_parallel_environments # Training train_batch_size = 32 train_iterations = 100000 train_summary_interval = 200 train_checkpoint_interval = 200 # Evaluation eval_num_parallel_environments = 5 eval_summary_interval = 500 eval_num_episodes = 20 # File paths path = pathlib.Path(__file__) parent_dir = path.parent.resolve() folder_name = path.stem + time.strftime("_%Y%m%d_%H%M%S") train_checkpoint_dir = str(parent_dir / folder_name / "train_checkpoint") train_summary_dir = str(parent_dir / folder_name / "train_summary") eval_summary_dir = str(parent_dir / folder_name / "eval_summary") # Parallel training environment tf_env = TFPyEnvironment( ParallelPyEnvironment([ lambda: suite_atari.load( env_name, env_wrappers= [lambda env: TimeLimit(env, duration=max_steps_per_episode)], gym_env_wrappers=[AtariPreprocessing, FrameStack4], ) ] * train_num_parallel_environments)) tf_env.seed([42] * tf_env.batch_size) tf_env.reset() # Parallel evaluation environment eval_tf_env = TFPyEnvironment( ParallelPyEnvironment([ lambda: suite_atari.load( env_name, env_wrappers= [lambda env: TimeLimit(env, duration=max_steps_per_episode)], gym_env_wrappers=[AtariPreprocessing, FrameStack4], ) ] * eval_num_parallel_environments)) eval_tf_env.seed([42] * eval_tf_env.batch_size) eval_tf_env.reset() # Creating the Deep Q-Network preprocessing_layer = keras.layers.Lambda( lambda obs: tf.cast(obs, np.float32) / 255.) conv_layer_params = [(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)] fc_layer_params = [512] q_net = QNetwork(tf_env.observation_spec(), tf_env.action_spec(), preprocessing_layers=preprocessing_layer, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) # Creating the DQN Agent optimizer = keras.optimizers.RMSprop(lr=2.5e-4, rho=0.95, momentum=0.0, epsilon=0.00001, centered=True) epsilon_fn = keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=1.0, # initial ε decay_steps=2500000, end_learning_rate=0.01) # final ε global_step = tf.compat.v1.train.get_or_create_global_step() agent = DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=optimizer, target_update_period=200, td_errors_loss_fn=keras.losses.Huber(reduction="none"), gamma=0.99, # discount factor train_step_counter=global_step, epsilon_greedy=lambda: epsilon_fn(global_step)) agent.initialize() # Creating the Replay Buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) # Observer: Replay Buffer Observer replay_buffer_observer = replay_buffer.add_batch # Observer: Training Metrics train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric(batch_size=tf_env.batch_size), ] # Creating the Collect Driver collect_driver = DynamicStepDriver(tf_env, agent.collect_policy, observers=[replay_buffer_observer] + train_metrics, num_steps=collect_steps_per_iteration) # Initialize replay buffer initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()) init_driver = DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer_observer, ShowProgress()], num_steps=init_replay_buffer) final_time_step, final_policy_state = init_driver.run() # Creating the Dataset dataset = replay_buffer.as_dataset(sample_batch_size=train_batch_size, num_steps=2, num_parallel_calls=3).prefetch(3) # Optimize by wrapping some of the code in a graph using TF function. collect_driver.run = function(collect_driver.run) agent.train = function(agent.train) print("\n\n++++++++++++++++++++++++++++++++++\n") # Create checkpoint train_checkpointer = Checkpointer( ckpt_dir=train_checkpoint_dir, max_to_keep=1, agent=agent, # replay_buffer=replay_buffer, global_step=global_step, # metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics') ) # Restore checkpoint # train_checkpointer.initialize_or_restore() # Summary writers and metrics train_summary_writer = tf.summary.create_file_writer(train_summary_dir) eval_summary_writer = tf.summary.create_file_writer(eval_summary_dir) eval_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(batch_size=eval_tf_env.batch_size, buffer_size=eval_num_episodes), tf_metrics.AverageEpisodeLengthMetric( batch_size=eval_tf_env.batch_size, buffer_size=eval_num_episodes) ] # Create evaluate callback function eval_callback = evaluate(eval_metrics=eval_metrics, eval_tf_env=eval_tf_env, eval_policy=agent.policy, eval_num_episodes=eval_num_episodes, train_step=global_step, eval_summary_writer=eval_summary_writer) # Train agent train_agent(tf_env=tf_env, train_iterations=train_iterations, global_step=global_step, agent=agent, dataset=dataset, collect_driver=collect_driver, train_metrics=train_metrics, train_checkpointer=train_checkpointer, train_checkpoint_interval=train_checkpoint_interval, train_summary_writer=train_summary_writer, train_summary_interval=train_summary_interval, eval_summary_interval=eval_summary_interval, eval_callback=eval_callback) print("\n\n++++++++++ END OF TF_AGENTS RL TRAINING ++++++++++\n\n")
def breakout_v4(seed=42): env = suite_gym.load("Breakout-v4") env.seed(seed) env.reset() repeating_env = ActionRepeat(env, times=4) for name in dir(tf_agents.environments.wrappers): obj = getattr(tf_agents.environments.wrappers, name) if hasattr(obj, "__base__") and issubclass( obj, tf_agents.environments.wrappers.PyEnvironmentBaseWrapper): print("{:27s} {}".format(name, obj.__doc__.split("\n")[0])) limited_repeating_env = suite_gym.load( "Breakout-v4", gym_env_wrappers=[partial(TimeLimit, max_episode_steps=10000)], env_wrappers=[partial(ActionRepeat, times=4)], ) max_episode_steps = 27000 # <=> 108k ALE frames since 1 step = 4 frames environment_name = "BreakoutNoFrameskip-v4" env = suite_atari.load( environment_name, max_episode_steps=max_episode_steps, gym_env_wrappers=[AtariPreprocessing, FrameStack4], ) env.seed(42) env.reset() time_step = env.step(np.array(1)) # FIRE for _ in range(4): time_step = env.step(np.array(3)) # LEFT def plot_observation(obs): # Since there are only 3 color channels, you cannot display 4 frames # with one primary color per frame. So this code computes the delta between # the current frame and the mean of the other frames, and it adds this delta # to the red and blue channels to get a pink color for the current frame. obs = obs.astype(np.float32) img_ = obs[..., :3] current_frame_delta = np.maximum( obs[..., 3] - obs[..., :3].mean(axis=-1), 0.0) img_[..., 0] += current_frame_delta img_[..., 2] += current_frame_delta img_ = np.clip(img_ / 150, 0, 1) plt.imshow(img_) plt.axis("off") plt.figure(figsize=(6, 6)) plot_observation(time_step.observation) plt.tight_layout() plt.savefig("./images/preprocessed_breakout_plot.png", format="png", dpi=300) plt.show() tf_env = TFPyEnvironment(env) preprocessing_layer = keras.layers.Lambda( lambda obs: tf.cast(obs, np.float32) / 255.0) conv_layer_params = [(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)] fc_layer_params = [512] q_net = QNetwork( tf_env.observation_spec(), tf_env.action_spec(), preprocessing_layers=preprocessing_layer, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params, ) # see TF-agents issue #113 # optimizer = keras.optimizers.RMSprop(lr=2.5e-4, rho=0.95, momentum=0.0, # epsilon=0.00001, centered=True) train_step = tf.Variable(0) update_period = 4 # run a training step every 4 collect steps optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=2.5e-4, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) epsilon_fn = keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=1.0, # initial ε decay_steps=250000 // update_period, # <=> 1,000,000 ALE frames end_learning_rate=0.01, ) # final ε agent = DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=optimizer, target_update_period=2000, # <=> 32,000 ALE frames td_errors_loss_fn=keras.losses.Huber(reduction="none"), gamma=0.99, # discount factor train_step_counter=train_step, epsilon_greedy=lambda: epsilon_fn(train_step), ) agent.initialize() from tf_agents.replay_buffers import tf_uniform_replay_buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=1000000) replay_buffer_observer = replay_buffer.add_batch class ShowProgress: def __init__(self, total): self.counter = 0 self.total = total def __call__(self, trajectory): if not trajectory.is_boundary(): self.counter += 1 if self.counter % 100 == 0: print("\r{}/{}".format(self.counter, self.total), end="") from tf_agents.metrics import tf_metrics train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] from tf_agents.eval.metric_utils import log_metrics import logging logging.getLogger().setLevel(logging.INFO) log_metrics(train_metrics) from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver collect_driver = DynamicStepDriver( tf_env, agent.collect_policy, observers=[replay_buffer_observer] + train_metrics, num_steps=update_period, ) # collect 4 steps for each training iteration from tf_agents.policies.random_tf_policy import RandomTFPolicy initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()) init_driver = DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch, ShowProgress(20000)], num_steps=20000, ) # <=> 80,000 ALE frames final_time_step, final_policy_state = init_driver.run()