def test_parallel_envs(self): env_num = 5 ctors = [ lambda: suite_socialbot.load('SocialBot-CartPole-v0', wrap_with_process=False) ] * env_num self._env = parallel_py_environment.ParallelPyEnvironment( env_constructors=ctors, start_serially=False) tf_env = tf_py_environment.TFPyEnvironment(self._env) self.assertTrue(tf_env.batched) self.assertEqual(tf_env.batch_size, env_num) random_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_capacity = 100 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( random_policy.trajectory_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) steps = 100 step_driver = dynamic_step_driver.DynamicStepDriver( tf_env, random_policy, observers=[replay_buffer.add_batch], num_steps=steps) step_driver.run = common.function(step_driver.run) step_driver.run() self.assertIsNotNone(replay_buffer.get_next())
def test_thread_env(self): env_name = 'SocialBot-CartPole-v0' self._env = thread_environment.ThreadEnvironment( lambda: suite_socialbot.load(environment_name=env_name, wrap_with_process=False)) self.assertEqual(torch.float32, self._env.observation_spec().dtype) self.assertEqual((4, ), self._env.observation_spec().shape) self.assertEqual(torch.float32, self._env.action_spec().dtype) self.assertEqual((1, ), self._env.action_spec().shape) actions = self._env.action_spec().sample() for _ in range(10): time_step = self._env.step(actions)
def create_environment(env_name='CartPole-v0', env_load_fn=suite_gym.load, num_parallel_environments=30): """Create environment. Args: env_name (str): env name env_load_fn (Callable) : callable that create an environment num_parallel_environments (int): num of parallel environments """ if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: if env_load_fn == suite_socialbot.load: logging.info("suite_socialbot environment") # No need to wrap with process since ParallelPyEnvironment will do it env_load_fn = lambda env_name: suite_socialbot.load( env_name, wrap_with_process=False) py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) return tf_py_environment.TFPyEnvironment(py_env)
def ctor(env_name, env_id=None): return suite_socialbot.load(environment_name=env_name, wrap_with_process=False)
def test_action_spec(self): self._env = suite_socialbot.load('SocialBot-CartPole-v0', wrap_with_process=True) self.assertEqual(torch.float32, self._env.action_spec().dtype) self.assertEqual((1, ), self._env.action_spec().shape)
def test_socialbot_env_registered(self): self._env = suite_socialbot.load('SocialBot-CartPole-v0', wrap_with_process=True) self.assertIsInstance(self._env, alf_environment.AlfEnvironment)
def train_eval( root_dir, env_name='SocialBot-ICubWalkPID-v0', num_iterations=10000000, actor_fc_layers=(256, 128), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 128), # Params for collect initial_collect_steps=2000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, num_parallel_environments=12, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=5e-4, critic_learning_rate=5e-4, alpha_learning_rate=5e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: suite_socialbot.load(env_name,wrap_with_process=False)] * num_parallel_environments)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_socialbot.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = ( global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def test_action_spec(self): env = suite_socialbot.load('SocialBot-CartPole-v0') self.assertEqual(np.float32, env.action_spec().dtype) self.assertEqual((1, ), env.action_spec().shape)
def test_socialbot_env_registered(self): env = suite_socialbot.load('SocialBot-CartPole-v0') self.assertIsInstance(env, py_environment.PyEnvironment)