Exemplo n.º 1
0
 def testConstructorWithOutOfBoundsDiscountFactor(self):
   exception_string = r'Discount factor \(gamma\) must be in \[0, 1\]\.'
   with self.assertRaisesRegexp(ValueError, exception_string):
     circular_replay_buffer.WrappedReplayBuffer(
         observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, gamma=-1)
   with self.assertRaisesRegexp(ValueError, exception_string):
     circular_replay_buffer.WrappedReplayBuffer(
         observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, gamma=1.1)
Exemplo n.º 2
0
 def testCustomObsDataType(self):
   # Tests that observation store is initialized with the correct data type
   # when an observation_dtype argument is passed to the constructor.
   replay = circular_replay_buffer.WrappedReplayBuffer(
       observation_shape=OBSERVATION_SHAPE,
       stack_size=STACK_SIZE, replay_capacity=10, observation_dtype=np.int32)
   self.assertEqual(replay.memory._store['observation'].dtype, np.int32)
Exemplo n.º 3
0
 def testDefaultObsDataType(self):
   # Tests that the default data type for observations is np.uint8 for
   # integration with Atari 2600.
   replay = circular_replay_buffer.WrappedReplayBuffer(
       observation_shape=OBSERVATION_SHAPE,
       stack_size=STACK_SIZE, replay_capacity=10)
   self.assertEqual(replay.memory._store['observation'].dtype, np.uint8)
Exemplo n.º 4
0
 def testWrapperLoad(self):
   replay = circular_replay_buffer.WrappedReplayBuffer(
       observation_shape=OBSERVATION_SHAPE,
       stack_size=STACK_SIZE,
       replay_capacity=5,
       batch_size=BATCH_SIZE)
   self.assertNotEqual(replay.memory._store['observation'],
                       self._test_observation)
   self.assertNotEqual(replay.memory._store['action'], self._test_action)
   self.assertNotEqual(replay.memory._store['reward'], self._test_reward)
   self.assertNotEqual(replay.memory._store['terminal'], self._test_terminal)
   self.assertNotEqual(replay.memory.add_count, self._test_add_count)
   self.assertNotEqual(replay.memory.invalid_range, self._test_invalid_range)
   store_prefix = '$store$_'
   numpy_arrays = {
       store_prefix + 'observation': self._test_observation,
       store_prefix + 'action': self._test_action,
       store_prefix + 'reward': self._test_reward,
       store_prefix + 'terminal': self._test_terminal,
       'add_count': self._test_add_count,
       'invalid_range': self._test_invalid_range
   }
   for attr in numpy_arrays:
     filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr))
     with tf.gfile.Open(filename, 'w') as f:
       with gzip.GzipFile(fileobj=f) as outfile:
         np.save(outfile, numpy_arrays[attr], allow_pickle=False)
   replay.load(self._test_subdir, '3')
   self.assertAllClose(replay.memory._store['observation'],
                       self._test_observation)
   self.assertAllClose(replay.memory._store['action'], self._test_action)
   self.assertAllClose(replay.memory._store['reward'], self._test_reward)
   self.assertAllClose(replay.memory._store['terminal'], self._test_terminal)
   self.assertEqual(replay.memory.add_count, self._test_add_count)
   self.assertAllClose(replay.memory.invalid_range, self._test_invalid_range)
 def testConstructorWithZeroUpdateHorizon(self):
     with self.assertRaisesRegexp(ValueError,
                                  r'Update horizon must be positive\.'):
         circular_replay_buffer.WrappedReplayBuffer(
             observation_shape=OBSERVATION_SHAPE,
             stack_size=STACK_SIZE,
             update_horizon=0)
