def testRunOnce(self, max_steps, max_episodes, expected_steps): env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() transition_replay_buffer_observer = MockReplayBufferObserver() driver = tf_driver.TFDriver( tf_env, policy, observers=[replay_buffer_observer], transition_observers=[transition_replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes) initial_time_step = tf_env.reset() initial_policy_state = policy.get_initial_state(batch_size=1) self.evaluate(driver.run(initial_time_step, initial_policy_state)) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:expected_steps]) transitions = transition_replay_buffer_observer.gather_all() self.assertLen(transitions, expected_steps) # TimeStep, Action, NextTimeStep self.assertLen(transitions[0], 3)
def testMultipleRunMaxEpisodes(self): num_episodes = 2 num_expected_steps = 6 env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() driver = tf_driver.TFDriver( tf_env, policy, observers=[replay_buffer_observer], max_steps=None, max_episodes=1, ) time_step = tf_env.reset() policy_state = policy.get_initial_state(batch_size=1) for _ in range(num_episodes): time_step, policy_state = self.evaluate( driver.run(time_step, policy_state)) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
def test_with_tf_driver(self): env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(), policy.policy_step_spec, tf_env.time_step_spec()) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec) driver = tf_driver.TFDriver(tf_env, policy, [tfrecord_observer], max_steps=10) self.evaluate(tf.compat.v1.global_variables_initializer()) time_step = self.evaluate(tf_env.reset()) initial_policy_state = policy.get_initial_state(batch_size=1) self.evaluate( common.function(driver.run)(time_step, initial_policy_state)) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def testBatchedEnvironment(self, max_steps, max_episodes, expected_length): expected_trajectories = [ trajectory.Trajectory( step_type=np.array([0, 0]), observation=np.array([0, 0]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([1, 1]), reward=np.array([1., 1.]), discount=np.array([1., 1.])), trajectory.Trajectory( step_type=np.array([1, 1]), observation=np.array([2, 1]), action=np.array([1, 2]), policy_info=np.array([2, 4]), next_step_type=np.array([2, 1]), reward=np.array([1., 1.]), discount=np.array([0., 1.])), trajectory.Trajectory( step_type=np.array([2, 1]), observation=np.array([3, 3]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([0, 2]), reward=np.array([0., 1.]), discount=np.array([1., 0.])) ] env1 = driver_test_utils.PyEnvironmentMock(final_state=3) env2 = driver_test_utils.PyEnvironmentMock(final_state=4) env = batched_py_environment.BatchedPyEnvironment([env1, env2]) tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock( tf_env.time_step_spec(), tf_env.action_spec(), batch_size=2, initial_policy_state=tf.constant([1, 2], dtype=tf.int32)) replay_buffer_observer = MockReplayBufferObserver() driver = tf_driver.TFDriver( tf_env, policy, observers=[replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, ) initial_time_step = tf_env.reset() initial_policy_state = tf.constant([1, 2], dtype=tf.int32) self.evaluate(driver.run(initial_time_step, initial_policy_state)) trajectories = replay_buffer_observer.gather_all() self.assertEqual( len(trajectories), len(expected_trajectories[:expected_length])) for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]): for t1_field, t2_field in zip(t1, t2): self.assertAllEqual(t1_field, t2_field)
def testValueErrorOnInvalidArgs(self, max_steps, max_episodes): env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() with self.assertRaises(ValueError): tf_driver.TFDriver( tf_env, policy, observers=[replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, )
def eager_compute(metrics, environment, policy, num_episodes=1, train_step=None, summary_writer=None, summary_prefix='', use_function=True): """Compute metrics using `policy` on the `environment`. *NOTE*: Because placeholders are not compatible with Eager mode we can not use python policies. Because we use tf_policies we need the environment time_steps to be tensors making it easier to use a tf_env for evaluations. Otherwise this method mirrors `compute` directly. Args: metrics: List of metrics to compute. environment: tf_environment instance. policy: tf_policy instance used to step the environment. num_episodes: Number of episodes to compute the metrics over. train_step: An optional step to write summaries against. summary_writer: An optional writer for generating metric summaries. summary_prefix: An optional prefix scope for metric summaries. use_function: Option to enable use of `tf.function` when collecting the metrics. Returns: A dictionary of results {metric_name: metric_value} """ for metric in metrics: metric.reset() multiagent_metrics = [m for m in metrics if 'Multiagent' in m.name] driver = tf_driver.TFDriver(environment, policy, observers=metrics, max_episodes=num_episodes, disable_tf_function=not use_function) def run_driver(): time_step = environment.reset() policy_state = policy.get_initial_state(environment.batch_size) driver.run(time_step, policy_state) if use_function: common.function(run_driver)() else: run_driver() results = [(metric.name, metric.result()) for metric in metrics] for m in multiagent_metrics: for a in range(m.n_agents): results.append((m.name + '_agent' + str(a), m.result_for_agent(a))) # TODO(b/120301678) remove the summaries and merge with compute if train_step and summary_writer: with summary_writer.as_default(): for m in metrics: tag = common.join_scope(summary_prefix, m.name) tf.compat.v2.summary.scalar(name=tag, data=m.result(), step=train_step) if 'Multiagent' in m.name: for a in range(m.n_agents): tf.compat.v2.summary.scalar(name=tag + '_agent' + str(a), data=m.result_for_agent(a), step=train_step) # TODO(b/130249101): Add an option to log metrics. return collections.OrderedDict(results)
def train_eval( root_dir, env_name='MultiGrid-Empty-5x5-v0', env_load_fn=multiagent_gym_suite.load, random_seed=0, # Architecture params agent_class=multiagent_ppo.MultiagentPPO, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), lstm_size=(64, ), conv_filters=64, conv_kernel=3, direction_fc=5, entropy_regularization=0., use_attention_networks=False, # Specialized agents inactive_agent_ids=tuple(), # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=5, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=2, learning_rate=1e-4, # Params for eval num_eval_episodes=2, eval_interval=5, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, log_interval=10, summary_interval=10, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None, reinit_checkpoint_dir=None, debug=True): """A simple train and eval for PPO.""" tf.compat.v1.enable_v2_behavior() if root_dir is None: raise AttributeError('train_eval requires a root_dir.') if debug: logging.info('In debug mode, turning tf_functions off') use_tf_functions = False for a in inactive_agent_ids: logging.info('Fixing and not training agent %d', a) # Load multiagent gym environment and determine number of agents gym_env = env_load_fn(env_name) n_agents = gym_env.n_agents # Set up logging root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ multiagent_metrics.AverageReturnMetric(n_agents, buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) logging.info('Creating %d environments...', num_parallel_environments) wrappers = [] if use_attention_networks: wrappers = [ lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size) ] eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name, gym_kwargs=dict(seed=random_seed), gym_env_wrappers=wrappers)) # pylint: disable=g-complex-comprehension tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ functools.partial(env_load_fn, environment_name=env_name, gym_env_wrappers=wrappers, gym_kwargs=dict(seed=random_seed * 1234 + i)) for i in range(num_parallel_environments) ])) logging.info('Preparing to train...') environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] bonus_metrics = [ multiagent_metrics.MultiagentScalar(n_agents, name='UnscaledMultiagentBonus', buffer_size=1000), ] train_metrics = step_metrics + [ multiagent_metrics.AverageReturnMetric( n_agents, batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] logging.info('Creating agent...') tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy logging.info('Allocating replay buffer ...') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) logging.info('RB capacity: %i', replay_buffer.capacity) # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is # reinitialized. The other agents are novices. # Otherwise, all agents are reinitialized from train_dir. if reinit_checkpoint_dir: reinit_checkpointer = common.Checkpointer( ckpt_dir=reinit_checkpoint_dir, agent=tf_agent, ) reinit_checkpointer.initialize_or_restore() temp_dir = os.path.join(train_dir, 'tmp') agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1], ) agent_checkpointer.save(global_step=0) tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids, non_learning_agents=list(range(n_agents - 1))) agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1]) agent_checkpointer.initialize_or_restore() tf.io.gfile.rmtree(temp_dir) eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=multiagent_metrics.MultiagentMetricsGroup( train_metrics + bonus_metrics, 'train_metrics')) if not reinit_checkpoint_dir: train_checkpointer.initialize_or_restore() logging.info('Successfully initialized train checkpointer') policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) collect_policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=collect_policy, global_step=global_step) collect_saved_model = policy_saver.PolicySaver(collect_policy, train_step=global_step) logging.info('Successfully initialized policy saver.') print('Using TFDriver') if use_attention_networks: collect_driver = drivers.StateTFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) else: collect_driver = tf_driver.TFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 # Save operative config as late as possible to include used configurables. if global_step.numpy() == 0: config_filename = os.path.join( train_dir, 'operative_config-{}.gin'.format(global_step.numpy())) with tf.io.gfile.GFile(config_filename, 'wb') as f: f.write(gin.operative_config_str()) total_episodes = 0 logging.info('Commencing train loop!') while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() # Evaluation if global_step_val % eval_interval == 0: if debug: logging.info('Performing evaluation at step %d', global_step_val) results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics) # Collect data if debug: logging.info('Collecting at step %d', global_step_val) start_time = time.time() time_step = tf_env.reset() policy_state = collect_policy.get_initial_state(tf_env.batch_size) if use_attention_networks: # Attention networks require previous policy state to compute attention # weights. time_step.observation['policy_state'] = ( policy_state['actor_network_state'][0], policy_state['actor_network_state'][1]) collect_driver.run(time_step, policy_state) collect_time += time.time() - start_time total_episodes += collect_episodes_per_iteration if debug: logging.info('Have collected a total of %d episodes', total_episodes) # Train if debug: logging.info('Training at step %d', global_step_val) start_time = time.time() total_loss, extra_loss = train_step() replay_buffer.clear() train_time += time.time() - start_time # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: logging.info( 'Loss diverged for too many timesteps, breaking...') break else: loss_divergence_counter = 0 for train_metric in train_metrics + bonus_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, total loss = %f', global_step_val, total_loss) for a in range(n_agents): if not inactive_agent_ids or a not in inactive_agent_ids: logging.info('Loss for agent %d = %f', a, extra_loss[a].loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) collect_policy_checkpointer.save(global_step=global_step_val) collect_saved_model_path = os.path.join( saved_model_dir, 'collect_policy_' + ('%d' % global_step_val).zfill(9)) collect_saved_model.save(collect_saved_model_path) # One final eval before exiting. results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics)