def testConstructorWithOutOfBoundsDiscountFactor(self): exception_string = r'Discount factor \(gamma\) must be in \[0, 1\]\.' with self.assertRaisesRegexp(ValueError, exception_string): circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, gamma=-1) with self.assertRaisesRegexp(ValueError, exception_string): circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, gamma=1.1)
def testCustomObsDataType(self): # Tests that observation store is initialized with the correct data type # when an observation_dtype argument is passed to the constructor. replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, observation_dtype=np.int32) self.assertEqual(replay.memory._store['observation'].dtype, np.int32)
def testDefaultObsDataType(self): # Tests that the default data type for observations is np.uint8 for # integration with Atari 2600. replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10) self.assertEqual(replay.memory._store['observation'].dtype, np.uint8)
def testWrapperLoad(self): replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) self.assertNotEqual(replay.memory._store['observation'], self._test_observation) self.assertNotEqual(replay.memory._store['action'], self._test_action) self.assertNotEqual(replay.memory._store['reward'], self._test_reward) self.assertNotEqual(replay.memory._store['terminal'], self._test_terminal) self.assertNotEqual(replay.memory.add_count, self._test_add_count) self.assertNotEqual(replay.memory.invalid_range, self._test_invalid_range) store_prefix = '$store$_' numpy_arrays = { store_prefix + 'observation': self._test_observation, store_prefix + 'action': self._test_action, store_prefix + 'reward': self._test_reward, store_prefix + 'terminal': self._test_terminal, 'add_count': self._test_add_count, 'invalid_range': self._test_invalid_range } for attr in numpy_arrays: filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr)) with tf.gfile.Open(filename, 'w') as f: with gzip.GzipFile(fileobj=f) as outfile: np.save(outfile, numpy_arrays[attr], allow_pickle=False) replay.load(self._test_subdir, '3') self.assertAllClose(replay.memory._store['observation'], self._test_observation) self.assertAllClose(replay.memory._store['action'], self._test_action) self.assertAllClose(replay.memory._store['reward'], self._test_reward) self.assertAllClose(replay.memory._store['terminal'], self._test_terminal) self.assertEqual(replay.memory.add_count, self._test_add_count) self.assertAllClose(replay.memory.invalid_range, self._test_invalid_range)
def testConstructorWithZeroUpdateHorizon(self): with self.assertRaisesRegexp(ValueError, r'Update horizon must be positive\.'): circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, update_horizon=0)
def _build_replay_buffer(self, use_staging, use_cyclic_buffer): """Creates the replay buffer used by the agent. Args: use_staging: bool, if True, uses a staging area to prefetch data for faster training. Returns: A WrapperReplayBuffer object. """ return circular_replay_buffer.WrappedReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, num_actions=self.num_actions, use_cyclic_buffer=use_cyclic_buffer, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype.as_numpy_dtype, extra_storage_types=[ circular_replay_buffer.ReplayElement('features', self._features_shape, np.float32), circular_replay_buffer.ReplayElement( 'state', self.buffer_state_shape, self.observation_dtype.as_numpy_dtype), circular_replay_buffer.ReplayElement( 'next_state', self.buffer_state_shape, self.observation_dtype.as_numpy_dtype), circular_replay_buffer.ReplayElement('next_action', (), np.int32), circular_replay_buffer.ReplayElement('next_reward', (), np.float32) ])
def testConstructorWithExtraStorageTypes(self): circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ])
def testConstructorCapacityNotLargeEnough(self): with self.assertRaisesRegexp( ValueError, r'Update horizon \(5\) should be significantly ' r'smaller than replay capacity \(5\)\.'): circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, update_horizon=5)
def testConstructorWithNoStaging(self): replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=100, batch_size=BATCH_SIZE, use_staging=False) with self.test_session() as sess: for i in range(BATCH_SIZE * 2): observation = np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE) replay.add(observation, 2, 1, 0) self._verify_sampled_trajectories(sess.run(replay.transition))
def testReplayBufferIsLocked(self): """Tests that the is properly checked.""" replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=(2, ), stack_size=1, replay_capacity=10, batch_size=2) lock = mock.MagicMock() replay.memory._lock = lock # Add one element. replay.add((1, 2), 0, 0, False) # Check that the lock went through the proper lock/unlock process. lock.__enter__.assert_called_once() lock.__exit__.assert_called_once()
def _build_replay_buffer(self, use_staging): """Creates the replay buffer used by the agent. Args: use_staging: bool, if True, uses a staging area to prefetch data for faster training. Returns: A WrapperReplayBuffer object. """ return circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma)
def _build_replay_buffer(self, use_staging): """Creates the replay buffer used by the agent. Args: use_staging: bool, if True, uses a staging area to prefetch data for faster training. Returns: A WrapperReplayBuffer object. """ return circular_replay_buffer.WrappedReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype.as_numpy_dtype)
def testWrapperSave(self): replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) replay.memory.observation = self._test_observation replay.memory.action = self._test_action replay.memory.reward = self._test_reward replay.memory.terminal = self._test_terminal replay.memory.add_count = self._test_add_count replay.memory.invalid_range = self._test_invalid_range replay.save(self._test_subdir, 3) for attr in replay.memory.__dict__: if attr.startswith('_'): continue filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr)) self.assertTrue(tf.gfile.Exists(filename))
def testConstructorWithStaging(self): replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=100, batch_size=BATCH_SIZE, use_staging=True) # When staging is on, replay._prefetch_batch tries to prefetch transitions # for efficient sampling. Since no transitions have been added, this raises # an error. with self.assertRaisesOpError( 'Cannot sample a batch with fewer than stack size'): self.evaluate(replay._prefetch_batch) with self.test_session() as sess: for i in range(BATCH_SIZE * 2): observation = np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE) replay.add(observation, 2, 1, 0) sess.run(replay._prefetch_batch) self._verify_sampled_trajectories(sess.run(replay.transition))
def _build_replay_buffer(self, use_staging): """Build WrappedReplayBuffer with custom OutOfGraphReplayBuffer.""" replay_buffer_kwargs = dict( observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE, stack_size=dqn_agent.NATURE_DQN_STACK_SIZE, replay_capacity=self._replay_capacity, batch_size=self._batch_size, update_horizon=self.update_horizon, gamma=self.gamma, extra_storage_types=None, observation_dtype=np.uint8, ) replay_memory = _OutOfGraphReplayBuffer( artificial_done=not self._generates_trainable_dones, **replay_buffer_kwargs) return circular_replay_buffer.WrappedReplayBuffer( wrapped_memory=replay_memory, use_staging=use_staging, **replay_buffer_kwargs)
def wrapped_replay_buffer(**kwargs): return circular_replay_buffer.WrappedReplayBuffer(**kwargs)
def __init__(self, sess, num_actions, observation_shape=atari_lib.NATURE_DQN_OBSERVATION_SHAPE, observation_dtype=atari_lib.NATURE_DQN_DTYPE, stack_size=atari_lib.NATURE_DQN_STACK_SIZE, network=atari_lib.nature_dqn_network, gamma=0.99, update_horizon=1, min_replay_history=20000, update_period=4, target_update_period=8000, exploration_mode='epsilon-greedy', entropy_fn=tsallis_entropy, entropic_index=0.5, alpha=0.1, epsilon_fn=linearly_decaying_epsilon, epsilon_train=0.01, epsilon_eval=0.001, epsilon_decay_period=250000, tf_device='/cpu:*', eval_mode=False, use_staging=True, max_tf_checkpoints_to_keep=4, optimizer=tf.train.RMSPropOptimizer(learning_rate=0.00025, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True), summary_writer=None, summary_writing_frequency=500, allow_partial_reload=False): """Initializes the agent and constructs the components of its graph. Args: sess: `tf.Session`, for executing ops. num_actions: int, number of actions the agent can take at any state. observation_shape: tuple of ints describing the observation shape. observation_dtype: tf.DType, specifies the type of the observations. Note that if your inputs are continuous, you should set this to tf.float32. stack_size: int, number of frames to use in state stack. network: function expecting three parameters: (num_actions, network_type, state). This function will return the network_type object containing the tensors output by the network. See dopamine.discrete_domains.atari_lib.nature_dqn_network as an example. gamma: float, discount factor with the usual RL meaning. update_horizon: int, horizon at which updates are performed, the 'n' in n-step update. min_replay_history: int, number of transitions that should be experienced before the agent begins training its value function. update_period: int, period between DQN updates. target_update_period: int, update period for the target network. epsilon_fn: function expecting 4 parameters: (decay_period, step, warmup_steps, epsilon). This function should return the epsilon value used for exploration during training. epsilon_train: float, the value to which the agent's epsilon is eventually decayed during training. epsilon_decay_period: int, length of the epsilon decay schedule. tf_device: str, Tensorflow device on which the agent's graph is executed. eval_mode: bool, True for evaluation and False for training. use_staging: bool, when True use a staging area to prefetch the next training batch, speeding training up by about 30%. max_tf_checkpoints_to_keep: int, the number of TensorFlow checkpoints to keep. optimizer: `tf.train.Optimizer`, for training the value function. summary_writer: SummaryWriter object for outputting training statistics. Summary writing disabled if set to None. summary_writing_frequency: int, frequency with which summaries will be written. Lower values will result in slower training. allow_partial_reload: bool, whether we allow reloading a partial agent (for instance, only the network parameters). """ assert isinstance(observation_shape, tuple) tf.logging.info('Creating %s agent with the following parameters:', self.__class__.__name__) tf.logging.info('\t gamma: %f', gamma) tf.logging.info('\t update_horizon: %f', update_horizon) tf.logging.info('\t min_replay_history: %d', min_replay_history) tf.logging.info('\t update_period: %d', update_period) tf.logging.info('\t target_update_period: %d', target_update_period) tf.logging.info('\t epsilon_train: %f', epsilon_train) tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period) tf.logging.info('\t tf_device: %s', tf_device) tf.logging.info('\t use_staging: %s', use_staging) tf.logging.info('\t optimizer: %s', optimizer) tf.logging.info('\t max_tf_checkpoints_to_keep: %d', max_tf_checkpoints_to_keep) self.num_actions = num_actions self.observation_shape = tuple(observation_shape) self.observation_dtype = observation_dtype self.stack_size = stack_size self.network = network self.gamma = gamma self.update_horizon = update_horizon self.cumulative_gamma = math.pow(gamma, update_horizon) self.min_replay_history = min_replay_history self.target_update_period = target_update_period self.exploration_mode = exploration_mode self.entropy_fn = entropy_fn self.entropic_index = entropic_index self.alpha = alpha self.epsilon_fn = epsilon_fn self.epsilon_train = epsilon_train self.epsilon_decay_period = epsilon_decay_period self.update_period = update_period self.eval_mode = eval_mode self.training_steps = 0 self.optimizer = optimizer self.summary_writer = summary_writer self.summary_writing_frequency = summary_writing_frequency self.allow_partial_reload = allow_partial_reload with tf.device(tf_device): # Create a placeholder for the state input to the DQN network. # The last axis indicates the number of consecutive frames stacked. state_shape = (1, ) + self.observation_shape + (stack_size, ) self.state = np.zeros(state_shape) self.state_ph = tf.placeholder(self.observation_dtype, state_shape, name='state_ph') self._replay = circular_replay_buffer.WrappedReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype.as_numpy_dtype) self._build_networks() self._train_op = self._build_train_op() self._sync_qt_ops = self._build_sync_op() if self.summary_writer is not None: # All tf.summaries should have been defined prior to running this. self._merged_summaries = tf.summary.merge_all() self._sess = sess self._saver = tf.train.Saver(max_to_keep=max_tf_checkpoints_to_keep) # Variables to be initialized by the agent once it interacts with the # environment. self._observation = None self._last_observation = None