def _make_batched_mock_gym_py_environment(self, multithreading, num_envs=3): self.time_step_spec = ts.time_step_spec(self.observation_spec) constructor = functools.partial(GymWrapperEnvironmentMock, self.observation_spec, self.action_spec) return batched_py_environment.BatchedPyEnvironment( envs=[constructor() for _ in range(num_envs)], multithreading=multithreading)
def testBatchedPyEnvCompatible(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') actor_net = actor_network.ActorNetwork( self._observation_tensor_spec, self._action_tensor_spec, fc_layer_params=(10, ), ) tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec, self._action_tensor_spec, actor_network=actor_net) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy, batch_time_steps=False) env_ctr = lambda: random_py_environment.RandomPyEnvironment( # pylint: disable=g-long-lambda self._observation_spec, self._action_spec) env = batched_py_environment.BatchedPyEnvironment( [env_ctr() for _ in range(3)]) time_step = env.reset() for _ in range(20): action_step = py_policy.action(time_step) time_step = env.step(action_step.action)
def testTimeStepSpec(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment( [py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) spec = tf_env.time_step_spec() # step_type self.assertEqual(type(spec.step_type), specs.TensorSpec) self.assertEqual(spec.step_type.dtype, tf.int32) self.assertEqual(spec.step_type.shape, tf.TensorShape([])) # reward self.assertEqual(type(spec.reward), specs.TensorSpec) self.assertEqual(spec.reward.dtype, tf.float32) self.assertEqual(spec.reward.shape, tf.TensorShape([])) # discount self.assertEqual(type(spec.discount), specs.BoundedTensorSpec) self.assertEqual(spec.discount.dtype, tf.float32) self.assertEqual(spec.discount.shape, tf.TensorShape([])) self.assertEqual(spec.discount.minimum, 0.0) self.assertEqual(spec.discount.maximum, 1.0) # observation self.assertEqual(type(spec.observation), specs.TensorSpec)
def testBatchedEnvironment(self, max_steps, max_episodes, expected_length): expected_trajectories = [ trajectory.Trajectory( step_type=np.array([0, 0]), observation=np.array([0, 0]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([1, 1]), reward=np.array([1., 1.]), discount=np.array([1., 1.])), trajectory.Trajectory( step_type=np.array([1, 1]), observation=np.array([2, 1]), action=np.array([1, 2]), policy_info=np.array([2, 4]), next_step_type=np.array([2, 1]), reward=np.array([1., 1.]), discount=np.array([0., 1.])), trajectory.Trajectory( step_type=np.array([2, 1]), observation=np.array([3, 3]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([0, 2]), reward=np.array([0., 1.]), discount=np.array([1., 0.])) ] env1 = driver_test_utils.PyEnvironmentMock(final_state=3) env2 = driver_test_utils.PyEnvironmentMock(final_state=4) env = batched_py_environment.BatchedPyEnvironment([env1, env2]) tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock( tf_env.time_step_spec(), tf_env.action_spec(), batch_size=2, initial_policy_state=tf.constant([1, 2], dtype=tf.int32)) replay_buffer_observer = MockReplayBufferObserver() driver = tf_driver.TFDriver( tf_env, policy, observers=[replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, ) initial_time_step = tf_env.reset() initial_policy_state = tf.constant([1, 2], dtype=tf.int32) self.evaluate(driver.run(initial_time_step, initial_policy_state)) trajectories = replay_buffer_observer.gather_all() self.assertEqual( len(trajectories), len(expected_trajectories[:expected_length])) for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]): for t1_field, t2_field in zip(t1, t2): self.assertAllEqual(t1_field, t2_field)
def testBatchedFirstTimeStepAndOneStep(self): py_envs = [PYEnvironmentMock() for _ in range(3)] batched_py_env = batched_py_environment.BatchedPyEnvironment(py_envs) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) self.assertEqual(tf_env.batch_size, 3) time_step_0 = tf_env.current_time_step() time_step_0_val = self.evaluate(time_step_0) self.assertAllEqual([ts.StepType.FIRST] * 3, time_step_0_val.step_type) self.assertAllEqual([0.0] * 3, time_step_0_val.reward) self.assertAllEqual([1.0] * 3, time_step_0_val.discount) self.assertAllEqual(np.array([0, 0, 0]), time_step_0_val.observation) for py_env in py_envs: self.assertEqual([], py_env.actions_taken) self.assertEqual(1, py_env.resets) self.assertEqual(0, py_env.steps) self.assertEqual(0, py_env.episodes) time_step_1 = tf_env.step(np.array([1, 1, 1])) time_step_1_val = self.evaluate(time_step_1) self.assertAllEqual([ts.StepType.MID] * 3, time_step_1_val.step_type) self.assertAllEqual([0.] * 3, time_step_1_val.reward) self.assertAllEqual([1.0] * 3, time_step_1_val.discount) self.assertAllEqual(np.array([1, 1, 1]), time_step_1_val.observation) for py_env in py_envs: self.assertEqual([1], py_env.actions_taken) self.assertEqual(1, py_env.resets) self.assertEqual(1, py_env.steps) self.assertEqual(0, py_env.episodes)
def _make_batched_py_environment(self, num_envs=3): self.time_step_spec = ts.time_step_spec(self.observation_spec) constructor = functools.partial( random_py_environment.RandomPyEnvironment, self.observation_spec, self.action_spec) return batched_py_environment.BatchedPyEnvironment( envs=[constructor() for _ in range(num_envs)])
def __init__(self, environment): """Initializes a new `TFPyEnvironment`. Args: environment: Environment to interact with, implementing `py_environment.PyEnvironment`. Raises: TypeError: If `environment` is not a subclass of `py_environment.PyEnvironment`. """ if not isinstance(environment, py_environment.PyEnvironment): raise TypeError( 'Environment should implement py_environment.PyEnvironment') if not environment.batched: environment = batched_py_environment.BatchedPyEnvironment( [environment]) self._env = environment observation_spec = tensor_spec.from_spec(self._env.observation_spec()) action_spec = tensor_spec.from_spec(self._env.action_spec()) time_step_spec = ts.time_step_spec(observation_spec) batch_size = self._env.batch_size if self._env.batch_size else 1 super(TFPyEnvironment, self).__init__(time_step_spec, action_spec, batch_size) # Gather all the dtypes of the elements in time_step. self._time_step_dtypes = [ s.dtype for s in tf.nest.flatten(self.time_step_spec()) ] self._time_step = None self._lock = threading.Lock()
def _create_env(): if batch_size is None: py_env = PYEnvironmentMock() else: py_env = [PYEnvironmentMock() for _ in range(batch_size)] if batch_py_env: py_env = batched_py_environment.BatchedPyEnvironment( py_env if isinstance(py_env, list) else [py_env]) return py_env
def testResetOp(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) reset = tf_env.reset() self.evaluate(reset) self.assertEqual(1, py_env.resets) self.assertEqual(0, py_env.steps) self.assertEqual(0, py_env.episodes)
def testActionSpec(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(py_env) self.assertTrue(tf_env.batched) self.assertEqual(tf_env.batch_size, 1) spec = tf_env.action_spec() self.assertEqual(type(spec), specs.BoundedTensorSpec) self.assertEqual(spec.dtype, tf.int32) self.assertEqual(spec.shape, tf.TensorShape([])) self.assertEqual(spec.name, 'action')
def testMultipleReset(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) reset = tf_env.reset() self.evaluate(reset) self.assertEqual(1, py_env.resets) self.evaluate(reset) self.assertEqual(2, py_env.resets) self.evaluate(reset) self.assertEqual(3, py_env.resets)
def testFirstObservationIsPreservedAfterTwoSteps(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) time_step = tf_env.current_time_step() with tf.control_dependencies([time_step.step_type]): action = tf.constant([1]) next_time_step = tf_env.step(action) with tf.control_dependencies([next_time_step.step_type]): action = tf.constant([2]) _, observation = self.evaluate([tf_env.step(action), time_step.observation]) self.assertEqual(np.array([0]), observation)
def testFirstTimeStep(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) time_step = tf_env.current_time_step() time_step = self.evaluate(time_step) self.assertAllEqual([ts.StepType.FIRST], time_step.step_type) self.assertAllEqual([0.0], time_step.reward) self.assertAllEqual([1.0], time_step.discount) self.assertAllEqual([0], time_step.observation) self.assertAllEqual([], py_env.actions_taken) self.assertEqual(1, py_env.resets) self.assertEqual(0, py_env.steps) self.assertEqual(0, py_env.episodes)
def testMultipleReset(self, batch_py_env): if tf.executing_eagerly(): self.skipTest('b/123881757') py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) reset = tf_env.reset() self.evaluate(reset) self.assertEqual(1, py_env.resets) self.evaluate(reset) self.assertEqual(2, py_env.resets) self.evaluate(reset) self.assertEqual(3, py_env.resets)
def __init__(self, flags): """Initialize runner.""" self.num_episodes = flags['num_episodes'] self.environment = batched_py_environment.BatchedPyEnvironment( [rl_env.make('Hanabi-Full', num_players=flags['players'])]) self.agent_config = { 'max_information_tokens': self.environment.envs[0].game.max_information_tokens(), 'action_spec': self.environment.action_spec(), 'observation_spec': self.environment.observation_spec(), 'environment_batch_size': self.environment.batch_size } self.agent_1 = AGENT_CLASSES[flags['agent_1']](self.agent_config) self.agent_2 = AGENT_CLASSES[flags['agent_2']](self.agent_config)
def testOneStep(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) time_step = tf_env.current_time_step() with tf.control_dependencies([time_step.step_type]): action = tf.constant([1]) time_step = self.evaluate(tf_env.step(action)) self.assertAllEqual([ts.StepType.MID], time_step.step_type) self.assertAllEqual([0.], time_step.reward) self.assertAllEqual([1.0], time_step.discount) self.assertAllEqual([1], time_step.observation) self.assertAllEqual([1], py_env.actions_taken) self.assertEqual(1, py_env.resets) self.assertEqual(1, py_env.steps) self.assertEqual(0, py_env.episodes)
def testTwoStepsDependenceOnTheFirst(self, batch_py_env): py_env = PYEnvironmentMock() if batch_py_env: batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env]) tf_env = tf_py_environment.TFPyEnvironment(batched_py_env) else: tf_env = tf_py_environment.TFPyEnvironment(py_env) time_step = tf_env.current_time_step() with tf.control_dependencies([time_step.step_type]): action = tf.constant([1]) time_step = tf_env.step(action) with tf.control_dependencies([time_step.step_type]): action = tf.constant([2]) time_step = self.evaluate(tf_env.step(action)) self.assertEqual(ts.StepType.LAST, time_step.step_type) self.assertEqual([2], time_step.observation) self.assertEqual(1., time_step.reward) self.assertEqual(0., time_step.discount) self.assertEqual([1, 2], py_env.actions_taken)
def test_metric_results_equal_with_batched_env(self): env_ctor = lambda: random_py_environment.RandomPyEnvironment( # pylint: disable=g-long-lambda self._time_step_spec.observation, self._action_spec) batch_size = 5 env = batched_py_environment.BatchedPyEnvironment( [env_ctor() for _ in range(batch_size)]) tf_env = tf_py_environment.TFPyEnvironment(env) python_metrics, tensorflow_metrics = self._build_metrics( batch_size=batch_size) observers = python_metrics + tensorflow_metrics driver = dynamic_step_driver.DynamicStepDriver( tf_env, self._policy, observers=observers, num_steps=1000) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(driver.run()) for python_metric, tensorflow_metric in zip(python_metrics, tensorflow_metrics): python_result = self.evaluate(python_metric.result()) tensorflow_result = self.evaluate(tensorflow_metric.result()) self.assertEqual(python_result, tensorflow_result)
def create_envs(env_name, use_multiprocessing, num_parallel_envs, visualize_eval=False, mock_train_envs=False): def env_load_fn(env_map_name, visualize=False, mock=False): env = gym_wrapper.GymWrapper( gym_env=SC2GymEnv(map_name=env_map_name, visualize=visualize, mock=mock), spec_dtype_map={ gym.spaces.Box: np.float32, gym.spaces.Discrete: np.int32, gym.spaces.MultiBinary: np.float32 }, ) return env if num_parallel_envs == 1: par_env = env_load_fn(env_map_name=env_name, mock=mock_train_envs) elif use_multiprocessing: par_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_map_name=env_name, mock=mock_train_envs) ] * num_parallel_envs, start_serially=False) else: par_env = batched_py_environment.BatchedPyEnvironment(envs=[ env_load_fn(env_map_name=env_name, mock=mock_train_envs) for _ in range(num_parallel_envs) ]) tf_env = tf_py_environment.TFPyEnvironment(par_env) tf_env.reset() eval_env = env_load_fn(env_name, visualize=visualize_eval) eval_env = tf_py_environment.TFPyEnvironment(eval_env) eval_env.reset() return tf_env, eval_env
def __init__(self, environment, check_dims=False, isolation=False): """Initializes a new `TFPyEnvironment`. Args: environment: Environment to interact with, implementing `py_environment.PyEnvironment`. Or a `callable` that returns an environment of this form. If a `callable` is provided and `thread_isolation` is provided, the callable is executed in the dedicated thread. check_dims: Whether should check batch dimensions of actions in `step`. isolation: If this value is `False` (default), interactions with the environment will occur within whatever thread the methods of the `TFPyEnvironment` are run from. For example, in TF graph mode, methods like `step` are called from multiple threads created by the TensorFlow engine; calls to step the environment are guaranteed to be sequential, but not from the same thread. This creates problems for environments that are not thread-safe. Using isolation ensures not only that a dedicated thread (or thread-pool) is used to interact with the environment, but also that interaction with the environment happens in a serialized manner. If `isolation == True`, a dedicated thread is created for interactions with the environment. If `isolation` is an instance of `multiprocessing.pool.Pool` (this includes instances of `multiprocessing.pool.ThreadPool`, nee `multiprocessing.dummy.Pool` and `multiprocessing.Pool`, then this pool is used to interact with the environment. **NOTE** If using `isolation` with a `BatchedPyEnvironment`, ensure you create the `BatchedPyEnvironment` with `multithreading=False`, since otherwise the multithreading in that wrapper reverses the effects of this one. Raises: TypeError: If `environment` is not an instance of `py_environment.PyEnvironment` or subclasses, or is a callable that does not return an instance of `PyEnvironment`. TypeError: If `isolation` is not `True`, `False`, or an instance of `multiprocessing.pool.Pool`. """ if not isolation: self._pool = None elif isinstance(isolation, pool.Pool): self._pool = isolation elif isolation: self._pool = pool.ThreadPool(1) else: raise TypeError( 'isolation should be True, False, or an instance of ' 'a multiprocessing Pool or ThreadPool. Saw: {}'.format( isolation)) if callable(environment): environment = self._execute(environment) if not isinstance(environment, py_environment.PyEnvironment): raise TypeError( 'Environment should implement py_environment.PyEnvironment') if not environment.batched: # If executing in an isolated thread, do not enable multiprocessing for # this environment. environment = batched_py_environment.BatchedPyEnvironment( [environment], multithreading=not self._pool) self._env = environment self._check_dims = check_dims if isolation and getattr(self._env, '_parallel_execution', None): logging.warn( 'Wrapped environment is executing in parallel. ' 'Perhaps it is a BatchedPyEnvironment with multithreading=True, ' 'or it is a ParallelPyEnvironment. This conflicts with the ' '`isolation` arg passed to TFPyEnvironment: interactions with the ' 'wrapped environment are no longer guaranteed to happen in a common ' 'thread. Environment: %s', (self._env, )) observation_spec = tensor_spec.from_spec(self._env.observation_spec()) action_spec = tensor_spec.from_spec(self._env.action_spec()) time_step_spec = ts.time_step_spec(observation_spec) batch_size = self._env.batch_size if self._env.batch_size else 1 super(TFPyEnvironment, self).__init__(time_step_spec, action_spec, batch_size) # Gather all the dtypes of the elements in time_step. self._time_step_dtypes = [ s.dtype for s in tf.nest.flatten(self.time_step_spec()) ] self._time_step = None self._lock = threading.Lock()
def __init__( self, root_dir, env_name, num_iterations=200, max_episode_frames=108000, # ALE frames terminal_on_life_loss=False, conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), fc_layer_params=(512, ), # Params for collect initial_collect_steps=80000, # ALE frames epsilon_greedy=0.01, epsilon_decay_period=1000000, # ALE frames replay_buffer_capacity=1000000, # Params for train train_steps_per_iteration=1000000, # ALE frames update_period=16, # ALE frames target_update_tau=1.0, target_update_period=32000, # ALE frames batch_size=32, learning_rate=2.5e-4, n_step_update=2, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval do_eval=True, eval_steps_per_iteration=500000, # ALE frames eval_epsilon_greedy=0.001, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None): """A simple Atari train and eval for DQN. Args: root_dir: Directory to write log files to. env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0). num_iterations: Number of train/eval iterations to run. max_episode_frames: Maximum length of a single episode, in ALE frames. terminal_on_life_loss: Whether to simulate an episode termination when a life is lost. conv_layer_params: Params for convolutional layers of QNetwork. fc_layer_params: Params for fully connected layers of QNetwork. initial_collect_steps: Number of frames to ALE frames to process before beginning to train. Since this is in ALE frames, there will be initial_collect_steps/4 items in the replay buffer when training starts. epsilon_greedy: Final epsilon value to decay to for training. epsilon_decay_period: Period over which to decay epsilon, from 1.0 to epsilon_greedy (defined above). replay_buffer_capacity: Maximum number of items to store in the replay buffer. train_steps_per_iteration: Number of ALE frames to run through for each iteration of training. update_period: Run a train operation every update_period ALE frames. target_update_tau: Coeffecient for soft target network updates (1.0 == hard updates). target_update_period: Period, in ALE frames, to copy the live network to the target network. batch_size: Number of frames to include in each training batch. learning_rate: RMS optimizer learning rate. n_step_update: The number of steps to consider when computing TD error and TD loss. Applies standard single-step updates when set to 1. gamma: Discount for future rewards. reward_scale_factor: Scaling factor for rewards. gradient_clipping: Norm length to clip gradients. do_eval: If True, run an eval every iteration. If False, skip eval. eval_steps_per_iteration: Number of ALE frames to run through for each iteration of evaluation. eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 == totally greedy policy). log_interval: Log stats to the terminal every log_interval training steps. summary_interval: Write TF summaries every summary_interval training steps. summaries_flush_secs: Flush summaries to disk every summaries_flush_secs seconds. debug_summaries: If True, write additional summaries for debugging (see dqn_agent for which summaries are written). summarize_grads_and_vars: Include gradients in summaries. eval_metrics_callback: A callback function that takes (metric_dict, global_step) as parameters. Called after every eval with the results of the evaluation. """ self._update_period = update_period / ATARI_FRAME_SKIP self._train_steps_per_iteration = (train_steps_per_iteration / ATARI_FRAME_SKIP) self._do_eval = do_eval self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP self._eval_epsilon_greedy = eval_epsilon_greedy self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP self._summary_interval = summary_interval self._num_iterations = num_iterations self._log_interval = log_interval self._eval_metrics_callback = eval_metrics_callback with gin.unlock_config(): gin.bind_parameter(('tf_agents.environments.atari_preprocessing.' 'AtariPreprocessing.terminal_on_life_loss'), terminal_on_life_loss) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() self._train_summary_writer = train_summary_writer self._eval_summary_writer = None if self._do_eval: self._eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(lambda: tf.math.equal( self._global_step % self._summary_interval, 0)): self._env = suite_atari.load( env_name, max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) self._env = batched_py_environment.BatchedPyEnvironment( [self._env]) observation_spec = tensor_spec.from_spec( self._env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.from_spec(self._env.action_spec()) with tf.device('/cpu:0'): epsilon = tf.compat.v1.train.polynomial_decay( 1.0, self._global_step, epsilon_decay_period / ATARI_FRAME_SKIP / self._update_period, end_learning_rate=epsilon_greedy) with tf.device('/gpu:0'): optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) categorical_q_net = AtariCategoricalQNetwork( observation_spec, action_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) agent = categorical_dqn_agent.CategoricalDqnAgent( time_step_spec, action_spec, categorical_q_network=categorical_q_net, optimizer=optimizer, epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=(target_update_period / ATARI_FRAME_SKIP / self._update_period), gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=self._global_step) self._collect_policy = py_tf_policy.PyTFPolicy( agent.collect_policy) if self._do_eval: self._eval_policy = py_tf_policy.PyTFPolicy( epsilon_greedy_policy.EpsilonGreedyPolicy( policy=agent.policy, epsilon=self._eval_epsilon_greedy)) py_observation_spec = self._env.observation_spec() py_time_step_spec = ts.time_step_spec(py_observation_spec) py_action_spec = policy_step.PolicyStep( self._env.action_spec()) data_spec = trajectory.from_transition(py_time_step_spec, py_action_spec, py_time_step_spec) self._replay_buffer = py_hashed_replay_buffer.PyHashedReplayBuffer( data_spec=data_spec, capacity=replay_buffer_capacity) with tf.device('/cpu:0'): ds = self._replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) ds = ds.apply( tf.data.experimental.prefetch_to_device('/gpu:0')) with tf.device('/gpu:0'): self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds) experience = self._ds_itr.get_next() self._train_op = agent.train(experience) self._env_steps_metric = py_metrics.EnvironmentSteps() self._step_metrics = [ py_metrics.NumberOfEpisodes(), self._env_steps_metric, ] self._train_metrics = self._step_metrics + [ py_metrics.AverageReturnMetric(buffer_size=10), py_metrics.AverageEpisodeLengthMetric(buffer_size=10), ] # The _train_phase_metrics average over an entire train iteration, # rather than the rolling average of the last 10 episodes. self._train_phase_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._iteration_metric = py_metrics.CounterMetric( name='Iteration') # Summaries written from python should run every time they are # generated. with tf.compat.v2.summary.record_if(True): self._steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') self._steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=self._steps_per_second_ph, step=self._global_step) for metric in self._train_metrics: metric.tf_summaries(train_step=self._global_step, step_metrics=self._step_metrics) for metric in self._train_phase_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._iteration_metric.tf_summaries( train_step=self._global_step) if self._do_eval: with self._eval_summary_writer.as_default(): for metric in self._eval_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=agent, global_step=self._global_step, optimizer=optimizer, metrics=metric_utils.MetricsGroup( self._train_metrics + self._train_phase_metrics + [self._iteration_metric], 'train_metrics')) self._policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=agent.policy, global_step=self._global_step) self._rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._init_agent_op = agent.initialize()
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, log_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # Note this is a python environment. env = batched_py_environment.BatchedPyEnvironment( [suite_gym.load(env_name)]) eval_py_env = suite_gym.load(env_name) # Convert specs to BoundedTensorSpec. action_spec = tensor_spec.from_spec(env.action_spec()) observation_spec = tensor_spec.from_spec(env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()), tensor_spec.from_spec(env.action_spec()), fc_layer_params=fc_layer_params) # The agent must be in graph. global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_collect_policy = agent.collect_policy collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy) greedy_policy = py_tf_policy.PyTFPolicy(agent.policy) random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) # Python replay buffer. replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec)) time_step = env.reset() # Initialize the replay buffer with some transitions. We use the random # policy to initialize the replay buffer to make sure we get a good # distribution of actions. for _ in range(initial_collect_steps): time_step = collect_step(env, time_step, random_policy, replay_buffer) # TODO(b/112041045) Use global_step as counter. train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=agent.policy, global_step=global_step) ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) itr = tf.compat.v1.data.make_initializable_iterator(ds) experience = itr.get_next() train_op = common.function(agent.train)(experience) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) with tf.compat.v1.Session() as session: train_checkpointer.initialize_or_restore(session) common.initialize_uninitialized_variables(session) session.run(itr.initializer) # Copy critic network values to the target critic network. session.run(agent.initialize()) train = session.make_callable(train_op) global_step_call = session.make_callable(global_step) session.run(train_summary_writer.init()) session.run(eval_summary_writer.init()) # Compute initial evaluation metrics. global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) timed_at_step = global_step_val collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() for _ in range(collect_steps_per_iteration): time_step = collect_step(env, time_step, collect_policy, replay_buffer) collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss = train() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) session.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) # Reset timing to avoid counting eval time. timed_at_step = global_step_val start_time = time.time()
def testMultiStepEpisodicReplayBuffer(self): num_episodes = 5 num_driver_episodes = 5 # Create mock environment. py_env = batched_py_environment.BatchedPyEnvironment([ driver_test_utils.PyEnvironmentMock(final_state=i + 1) for i in range(num_episodes) ]) env = tf_py_environment.TFPyEnvironment(py_env) # Creat mock policy. policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec(), batch_size=num_episodes) # Create replay buffer and driver. replay_buffer = self._make_replay_buffer(env) stateful_buffer = episodic_replay_buffer.StatefulEpisodicReplayBuffer( replay_buffer, num_episodes) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=num_driver_episodes, observers=[stateful_buffer.add_batch]) run_driver = driver.run() end_episodes = replay_buffer._maybe_end_batch_episodes( stateful_buffer.episode_ids, end_episode=True) completed_episodes = replay_buffer._completed_episodes() self.evaluate([ tf.compat.v1.local_variables_initializer(), tf.compat.v1.global_variables_initializer() ]) self.evaluate(run_driver) self.evaluate(end_episodes) completed_episodes = self.evaluate(completed_episodes) eps = [replay_buffer._get_episode(ep) for ep in completed_episodes] eps = self.evaluate(eps) episodes_length = [tf.nest.flatten(ep)[0].shape[0] for ep in eps] # Compare with expected output. self.assertAllEqual(completed_episodes, [3, 4, 5, 6, 7]) self.assertAllEqual(episodes_length, [4, 4, 2, 1, 1]) first = ts.StepType.FIRST mid = ts.StepType.MID last = ts.StepType.LAST step_types = [ep.step_type for ep in eps] observations = [ep.observation for ep in eps] rewards = [ep.reward for ep in eps] actions = [ep.action for ep in eps] self.assertAllClose([[first, mid, mid, last], [first, mid, mid, mid], [first, last], [first], [first]], step_types) self.assertAllClose([ [0, 1, 3, 4], [0, 1, 3, 4], [0, 1], [0], [0], ], observations) self.assertAllClose([ [1, 2, 1, 2], [1, 2, 1, 2], [1, 2], [1], [1], ], actions) self.assertAllClose([ [1, 1, 1, 0], [1, 1, 1, 1], [1, 0], [1], [1], ], rewards)