예제 #1
0
def main(_):
  logging.set_verbosity(logging.INFO)
  if common.has_eager_been_enabled():
    return 0
  tf.enable_resource_variables()
  TrainEval(FLAGS.root_dir, suite_atari.game(name=FLAGS.game_name),
            **get_run_args()).run()
예제 #2
0
def main(_):
    logging.set_verbosity(logging.INFO)
    tf.enable_resource_variables()
    environment_name = FLAGS.environment_name
    if environment_name is None:
        environment_name = suite_atari.game(name=FLAGS.game_name)
    TrainEval(FLAGS.root_dir, environment_name, **get_run_args()).run()
예제 #3
0
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.enable_resource_variables()
  runner = abps_runners.EvalRunner(
      root_dir=FLAGS.root_dir,
      env_name=suite_atari.game(name=FLAGS.game_name),
      **get_run_args())
  runner.run()
예제 #4
0
def main(_):
  tf.disable_eager_execution()
  logging.set_verbosity(logging.INFO)
  tf.enable_resource_variables()
  runner = abps_runners.TrainRunner(
      root_dir=FLAGS.root_dir,
      env_name=suite_atari.game(name=FLAGS.game_name),
      **get_run_args())
  runner.run()
예제 #5
0
def main(_):
    logging.set_verbosity(logging.INFO)
    tf.enable_resource_variables()
    if FLAGS.select_policy_way == 'independent':
        runner = baseline_runners.EvalRunner(
            root_dir=FLAGS.root_dir,
            env_name=suite_atari.game(name=FLAGS.game_name),
            **get_run_args())
    runner.run()
예제 #6
0
def main(_):
  logging.set_verbosity(logging.INFO)
  tf.enable_resource_variables()
  if FLAGS.select_policy_way == 'independent':
    # runner = abps_runners.TrainIndependRunner(
    #     root_dir=FLAGS.root_dir,
    #     env_name=suite_atari.game(name=FLAGS.game_name),
    #     **get_run_args())
    runner = baseline_runners.PBTRunner(
        root_dir=FLAGS.root_dir,
        env_name=suite_atari.game(name=FLAGS.game_name),
        **get_run_args())
  elif FLAGS.select_policy_way == 'controller':
    runner = baseline_runners.PBTController(
        root_dir=FLAGS.root_dir,
        env_name=suite_atari.game(name=FLAGS.game_name),
        **get_run_args())
  runner.run()
예제 #7
0
def main(_):
    logging.set_verbosity(logging.INFO)
    gin.parse_config_files_and_bindings(None, FLAGS.gin_binding)
    TrainEval(FLAGS.root_dir, suite_atari.game(name=FLAGS.game_name)).run()
예제 #8
0
 def testGameSetAll(self):
     name = suite_atari.game('Pong', 'ram', 'Deterministic', 'v4')
     self.assertEqual(name, 'Pong-ramDeterministic-v4')
예제 #9
0
def main(_):
    logging.set_verbosity(logging.INFO)
    tf.enable_resource_variables()
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
    TrainEval(FLAGS.root_dir, suite_atari.game(name=FLAGS.game_name)).run()
예제 #10
0
 def testGameMode(self):
     name = suite_atari.game('Pong', mode='Deterministic')
     self.assertEqual(name, 'PongDeterministic-v0')
예제 #11
0
 def testGameObsType(self):
     name = suite_atari.game('Pong', obs_type='ram')
     self.assertEqual(name, 'Pong-ramNoFrameskip-v0')
예제 #12
0
 def testGameName(self):
     name = suite_atari.game('Pong')
     self.assertEqual(name, 'PongNoFrameskip-v0')
예제 #13
0
num_eval_episodes = 1  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

video_interval = 10000

fc_layer_params = (
    256,
    256,
)

# -------------------Environment-------------------------------
# env_name = 'Pong-ram-v0'
# env_name = 'Breakout-ram-v0'
# env_name = 'BreakoutNoFrameskip-v4'
# env_name = 'CartPole-v0'
env_name = suite_atari.game('Breakout', 'ram')
# print(env_name)
env = suite_atari.load(env_name,
                       max_episode_steps=1000,
                       gym_env_wrappers=[atari_wrappers.FireOnReset])
# print(env)
# env.reset()
# print('Observation Spec:')
# print(env.time_step_spec().observation)
# print('Action Spec:')
# print(env.action_spec())

