def test_with_normal_context_and_normal_reward(self): def _context_sampling_fn(): return np.random.normal(0, 3, [1, 2]) def _reward_fn(x): return np.random.normal(2 * x[0], abs(x[1]) + 1) env = sspe.StationaryStochasticPyEnvironment(_context_sampling_fn, [_reward_fn]) time_step_spec = env.time_step_spec() action_spec = env.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec) for _ in range(5): time_step = env.reset() self.assertTrue( check_unbatched_time_step_spec( time_step=time_step, time_step_spec=time_step_spec, batch_size=env.batch_size)) action = random_policy.action(time_step).action time_step = env.step(action)
def testMetricIsComputedCorrectly(self): def reward_fn(*unused_args): reward = np.random.uniform() reward_fn.total_reward += reward return reward reward_fn.total_reward = 0 action_spec = array_spec.BoundedArraySpec((1, ), np.int32, -10, 10) observation_spec = array_spec.BoundedArraySpec((1, ), np.int32, -10, 10) env = random_py_environment.RandomPyEnvironment(observation_spec, action_spec, reward_fn=reward_fn) policy = random_py_policy.RandomPyPolicy(time_step_spec=None, action_spec=action_spec) average_return = py_metrics.AverageReturnMetric() num_episodes = 10 results = metric_utils.compute([average_return], env, policy, num_episodes) self.assertAlmostEqual(reward_fn.total_reward / num_episodes, results[average_return.name], places=5)
def test_with_uniform_context_and_normal_mu_reward(self): def _context_sampling_fn(): return np.random.randint(-10, 10, [1, 4]) reward_fns = [ LinearNormalReward(theta) for theta in ([0, 1, 2, 3], [3, 2, 1, 0], [-1, -2, -3, -4]) ] env = sspe.StationaryStochasticPyEnvironment(_context_sampling_fn, reward_fns) time_step_spec = env.time_step_spec() action_spec = env.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec) for _ in range(5): time_step = env.reset() self.assertTrue( check_unbatched_time_step_spec(time_step=time_step, time_step_spec=time_step_spec, batch_size=env.batch_size)) action = random_policy.action(time_step).action time_step = env.step(action)
def validate_py_environment( environment: py_environment.PyEnvironment, episodes: int = 5, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None): """Validates the environment follows the defined specs.""" time_step_spec = environment.time_step_spec() action_spec = environment.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec, observation_and_action_constraint_splitter=( observation_and_action_constraint_splitter)) if environment.batch_size is not None: batched_time_step_spec = array_spec.add_outer_dims_nest( time_step_spec, outer_dims=(environment.batch_size, )) else: batched_time_step_spec = time_step_spec episode_count = 0 time_step = environment.reset() while episode_count < episodes: if not array_spec.check_arrays_nest(time_step, batched_time_step_spec): raise ValueError('Given `time_step`: %r does not match expected ' '`time_step_spec`: %r' % (time_step, batched_time_step_spec)) action = random_policy.action(time_step).action time_step = environment.step(action) episode_count += np.sum(time_step.is_last())
def collect(summary_dir: Text, environment_name: Text, collect_policy: py_tf_eager_policy.PyTFEagerPolicyBase, replay_buffer_server_address: Text, variable_container_server_address: Text, suite_load_fn: Callable[ [Text], py_environment.PyEnvironment] = suite_mujoco.load, initial_collect_steps: int = 10000, max_train_steps: int = 2000000) -> None: """Collects experience using a policy updated after every episode.""" # Create the environment. For now support only single environment collection. collect_env = suite_load_fn(environment_name) # Create the variable container. train_step = train_utils.create_train_step() variables = { reverb_variable_container.POLICY_KEY: collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.update(variables) # Create the replay buffer observer. rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb.Client(replay_buffer_server_address), table_name=reverb_replay_buffer.DEFAULT_TABLE, sequence_length=2, stride_length=1) random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() env_step_metric = py_metrics.EnvironmentSteps() collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=summary_dir, observers=[rb_observer, env_step_metric]) # Run the experience collection loop. while train_step.numpy() < max_train_steps: logging.info('Collecting with policy at step: %d', train_step.numpy()) collect_actor.run() variable_container.update(variables)
def create_random_gif(): """Create a gif showing a random policy.""" env_params = { 'monster_speed': 0.7, 'timeout_factor': 20, 'step_size': 0.05, 'n_actions': 8 } py_env = LakeMonsterEnvironment(**env_params) policy = random_py_policy.RandomPyPolicy(time_step_spec=None, action_spec=py_env.action_spec()) save_path = os.path.join(configs.ASSETS_DIR, 'random.gif') episode_as_gif(py_env, policy, save_path=save_path)
def testGeneratesActions(self): action_spec = [ array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10), array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10) ] policy = random_py_policy.RandomPyPolicy(time_step_spec=None, action_spec=action_spec) action_step = policy.action(None) tf.nest.assert_same_structure(action_spec, action_step.action) self.assertTrue(np.all(action_step.action[0] >= -10)) self.assertTrue(np.all(action_step.action[0] <= 10)) self.assertTrue(np.all(action_step.action[1] >= -10)) self.assertTrue(np.all(action_step.action[1] <= 10))
def _initial_collect(self): """Collect initial experience before training begins.""" logging.info('Collecting initial experience...') time_step_spec = ts.time_step_spec(self._env.observation_spec()) random_policy = random_py_policy.RandomPyPolicy( time_step_spec, self._env.action_spec()) time_step = self._env.reset() while self._replay_buffer.size < self._initial_collect_steps: if self.game_over(): time_step = self._env.reset() action_step = random_policy.action(time_step) next_time_step = self._env.step(action_step.action) self._replay_buffer.add_batch(trajectory.from_transition( time_step, action_step, next_time_step)) time_step = next_time_step logging.info('Done.')
def test_with_random_policy(self): def _global_context_sampling_fn(): abc = np.array(['a', 'b', 'c']) return { 'global1': np.random.randint(-2, 3, [3, 4]), 'global2': abc[np.random.randint(0, 2, [1])] } def _arm_context_sampling_fn(): aabbcc = np.array(['aa', 'bb', 'cc']) return { 'arm1': np.random.randint(-3, 4, [5]), 'arm2': np.random.randint(-3, 4, [3, 1]), 'arm3': aabbcc[np.random.randint(0, 2, [1])] } def _reward_fn(global_obs, arm_obs): return global_obs['global1'][2, 1] + arm_obs['arm1'][4] env = ssspe.StationaryStochasticStructuredPyEnvironment( _global_context_sampling_fn, _arm_context_sampling_fn, 6, _reward_fn, batch_size=2) time_step_spec = env.time_step_spec() action_spec = array_spec.BoundedArraySpec(shape=(), minimum=0, maximum=5, dtype=np.int32) random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec) for _ in range(5): time_step = env.reset() self.assertTrue( check_unbatched_time_step_spec(time_step=time_step, time_step_spec=time_step_spec, batch_size=env.batch_size)) action = random_policy.action(time_step).action self.assertAllEqual(action.shape, [2]) self.assertAllGreaterEqual(action, 0) self.assertAllLess(action, 6) time_step = env.step(action) self.assertEqual(time_step.reward.shape, (2, ))
def testGeneratesBatchedActions(self): action_spec = [ array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10), array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10) ] policy = random_py_policy.RandomPyPolicy( time_step_spec=None, action_spec=action_spec, outer_dims=(3,)) action_step = policy.action(None) nest.assert_same_structure(action_spec, action_step.action) self.assertEqual((3, 2, 3), action_step.action[0].shape) self.assertEqual((3, 1, 2), action_step.action[1].shape) self.assertTrue(np.all(action_step.action[0] >= -10)) self.assertTrue(np.all(action_step.action[0] <= 10)) self.assertTrue(np.all(action_step.action[1] >= -10)) self.assertTrue(np.all(action_step.action[1] <= 10))
def _insert_random_data(self, env, num_steps, sequence_length=2, additional_observers=None): """Insert `num_step` random observations into Reverb server.""" observers = [] if additional_observers is None else additional_observers traj_obs = reverb_utils.ReverbAddTrajectoryObserver( self._py_client, self._table_name, sequence_length=sequence_length) observers.append(traj_obs) policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) driver = py_driver.PyDriver(env, policy, observers=observers, max_steps=num_steps) time_step = env.reset() driver.run(time_step) traj_obs.close()
def test_with_variable_num_actions_masking(self): def _global_context_sampling_fn(): return np.random.randint(-10, 10, [4]) def _arm_context_sampling_fn(): return np.random.randint(-2, 3, [5]) def _num_actions_fn(): return np.random.randint(0, 7) reward_fn = LinearNormalReward([0, 1, 2, 3, 4, 5, 6, 7, 8]) env = sspe.StationaryStochasticPerArmPyEnvironment( _global_context_sampling_fn, _arm_context_sampling_fn, 6, reward_fn, _num_actions_fn, batch_size=2, add_num_actions_feature=False) time_step_spec = env.time_step_spec() self.assertAllEqual(time_step_spec.observation[1].shape, [6]) action_spec = array_spec.BoundedArraySpec(shape=(), minimum=0, maximum=5, dtype=np.int32) random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec) for _ in range(5): time_step = env.reset() self.assertTrue( check_unbatched_time_step_spec(time_step=time_step, time_step_spec=time_step_spec, batch_size=env.batch_size)) action = random_policy.action(time_step).action self.assertAllEqual(action.shape, [2]) self.assertAllGreaterEqual(action, 0) self.assertAllLess(action, 6) time_step = env.step(action)
def testGeneratesBatchedActionsWithoutSpecifyingOuterDims(self): action_spec = [ array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10), array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10) ] time_step_spec = time_step.time_step_spec( observation_spec=array_spec.ArraySpec((1, ), np.int32)) policy = random_py_policy.RandomPyPolicy(time_step_spec=time_step_spec, action_spec=action_spec) action_step = policy.action( time_step.restart(np.array([[1], [2], [3]], dtype=np.int32))) tf.nest.assert_same_structure(action_spec, action_step.action) self.assertEqual((3, 2, 3), action_step.action[0].shape) self.assertEqual((3, 1, 2), action_step.action[1].shape) self.assertTrue(np.all(action_step.action[0] >= -10)) self.assertTrue(np.all(action_step.action[0] <= 10)) self.assertTrue(np.all(action_step.action[1] >= -10)) self.assertTrue(np.all(action_step.action[1] <= 10))
def testRandomPyPolicyGeneratesActionTensors(self): array_action_spec = array_spec.BoundedArraySpec((7,), np.int32, -10, 10) observation = tf.ones([3], tf.float32) time_step = ts.restart(observation) observation_spec = tensor_spec.TensorSpec.from_tensor(observation) time_step_spec = ts.time_step_spec(observation_spec) tf_py_random_policy = tf_py_policy.TFPyPolicy( random_py_policy.RandomPyPolicy(time_step_spec=time_step_spec, action_spec=array_action_spec)) batched_time_step = nest_utils.batch_nested_tensors(time_step) action_step = tf_py_random_policy.action(time_step=batched_time_step) action, new_policy_state = self.evaluate( [action_step.action, action_step.state]) self.assertEqual((1,) + array_action_spec.shape, action.shape) self.assertTrue(np.all(action >= array_action_spec.minimum)) self.assertTrue(np.all(action <= array_action_spec.maximum)) self.assertEqual(new_policy_state, ())
def _create_collect_actor( collect_env: YGOEnvironment, collect_policy: PyTFEagerPolicy, train_step, rb_observer: ReverbAddTrajectoryObserver) -> actor.Actor: initial_collect_actor = actor.Actor( collect_env, random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()), train_step, episodes_per_run=_initial_collect_episodes, observers=[rb_observer]) initial_collect_actor.run() return actor.Actor(collect_env, collect_policy, train_step, episodes_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join(tempdir, learner.TRAIN_DIR), observers=[rb_observer, py_metrics.EnvironmentSteps()])
def validate_py_environment(environment, episodes=5): """Validates the environment follows the defined specs.""" time_step_spec = environment.time_step_spec() action_spec = environment.action_spec() random_policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec) episode_count = 0 time_step = environment.reset() while episode_count < episodes: if not array_spec.check_arrays_nest(time_step, time_step_spec): raise ValueError( 'Given `time_step`: %r does not match expected `time_step_spec`: %r' % (time_step, random_policy.time_step_spec())) action = random_policy.action(time_step).action time_step = environment.step(action) if time_step.is_last(): episode_count += 1
def profile_env(env_str, max_ep_len, n_steps=None, env_wrappers=[]): n_steps = n_steps or max_ep_len * 2 profile = [None] def profile_fn(p): assert isinstance(p, cProfile.Profile) profile[0] = p py_env = suite_gym.load(env_str, gym_env_wrappers=env_wrappers, max_episode_steps=max_ep_len) env = wrappers.PerformanceProfiler( py_env, process_profile_fn=profile_fn, process_steps=n_steps) policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) driver = py_driver.PyDriver(env, policy, [], max_steps=n_steps) time_step = env.reset() policy_state = policy.get_initial_state() for _ in range(n_steps): time_step, policy_state = driver.run(time_step, policy_state) stats = pstats.Stats(profile[0]) stats.print_stats()
def testPyPolicyIsBatchedTrue(self): action_dims = 5 observation_dims = 3 batch_size = 2 array_action_spec = array_spec.BoundedArraySpec((action_dims, ), np.int32, -10, 10) observation_spec = array_spec.ArraySpec((observation_dims, ), np.float32) array_time_step_spec = ts.time_step_spec(observation_spec) observation = tf.ones([batch_size, observation_dims], tf.float32) time_step = ts.restart(observation, batch_size=batch_size) tf_py_random_policy = tf_py_policy.TFPyPolicy( random_py_policy.RandomPyPolicy( time_step_spec=array_time_step_spec, action_spec=array_action_spec), py_policy_is_batched=True) action_step = tf_py_random_policy.action(time_step=time_step) action = self.evaluate(action_step.action) self.assertEqual(action.shape, (batch_size, action_dims))
def testRandomPyPolicyGeneratesActionTensors(self): if tf.executing_eagerly(): self.skipTest('b/123935604') py_action_spec = array_spec.BoundedArraySpec((7,), np.int32, -10, 10) observation = tf.ones([3], tf.float32) time_step = ts.restart(observation) observation_spec = tensor_spec.TensorSpec.from_tensor(observation) time_step_spec = ts.time_step_spec(observation_spec) tf_py_random_policy = tf_py_policy.TFPyPolicy( random_py_policy.RandomPyPolicy(time_step_spec=time_step_spec, action_spec=py_action_spec)) action_step = tf_py_random_policy.action(time_step=time_step) py_action, py_new_policy_state = self.evaluate( [action_step.action, action_step.state]) self.assertEqual(py_action.shape, py_action_spec.shape) self.assertTrue(np.all(py_action >= py_action_spec.minimum)) self.assertTrue(np.all(py_action <= py_action_spec.maximum)) self.assertEqual(py_new_policy_state, ())
def testMasking(self): batch_size = 1000 time_step_spec = time_step.time_step_spec( observation_spec=array_spec.ArraySpec((1, ), np.int32)) action_spec = array_spec.BoundedArraySpec((), np.int64, -5, 5) # We create a fixed mask here for testing purposes. Normally the mask would # be part of the observation. mask = [0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0] np_mask = np.array(mask) batched_mask = np.array([mask for _ in range(batch_size)]) policy = random_py_policy.RandomPyPolicy( time_step_spec=time_step_spec, action_spec=action_spec, observation_and_action_constraint_splitter=(lambda obs: (obs, batched_mask))) my_time_step = time_step.restart(time_step_spec, batch_size) action_step = policy.action(my_time_step) tf.nest.assert_same_structure(action_spec, action_step.action) # Sample from the policy 1000 times, and ensure that actions considered # invalid according to the mask are never chosen. action_ = self.evaluate(action_step.action) self.assertTrue(np.all(action_ >= -5)) self.assertTrue(np.all(action_ <= 5)) self.assertAllEqual(np_mask[action_ - action_spec.minimum], np.ones([batch_size])) # Ensure that all valid actions occur somewhere within the batch. Because we # sample 1000 times, the chance of this failing for any particular action is # (2/3)^1000, roughly 1e-176. for index in range(action_spec.minimum, action_spec.maximum + 1): if np_mask[index - action_spec.minimum]: self.assertIn(index, action_)
def testPolicyStateSpecIsEmpty(self): policy = random_py_policy.RandomPyPolicy(time_step_spec=None, action_spec=[]) self.assertEqual(policy.policy_state_spec, ())
def train_eval( root_dir, # Dataset params env_name, data_dir=None, load_pretrained=False, pretrained_model_dir=None, img_pad=4, frame_shape=(84, 84, 3), frame_stack=3, num_augmentations=2, # K and M in DrQ # Training params contrastive_loss_weight=1.0, contrastive_loss_temperature=0.5, image_encoder_representation=True, initial_collect_steps=1000, num_train_steps=3000000, actor_fc_layers=(1024, 1024), critic_joint_fc_layers=(1024, 1024), # Agent params batch_size=256, actor_learning_rate=1e-3, critic_learning_rate=1e-3, alpha_learning_rate=1e-3, encoder_learning_rate=1e-3, actor_update_freq=2, gamma=0.99, target_update_tau=0.01, target_update_period=2, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others checkpoint_interval=10000, policy_save_interval=5000, eval_interval=10000, summary_interval=250, debug_summaries=False, eval_episodes_per_run=10, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" collect_env = env_utils.load_dm_env_for_training(env_name, frame_shape, frame_stack=frame_stack) eval_env = env_utils.load_dm_env_for_eval(env_name, frame_shape, frame_stack=frame_stack) logging.info('Data directory: %s', data_dir) logging.info('Num train steps: %d', num_train_steps) logging.info('Contrastive loss coeff: %.2f', contrastive_loss_weight) logging.info('Contrastive loss temperature: %.4f', contrastive_loss_temperature) logging.info('load_pretrained: %s', 'yes' if load_pretrained else 'no') logging.info('encoder representation: %s', 'yes' if image_encoder_representation else 'no') load_episode_data = (contrastive_loss_weight > 0) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() image_encoder = networks.ImageEncoder(observation_tensor_spec) actor_net = model_utils.Actor( observation_tensor_spec, action_tensor_spec, image_encoder=image_encoder, fc_layers=actor_fc_layers, image_encoder_representation=image_encoder_representation) critic_net = networks.Critic((observation_tensor_spec, action_tensor_spec), image_encoder=image_encoder, joint_fc_layers=critic_joint_fc_layers) critic_net_2 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=image_encoder, joint_fc_layers=critic_joint_fc_layers) target_image_encoder = networks.ImageEncoder(observation_tensor_spec) target_critic_net_1 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=target_image_encoder) target_critic_net_2 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=target_image_encoder) agent = pse_drq_agent.DrQSacModifiedAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, critic_network_2=critic_net_2, target_critic_network=target_critic_net_1, target_critic_network_2=target_critic_net_2, actor_update_frequency=actor_update_freq, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), contrastive_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=encoder_learning_rate), contrastive_loss_weight=contrastive_loss_weight, contrastive_loss_temperature=contrastive_loss_temperature, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, use_log_alpha_in_alpha_loss=False, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step, num_augmentations=num_augmentations) agent.initialize() # Setup the replay buffer. reverb_replay, rb_observer = ( replay_buffer_utils.get_reverb_buffer_and_observer( agent.collect_data_spec, sequence_length=2, replay_capacity=replay_capacity, port=reverb_port)) # pylint: disable=g-long-lambda if num_augmentations == 0: image_aug = lambda traj, meta: (dict( experience=traj, augmented_obs=[], augmented_next_obs=[]), meta) else: image_aug = lambda traj, meta: pse_drq_agent.image_aug( traj, meta, img_pad, num_augmentations) augmented_dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).unbatch().map( image_aug, num_parallel_calls=3) augmented_iterator = iter(augmented_dataset) trajs = augmented_dataset.batch(batch_size).prefetch(50) if load_episode_data: # Load full episodes and zip them episodes = dataset_utils.load_episodes( os.path.join(data_dir, 'episodes2'), img_pad) episode_iterator = iter(episodes) dataset = tf.data.Dataset.zip((trajs, episodes)).prefetch(10) else: dataset = trajs experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) learning_triggers = [ triggers.PolicySavedModelTrigger(saved_model_dir, agent, train_step, interval=policy_save_interval), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] agent_learner = model_utils.Learner( root_dir, train_step, agent, experience_dataset_fn=experience_dataset_fn, triggers=learning_triggers, checkpoint_interval=checkpoint_interval, summary_interval=summary_interval, load_episode_data=load_episode_data, use_kwargs_in_agent_train=True, # Turn off the initialization of the optimizer variables since, the agent # expects different batching for the `training_data_spec` and # `train_argspec` which can't be handled in general by the initialization # logic in the learner. run_optimizer_variable_init=False) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. train_dir = os.path.join(root_dir, learner.TRAIN_DIR) # Code for loading pretrained policy. if load_pretrained: # Note that num_train_steps is same as the max_train_step we want to # load the pretrained policy for our experiments pretrained_policy = model_utils.load_pretrained_policy( pretrained_model_dir, num_train_steps) initial_collect_policy = pretrained_policy agent.policy.update_partial(pretrained_policy) agent.collect_policy.update_partial(pretrained_policy) logging.info('Restored pretrained policy.') else: initial_collect_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, initial_collect_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer], metrics=actor.collect_metrics(buffer_size=10), summary_dir=train_dir, summary_interval=summary_interval, name='CollectActor') # If restarting with train_step > 0, the replay buffer will be empty # except for random experience. Populate the buffer with some on-policy # experience. if load_pretrained or (agent_learner.train_step_numpy > 0): for _ in range(batch_size * 50): collect_actor.run() tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor(eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes_per_run, metrics=actor.eval_metrics(buffer_size=10), summary_dir=os.path.join(root_dir, 'eval'), summary_interval=-1, name='EvalTrainActor') if eval_interval: logging.info('Evaluating.') img_summary( next(augmented_iterator)[0], eval_actor.summary_writer, train_step) if load_episode_data: contrastive_img_summary(next(episode_iterator), agent, eval_actor.summary_writer, train_step) eval_actor.run_and_log() logging.info('Saving operative gin config file.') gin_path = os.path.join(train_dir, 'train_operative_gin_config.txt') with tf.io.gfile.GFile(gin_path, mode='w') as f: f.write(gin.operative_config_str()) logging.info('Training Staring at: %r', train_step.numpy()) while train_step < num_train_steps: collect_actor.run() agent_learner.run(iterations=1) if (not eval_interval) and (train_step % 10000 == 0): img_summary( next(augmented_iterator)[0], agent_learner.train_summary_writer, train_step) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') img_summary( next(augmented_iterator)[0], eval_actor.summary_writer, train_step) if load_episode_data: contrastive_img_summary(next(episode_iterator), agent, eval_actor.summary_writer, train_step) eval_actor.run_and_log()
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, log_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # Note this is a python environment. env = batched_py_environment.BatchedPyEnvironment( [suite_gym.load(env_name)]) eval_py_env = suite_gym.load(env_name) # Convert specs to BoundedTensorSpec. action_spec = tensor_spec.from_spec(env.action_spec()) observation_spec = tensor_spec.from_spec(env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()), tensor_spec.from_spec(env.action_spec()), fc_layer_params=fc_layer_params) # The agent must be in graph. global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_collect_policy = agent.collect_policy collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy) greedy_policy = py_tf_policy.PyTFPolicy(agent.policy) random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) # Python replay buffer. replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec)) time_step = env.reset() # Initialize the replay buffer with some transitions. We use the random # policy to initialize the replay buffer to make sure we get a good # distribution of actions. for _ in range(initial_collect_steps): time_step = collect_step(env, time_step, random_policy, replay_buffer) # TODO(b/112041045) Use global_step as counter. train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=agent.policy, global_step=global_step) ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) itr = tf.compat.v1.data.make_initializable_iterator(ds) experience = itr.get_next() train_op = common.function(agent.train)(experience) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) with tf.compat.v1.Session() as session: train_checkpointer.initialize_or_restore(session) common.initialize_uninitialized_variables(session) session.run(itr.initializer) # Copy critic network values to the target critic network. session.run(agent.initialize()) train = session.make_callable(train_op) global_step_call = session.make_callable(global_step) session.run(train_summary_writer.init()) session.run(eval_summary_writer.init()) # Compute initial evaluation metrics. global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) timed_at_step = global_step_val collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() for _ in range(collect_steps_per_iteration): time_step = collect_step(env, time_step, collect_policy, replay_buffer) collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss = train() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) session.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) # Reset timing to avoid counting eval time. timed_at_step = global_step_val start_time = time.time()
def _create_random_policy_from_env(env): return random_py_policy.RandomPyPolicy( ts.time_step_spec(env.observation_spec()), env.action_spec())
def train_eval( root_dir, env_name='CartPole-v0', # Training params initial_collect_steps=1000, num_iterations=100000, fc_layer_params=(100, ), # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec()) action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec()) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 # Define a helper function to create Dense layers configured with the right # activation and kernel initializer. def dense_layer(num_units): return tf.keras.layers.Dense( num_units, activation=tf.keras.activations.relu, kernel_initializer=tf.keras.initializers.VarianceScaling( scale=2.0, mode='fan_in', distribution='truncated_normal')) # QNetwork consists of a sequence of Dense layers followed by a dense layer # with `num_actions` units to generate one q_value per available action as # it's output. dense_layers = [dense_layer(num_units) for num_units in fc_layer_params] q_values_layer = tf.keras.layers.Dense( num_actions, activation=None, kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03, maxval=0.03), bias_initializer=tf.keras.initializers.Constant(-0.2)) q_net = sequential.Sequential(dense_layers + [q_values_layer]) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name, # Training params train_sequence_length, initial_collect_steps=1000, collect_steps_per_iteration=1, num_iterations=100000, # RNN params. q_network_fn=q_lstm_network, # defaults to q_lstm_network. # Agent params epsilon_greedy=0.1, batch_size=64, learning_rate=1e-3, gamma=0.99, target_update_tau=0.05, target_update_period=5, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10): """Trains and evaluates DQN.""" collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 q_net = q_network_fn(num_actions=num_actions) sequence_length = train_sequence_length + 1 agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, # n-step updates aren't supported with RNNs yet. n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step) table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=sequence_length, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=sequence_length, stride_length=1, pad_end_of_episodes=True) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_steps_per_iteration, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def __init__( self, greedy_policy: py_policy.PyPolicy, epsilon: types.Float, random_policy: Optional[random_py_policy.RandomPyPolicy] = None, epsilon_decay_end_count: Optional[types.Float] = None, epsilon_decay_end_value: Optional[types.Float] = None, random_seed: Optional[types.Seed] = None): """Initializes the epsilon-greedy policy. Args: greedy_policy: An instance of py_policy.PyPolicy to use as the greedy policy. epsilon: The probability 0.0 <= epsilon <= 1.0 with which an action will be selected at random. random_policy: An instance of random_py_policy.RandomPyPolicy to use as the random policy, if None is provided, a RandomPyPolicy will be automatically created with the greedy_policy's action_spec and observation_spec and random_seed. epsilon_decay_end_count: if set, anneal the epislon every time this policy is used, until it hits the epsilon_decay_end_value. epsilon_decay_end_value: the value of epislon to use when the policy usage count hits epsilon_decay_end_count. random_seed: seed used to create numpy.random.RandomState. /dev/urandom will be used if it's None. Raises: ValueError: If epsilon is not between 0.0 and 1.0. Or if epsilon_decay_end_value is invalid when epsilon_decay_end_count is set. """ if not 0 <= epsilon <= 1.0: raise ValueError('epsilon should be in [0.0, 1.0]') self._greedy_policy = greedy_policy if random_policy is None: self._random_policy = random_py_policy.RandomPyPolicy( time_step_spec=greedy_policy.time_step_spec, action_spec=greedy_policy.action_spec, seed=random_seed) else: self._random_policy = random_policy # TODO(b/110841809) consider making epsilon be provided by a function. self._epsilon = epsilon self._epsilon_decay_end_count = epsilon_decay_end_count if epsilon_decay_end_count is not None: if epsilon_decay_end_value is None or epsilon_decay_end_value >= epsilon: raise ValueError( 'Invalid value for epsilon_decay_end_value {}'.format( epsilon_decay_end_value)) self._epsilon_decay_step_factor = float( epsilon - epsilon_decay_end_value) / epsilon_decay_end_count self._epsilon_decay_end_value = epsilon_decay_end_value self._random_seed = random_seed # Keep it for copy method. self._rng = np.random.RandomState(random_seed) # Total times action method has been called. self._count = 0 super(EpsilonGreedyPolicy, self).__init__(greedy_policy.time_step_spec, greedy_policy.action_spec, greedy_policy.policy_state_spec, greedy_policy.info_spec)
sequence_length=2, table_name=table_name, local_server=reverb_server) dataset = reverb_replay.as_dataset(sample_batch_size=HyperParms.batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset print(f" -- POLICIES ({now()}) -- ") tf_eval_policy = tf_agent.policy eval_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_eval_policy, use_tf_function=True) tf_collect_policy = tf_agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) print(f" -- ACTORS ({now()}) -- ") rb_observer = reverb_utils.ReverbAddTrajectoryObserver(reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=HyperParms.initial_collect_steps, observers=[rb_observer]) initial_collect_actor.run()
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name='Pong-v0', # Training params update_frequency=4, # Number of collect steps per policy update initial_collect_steps=50000, # 50k collect steps num_iterations=50000000, # 50M collect steps # Taken from Rainbow as it's not specified in Mnih,15. max_episode_frames_collect=50000, # env frames observed by the agent max_episode_frames_eval=108000, # env frames observed by the agent # Agent params epsilon_greedy=0.1, epsilon_decay_period=250000, # 1M collect steps / update_frequency batch_size=32, learning_rate=0.00025, n_step_update=1, gamma=0.99, target_update_tau=1.0, target_update_period=2500, # 10k collect steps / update_frequency reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=250000, eval_interval=1000, eval_episodes=30, debug_summaries=True): """Trains and evaluates DQN.""" collect_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_collect, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) eval_env = suite_atari.load( env_name, max_episode_steps=max_episode_frames_eval, gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 epsilon = tf.compat.v1.train.polynomial_decay( 1.0, train_step, epsilon_decay_period, end_learning_rate=epsilon_greedy) agent = dqn_agent.DqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=create_q_network(num_actions), epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.95, epsilon=0.01, centered=True), td_errors_loss_fn=common.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step, debug_summaries=debug_summaries) table_name = 'uniform_table' table = reverb.Table( table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=update_frequency, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()