def make_possibly_parallel_environment(env_name_): """Returns a function creating env_name_, possibly a parallel one.""" if num_parallel_environments == 1: return env_load_fn(env_name_) else: return parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name_)] * num_parallel_environments)
def train_eval_bomberman(root_dir, num_parallel_environments=4, summary_interval=1000): root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') ckpt_dir = os.path.join(root_dir, 'checkpoint') policy_dir = os.path.join(root_dir, 'policy') train_summary_writer = tf.summary.create_file_writer(train_dir, flush_millis=1000) train_summary_writer.set_as_default() eval_summary_writer = tf.summary.create_file_writer(eval_dir) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=10), tf_metrics.AverageEpisodeLengthMetric(buffer_size=10) ] global_step = tf.Variable(0) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [BombermanEnvironment] * num_parallel_environments)) eval_tf_env = BombermanEnvironment() optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
def train_eval(root_dir, env_name, env_load_fn, num_parallel_environments): global_step = tf.compat.v1.train.get_or_create_global_step() print("oof") eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) print("oof")
def _build_configuration(self): """Builds a configuration using an SAC agent """ self._scenario_generator = \ DeterministicDroneChallengeGeneration(num_scenarios=3, random_seed=0, params=self._params) self._observer = CustomObserver(params=self._params) self._behavior_model = DynamicModel(model_name="TripleIntegratorModel", params=self._params) self._evaluator = CustomEvaluator(params=self._params) viewer = MPViewer(params=self._params, x_range=[-20, 20], y_range=[-20, 20], follow_agent_id=True) self._viewer = viewer # self._viewer = VideoRenderer(renderer=viewer, world_step_time=0.2) self._runtime = RuntimeRL(action_wrapper=self._behavior_model, observer=self._observer, evaluator=self._evaluator, step_time=0.2, viewer=self._viewer, scenario_generator=self._scenario_generator) # tfa_env = tf_py_environment.TFPyEnvironment(TFAWrapper(self._runtime)) tfa_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: TFAWrapper(self._runtime)] * self._params["ML"]["Agent"]["num_parallel_environments"])) self._agent = SACAgent(tfa_env, params=self._params) self._runner = SACRunner(tfa_env, self._agent, params=self._params, unwrapped_runtime=self._runtime)
def create_environment(env_name='CartPole-v0', env_load_fn=suite_gym.load, num_parallel_environments=30, nonparallel=False): """Create environment. Args: env_name (str): env name env_load_fn (Callable) : callable that create an environment num_parallel_environments (int): num of parallel environments nonparallel (bool): force to create a single env in the current process. Used for correctly exposing game gin confs to tensorboard. Returns: TFPyEnvironment """ if nonparallel: # Each time we can only create one unwrapped env at most # Create and step the env in a separate thread. env `step` and `reset` must # run in the same thread which the env is created in for some simulation # environments such as social_bot(gazebo) py_env = ThreadPyEnvironment(lambda: env_load_fn(env_name)) py_env.seed(np.random.randint(0, np.iinfo(np.int32).max)) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) py_env.seed([ np.random.randint(0, np.iinfo(np.int32).max) for i in range(num_parallel_environments) ]) return tf_py_environment.TFPyEnvironment(py_env)
def test_parallel_envs(self): env_num = 5 ctors = [ lambda: suite_socialbot.load('SocialBot-CartPole-v0', wrap_with_process=False) ] * env_num self._env = parallel_py_environment.ParallelPyEnvironment( env_constructors=ctors, start_serially=False) tf_env = tf_py_environment.TFPyEnvironment(self._env) self.assertTrue(tf_env.batched) self.assertEqual(tf_env.batch_size, env_num) random_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_capacity = 100 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( random_policy.trajectory_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) steps = 100 step_driver = dynamic_step_driver.DynamicStepDriver( tf_env, random_policy, observers=[replay_buffer.add_batch], num_steps=steps) step_driver.run = common.function(step_driver.run) step_driver.run() self.assertIsNotNone(replay_buffer.get_next())
def get_env(env_name, max_episode_steps=None, constant_task=None, num_parallel_environments=1): """Loads the environment. Args: env_name: (str) name of the environment. max_episode_steps: (int) maximum number of steps per episode. Set to None to not include a limit. constant_task: specifies a fixed task to use for all episodes. Set to None to use tasks sampled from the task distribution. num_parallel_environments: (int) Number of parallel environments. Returns: tf_env: the environment, build from a dynamics and task distribution. This environment is an instance of TFPyEnvironment. task_distribution: the task distribution used for the environment. """ def env_load_fn(return_task_distribution=False): py_env, task_distribution = get_py_env( env_name, max_episode_steps=max_episode_steps, constant_task=constant_task) if return_task_distribution: return (py_env, task_distribution) else: return py_env py_env, task_distribution = env_load_fn(return_task_distribution=True) if num_parallel_environments > 1: del py_env py_env = parallel_py_environment.ParallelPyEnvironment( [env_load_fn] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) return tf_env, task_distribution
def get_tf_env(): def _load_env(): return test_env.CountingEnv(steps_per_episode=10) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([lambda: _load_env()] * FP.NUM_PARALLEL_ENVS)) return tf_env
def test_checks_constructors(self): self._set_default_specs() # pytype: disable=wrong-arg-types with self.assertRaisesRegex(TypeError, '.*non-callable.*'): parallel_py_environment.ParallelPyEnvironment([ random_py_environment.RandomPyEnvironment( self.observation_spec, self.action_spec) ])
def _make_parallel_py_environment(self, constructor=None, num_envs=2, blocking=True): self._set_default_specs() constructor = constructor or functools.partial( random_py_environment.RandomPyEnvironment, self.observation_spec, self.action_spec) return parallel_py_environment.ParallelPyEnvironment( env_constructors=[constructor] * num_envs, blocking=blocking)
def test_dmlab_env(self): ctor = lambda: suite_dmlab.load(scene='lt_chasm', gym_env_wrappers=[ wrappers.FrameGrayScale, wrappers. FrameResize, wrappers.FrameStack ], wrap_with_process=False) self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 2) env = tf_py_environment.TFPyEnvironment(self._env) self.assertEqual((84, 84, 4), env.observation_spec().shape)
def _make_parallel_py_environment(self, constructor=None, num_envs=2): self.observation_spec = array_spec.ArraySpec((3, 3), np.float32) self.time_step_spec = ts.time_step_spec(self.observation_spec) self.action_spec = array_spec.BoundedArraySpec([7], dtype=np.float32, minimum=-1.0, maximum=1.0) constructor = constructor or functools.partial( random_py_environment.RandomPyEnvironment, self.observation_spec, self.action_spec) return parallel_py_environment.ParallelPyEnvironment( env_constructors=[constructor] * num_envs, blocking=True)
def test_dmlab_env_run(self, scene): ctor = lambda: suite_dmlab.load(scene=scene, gym_env_wrappers= [wrappers.FrameResize], wrap_with_process=False) self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4) env = tf_py_environment.TFPyEnvironment(self._env) self.assertEqual((84, 84, 3), env.observation_spec().shape) random_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) driver = dynamic_step_driver.DynamicStepDriver(env=env, policy=random_policy, observers=None, num_steps=10) driver.run(maximum_iterations=10)
def _build_configuration(self): """Builds a configuration using an PPO agent """ # self._runtime = RuntimeRL(action_wrapper=self._behavior_model, # observer=self._observer, # evaluator=self._evaluator, # step_time=0.2, # viewer=self._viewer, # scenario_generator=self._scenario_generator) self._runtime = gym.make('Pendulum-v0') # tfa_env = tf_py_environment.TFPyEnvironment(TFAWrapper(self._runtime)) tfa_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: TFAWrapper(self._runtime)] * self._params["ML"]["Agent"]["num_parallel_environments", "", 0])) self._agent = PPOAgent(tfa_env, params=self._params) self._runner = PPORunner(tfa_env, self._agent, params=self._params, unwrapped_runtime=self._runtime)
def create_environment(env_name='CartPole-v0', env_load_fn=suite_gym.load, num_parallel_environments=30): """Create environment. Args: env_name (str): env name env_load_fn (Callable) : callable that create an environment num_parallel_environments (int): num of parallel environments """ if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: if env_load_fn == suite_socialbot.load: logging.info("suite_socialbot environment") # No need to wrap with process since ParallelPyEnvironment will do it env_load_fn = lambda env_name: suite_socialbot.load( env_name, wrap_with_process=False) py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) return tf_py_environment.TFPyEnvironment(py_env)
def test_mario_env(self): ctor = lambda: suite_mario.load( 'SuperMarioBros-Nes', 'Level1-1', wrap_with_process=False) self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4) env = tf_py_environment.TFPyEnvironment(self._env) self.assertEqual(np.uint8, env.observation_spec().dtype) self.assertEqual((84, 84, 4), env.observation_spec().shape) random_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) metrics = [ AverageReturnMetric(batch_size=4), AverageEpisodeLengthMetric(batch_size=4), EnvironmentSteps(), NumberOfEpisodes() ] driver = dynamic_step_driver.DynamicStepDriver(env, random_policy, metrics, 10000) driver.run(maximum_iterations=10000)
def create_envs(env_name, use_multiprocessing, num_parallel_envs, visualize_eval=False, mock_train_envs=False): def env_load_fn(env_map_name, visualize=False, mock=False): env = gym_wrapper.GymWrapper( gym_env=SC2GymEnv(map_name=env_map_name, visualize=visualize, mock=mock), spec_dtype_map={ gym.spaces.Box: np.float32, gym.spaces.Discrete: np.int32, gym.spaces.MultiBinary: np.float32 }, ) return env if num_parallel_envs == 1: par_env = env_load_fn(env_map_name=env_name, mock=mock_train_envs) elif use_multiprocessing: par_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_map_name=env_name, mock=mock_train_envs) ] * num_parallel_envs, start_serially=False) else: par_env = batched_py_environment.BatchedPyEnvironment(envs=[ env_load_fn(env_map_name=env_name, mock=mock_train_envs) for _ in range(num_parallel_envs) ]) tf_env = tf_py_environment.TFPyEnvironment(par_env) tf_env.reset() eval_env = env_load_fn(env_name, visualize=visualize_eval) eval_env = tf_py_environment.TFPyEnvironment(eval_env) eval_env.reset() return tf_env, eval_env
def test(): num_episodes = 5 py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: point_mass.env_load_fn() for _ in range(num_episodes)]) env = tf_py_environment.TFPyEnvironment(py_env) policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(), env.action_spec()) traj_spec = trajectory.from_transition(env.time_step_spec(), policy.policy_step_spec, env.time_step_spec()) rb = episodic_replay_buffer.EpisodicReplayBuffer(traj_spec) srb = episodic_replay_buffer.StatefulEpisodicReplayBuffer( rb, num_episodes=num_episodes) rb2 = tf_uniform_replay_buffer.TFUniformReplayBuffer(traj_spec, 1) driver = safe_dynamic_episode_driver.SafeDynamicEpisodeDriver( env, policy, rb, rb2, observers=[srb.add_batch], num_episodes=num_episodes) driver.run()
def train_eval( root_dir, env_name='MultiGrid-Empty-5x5-v0', env_load_fn=multiagent_gym_suite.load, random_seed=0, # Architecture params agent_class=multiagent_ppo.MultiagentPPO, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), lstm_size=(64, ), conv_filters=64, conv_kernel=3, direction_fc=5, entropy_regularization=0., use_attention_networks=False, # Specialized agents inactive_agent_ids=tuple(), # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=5, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=2, learning_rate=1e-4, # Params for eval num_eval_episodes=2, eval_interval=5, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, log_interval=10, summary_interval=10, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None, reinit_checkpoint_dir=None, debug=True): """A simple train and eval for PPO.""" tf.compat.v1.enable_v2_behavior() if root_dir is None: raise AttributeError('train_eval requires a root_dir.') if debug: logging.info('In debug mode, turning tf_functions off') use_tf_functions = False for a in inactive_agent_ids: logging.info('Fixing and not training agent %d', a) # Load multiagent gym environment and determine number of agents gym_env = env_load_fn(env_name) n_agents = gym_env.n_agents # Set up logging root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ multiagent_metrics.AverageReturnMetric(n_agents, buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) logging.info('Creating %d environments...', num_parallel_environments) wrappers = [] if use_attention_networks: wrappers = [ lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size) ] eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name, gym_kwargs=dict(seed=random_seed), gym_env_wrappers=wrappers)) # pylint: disable=g-complex-comprehension tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ functools.partial(env_load_fn, environment_name=env_name, gym_env_wrappers=wrappers, gym_kwargs=dict(seed=random_seed * 1234 + i)) for i in range(num_parallel_environments) ])) logging.info('Preparing to train...') environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] bonus_metrics = [ multiagent_metrics.MultiagentScalar(n_agents, name='UnscaledMultiagentBonus', buffer_size=1000), ] train_metrics = step_metrics + [ multiagent_metrics.AverageReturnMetric( n_agents, batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] logging.info('Creating agent...') tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy logging.info('Allocating replay buffer ...') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) logging.info('RB capacity: %i', replay_buffer.capacity) # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is # reinitialized. The other agents are novices. # Otherwise, all agents are reinitialized from train_dir. if reinit_checkpoint_dir: reinit_checkpointer = common.Checkpointer( ckpt_dir=reinit_checkpoint_dir, agent=tf_agent, ) reinit_checkpointer.initialize_or_restore() temp_dir = os.path.join(train_dir, 'tmp') agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1], ) agent_checkpointer.save(global_step=0) tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids, non_learning_agents=list(range(n_agents - 1))) agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1]) agent_checkpointer.initialize_or_restore() tf.io.gfile.rmtree(temp_dir) eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=multiagent_metrics.MultiagentMetricsGroup( train_metrics + bonus_metrics, 'train_metrics')) if not reinit_checkpoint_dir: train_checkpointer.initialize_or_restore() logging.info('Successfully initialized train checkpointer') policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) collect_policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=collect_policy, global_step=global_step) collect_saved_model = policy_saver.PolicySaver(collect_policy, train_step=global_step) logging.info('Successfully initialized policy saver.') print('Using TFDriver') if use_attention_networks: collect_driver = drivers.StateTFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) else: collect_driver = tf_driver.TFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 # Save operative config as late as possible to include used configurables. if global_step.numpy() == 0: config_filename = os.path.join( train_dir, 'operative_config-{}.gin'.format(global_step.numpy())) with tf.io.gfile.GFile(config_filename, 'wb') as f: f.write(gin.operative_config_str()) total_episodes = 0 logging.info('Commencing train loop!') while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() # Evaluation if global_step_val % eval_interval == 0: if debug: logging.info('Performing evaluation at step %d', global_step_val) results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics) # Collect data if debug: logging.info('Collecting at step %d', global_step_val) start_time = time.time() time_step = tf_env.reset() policy_state = collect_policy.get_initial_state(tf_env.batch_size) if use_attention_networks: # Attention networks require previous policy state to compute attention # weights. time_step.observation['policy_state'] = ( policy_state['actor_network_state'][0], policy_state['actor_network_state'][1]) collect_driver.run(time_step, policy_state) collect_time += time.time() - start_time total_episodes += collect_episodes_per_iteration if debug: logging.info('Have collected a total of %d episodes', total_episodes) # Train if debug: logging.info('Training at step %d', global_step_val) start_time = time.time() total_loss, extra_loss = train_step() replay_buffer.clear() train_time += time.time() - start_time # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: logging.info( 'Loss diverged for too many timesteps, breaking...') break else: loss_divergence_counter = 0 for train_metric in train_metrics + bonus_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, total loss = %f', global_step_val, total_loss) for a in range(n_agents): if not inactive_agent_ids or a not in inactive_agent_ids: logging.info('Loss for agent %d = %f', a, extra_loss[a].loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) collect_policy_checkpointer.save(global_step=global_step_val) collect_saved_model_path = os.path.join( saved_model_dir, 'collect_policy_' + ('%d' % global_step_val).zfill(9)) collect_saved_model.save(collect_saved_model_path) # One final eval before exiting. results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics)
def train_eval( root_dir, env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=None, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, lstm_size=(20, ), # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-3, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) tf_agent = ppo_clip_agent.PPOClipAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) train_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def __init__(self, root_dir, env_load_fn=suite_gym.load, env_name='CartPole-v0', num_parallel_environments=1, agent_class=None, num_eval_episodes=30, write_summaries=True, summaries_flush_secs=10, eval_metrics_callback=None, env_metric_factories=None): """Evaluate policy checkpoints as they are produced. Args: root_dir: Main directory for experiment files. env_load_fn: Function to load the environment specified by env_name. env_name: Name of environment to evaluate in. num_parallel_environments: Number of environments to evaluate on in parallel. agent_class: TFAgent class to instantiate for evaluation. num_eval_episodes: Number of episodes to average evaluation over. write_summaries: Whether to write summaries to the file system. summaries_flush_secs: How frequently to flush summaries (in seconds). eval_metrics_callback: A function that will be called with evaluation results for every checkpoint. env_metric_factories: An iterable of metric factories. Use this for eval metrics that needs access to the evaluated environment. A metric factory is a function that takes an eviornment and buffer_size as keyword arguments and returns an instance of py_metric. Raises: ValueError: when num_parallel_environments > num_eval_episodes or agent_class is not set """ if not agent_class: raise ValueError( 'The `agent_class` parameter of Evaluator must be set.') if num_parallel_environments > num_eval_episodes: raise ValueError( 'num_parallel_environments should not be greater than ' 'num_eval_episodes') self._num_eval_episodes = num_eval_episodes self._eval_metrics_callback = eval_metrics_callback # Flag that controls eval cycle. If set, evaluation will exit eval loop # before the max checkpoint number is reached. self._terminate_early = False # Save root dir to self so derived classes have access to it. self._root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(self._root_dir, 'train') self._eval_dir = os.path.join(self._root_dir, 'eval') self._global_step = tf.compat.v1.train.get_or_create_global_step() self._env_name = env_name if num_parallel_environments == 1: eval_env = env_load_fn(env_name) else: eval_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) if isinstance(eval_env, py_environment.PyEnvironment): self._eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env) self._eval_py_env = eval_env else: self._eval_tf_env = eval_env self._eval_py_env = None # Can't generically convert to PyEnvironment. self._eval_metrics = [ tf_metrics.AverageReturnMetric( buffer_size=self._num_eval_episodes, batch_size=self._eval_tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=self._num_eval_episodes, batch_size=self._eval_tf_env.batch_size), ] if env_metric_factories: if not self._eval_py_env: raise ValueError( 'The `env_metric_factories` parameter of Evaluator ' 'can only be used with a PyEnvironment environment.') for metric_factory in env_metric_factories: py_metric = metric_factory(environment=self._eval_py_env, buffer_size=self._num_eval_episodes) self._eval_metrics.append(tf_py_metric.TFPyMetric(py_metric)) if write_summaries: self._eval_summary_writer = tf.compat.v2.summary.create_file_writer( self._eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_summary_writer.set_as_default() else: self._eval_summary_writer = None environment_specs.set_observation_spec( self._eval_tf_env.observation_spec()) environment_specs.set_action_spec(self._eval_tf_env.action_spec()) # Agent params configured with gin. self._agent = agent_class(self._eval_tf_env.time_step_spec(), self._eval_tf_env.action_spec()) self._eval_policy = greedy_policy.GreedyPolicy(self._agent.policy) self._eval_policy.action = common.function(self._eval_policy.action) # Run the agent on dummy data to force instantiation of the network. Keras # doesn't create variables until you first use the layer. This is needed # for checkpoint restoration to work. dummy_obs = tensor_spec.sample_spec_nest( self._eval_tf_env.observation_spec(), outer_dims=(self._eval_tf_env.batch_size, )) self._eval_policy.action( ts.restart(dummy_obs, batch_size=self._eval_tf_env.batch_size), self._eval_policy.get_initial_state(self._eval_tf_env.batch_size)) self._policy_checkpoint = tf.train.Checkpoint( policy=self._agent.policy, global_step=self._global_step) self._policy_checkpoint_dir = os.path.join(train_dir, 'policy')
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_allowlist='position', eval_env_name=None, num_iterations=1000000, # Params for networks. actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), num_parallel_environments=1, # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=256, critic_learning_rate=3e-4, train_sequence_length=20, actor_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for RNN SAC on DM control.""" root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_allowlist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_allowlist=[observations_allowlist]) ] else: env_wrappers = [] env_load_fn = functools.partial(suite_dm_control.load, task_name=task_name, env_wrappers=env_wrappers) if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length + 1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_load_fn=get_env, random_seed=None, # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=10, num_parallel_environments=10, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=10, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, policy_save_interval=10000, log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): if random_seed is not None: tf.set_random_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') logging.info('Running %d episodes in parallel' % num_parallel_environments) logging.info('Collecting %d episodes per step' % collect_episodes_per_iteration) logging.info('Using replay buffer capacity of %d' % replay_buffer_capacity) train_summary_writer = tf.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn()) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn()] * num_parallel_environments)) actor_net, value_net = get_actor_and_value_network( tf_env.action_spec(), tf_env.observation_spec()) train_steps = tf.Variable(0) with tf.summary.record_if( lambda: tf.math.equal(train_steps % summary_interval, 0)): tf_agent = get_agent(time_step_spec=tf_env.time_step_spec(), action_spec=tf_env.action_spec(), actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, step_counter=train_steps, learning_rate=learning_rate) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy step_metrics, train_metrics, eval_metrics = get_metrics( n_parallel_env=num_parallel_environments, num_eval_episodes=num_eval_episodes) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=train_steps, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=train_steps) saved_model = policy_saver.PolicySaver(eval_policy, train_step=train_steps) train_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % policy_save_interval == 0: saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def train_eval( root_dir, gpu=0, env_load_fn=None, model_ids=None, eval_env_mode='headless', num_iterations=1000000, conv_layer_params=None, encoder_fc_layers=[256], actor_fc_layers=[400, 300], critic_obs_fc_layers=[400], critic_action_fc_layers=None, critic_joint_fc_layers=[300], # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, eval_only=False, eval_deterministic=False, num_parallel_environments_eval=1, model_ids_eval=None, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=10000, rb_checkpoint_interval=50000, log_interval=100, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( py_metrics.AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments_eval), batched_py_metric.BatchedPyMetric( py_metrics.AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments_eval), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if model_ids is None: model_ids = [None] * num_parallel_environments else: assert len(model_ids) == num_parallel_environments, \ 'model ids provided, but length not equal to num_parallel_environments' if model_ids_eval is None: model_ids_eval = [None] * num_parallel_environments_eval else: assert len(model_ids_eval) == num_parallel_environments_eval,\ 'model ids eval provided, but length not equal to num_parallel_environments_eval' tf_py_env = [ lambda model_id=model_ids[i]: env_load_fn(model_id, 'headless', gpu ) for i in range(num_parallel_environments) ] tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment(tf_py_env)) if eval_env_mode == 'gui': assert num_parallel_environments_eval == 1, 'only one GUI env is allowed' eval_py_env = [ lambda model_id=model_ids_eval[i]: env_load_fn( model_id, eval_env_mode, gpu) for i in range(num_parallel_environments_eval) ] eval_py_env = parallel_py_environment.ParallelPyEnvironment( eval_py_env) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() print('observation_spec', observation_spec) print('action_spec', action_spec) glorot_uniform_initializer = tf.compat.v1.keras.initializers.glorot_uniform( ) preprocessing_layers = { 'depth_seg': tf.keras.Sequential( mlp_layers( conv_layer_params=conv_layer_params, fc_layer_params=encoder_fc_layers, kernel_initializer=glorot_uniform_initializer, )), 'sensor': tf.keras.Sequential( mlp_layers( conv_layer_params=None, fc_layer_params=encoder_fc_layers, kernel_initializer=glorot_uniform_initializer, )), } preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) actor_net = actor_network.ActorNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=actor_fc_layers, kernel_initializer=glorot_uniform_initializer, ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer=glorot_uniform_initializer, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True sess = tf.compat.v1.Session(config=config) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] if eval_deterministic: eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) else: eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) step_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( buffer_size=100, batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( buffer_size=100, batch_size=num_parallel_environments), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps * num_parallel_environments).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration * num_parallel_environments).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=5, sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch(5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics)) with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if( True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) init_agent_op = tf_agent.initialize() with sess.as_default(): # Initialize the graph. train_checkpointer.initialize_or_restore(sess) if eval_only: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=0, callback=eval_metrics_callback, tf_summaries=False, log=True, ) episodes = eval_py_env.get_stored_episodes() episodes = [ episode for sublist in episodes for episode in sublist ][:num_eval_episodes] metrics = episode_utils.get_metrics(episodes) for key in sorted(metrics.keys()): print(key, ':', metrics[key]) save_path = os.path.join(eval_dir, 'episodes_vis.pkl') episode_utils.save(episodes, save_path) print('EVAL DONE') return # Initialize training. rb_checkpointer.initialize_or_restore(sess) sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=0, callback=eval_metrics_callback, tf_summaries=True, log=True, ) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() # print('collect:', time.time() - start_time) # train_start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() # print('train:', time.time() - train_start_time) time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=0, callback=eval_metrics_callback, tf_summaries=True, log=True, ) with eval_summary_writer.as_default( ), tf.compat.v2.summary.record_if(True): with tf.name_scope('Metrics/'): episodes = eval_py_env.get_stored_episodes() episodes = [ episode for sublist in episodes for episode in sublist ][:num_eval_episodes] metrics = episode_utils.get_metrics(episodes) for key in sorted(metrics.keys()): print(key, ':', metrics[key]) metric_op = tf.compat.v2.summary.scalar( name=key, data=metrics[key], step=global_step_val) sess.run(metric_op) sess.run(eval_summary_flush_op) sess.close()
def main(): logging.set_verbosity(logging.INFO) tf.compat.v1.enable_v2_behavior() parser = argparse.ArgumentParser() ## Essential parameters parser.add_argument("--output_dir", default=None, type=str, required=True,help="The output directory where the model stats and checkpoints will be written.") parser.add_argument("--env", default=None, type=str, required=True,help="The environment to train the agent on") parser.add_argument("--max_horizon", default=4, type=int) parser.add_argument("--atari", default=False, type=bool, help = "Gets some data Types correctly") ##agent parameters parser.add_argument("--reward_scale_factor", default=1.0, type=float) parser.add_argument("--debug_summaries", default=False, type=bool) parser.add_argument("--summarize_grads_and_vars", default=False, type=bool) ##transformer parameters parser.add_argument("--d_model", default=64, type=int) parser.add_argument("--num_layers", default=3, type=int) parser.add_argument("--dff", default=256, type=int) ##Training parameters parser.add_argument('--num_iterations', type=int, default=100000,help="steps in the env") parser.add_argument('--num_parallel', type=int, default=30,help="how many envs should run in parallel") parser.add_argument("--collect_episodes_per_iteration", default=1, type=int) parser.add_argument('--num_epochs', type=int, default = 25,help = 'Number of epochs for computing policy updates.') ## Other parameters parser.add_argument("--num_eval_episodes", default=10, type=int) parser.add_argument("--eval_interval", default=1000, type=int) parser.add_argument("--log_interval", default=10, type=int) parser.add_argument("--summary_interval", default=1000, type=int) parser.add_argument("--run_graph_mode", default=True, type=bool) parser.add_argument("--checkpoint_interval", default=1000, type=int) parser.add_argument("--summary_flush", default=10, type=int) #what does this exactly do? # HP opt params #parser.add_argument("--doubleQ", default=True, type=bool,help="Whether to use a DoubleQ agent") parser.add_argument("--custom_last_layer", default=True, type=bool) parser.add_argument("--custom_layer_init", default=1.0,type= float) parser.add_argument("--initial_collect_steps", default=5000, type=int) #parser.add_argument("--loss_function", default="element_wise_huber_loss", type=str) parser.add_argument("--num_heads", default=4, type=int) parser.add_argument("--normalize_env", default=False, type=bool) parser.add_argument('--custom_lr_schedule',default="No",type=str,help = "whether to use a custom LR schedule") #parser.add_argument("--epsilon_greedy", default=0.3, type=float) #parser.add_argument("--target_update_period", default=1000, type=int) parser.add_argument("--rate", default=0.1, type=float) # dropout rate (might be not used depending on the q network) #Setting this to 0.0 somehow break the code. Not relevant tho just select a network without dropout parser.add_argument("--gradient_clipping", default=True, type=bool) parser.add_argument("--replay_buffer_max_length", default=1001, type=int) #parser.add_argument("--batch_size", default=32, type=int) parser.add_argument("--learning_rate", default=1e-4, type=float) parser.add_argument("--encoder_type", default=3, type=int,help="Which Type of encoder is used for the model") parser.add_argument("--layer_type", default=3, type=int,help="Which Type of layer is used for the encoder") #parser.add_argument("--target_update_tau", default=1, type=float) #parser.add_argument("--gamma", default=0.99, type=float) args = parser.parse_args() global_step = tf.compat.v1.train.get_or_create_global_step() baseEnv = gym.make(args.env) eval_tf_env = tf_py_environment.TFPyEnvironment(PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari)) #[lambda: PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari)] * args.num_parallel) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( #[lambda: PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari)] * args.num_parallel)) [lambda: PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari)] * args.num_parallel)) actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=(200, 100), activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=(200, 100), activation_fn=tf.keras.activations.tanh) actor_net = QTransformer( tf_env.observation_spec(), baseEnv.action_space.n, num_layers=args.num_layers, d_model=args.d_model, num_heads=args.num_heads, dff=args.dff, rate = args.rate, encoderType = args.encoder_type, enc_layer_type=args.layer_type, max_horizon=args.max_horizon, custom_layer = args.custom_layer_init, custom_last_layer = args.custom_last_layer) if args.custom_lr_schedule == "Transformer": # builds a lr schedule according to the original usage for the transformer learning_rate = CustomSchedule(args.d_model,int(args.num_iterations/10)) optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) elif args.custom_lr_schedule == "Transformer_low": # builds a lr schedule according to the original usage for the transformer learning_rate = CustomSchedule(int(args.d_model/2),int(args.num_iterations/10)) # --> same schedule with lower general lr optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) elif args.custom_lr_schedule == "Linear": lrs = LinearCustomSchedule(learning_rate,args.num_iterations) optimizer = tf.keras.optimizers.Adam(lrs, beta_1=0.9, beta_2=0.98, epsilon=1e-9) else: optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.learning_rate) tf_agent = ppo_clip_agent.PPOClipAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, num_epochs=args.num_epochs, debug_summaries=args.debug_summaries, summarize_grads_and_vars=args.summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_eval( args.output_dir, 0, # ?? # TODO(b/127576522): rename to policy_fc_layers. tf_agent, eval_tf_env, tf_env, # Params for collect args.num_iterations, args.collect_episodes_per_iteration, args.num_parallel, args.replay_buffer_max_length, # Per-environment # Params for train args.num_epochs, args.learning_rate, # Params for eval args.num_eval_episodes, args.eval_interval, # Params for summaries and logging args.checkpoint_interval, args.checkpoint_interval, args.checkpoint_interval, args.log_interval, args.summary_interval, args.summary_flush, args.debug_summaries, args.summarize_grads_and_vars, args.run_graph_mode, None) pickle.dump(args,open(args.output_dir + "/training_args.p","wb")) print("Successfully trained and evaluation.")
def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=50, rb_checkpoint_interval=200, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss = sess.run(train_op) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def train_eval( root_dir, env_name='HalfCheetah-v1', env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400,), critic_action_fc_layers=None, critic_joint_fc_layers=(300,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_py_env = env_load_fn(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() collect_policy = tf_agent.collect_policy() initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, unused_info = iterator.get_next() train_op = tf_agent.train( experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def main(_): tf.compat.v1.enable_resource_variables() if tf.executing_eagerly(): # self.skipTest('b/123777119') # Secondary bug: ('b/123775375') return # loop over game params to create different configs logging.set_verbosity(logging.INFO) # todo: when this training is done, try different learning rates and architectures for colors in COLORS: for ranks in RANKS: for num_players in NUM_PLAYERS: for hand_size in HAND_SIZES: for max_information_tokens in MAX_INFORMATION_TOKENS: for max_life_tokens in MAX_LIFE_TOKENS: # 2 * 1 * 1 * 4 * 4 * 2 = 64 total iterations for custom_reward in CUSTOM_REWARDS: for penalty in PENALTIES_LAST_HINT_TOKEN: config = { "colors": colors, "ranks": ranks, "players": num_players, "hand_size": hand_size, "max_information_tokens": max_information_tokens, "max_life_tokens": max_life_tokens, "observation_type": OBSERVATION_TYPE, "custom_reward": custom_reward, "penalty_last_hint_token": penalty, "per_card_reward": True } # ################################################ # # --------------- Load Environments -------------- # # ################################################ # eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: load_hanabi_env(config)] * FLAGS.num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment. ParallelPyEnvironment( [lambda: load_hanabi_env(config)] * FLAGS.num_parallel_environments)) train_eval( root_dir=FLAGS.root_dir, summary_dir=FLAGS.summary_dir, game_config=config, tf_master=FLAGS.master, replay_buffer_capacity=FLAGS. replay_buffer_capacity, env_load_fn=load_hanabi_env, num_environment_steps=FLAGS. num_environment_steps, num_parallel_environments=FLAGS. num_parallel_environments, num_epochs=FLAGS.num_epochs, collect_episodes_per_iteration=FLAGS. collect_episodes_per_iteration, num_eval_episodes=FLAGS. num_eval_episodes, use_rnns=FLAGS.use_rnns, eval_py_env=eval_py_env, tf_env=tf_env) del eval_py_env del tf_env