# # train_py_env = suite_gym.load(env_name)
# # eval_py_env = suite_gym.load(env_name)
train_py_env = suite_atari.load(env_name,
                                max_episode_steps=10000,
def train_eval(
		root_dir,
		env_name=suite_atari.game('Breakout'),
		env_load_fn=suite_atari.load,
		random_seed=0,
		# TODO(b/127576522): rename to policy_fc_layers.
		actor_fc_layers=(200, 100),
		value_fc_layers=(200, 100),
		use_rnns=False,
		# Params for collect
		num_environment_steps=10000000,
		collect_episodes_per_iteration=30,
		num_parallel_environments=30,
		replay_buffer_capacity=1001,  # Per-environment
		# Params for train
		num_epochs=25,
		learning_rate=1e-4,
		# Params for eval
		num_eval_episodes=30,
		eval_interval=500,
		# Params for summaries and logging
		train_checkpoint_interval=500,
		policy_checkpoint_interval=500,
		log_interval=50,
		summary_interval=50,
		summaries_flush_secs=1,
		use_tf_functions=True,
		debug_summaries=False,
		summarize_grads_and_vars=False):
	"""A simple train and eval for PPO."""
	if root_dir is None:
		raise AttributeError('train_eval requires a root_dir.')

	root_dir = os.path.expanduser(root_dir)
	train_dir = os.path.join(root_dir, 'train')
	eval_dir = os.path.join(root_dir, 'eval')
	saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

	train_summary_writer = tf.compat.v2.summary.create_file_writer(
		train_dir, flush_millis=summaries_flush_secs * 1000)
	train_summary_writer.set_as_default()

	eval_summary_writer = tf.compat.v2.summary.create_file_writer(
		eval_dir, flush_millis=summaries_flush_secs * 1000)
	eval_metrics = [
		tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
		tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
	]

	global_step = tf.compat.v1.train.get_or_create_global_step()
	with tf.compat.v2.summary.record_if(lambda: tf.math.equal(global_step % summary_interval, 0)):
		tf.compat.v1.set_random_seed(random_seed)

		eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
		tf_env = tf_py_environment.TFPyEnvironment(parallel_py_environment.ParallelPyEnvironment([lambda: env_load_fn(env_name)] * num_parallel_environments))
		optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

		if use_rnns:
			actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
				tf_env.observation_spec(),
				tf_env.action_spec(),
				input_fc_layer_params=actor_fc_layers,
				output_fc_layer_params=None)
			value_net = value_rnn_network.ValueRnnNetwork(
				tf_env.observation_spec(),
				input_fc_layer_params=value_fc_layers,
				output_fc_layer_params=None)
		else:
			actor_net = actor_distribution_network.ActorDistributionNetwork(
				tf_env.observation_spec(),
				tf_env.action_spec(),
				fc_layer_params=actor_fc_layers)
			value_net = value_network.ValueNetwork(
				tf_env.observation_spec(), fc_layer_params=value_fc_layers)

		tf_agent = ppo_agent.PPOAgent(
			tf_env.time_step_spec(),
			tf_env.action_spec(),
			optimizer,
			actor_net=actor_net,
			value_net=value_net,
			num_epochs=num_epochs,
			debug_summaries=debug_summaries,
			summarize_grads_and_vars=summarize_grads_and_vars,
			train_step_counter=global_step)
		tf_agent.initialize()

		environment_steps_metric = tf_metrics.EnvironmentSteps()
		step_metrics = [
			tf_metrics.NumberOfEpisodes(),
			environment_steps_metric,
		]

		train_metrics = step_metrics + [
			tf_metrics.AverageReturnMetric(
				batch_size=num_parallel_environments),
			tf_metrics.AverageEpisodeLengthMetric(
				batch_size=num_parallel_environments),
		]

		eval_policy = tf_agent.policy
		collect_policy = tf_agent.collect_policy

		replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
			tf_agent.collect_data_spec,
			batch_size=num_parallel_environments,
			max_length=replay_buffer_capacity)

		train_checkpointer = common.Checkpointer(
			ckpt_dir=train_dir,
			agent=tf_agent,
			global_step=global_step,
			metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
		policy_checkpointer = common.Checkpointer(
			ckpt_dir=os.path.join(train_dir, 'policy'),
			policy=eval_policy,
			global_step=global_step)
		saved_model = policy_saver.PolicySaver(
			eval_policy, train_step=global_step)

		train_checkpointer.initialize_or_restore()

		collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
			tf_env,
			collect_policy,
			observers=[replay_buffer.add_batch] + train_metrics,
			num_episodes=collect_episodes_per_iteration)

		def train_step():
			trajectories = replay_buffer.gather_all()
			return tf_agent.train(experience=trajectories)

		if use_tf_functions:
			# TODO(b/123828980): Enable once the cause for slowdown was identified.
			collect_driver.run = common.function(collect_driver.run, autograph=False)
			tf_agent.train = common.function(tf_agent.train, autograph=False)
			train_step = common.function(train_step)

		collect_time = 0
		train_time = 0
		timed_at_step = global_step.numpy()

		while environment_steps_metric.result() < num_environment_steps:
			global_step_val = global_step.numpy()
			if global_step_val % eval_interval == 0 and global_step_val > 0:
				metric_utils.eager_compute(
					eval_metrics,
					eval_tf_env,
					eval_policy,
					num_episodes=num_eval_episodes,
					train_step=global_step,
					summary_writer=eval_summary_writer,
					summary_prefix='Metrics',
				)

			start_time = time.time()
			collect_driver.run()
			collect_time += time.time() - start_time

			start_time = time.time()
			total_loss, _ = train_step()
			replay_buffer.clear()
			train_time += time.time() - start_time

			for train_metric in train_metrics:
				train_metric.tf_summaries(
					train_step=global_step, step_metrics=step_metrics)

			if global_step_val % log_interval == 0:
				logging.info('step = %d, loss = %f', global_step_val, total_loss)
				steps_per_sec = (
						(global_step_val - timed_at_step) / (collect_time + train_time))
				logging.info('%.3f steps/sec', steps_per_sec)
				logging.info('collect_time = {}, train_time = {}'.format(
					collect_time, train_time))
				with tf.compat.v2.summary.record_if(True):
					tf.compat.v2.summary.scalar(
						name='global_steps_per_sec', data=steps_per_sec, step=global_step)

				if global_step_val % train_checkpoint_interval == 0:
					train_checkpointer.save(global_step=global_step_val)

				if global_step_val % policy_checkpoint_interval == 0:
					policy_checkpointer.save(global_step=global_step_val)
					saved_model_path = os.path.join(
						saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9))
					saved_model.save(saved_model_path)

				timed_at_step = global_step_val
				collect_time = 0
				train_time = 0

		# One final eval before exiting.
		metric_utils.eager_compute(
			eval_metrics,
			eval_tf_env,
			eval_policy,
			num_episodes=num_eval_episodes,
			train_step=global_step,
			summary_writer=eval_summary_writer,
			summary_prefix='Metrics',
		)