Exemplo n.º 6
0
    def _build_replay_buffer(self, use_staging, use_cyclic_buffer):
        """Creates the replay buffer used by the agent.

        Args:
          use_staging: bool, if True, uses a staging area to prefetch data for
            faster training.

        Returns:
          A WrapperReplayBuffer object.
        """
        return circular_replay_buffer.WrappedReplayBuffer(
            observation_shape=self.observation_shape,
            stack_size=self.stack_size,
            num_actions=self.num_actions,
            use_cyclic_buffer=use_cyclic_buffer,
            use_staging=use_staging,
            update_horizon=self.update_horizon,
            gamma=self.gamma,
            observation_dtype=self.observation_dtype.as_numpy_dtype,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('features',
                                                     self._features_shape,
                                                     np.float32),
                circular_replay_buffer.ReplayElement(
                    'state', self.buffer_state_shape,
                    self.observation_dtype.as_numpy_dtype),
                circular_replay_buffer.ReplayElement(
                    'next_state', self.buffer_state_shape,
                    self.observation_dtype.as_numpy_dtype),
                circular_replay_buffer.ReplayElement('next_action', (),
                                                     np.int32),
                circular_replay_buffer.ReplayElement('next_reward', (),
                                                     np.float32)
            ])
 def testConstructorWithExtraStorageTypes(self):
     circular_replay_buffer.WrappedReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         extra_storage_types=[
             circular_replay_buffer.ReplayElement('extra1', [], np.float32),
             circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
         ])
 def testConstructorCapacityNotLargeEnough(self):
     with self.assertRaisesRegexp(
             ValueError, r'Update horizon \(5\) should be significantly '
             r'smaller than replay capacity \(5\)\.'):
         circular_replay_buffer.WrappedReplayBuffer(
             observation_shape=OBSERVATION_SHAPE,
             stack_size=STACK_SIZE,
             replay_capacity=5,
             update_horizon=5)
 def testConstructorWithNoStaging(self):
     replay = circular_replay_buffer.WrappedReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=100,
         batch_size=BATCH_SIZE,
         use_staging=False)
     with self.test_session() as sess:
         for i in range(BATCH_SIZE * 2):
             observation = np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE)
             replay.add(observation, 2, 1, 0)
     self._verify_sampled_trajectories(sess.run(replay.transition))
 def testReplayBufferIsLocked(self):
     """Tests that the is properly checked."""
     replay = circular_replay_buffer.WrappedReplayBuffer(
         observation_shape=(2, ),
         stack_size=1,
         replay_capacity=10,
         batch_size=2)
     lock = mock.MagicMock()
     replay.memory._lock = lock
     # Add one element.
     replay.add((1, 2), 0, 0, False)
     # Check that the lock went through the proper lock/unlock process.
     lock.__enter__.assert_called_once()
     lock.__exit__.assert_called_once()
Exemplo n.º 11
0
    def _build_replay_buffer(self, use_staging):
        """Creates the replay buffer used by the agent.

    Args:
      use_staging: bool, if True, uses a staging area to prefetch data for
        faster training.

    Returns:
      A WrapperReplayBuffer object.
    """
        return circular_replay_buffer.WrappedReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            use_staging=use_staging,
            update_horizon=self.update_horizon,
            gamma=self.gamma)
Exemplo n.º 12
0
    def _build_replay_buffer(self, use_staging):
        """Creates the replay buffer used by the agent.

    Args:
      use_staging: bool, if True, uses a staging area to prefetch data for
        faster training.

    Returns:
      A WrapperReplayBuffer object.
    """
        return circular_replay_buffer.WrappedReplayBuffer(
            observation_shape=self.observation_shape,
            stack_size=self.stack_size,
            use_staging=use_staging,
            update_horizon=self.update_horizon,
            gamma=self.gamma,
            observation_dtype=self.observation_dtype.as_numpy_dtype)
