def _make_env(): # function to create a tf environment return tf_py_environment.TFPyEnvironment( suite_gym.load("MountainCarContinuous-v0"))
discount=1, spec_dtype_map=None, auto_reset=True, render_kwargs=None, ) eval_py_env = gym_wrapper.GymWrapper( ChangeRewardMountainCarEnv(), discount=1, spec_dtype_map=None, auto_reset=True, render_kwargs=None, ) train_py_env = wrappers.TimeLimit(train_py_env, duration=200) eval_py_env = wrappers.TimeLimit(eval_py_env, duration=200) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) RL_train(train_env, eval_env, fc_layer_params = (48,64,), name = '_train') """Set num_iterations to 50000+ will let agent converge to less than 110 steps""" iterations = range(len(returns)) plt.plot(iterations, returns) plt.ylabel('Average Return') plt.xlabel('Iterations') iterations = range(len(steps)) plt.plot(iterations, steps) plt.ylabel('Average Step') plt.xlabel('Iterations')
def main(unused_argv): tf.compat.v1.enable_v2_behavior() # The trainer only runs with V2 enabled. with tf.device('/CPU:0'): # due to b/128333994 if FLAGS.normalize_reward_fns: action_reward_fns = ( environment_utilities.normalized_sliding_linear_reward_fn_generator( CONTEXT_DIM, NUM_ACTIONS, REWARD_NOISE_VARIANCE)) else: action_reward_fns = ( environment_utilities.sliding_linear_reward_fn_generator( CONTEXT_DIM, NUM_ACTIONS, REWARD_NOISE_VARIANCE)) env = sspe.StationaryStochasticPyEnvironment( functools.partial( environment_utilities.context_sampling_fn, batch_size=BATCH_SIZE, context_dim=CONTEXT_DIM), action_reward_fns, batch_size=BATCH_SIZE) environment = tf_py_environment.TFPyEnvironment(env) optimal_reward_fn = functools.partial( environment_utilities.tf_compute_optimal_reward, per_action_reward_fns=action_reward_fns) optimal_action_fn = functools.partial( environment_utilities.tf_compute_optimal_action, per_action_reward_fns=action_reward_fns) network = q_network.QNetwork( input_tensor_spec=environment.time_step_spec().observation, action_spec=environment.action_spec(), fc_layer_params=LAYERS) if FLAGS.agent == 'LinUCB': agent = lin_ucb_agent.LinearUCBAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), alpha=AGENT_ALPHA, dtype=tf.float32) elif FLAGS.agent == 'LinTS': agent = lin_ts_agent.LinearThompsonSamplingAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), alpha=AGENT_ALPHA, dtype=tf.float32) elif FLAGS.agent == 'epsGreedy': agent = neural_epsilon_greedy_agent.NeuralEpsilonGreedyAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), reward_network=network, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR), epsilon=EPSILON) elif FLAGS.agent == 'Mix': emit_policy_info = policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN agent_linucb = lin_ucb_agent.LinearUCBAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), emit_policy_info=emit_policy_info, alpha=AGENT_ALPHA, dtype=tf.float32) agent_lints = lin_ts_agent.LinearThompsonSamplingAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), emit_policy_info=emit_policy_info, alpha=AGENT_ALPHA, dtype=tf.float32) agent_epsgreedy = neural_epsilon_greedy_agent.NeuralEpsilonGreedyAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), reward_network=network, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR), emit_policy_info=emit_policy_info, epsilon=EPSILON) agent = exp3_mixture_agent.Exp3MixtureAgent( (agent_linucb, agent_lints, agent_epsgreedy)) regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn) suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric( optimal_action_fn) trainer.train( root_dir=FLAGS.root_dir, agent=agent, environment=environment, training_loops=TRAINING_LOOPS, steps_per_loop=STEPS_PER_LOOP, additional_metrics=[regret_metric, suboptimal_arms_metric])
def main(): env = SquigglesEnvironment(num_notes=2) env = tf_py_environment.TFPyEnvironment(env) N = env.observation_spec().shape[0] _, the_hits, actions = get_beats(N, ITER, env, policy_saved_filename) fpsClock = pygame.time.Clock() pygame.init() DISPLAY = pygame.display.set_mode((WIDTH, HEIGHT)) pygame.display.set_caption("Squigs") """ Here's different sounds to use , "sound_effects/19827__cabled-mess__glockenspiel/348882__cabled-mess__glockenspiel-18-g3-04.wav", "sound_effects/19827__cabled-mess__glockenspiel/348889__cabled-mess__glockenspiel-23-a3-05.wav", "sound_effects/19827__cabled-mess__glockenspiel/348895__cabled-mess__glockenspiel-24-bb3-01.wav", "sound_effects/19827__cabled-mess__glockenspiel/348904__cabled-mess__glockenspiel-29-b3-02.wav", "sound_effects/19827__cabled-mess__glockenspiel/348914__cabled-mess__glockenspiel-39-d4-04.wav", "sound_effects/19827__cabled-mess__glockenspiel/348918__cabled-mess__glockenspiel-40-e4-01.wav", "sound_effects/19827__cabled-mess__glockenspiel/348921__cabled-mess__glockenspiel-43-f4-01.wav" "sound_effects/19827__cabled-mess__glockenspiel/348870__cabled-mess__glockenspiel-04-d3-04.wav", "sound_effects/19827__cabled-mess__glockenspiel/348871__cabled-mess__glockenspiel-06-e3-01.wav", "sound_effects/19827__cabled-mess__glockenspiel/348878__cabled-mess__glockenspiel-11-f3-02.wav", "sound_effects/19827__cabled-mess__glockenspiel/348908__cabled-mess__glockenspiel-33-c4-02.wav" """ """ "sound_effects/9008__jamieblam__metallophone/146077__jamieblam__1d-hard.wav" "sound_effects/9008__jamieblam__metallophone/146079__jamieblam__1c-hard.wav" """ """ "sound_effects/21030__samulis__vsco-2-ce-percussion-marimba/373577__samulis__marimba-b3-marimba-hit-outrigger-b2-loud-01.wav" "sound_effects/21030__samulis__vsco-2-ce-percussion-marimba/373582__samulis__marimba-e-2-marimba-hit-outrigger-f1-loud-01.wav" """ """ "sound_effects/9008__jamieblam__metallophone/146096__jamieblam__2e-hard.wav", , "sound_effects/9008__jamieblam__metallophone/146100__jamieblam__2f-hard.wav" , , , "sound_effects/9008__jamieblam__metallophone/146082__jamieblam__1f-hard.wav", , "sound_effects/9008__jamieblam__metallophone/146091__jamieblam__2c-hard.wav", "sound_effects/9008__jamieblam__metallophone/146093__jamieblam__2b-hard.wav" """ env_slider = SoundSlider( sound_list=the_hits, position_x=0, position_y=HEIGHT // 3, height=HEIGHT // 7, width=WIDTH, color=(100, 100, 255), soundfile_name= [ #"sound_effects/9008__jamieblam__metallophone/146079__jamieblam__1c-hard.wav"#"sound_effects/9008__jamieblam__metallophone/146097__jamieblam__2d-hard.wav" "sound_effects/drum11.wav" ]) agent_slider = SoundSlider( sound_list=actions, position_x=0, position_y=HEIGHT * 2 // 3, height=HEIGHT // 7, width=WIDTH, color=(255, 150, 30), soundfile_name= [ #"sound_effects/9008__jamieblam__metallophone/146084__jamieblam__1e-hard.wav" #"sound_effects/9008__jamieblam__metallophone/146087__jamieblam__1g-hard.wav" "sound_effects/first_clap.wav" ]) barrier = SoundBarrier(position_x=WIDTH * 2 // 3, position_y=HEIGHT // 4, height=HEIGHT * 5 // 8, width=WIDTH // 56, color=(255, 100, 100), slider_list=[env_slider, agent_slider]) start = False while True: DISPLAY.fill((0, 0, 0)) pygame.event.pump() for event in pygame.event.get(): if event.type == pygame.KEYDOWN: if event.key == pygame.K_SPACE: start = True if event.type == QUIT: pygame.quit() sys.exit() if start: env_slider.update() agent_slider.update() barrier.update() env_slider.render(DISPLAY) agent_slider.render(DISPLAY) barrier.render(DISPLAY) pygame.display.update() fpsClock.tick(FPS)
def main(_): tf.random.set_seed(FLAGS.seed) if FLAGS.models_dir is None: raise ValueError('You must set a value for models_dir.') env = suite_mujoco.load(FLAGS.env_name) env.seed(FLAGS.seed) env = tf_py_environment.TFPyEnvironment(env) sac = actor_lib.Actor(env.observation_spec().shape[0], env.action_spec()) model_filename = os.path.join(FLAGS.models_dir, 'DM-' + FLAGS.env_name, str(FLAGS.model_seed), '1000000') sac.load_weights(model_filename) if FLAGS.std is None: if 'Reacher' in FLAGS.env_name: std = 0.5 elif 'Ant' in FLAGS.env_name: std = 0.4 elif 'Walker' in FLAGS.env_name: std = 2.0 else: std = 0.75 else: std = FLAGS.std def get_action(state): _, action, log_prob = sac(state, std) return action, log_prob dataset = dict( model_filename=model_filename, behavior_std=std, trajectories=dict( states=[], actions=[], log_probs=[], next_states=[], rewards=[], masks=[])) for i in range(FLAGS.num_episodes): timestep = env.reset() trajectory = dict( states=[], actions=[], log_probs=[], next_states=[], rewards=[], masks=[]) while not timestep.is_last(): action, log_prob = get_action(timestep.observation) next_timestep = env.step(action) trajectory['states'].append(timestep.observation) trajectory['actions'].append(action) trajectory['log_probs'].append(log_prob) trajectory['next_states'].append(next_timestep.observation) trajectory['rewards'].append(next_timestep.reward) trajectory['masks'].append(next_timestep.discount) timestep = next_timestep for k, v in trajectory.items(): dataset['trajectories'][k].append(tf.concat(v, 0).numpy()) logging.info('%d trajectories', i + 1) data_save_dir = os.path.join(FLAGS.save_dir, FLAGS.env_name, str(FLAGS.model_seed)) if not tf.io.gfile.isdir(data_save_dir): tf.io.gfile.makedirs(data_save_dir) save_filename = os.path.join(data_save_dir, f'dualdice_{FLAGS.std}.pckl') with tf.io.gfile.GFile(save_filename, 'wb') as f: pickle.dump(dataset, f)
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_py_env = env_load_fn(eval_env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def testPyenv(self): py_env = PYEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(py_env) self.assertIsInstance(tf_env.pyenv, batched_py_environment.BatchedPyEnvironment)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', eval_env_name=None, num_iterations=1000000, # Params for networks. actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), num_parallel_environments=1, # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=256, train_sequence_length=20, critic_learning_rate=3e-4, actor_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=_DEFAULT_REWARD_SCALE, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for RNN SAC on DM control.""" root_dir = os.path.expanduser(root_dir) if reward_scale_factor == _DEFAULT_REWARD_SCALE: # Use value recommended by https://arxiv.org/abs/1801.01290 if env_name.startswith('Humanoid'): reward_scale_factor = 20.0 else: reward_scale_factor = 5.0 root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] env_load_fn = functools.partial(suite_dm_control.load, task_name=task_name, env_wrappers=env_wrappers) if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size * num_parallel_environments, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, agent_class=dqn_agent.DqnAgent, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) q_net = q_network.QNetwork(tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate), # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, _ = iterator.get_next() train_op = tf_agent.train(experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable( [train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) collect_time = 0 train_time = 0 steps_per_second_ph = tf.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() train_time += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) tf.logging.info('%.3f steps/sec' % steps_per_sec) tf.logging.info( 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, debug_summaries=debug_summaries, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # TODO(oars): Refactor drivers to better handle policy states. Remove the # policy reset and passing down an empyt policy state to the driver. collect_policy = tf_agent.collect_policy policy_state = collect_policy.get_initial_state(tf_env.batch_size) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run(policy_state=policy_state) policy_state = collect_policy.get_initial_state(tf_env.batch_size) collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run( policy_state=policy_state) # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_op = tf_agent.train(experience=trajectories) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
print(env.action_spec()) time_step = env.reset() print('Time step:') print(time_step) action = np.array(1, dtype=np.int32) next_time_step = env.step(action) print('Next time step:') print(next_time_step) train_py_env = suite_gym.load(env_name) eval_py_env = suite_gym.load(env_name) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) # AGENT fc_layer_params = (100, 50) action_tensor_spec = tensor_spec.from_spec(env.action_spec()) num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 # Define a helper function to create Dense layers configured with the right # activation and kernel initializer. def dense_layer(num_units): return tf.keras.layers.Dense( num_units, activation=tf.keras.activations.relu,
def main(): parser = argparse.ArgumentParser() ## Essential parameters parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model stats and checkpoints will be written." ) parser.add_argument("--env", default=None, type=str, required=True, help="The environment to train the agent on") parser.add_argument("--approx_env_boundaries", default=False, type=bool, help="Whether to get the env boundaries approximately") parser.add_argument("--max_horizon", default=5, type=int) parser.add_argument("--atari", default=False, type=bool, help="Gets some data Types correctly") ##agent parameters parser.add_argument("--reward_scale_factor", default=1.0, type=float) parser.add_argument("--debug_summaries", default=True, type=bool) parser.add_argument("--summarize_grads_and_vars", default=True, type=bool) ##transformer parameters parser.add_argument("--d_model", default=64, type=int) parser.add_argument("--num_layers", default=3, type=int) parser.add_argument("--dff", default=256, type=int) ##Training parameters parser.add_argument('--num_iterations', type=int, default=150000, help="steps in the env") parser.add_argument('--num_iparallel', type=int, default=1, help="how many envs should run in parallel") parser.add_argument("--collect_steps_per_iteration", default=1, type=int) parser.add_argument("--train_steps_per_iteration", default=1, type=int) ## Other parameters parser.add_argument("--num_eval_episodes", default=10, type=int) parser.add_argument("--eval_interval", default=1000, type=int) parser.add_argument("--log_interval", default=1000, type=int) parser.add_argument("--summary_interval", default=10, type=int) parser.add_argument("--run_graph_mode", default=True, type=bool) parser.add_argument("--checkpoint_interval", default=10000, type=int) parser.add_argument("--summary_flush", default=10, type=int) #what does this exactly do? # HP opt params parser.add_argument("--doubleQ", default=True, type=bool, help="Whether to use a DoubleQ agent") parser.add_argument("--custom_last_layer", default=True, type=bool) parser.add_argument("--custom_layer_init", default=0.5, type=float) parser.add_argument("--initial_collect_steps", default=1000, type=int) parser.add_argument("--loss_function", default="element_wise_huber_loss", type=str) parser.add_argument("--num_heads", default=4, type=int) parser.add_argument("--normalize_env", default=False, type=bool) parser.add_argument('--custom_lr_schedule', default="No", type=str, help="whether to use a custom LR schedule") parser.add_argument("--epsilon_greedy", default=0.1, type=float) parser.add_argument("--target_update_period", default=1, type=int) parser.add_argument( "--rate", default=0.1, type=float ) # dropout rate (might be not used depending on the q network) #Setting this to 0.0 somehow break the code. Not relevant tho just select a network without dropout parser.add_argument("--gradient_clipping", default=None, type=bool) parser.add_argument("--replay_buffer_max_length", default=100000, type=int) parser.add_argument("--batch_size", default=32, type=int) parser.add_argument("--learning_rate", default=1e-5, type=float) parser.add_argument("--encoder_type", default=2, type=int, help="Which Type of encoder is used for the model") parser.add_argument("--layer_type", default=1, type=int, help="Which Type of layer is used for the encoder") parser.add_argument("--target_update_tau", default=1, type=float) parser.add_argument("--gamma", default=1.0, type=float) args = parser.parse_args() # List of encoder modules which we can use to change encoder based on a variable global_step = tf.compat.v1.train.get_or_create_global_step() baseEnv = gym.make(args.env) env = suite_gym.load(args.env) eval_env = suite_gym.load(args.env) if args.normalize_env == True: env = NormalizeWrapper(env, args.approx_env_boundaries, args.env) eval_env = NormalizeWrapper(eval_env, args.approx_env_boundaries, args.env) env = PyhistoryWrapper(env, args.max_horizon, args.atari) eval_env = PyhistoryWrapper(eval_env, args.max_horizon, args.atari) tf_env = tf_py_environment.TFPyEnvironment(env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env) q_net = QTransformer(tf_env.observation_spec(), baseEnv.action_space.n, num_layers=args.num_layers, d_model=args.d_model, num_heads=args.num_heads, dff=args.dff, rate=args.rate, encoderType=args.encoder_type, enc_layer_type=args.layer_type, max_horizon=args.max_horizon, custom_layer=args.custom_layer_init, custom_last_layer=args.custom_last_layer) if args.custom_lr_schedule == "Transformer": # builds a lr schedule according to the original usage for the transformer learning_rate = CustomSchedule(args.d_model, int(args.num_iterations / 10)) optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) elif args.custom_lr_schedule == "Transformer_low": # builds a lr schedule according to the original usage for the transformer learning_rate = CustomSchedule( int(args.d_model / 2), int(args.num_iterations / 10)) # --> same schedule with lower general lr optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) elif args.custom_lr_schedule == "Linear": lrs = LinearCustomSchedule(learning_rate, args.num_iterations) optimizer = tf.keras.optimizers.Adam(lrs, beta_1=0.9, beta_2=0.98, epsilon=1e-9) else: optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=args.learning_rate) if args.loss_function == "element_wise_huber_loss": lf = element_wise_huber_loss elif args.loss_function == "element_wise_squared_loss": lf = element_wise_squared_loss if args.doubleQ == False: # global step count agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=args.epsilon_greedy, target_update_tau=args.target_update_tau, target_update_period=args.target_update_period, td_errors_loss_fn=lf, optimizer=optimizer, gamma=args.gamma, reward_scale_factor=args.reward_scale_factor, gradient_clipping=args.gradient_clipping, debug_summaries=args.debug_summaries, summarize_grads_and_vars=args.summarize_grads_and_vars, train_step_counter=global_step) else: agent = dqn_agent.DdqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=args.epsilon_greedy, target_update_tau=args.target_update_tau, td_errors_loss_fn=lf, target_update_period=args.target_update_period, optimizer=optimizer, gamma=args.gamma, reward_scale_factor=args.reward_scale_factor, gradient_clipping=args.gradient_clipping, debug_summaries=args.debug_summaries, summarize_grads_and_vars=args.summarize_grads_and_vars, train_step_counter=global_step) agent.initialize() count_weights(q_net) train_eval(root_dir=args.output_dir, tf_env=tf_env, eval_tf_env=eval_tf_env, agent=agent, num_iterations=args.num_iterations, initial_collect_steps=args.initial_collect_steps, collect_steps_per_iteration=args.collect_steps_per_iteration, replay_buffer_capacity=args.replay_buffer_max_length, train_steps_per_iteration=args.train_steps_per_iteration, batch_size=args.batch_size, use_tf_functions=args.run_graph_mode, num_eval_episodes=args.num_eval_episodes, eval_interval=args.eval_interval, train_checkpoint_interval=args.checkpoint_interval, policy_checkpoint_interval=args.checkpoint_interval, rb_checkpoint_interval=args.checkpoint_interval, log_interval=args.log_interval, summary_interval=args.summary_interval, summaries_flush_secs=args.summary_flush) pickle.dump(args, open(args.output_dir + "/training_args.p", "wb")) print("Successfully trained and evaluation.")
def train_eval( root_dir, env_name='MultiGrid-Empty-5x5-v0', env_load_fn=multiagent_gym_suite.load, random_seed=0, # Architecture params actor_fc_layers=(64, 64), value_fc_layers=(64, 64), lstm_size=(64, ), conv_filters=64, conv_kernel=3, direction_fc=5, entropy_regularization=0., use_attention_networks=False, # Specialized agents inactive_agent_ids=tuple(), # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=5, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=2, learning_rate=1e-4, # Params for eval num_eval_episodes=2, eval_interval=5, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, log_interval=10, summary_interval=10, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None, reinit_checkpoint_dir=None, debug=True): """A simple train and eval for PPO.""" tf.compat.v1.enable_v2_behavior() if root_dir is None: raise AttributeError('train_eval requires a root_dir.') if debug: logging.info('In debug mode, turning tf_functions off') use_tf_functions = False for a in inactive_agent_ids: logging.info('Fixing and not training agent %d', a) # Load multiagent gym environment and determine number of agents gym_env = env_load_fn(env_name) n_agents = gym_env.n_agents # Set up logging root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ multiagent_metrics.AverageReturnMetric(n_agents, buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) logging.info('Creating %d environments...', num_parallel_environments) wrappers = [] if use_attention_networks: wrappers = [ lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size) ] eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name, gym_kwargs=dict(seed=random_seed), gym_env_wrappers=wrappers)) # pylint: disable=g-complex-comprehension tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ functools.partial(env_load_fn, environment_name=env_name, gym_env_wrappers=wrappers, gym_kwargs=dict(seed=random_seed * 1234 + i)) for i in range(num_parallel_environments) ])) logging.info('Preparing to train...') environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ multiagent_metrics.AverageReturnMetric( n_agents, batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments) ] logging.info('Creating agent...') tf_agent = multiagent_ppo.MultiagentPPO( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids, use_attention_networks=use_attention_networks) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy logging.info('Allocating replay buffer ...') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) logging.info('RB capacity: %i', replay_buffer.capacity) # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is # reinitialized. The other agents are novices. # Otherwise, all agents are reinitialized from train_dir. if reinit_checkpoint_dir: reinit_checkpointer = common.Checkpointer( ckpt_dir=reinit_checkpoint_dir, agent=tf_agent, ) reinit_checkpointer.initialize_or_restore() temp_dir = os.path.join(train_dir, 'tmp') agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1], ) agent_checkpointer.save(global_step=0) tf_agent = multiagent_ppo.MultiagentPPO( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids, non_learning_agents=list(range(n_agents - 1)), use_attention_networks=use_attention_networks) agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1]) agent_checkpointer.initialize_or_restore() tf.io.gfile.rmtree(temp_dir) eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=multiagent_metrics.MultiagentMetricsGroup( train_metrics, 'train_metrics')) if not reinit_checkpoint_dir: train_checkpointer.initialize_or_restore() logging.info('Successfully initialized train checkpointer') policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) logging.info('Successfully initialized policy saver.') print('Using TFDriver') if use_attention_networks: collect_driver = utils.StateTFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) else: collect_driver = tf_driver.TFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 # Save operative config as late as possible to include used configurables. if global_step.numpy() == 0: config_filename = os.path.join( train_dir, 'operative_config-{}.gin'.format(global_step.numpy())) with tf.io.gfile.GFile(config_filename, 'wb') as f: f.write(gin.operative_config_str()) total_episodes = 0 logging.info('Commencing train loop!') while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() # Evaluation if global_step_val % eval_interval == 0: if debug: logging.info('Performing evaluation at step %d', global_step_val) results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics) # Collect data if debug: logging.info('Collecting at step %d', global_step_val) start_time = time.time() time_step = tf_env.reset() policy_state = collect_policy.get_initial_state(tf_env.batch_size) if use_attention_networks: # Attention networks require previous policy state to compute attention # weights. time_step.observation['policy_state'] = ( policy_state['actor_network_state'][0], policy_state['actor_network_state'][1]) collect_driver.run(time_step, policy_state) collect_time += time.time() - start_time total_episodes += collect_episodes_per_iteration if debug: logging.info('Have collected a total of %d episodes', total_episodes) # Train if debug: logging.info('Training at step %d', global_step_val) start_time = time.time() total_loss, extra_loss = train_step() replay_buffer.clear() train_time += time.time() - start_time # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: logging.info( 'Loss diverged for too many timesteps, breaking...') break else: loss_divergence_counter = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, total loss = %f', global_step_val, total_loss) for a in range(n_agents): if not inactive_agent_ids or a not in inactive_agent_ids: logging.info('Loss for agent %d = %f', a, extra_loss[a].loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics)
def train_level(level, consecutive_wins_flag=5, collect_random_steps=True, max_iterations=num_iterations): """ create DQN agent to train a level of the game :param level: level of the game :param consecutive_wins_flag: number of consecutive wins in evaluation signifying the training is done :param collect_random_steps: whether to collect random steps at the beginning, always set to 'True' when the global step is 0. :param max_iterations: stop the training when it reaches the max iteration regardless of the result """ global saving_time cells = query_level(level) size = len(cells) env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells)) eval_env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells)) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) fc_layer_params = (neuron_num_mapper[size], ) q_net = q_network.QNetwork(env.observation_spec()[0], env.action_spec(), fc_layer_params=fc_layer_params, activation_fn=tf.keras.activations.relu) global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DdqnAgent( env.time_step_spec(), env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=global_step, observation_and_action_constraint_splitter=GameEnv. obs_and_mask_splitter) agent.initialize() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=env.batch_size, max_length=replay_buffer_max_length) # drivers collect_driver = dynamic_step_driver.DynamicStepDriver( env, policy=agent.collect_policy, observers=[replay_buffer.add_batch], num_steps=collect_steps_per_iteration) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_driver = dynamic_episode_driver.DynamicEpisodeDriver( eval_env, policy=agent.policy, observers=eval_metrics, num_episodes=num_eval_episodes) # checkpointer of the replay buffer and policy train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( dir_path, 'trained_policies/train_lv{0}'.format(level)), max_to_keep=1, agent=agent, policy=agent.policy, global_step=global_step, replay_buffer=replay_buffer) # policy saver tf_policy_saver = policy_saver.PolicySaver(agent.policy) train_checkpointer.initialize_or_restore() # optimize by wrapping some of the code in a graph using TF function agent.train = common.function(agent.train) collect_driver.run = common.function(collect_driver.run) eval_driver.run = common.function(eval_driver.run) # collect initial replay data if collect_random_steps: initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec=env.time_step_spec(), action_spec=env.action_spec(), observation_and_action_constraint_splitter=GameEnv. obs_and_mask_splitter) dynamic_step_driver.DynamicStepDriver( env, initial_collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) # train the model until 5 consecutive evaluation have reward greater than 100 consecutive_eval_win = 0 train_iterations = 0 while consecutive_eval_win < consecutive_wins_flag and train_iterations < max_iterations: collect_driver.run() for _ in range(collect_steps_per_iteration): experience, _ = next(iterator) train_loss = agent.train(experience).loss # evaluate the training at intervals step = global_step.numpy() if step % eval_interval == 0: eval_driver.run() average_return = eval_metrics[0].result().numpy() average_len = eval_metrics[1].result().numpy() print("level: {0} step: {1} AverageReturn: {2} AverageLen: {3}". format(level, step, average_return, average_len)) # evaluate consecutive wins if average_return > 10: consecutive_eval_win += 1 else: consecutive_eval_win = 0 if step % save_interval == 0: start = time.time() train_checkpointer.save(global_step=step) saving_time += time.time() - start train_iterations += 1 # save the policy train_checkpointer.save(global_step=global_step.numpy()) tf_policy_saver.save( os.path.join(dir_path, 'trained_policies/policy_lv{0}'.format(level)))
def get_env_and_policy(load_dir, env_name, alpha, env_seed=0, tabular_obs=False): if env_name == 'taxi': env = taxi.Taxi(tabular_obs=tabular_obs) env.seed(env_seed) policy_fn, policy_info_spec = taxi.get_taxi_policy(load_dir, env, alpha=alpha, py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'grid': env = navigation.GridWalk(tabular_obs=tabular_obs) env.seed(env_seed) policy_fn, policy_info_spec = navigation.get_navigation_policy( env, epsilon_explore=0.1 + 0.6 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'low_rank': env = low_rank.LowRank() env.seed(env_seed) policy_fn, policy_info_spec = low_rank.get_low_rank_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'tree': env = tree.Tree(branching=2, depth=10) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'lowrank_tree': env = tree.Tree(branching=2, depth=3, duplicate=10) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name.startswith('bandit'): num_arms = int(env_name[6:]) if len(env_name) > 6 else 2 env = bandit.Bandit(num_arms=num_arms) env.seed(env_seed) policy_fn, policy_info_spec = bandit.get_bandit_policy( env, epsilon_explore=1 - alpha, py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'small_tree': env = tree.Tree(branching=2, depth=3, loop=True) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'CartPole-v0': tf_env, policy = get_env_and_dqn_policy( env_name, os.path.join(load_dir, 'CartPole-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.3 + 0.15 * (1 - alpha)) elif env_name == 'cartpole': # Infinite-horizon cartpole. tf_env, policy = get_env_and_dqn_policy( 'CartPole-v0', os.path.join(load_dir, 'CartPole-v0-250', 'train', 'policy'), env_seed=env_seed, epsilon=0.3 + 0.15 * (1 - alpha)) env = InfiniteCartPole() tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) elif env_name == 'FrozenLake-v0': tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0', os.path.join( load_dir, 'FrozenLake-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.2 * (1 - alpha), ckpt_file='ckpt-100000') elif env_name == 'frozenlake': # Infinite-horizon frozenlake. tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0', os.path.join( load_dir, 'FrozenLake-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.2 * (1 - alpha), ckpt_file='ckpt-100000') env = InfiniteFrozenLake() tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) elif env_name in ['Reacher-v2', 'reacher']: if env_name == 'Reacher-v2': env = suites.load_mujoco(env_name) else: env = gym_wrapper.GymWrapper(InfiniteReacher()) env.seed(env_seed) tf_env = tf_py_environment.TFPyEnvironment(env) sac_policy = get_sac_policy(tf_env) directory = os.path.join(load_dir, 'Reacher-v2', 'train', 'policy') policy = load_policy(sac_policy, env_name, directory) policy = GaussianPolicy(policy, 0.4 - 0.3 * alpha, emit_log_probability=True) elif env_name == 'HalfCheetah-v2': env = suites.load_mujoco(env_name) env.seed(env_seed) tf_env = tf_py_environment.TFPyEnvironment(env) sac_policy = get_sac_policy(tf_env) directory = os.path.join(load_dir, env_name, 'train', 'policy') policy = load_policy(sac_policy, env_name, directory) policy = GaussianPolicy(policy, 0.2 - 0.1 * alpha, emit_log_probability=True) else: raise ValueError('Unrecognized environment %s.' % env_name) return tf_env, policy
def main(unused_argv): tf.compat.v1.enable_v2_behavior() # The trainer only runs with V2 enabled. class LinearNormalReward(object): def __init__(self, theta): self.theta = theta def __call__(self, x): mu = np.dot(x, self.theta) return np.random.normal(mu, 1) def _global_context_sampling_fn(): return np.random.randint(-10, 10, [4]).astype(np.float32) def _arm_context_sampling_fn(): return np.random.randint(-2, 3, [5]).astype(np.float32) reward_fn = LinearNormalReward(HIDDEN_PARAM) env = sspe.StationaryStochasticPerArmPyEnvironment( _global_context_sampling_fn, _arm_context_sampling_fn, NUM_ACTIONS, reward_fn, batch_size=BATCH_SIZE) environment = tf_py_environment.TFPyEnvironment(env) obs_spec = environment.observation_spec() if FLAGS.network == 'commontower': network = (global_and_arm_feature_network. create_feed_forward_common_tower_network( obs_spec, (4, 3), (3, 4), (4, 2))) elif FLAGS.network == 'dotproduct': network = (global_and_arm_feature_network. create_feed_forward_dot_product_network( obs_spec, (4, 3, 6), (3, 4, 6))) if FLAGS.drop_arm_obs: def drop_arm_feature_fn(traj): transformed_traj = copy.deepcopy(traj) del transformed_traj.observation[ bandit_spec_utils.PER_ARM_FEATURE_KEY] return transformed_traj else: drop_arm_feature_fn = None agent = neural_epsilon_greedy_agent.NeuralEpsilonGreedyAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), reward_network=network, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR), epsilon=EPSILON, accepts_per_arm_features=True, training_data_spec_transformation_fn=drop_arm_feature_fn, emit_policy_info=policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN) optimal_reward_fn = functools.partial(optimal_reward, hidden_param=HIDDEN_PARAM) optimal_action_fn = functools.partial(optimal_action, hidden_param=HIDDEN_PARAM) regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn) suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric( optimal_action_fn) trainer.train(root_dir=FLAGS.root_dir, agent=agent, environment=environment, training_loops=TRAINING_LOOPS, steps_per_loop=STEPS_PER_LOOP, additional_metrics=[regret_metric, suboptimal_arms_metric], training_data_spec_transformation_fn=drop_arm_feature_fn)
def record_env(): gif_path = "./images/test.gif" frames_path = "./images/episode-{i}-timestep-{t}.jpg" gating_bitmap = "./scenario-1/bitmaps/gating_mask.bmp" # pos_init = np.array([2.7, 2.0, 0.0]) # desired start state # pos_end_targ = np.array([-2.6, -1.5, 2.5]) # desired end state # state_init = np.array([2.7, 2.0, 0.0, 0.0, 0.0, 0.0]) # desired start state # state_end_targ = np.array([-2.6, -1.5, 2.5, 0.0, 0.0, 0.0]) # desired end state start_state = np.array( [2.7, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # desired start state target_state = np.array( [-2.6, -1.5, 2.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # desired end state eval_py_env = Quadcopter3DEnv(start_state=start_state, target_state=target_state, gating_bitmap=gating_bitmap) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) # eval_env = py_environment.PyEnvironment(eval_py_env) # eval_env = eval_py_env def policy(i): # return tf.reshape( # tf.constant([-0.1, 0.0, 0.0, 0.0], dtype=float_type), # [1, -1] # # tf.constant([-0.1, 0.01, 0.01, 0.01], dtype=float_type), [1, -1] # ) rng = np.random.default_rng(i) action = (tf_agents.specs.array_spec.sample_bounded_spec( eval_py_env.action_spec(), rng) * 1000) print(action) return tf.reshape(action, [1, -1]) num_episodes = 3 ts = [] for i in range(num_episodes): t = 0 print("Episode {i}".format(i=i)) time_step = eval_env.reset() while not time_step.is_last(): action = policy(i * t) time_step = eval_env.step(action) state = time_step.observation # fig, ax = eval_py_env.render(state) fig = eval_py_env.render(state) fig.savefig(frames_path.format(i=i, t=t)) t = t + 1 ts.append(t - 1) # writer.append_data(imageio.imread(frames_path.format(i=i, t=t))) # kwargs_write = {'fps':1.0, 'quantizer':'nq'} # imageio.mimsave('./powers.gif', [plot_for_offset(i/4, 100) for i in range(10)], fps=1) with imageio.get_writer(gif_path, mode="I", fps=8) as writer: for i, t in zip(range(num_episodes), ts): for t_ in range(t): writer.append_data( imageio.imread(frames_path.format(i=i, t=t_)))
def simulate(): # Set up the environments for the agent to train and test its performance envTrain = ComputerSnake.Snake() envEval = ComputerSnake.Snake(persistence = True) # Convert and wrap in TFPyEnvironment training and evaluation environments train_env = tf_py_environment.TFPyEnvironment(envTrain) eval_env = tf_py_environment.TFPyEnvironment(envEval) # Set up q network with necessary parameters fc_layer_params = (100,) q_net = q_network.QNetwork( train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params ) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) # look up train_step_counter = tf.Variable(0) # Set up and initialize the DQN learning agent. It takes in the time_step spec, # action spec, the q network, the optimizer, a loss function, and train_step_counter agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, # look up td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=train_step_counter ) agent.initialize() # Set up policies the agent can use eval_policy = agent.policy collect_policy = agent.collect_policy # Policy which randomly selects actions for each step random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec()) #Buffer to store previous states replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_max_length) # Dataset generates trajectories with shape [Bx2x...] This is so that the agent has access to both the current # and previous state to compute loss. Parallel calls and prefetching are used to optimize process. dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) # (Optional) Optimize by wrapping some of the code in a graph using TF function. agent.train = common.function(agent.train) # Reset the train step agent.train_step_counter.assign(0) # Evaluate the agent's policy once before training. avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes) # We initially fill the replay buffer with 100 trajectories to help the assistant collect_data(train_env, random_policy, replay_buffer, steps=5000) train_env.reset() # Here, we run the simulation to train the agent scores_list = [] num_steps_arr = [] for currStep in range(num_iterations): # Collect a few steps using collect_policy and save to the replay buffer. for _ in range(collect_steps_per_iteration): collect_step(train_env, agent.collect_policy, replay_buffer) # Sample a batch of data from the buffer and update the agent's network. experience, unused_info = next(iterator) train_loss = agent.train(experience).loss # Number of training steps so far step = agent.train_step_counter.numpy() # Prints every 1000 steps made by the training agent if step % log_interval == 0: print('Moves made = {0}'.format(step)) # Evaluates the agent's policy every 5000 steps, prints results, # ands saves the results for later so they can be plotted if step % eval_interval == 0: avg_return = 0 for i in range(num_eval_episodes): curr_return = compute_avg_return(eval_env, agent.policy, 1) scores_list.append(curr_return) num_steps_arr.append(currStep) avg_return += curr_return avg_return = avg_return/num_eval_episodes print('step = {0}: Average Return = {1}'.format(step, avg_return)) plt.scatter(num_steps_arr, scores_list) plt.xlabel('Number of Steps Trained') plt.ylabel('Score') plt.title('Snake Reinforcement Learning') plt.show()
num_train_iterations = 2000 # @param num_epochs = 5 # @param learning_rate = 1e-4 # @param # Params for summaries and logging log_interval = 1 # @param use_tf_functions = True debug_summaries = False summarize_grads_and_vars = False num_eval_episodes = 10 # @param eval_interval = 10 # @param global_step = tf.compat.v1.train.get_or_create_global_step() tf.compat.v1.set_random_seed(0) eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork(tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net,
policy_save_handler.save("policy") with open("checkpoint/train_loss.pickle", "wb") as f: pickle.dump(all_train_loss, f) with open("checkpoint/all_metrics.pickle", "wb") as f: pickle.dump(all_metrics, f) with open("checkpoint/returns.pickle", "wb") as f: pickle.dump(returns, f) if __name__ == '__main__': # tf_env = tf_py_environment.TFPyEnvironment( # parallel_py_environment.ParallelPyEnvironment( # [BombermanEnvironment] * N_PARALLEL_ENVIRONMENTS # )) tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment()) eval_tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment()) q_net = QNetwork(tf_env.observation_spec(), tf_env.action_spec(), conv_layer_params=[(32, 3, 1), (32, 3, 1)], fc_layer_params=[128, 64, 32]) train_step = tf.Variable(0) update_period = 4 optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) # todo fine tune epsilon_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=1.0, decay_steps=250000 // update_period, end_learning_rate=0.01)
def testMethodPropagation(self): env = self._get_py_env(True, False, batch_size=1) env.foo = mock.Mock() tf_env = tf_py_environment.TFPyEnvironment(env) tf_env.foo() env.foo.assert_called_once()
def train_eval( root_dir, env_name='gym_solventx-v0', eval_env_name=None, env_load_fn=suite_gym.load, # The SAC paper reported: # Hopper and Cartpole results up to 1000000 iters, # Humanoid results up to 10000000 iters, # Other mujoco tasks up to 3000000 iters. num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py # HalfCheetah and Ant take 10000 initial collection steps. # Other mujoco tasks take 1000. # Different choices roughly keep the initial episodes about the same. initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=5000, policy_checkpoint_interval=2500, rb_checkpoint_interval=25000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_env_name = eval_env_name or env_name gym_env = gym.make(env_name, config_file=config_file) py_env = suite_gym.wrap_env(gym_env, max_episode_steps=100) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_gym_env = gym.make(eval_env_name, config_file=config_file) eval_py_env = suite_gym.wrap_env(eval_gym_env, max_episode_steps=100) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) #tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) #eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
max_episode_steps = 5000000 # env = get_env(name='point_mass_full_goal', env_type='y', reward_type='sparse') # env = get_env(name='kitchen') env = get_env(name='playpen_reduced', task_list='rc_o', reward_type='sparse') base_dir = os.path.abspath('experiments/env_logs/playpen_reduced/symmetric/') env_log_dir = os.path.join(base_dir, 'rc_o/traj1/') # env = ResetFreeWrapper(env, reset_goal_frequency=500, full_reset_frequency=max_episode_steps) env = GoalTerminalResetWrapper( env, episodes_before_full_reset=max_episode_steps // 500, goal_reset_frequency=500) # env = Monitor(env, env_log_dir, video_callable=lambda x: x % 1 == 0, force=True) env = wrap_env(env) tf_env = tf_py_environment.TFPyEnvironment(env) tf_env.render = env.render time_step_spec = tf_env.time_step_spec() action_spec = tf_env.action_spec() policy = random_tf_policy.RandomTFPolicy( action_spec=action_spec, time_step_spec=time_step_spec) collect_data_spec = trajectory.Trajectory( step_type=time_step_spec.step_type, observation=time_step_spec.observation, action=action_spec, policy_info=policy.info_spec, next_step_type=time_step_spec.step_type, reward=time_step_spec.reward, discount=time_step_spec.discount) offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=collect_data_spec, batch_size=1, max_length=int(1e5))
def as_tf_env(env): return tf_py_environment.TFPyEnvironment(env)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=1000, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(100,), value_net_fc_layers=(100,), use_value_network=False, # Params for collect collect_episodes_per_iteration=2, replay_buffer_capacity=2000, # Params for train learning_rate=1e-3, gamma=0.9, gradient_clipping=None, normalize_returns=True, value_estimation_loss_coef=0.2, # Params for eval num_eval_episodes=10, eval_interval=100, # Params for checkpoints, summaries, and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, rb_checkpoint_interval=200, log_interval=100, summary_interval=100, summaries_flush_secs=1, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for Reinforce.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) # TODO(b/127870767): Handle distributions without gin. actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers) if use_value_network: value_net = value_network.ValueNetwork( tf_env.time_step_spec().observation, fc_layer_params=value_net_fc_layers) tf_agent = reinforce_agent.ReinforceAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, value_network=value_net if use_value_network else None, value_estimation_loss_coef=value_estimation_loss_coef, gamma=gamma, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), normalize_returns=normalize_returns, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() experience = replay_buffer.gather_all() train_op = tf_agent.train(experience) clear_rb_op = replay_buffer.clear() train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # TODO(b/126239733): Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Compute evaluation metrics. global_step_call = sess.make_callable(global_step) global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) clear_rb_call = sess.make_callable(clear_rb_op) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() total_loss, _ = train_step_call() clear_rb_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='SocialBot-ICubWalkPID-v0', num_iterations=10000000, actor_fc_layers=(256, 128), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 128), # Params for collect initial_collect_steps=2000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, num_parallel_environments=12, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=5e-4, critic_learning_rate=5e-4, alpha_learning_rate=5e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: suite_socialbot.load(env_name,wrap_with_process=False)] * num_parallel_environments)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_socialbot.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = ( global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_allowlist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, exploration_noise_std=0.1, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, actor_update_period=2, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, td_errors_loss_fn=None, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for TD3.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_allowlist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_allowlist=[observations_allowlist]) ] else: env_wrappers = [] tf_env = tf_py_environment.TFPyEnvironment( suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers)) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_agent(iterations, modeldir, logdir, policydir): """Train and convert the model using TF Agents.""" train_py_env = planestrike_py_environment.PlaneStrikePyEnvironment( board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2 ) eval_py_env = planestrike_py_environment.PlaneStrikePyEnvironment( board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2 ) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) # Alternatively you could use ActorDistributionNetwork as actor_net actor_net = tfa.networks.Sequential( [ tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE], [BOARD_SIZE**2]), tf.keras.layers.Dense(FC_LAYER_PARAMS, activation="relu"), tf.keras.layers.Dense(BOARD_SIZE**2), tf.keras.layers.Lambda(lambda t: tfp.distributions.Categorical(logits=t)), ], input_spec=train_py_env.observation_spec(), ) optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) train_step_counter = tf.Variable(0) tf_agent = reinforce_agent.ReinforceAgent( train_env.time_step_spec(), train_env.action_spec(), actor_network=actor_net, optimizer=optimizer, normalize_returns=True, train_step_counter=train_step_counter, ) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy tf_policy_saver = policy_saver.PolicySaver(collect_policy) # Use reverb as replay buffer replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec) table = reverb.Table( REPLAY_BUFFER_TABLE_NAME, max_size=REPLAY_BUFFER_CAPACITY, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=replay_buffer_signature, ) # specify signature here for validation at insertion time reverb_server = reverb.Server([table]) replay_buffer = reverb_replay_buffer.ReverbReplayBuffer( tf_agent.collect_data_spec, sequence_length=None, table_name=REPLAY_BUFFER_TABLE_NAME, local_server=reverb_server, ) replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver( replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME, REPLAY_BUFFER_CAPACITY ) # Optimize by wrapping some of the code in a graph using TF function. tf_agent.train = common.function(tf_agent.train) # Evaluate the agent's policy once before training. avg_return = compute_avg_return_and_steps( eval_env, tf_agent.policy, NUM_EVAL_EPISODES ) summary_writer = tf.summary.create_file_writer(logdir) for i in range(iterations): # Collect a few episodes using collect_policy and save to the replay buffer. collect_episode( train_py_env, collect_policy, COLLECT_EPISODES_PER_ITERATION, replay_buffer_observer, ) # Use data from the buffer and update the agent's network. iterator = iter(replay_buffer.as_dataset(sample_batch_size=1)) trajectories, _ = next(iterator) tf_agent.train(experience=trajectories) replay_buffer.clear() logger = tf.get_logger() if i % EVAL_INTERVAL == 0: avg_return, avg_episode_length = compute_avg_return_and_steps( eval_env, eval_policy, NUM_EVAL_EPISODES ) with summary_writer.as_default(): tf.summary.scalar("Average return", avg_return, step=i) tf.summary.scalar("Average episode length", avg_episode_length, step=i) summary_writer.flush() logger.info( "iteration = {0}: Average Return = {1}, Average Episode Length = {2}".format( i, avg_return, avg_episode_length ) ) summary_writer.close() tf_policy_saver.save(policydir)
def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v1', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(kbanoop): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=50, rb_checkpoint_interval=200, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) # TODO(sguada): Reenable metrics when ready for batch data. environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries() with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss = sess.run(train_op) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) trajectory_spec = trajectory.from_transition( time_step=tf_env.time_step_spec(), action_step=policy_step.PolicyStep(action=tf_env.action_spec()), next_time_step=tf_env.time_step_spec()) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=trajectory_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) q_net = q_network.QNetwork(tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy() collect_policy = tf_agent.collect_policy() collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) global_step = tf.compat.v1.train.get_or_create_global_step() initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) time_step = None policy_state = () timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience, train_step_counter=global_step) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) return train_loss