from tf_agents.environments import suite_atari
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import actor_distribution_rnn_network
from tf_agents.networks import value_network
from tf_agents.networks import value_rnn_network
from tf_agents.policies import policy_saver
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common


flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
					'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_string('env_name', suite_atari.game('Breakout'), 'Name of an environment')
flags.DEFINE_integer('replay_buffer_capacity', 512,
					 'Replay buffer capacity per env.')
flags.DEFINE_integer('num_parallel_environments', 16,
					 'Number of environments to run in parallel')
flags.DEFINE_integer('num_environment_steps', 10000000,
					 'Number of environment steps to run before finishing.')
flags.DEFINE_integer('num_epochs', 25,
					 'Number of epochs for computing policy updates.')
flags.DEFINE_integer(
	'collect_episodes_per_iteration', 16,
	'The number of episodes to take in the environment before '
	'each update. This is the total across all parallel '
	'environments.')
flags.DEFINE_integer('num_eval_episodes', 16,
					 'The number of episodes to run eval on.')
예제 #16
0
def get_game_id():
    copy_rom()
    return suite_atari.game(name=FLAGS.game_name,
                            mode=FLAGS.game_mode,
                            version=FLAGS.game_version)
예제 #17
0
def main(_):
    logging.set_verbosity(logging.INFO)
    tf.enable_resource_variables()
    TrainEval(FLAGS.root_dir, suite_atari.game(name=FLAGS.game_name),
              **get_run_args()).run()
예제 #18
0
 def testGameVersion(self):
     name = suite_atari.game('Pong', version='v4')
     self.assertEqual(name, 'PongNoFrameskip-v4')