Exemplo n.º 13
0
 def testWrapperSave(self):
   replay = circular_replay_buffer.WrappedReplayBuffer(
       observation_shape=OBSERVATION_SHAPE,
       stack_size=STACK_SIZE,
       replay_capacity=5,
       batch_size=BATCH_SIZE)
   replay.memory.observation = self._test_observation
   replay.memory.action = self._test_action
   replay.memory.reward = self._test_reward
   replay.memory.terminal = self._test_terminal
   replay.memory.add_count = self._test_add_count
   replay.memory.invalid_range = self._test_invalid_range
   replay.save(self._test_subdir, 3)
   for attr in replay.memory.__dict__:
     if attr.startswith('_'):
       continue
     filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr))
     self.assertTrue(tf.gfile.Exists(filename))
 def testConstructorWithStaging(self):
     replay = circular_replay_buffer.WrappedReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=100,
         batch_size=BATCH_SIZE,
         use_staging=True)
     # When staging is on, replay._prefetch_batch tries to prefetch transitions
     # for efficient sampling. Since no transitions have been added, this raises
     # an error.
     with self.assertRaisesOpError(
             'Cannot sample a batch with fewer than stack size'):
         self.evaluate(replay._prefetch_batch)
     with self.test_session() as sess:
         for i in range(BATCH_SIZE * 2):
             observation = np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE)
             replay.add(observation, 2, 1, 0)
         sess.run(replay._prefetch_batch)
     self._verify_sampled_trajectories(sess.run(replay.transition))
    def _build_replay_buffer(self, use_staging):
        """Build WrappedReplayBuffer with custom OutOfGraphReplayBuffer."""
        replay_buffer_kwargs = dict(
            observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE,
            stack_size=dqn_agent.NATURE_DQN_STACK_SIZE,
            replay_capacity=self._replay_capacity,
            batch_size=self._batch_size,
            update_horizon=self.update_horizon,
            gamma=self.gamma,
            extra_storage_types=None,
            observation_dtype=np.uint8,
        )
        replay_memory = _OutOfGraphReplayBuffer(
            artificial_done=not self._generates_trainable_dones,
            **replay_buffer_kwargs)

        return circular_replay_buffer.WrappedReplayBuffer(
            wrapped_memory=replay_memory,
            use_staging=use_staging,
            **replay_buffer_kwargs)
Exemplo n.º 16
0
def wrapped_replay_buffer(**kwargs):
    return circular_replay_buffer.WrappedReplayBuffer(**kwargs)
Exemplo n.º 17
0
    def __init__(self,
                 sess,
                 num_actions,
                 observation_shape=atari_lib.NATURE_DQN_OBSERVATION_SHAPE,
                 observation_dtype=atari_lib.NATURE_DQN_DTYPE,
                 stack_size=atari_lib.NATURE_DQN_STACK_SIZE,
                 network=atari_lib.nature_dqn_network,
                 gamma=0.99,
                 update_horizon=1,
                 min_replay_history=20000,
                 update_period=4,
                 target_update_period=8000,
                 exploration_mode='epsilon-greedy',
                 entropy_fn=tsallis_entropy,
                 entropic_index=0.5,
                 alpha=0.1,
                 epsilon_fn=linearly_decaying_epsilon,
                 epsilon_train=0.01,
                 epsilon_eval=0.001,
                 epsilon_decay_period=250000,
                 tf_device='/cpu:*',
                 eval_mode=False,
                 use_staging=True,
                 max_tf_checkpoints_to_keep=4,
                 optimizer=tf.train.RMSPropOptimizer(learning_rate=0.00025,
                                                     decay=0.95,
                                                     momentum=0.0,
                                                     epsilon=0.00001,
                                                     centered=True),
                 summary_writer=None,
                 summary_writing_frequency=500,
                 allow_partial_reload=False):
        """Initializes the agent and constructs the components of its graph.

    Args:
      sess: `tf.Session`, for executing ops.
      num_actions: int, number of actions the agent can take at any state.
      observation_shape: tuple of ints describing the observation shape.
      observation_dtype: tf.DType, specifies the type of the observations. Note
        that if your inputs are continuous, you should set this to tf.float32.
      stack_size: int, number of frames to use in state stack.
      network: function expecting three parameters:
        (num_actions, network_type, state). This function will return the
        network_type object containing the tensors output by the network.
        See dopamine.discrete_domains.atari_lib.nature_dqn_network as
        an example.
      gamma: float, discount factor with the usual RL meaning.
      update_horizon: int, horizon at which updates are performed, the 'n' in
        n-step update.
      min_replay_history: int, number of transitions that should be experienced
        before the agent begins training its value function.
      update_period: int, period between DQN updates.
      target_update_period: int, update period for the target network.
      epsilon_fn: function expecting 4 parameters:
        (decay_period, step, warmup_steps, epsilon). This function should return
        the epsilon value used for exploration during training.
      epsilon_train: float, the value to which the agent's epsilon is eventually
        decayed during training.
      epsilon_decay_period: int, length of the epsilon decay schedule.
      tf_device: str, Tensorflow device on which the agent's graph is executed.
      eval_mode: bool, True for evaluation and False for training.
      use_staging: bool, when True use a staging area to prefetch the next
        training batch, speeding training up by about 30%.
      max_tf_checkpoints_to_keep: int, the number of TensorFlow checkpoints to
        keep.
      optimizer: `tf.train.Optimizer`, for training the value function.
      summary_writer: SummaryWriter object for outputting training statistics.
        Summary writing disabled if set to None.
      summary_writing_frequency: int, frequency with which summaries will be
        written. Lower values will result in slower training.
      allow_partial_reload: bool, whether we allow reloading a partial agent
        (for instance, only the network parameters).
    """
        assert isinstance(observation_shape, tuple)
        tf.logging.info('Creating %s agent with the following parameters:',
                        self.__class__.__name__)
        tf.logging.info('\t gamma: %f', gamma)
        tf.logging.info('\t update_horizon: %f', update_horizon)
        tf.logging.info('\t min_replay_history: %d', min_replay_history)
        tf.logging.info('\t update_period: %d', update_period)
        tf.logging.info('\t target_update_period: %d', target_update_period)
        tf.logging.info('\t epsilon_train: %f', epsilon_train)
        tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period)
        tf.logging.info('\t tf_device: %s', tf_device)
        tf.logging.info('\t use_staging: %s', use_staging)
        tf.logging.info('\t optimizer: %s', optimizer)
        tf.logging.info('\t max_tf_checkpoints_to_keep: %d',
                        max_tf_checkpoints_to_keep)

        self.num_actions = num_actions
        self.observation_shape = tuple(observation_shape)
        self.observation_dtype = observation_dtype
        self.stack_size = stack_size
        self.network = network
        self.gamma = gamma
        self.update_horizon = update_horizon
        self.cumulative_gamma = math.pow(gamma, update_horizon)
        self.min_replay_history = min_replay_history
        self.target_update_period = target_update_period
        self.exploration_mode = exploration_mode
        self.entropy_fn = entropy_fn
        self.entropic_index = entropic_index
        self.alpha = alpha
        self.epsilon_fn = epsilon_fn
        self.epsilon_train = epsilon_train
        self.epsilon_decay_period = epsilon_decay_period
        self.update_period = update_period
        self.eval_mode = eval_mode
        self.training_steps = 0
        self.optimizer = optimizer
        self.summary_writer = summary_writer
        self.summary_writing_frequency = summary_writing_frequency
        self.allow_partial_reload = allow_partial_reload

        with tf.device(tf_device):
            # Create a placeholder for the state input to the DQN network.
            # The last axis indicates the number of consecutive frames stacked.
            state_shape = (1, ) + self.observation_shape + (stack_size, )
            self.state = np.zeros(state_shape)
            self.state_ph = tf.placeholder(self.observation_dtype,
                                           state_shape,
                                           name='state_ph')
            self._replay = circular_replay_buffer.WrappedReplayBuffer(
                observation_shape=self.observation_shape,
                stack_size=self.stack_size,
                use_staging=use_staging,
                update_horizon=self.update_horizon,
                gamma=self.gamma,
                observation_dtype=self.observation_dtype.as_numpy_dtype)

            self._build_networks()

            self._train_op = self._build_train_op()
            self._sync_qt_ops = self._build_sync_op()

        if self.summary_writer is not None:
            # All tf.summaries should have been defined prior to running this.
            self._merged_summaries = tf.summary.merge_all()
        self._sess = sess
        self._saver = tf.train.Saver(max_to_keep=max_tf_checkpoints_to_keep)

        # Variables to be initialized by the agent once it interacts with the
        # environment.
        self._observation = None
        self._last_observation = None