def _build_replay_buffer(self, use_staging, use_cyclic_buffer): """Creates the replay buffer used by the agent. Args: use_staging: bool, if True, uses a staging area to prefetch data for faster training. Returns: A WrapperReplayBuffer object. """ return circular_replay_buffer.WrappedReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, num_actions=self.num_actions, use_cyclic_buffer=use_cyclic_buffer, use_staging=use_staging, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype.as_numpy_dtype, extra_storage_types=[ circular_replay_buffer.ReplayElement('features', self._features_shape, np.float32), circular_replay_buffer.ReplayElement( 'state', self.buffer_state_shape, self.observation_dtype.as_numpy_dtype), circular_replay_buffer.ReplayElement( 'next_state', self.buffer_state_shape, self.observation_dtype.as_numpy_dtype), circular_replay_buffer.ReplayElement('next_action', (), np.int32), circular_replay_buffer.ReplayElement('next_reward', (), np.float32) ])
def testSampleTransitionBatchExtra(self): replay_capacity = 10 memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=replay_capacity, batch_size=2, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ]) num_adds = 50 # The number of transitions to add to the memory. for i in range(num_adds): memory.add( np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), i, dtype=OBS_DTYPE), 0, 0, i % 4, 0, [0, 0]) # Every 4 transitions is terminal. # Test sampling with default batch size. for i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Test changing batch sizes. for i in range(1000): batch = memory.sample_transition_batch(BATCH_SIZE) self.assertEqual(batch[0].shape[0], BATCH_SIZE) # Verify we revert to default batch size. for i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Verify we can specify what indices to sample. indices = [1, 2, 3, 5, 8] expected_states = np.array([ np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE, 1), i, dtype=OBS_DTYPE) for i in indices ]) expected_next_states = (expected_states + 1) % replay_capacity # Because the replay buffer is circular, we can exactly compute what the # states will be at the specified indices by doing a little mod math: expected_states += num_adds - replay_capacity expected_next_states += num_adds - replay_capacity # This is replicating the formula that was used above to determine what # transitions are terminal when adding observation (i % 4). expected_terminal = np.array( [min((x + num_adds - replay_capacity) % 4, 1) for x in indices]) expected_extra2 = np.zeros([len(indices), 2]) batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices) (states, action, reward, next_states, terminal, indices_batch, extra1, extra2) = batch self.assertAllEqual(states, expected_states) self.assertAllEqual(action, np.zeros(len(indices))) self.assertAllEqual(reward, np.zeros(len(indices))) self.assertAllEqual(next_states, expected_next_states) self.assertAllEqual(terminal, expected_terminal) self.assertAllEqual(indices_batch, indices) self.assertAllEqual(extra1, np.zeros(len(indices))) self.assertAllEqual(extra2, expected_extra2)
def testConstructorWithExtraStorageTypes(self): circular_replay_buffer.WrappedReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ])
def get_storage_signature(self): storage_elements = [ circular_replay_buffer.ReplayElement('observation', self._observation_shape, self._observation_dtype), circular_replay_buffer.ReplayElement('action', (), np.int32), circular_replay_buffer.ReplayElement('reward', (), np.float64), circular_replay_buffer.ReplayElement('terminal', (), np.uint8) ] for extra_replay_element in self._extra_storage_types: storage_elements.append(extra_replay_element) return storage_elements
def testCheckAddTypes(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ]) zeros = np.zeros(OBSERVATION_SHAPE) memory._check_add_types(zeros, 0, 0, 0, 0, [0, 0]) with self.assertRaisesRegexp(ValueError, 'Add expects'): memory._check_add_types(zeros, 0, 0, 0)
def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """Initialize the FixedReplayBuffer class. Args: data_dir: str, log Directory from which to load the replay buffer. replay_suffix: int, If not None, then only load the replay buffer corresponding to the specific suffix in data directory. *args: Arbitrary extra arguments. **kwargs: Arbitrary keyword arguments. """ self._args = args self._kwargs = kwargs if use_off_policy_replay_buffer: if self._kwargs['extra_storage_types'] is None: self._kwargs['extra_storage_types'] = [] self._kwargs['extra_storage_types'].append( circular_replay_buffer.ReplayElement('prob', [], np.float32)) self._data_dir = data_dir self._loaded_buffers = False self.add_count = np.array(0) self._replay_suffix = replay_suffix self._maxbuffernum = kwargs.pop('maxbuffernum') self._stratified_sample = kwargs.pop('stratified_sample') self._inorder = kwargs.pop('inorder') self._prefer_early = kwargs.pop('prefer_early') self._loadcount = 0 while not self._loaded_buffers: if replay_suffix: assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' self.load_single_buffer(replay_suffix) else: self._load_replay_buffers(num_buffers=1) self._loadcount = 0
def get_transition_elements(self, batch_size=None): batch_size = self._batch_size if batch_size is None else batch_size transition_elements = [ circular_replay_buffer.ReplayElement('state', (batch_size, ) + self._state_shape, self._observation_dtype), circular_replay_buffer.ReplayElement('action', (batch_size, ), np.int32), circular_replay_buffer.ReplayElement('reward', (batch_size, ), np.float64), circular_replay_buffer.ReplayElement('next_state', (batch_size, ) + self._state_shape, self._observation_dtype), circular_replay_buffer.ReplayElement('terminal', (batch_size, ), np.uint8), circular_replay_buffer.ReplayElement('indices', (batch_size, ), np.int32) ] for element in self._extra_storage_types: transition_elements.append( circular_replay_buffer.ReplayElement( element.name, (batch_size, ) + tuple(element.shape), element.type)) return transition_elements
def testExtraAdd(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ]) self.assertEqual(memory.cursor(), 0) zeros = np.zeros(OBSERVATION_SHAPE) memory.add(zeros, 0, 0, 0, 0, [0, 0]) with self.assertRaisesRegexp(ValueError, 'Add expects'): memory.add(zeros, 0, 0, 0) # Check if the cursor moved STACK_SIZE -1 zeros adds + 1, (the one above). self.assertEqual(memory.cursor(), STACK_SIZE)
def testAddWithAdditionalArgsAndPriority(self): memory = self.create_default_memory(extra_storage_types=[ circular_replay_buffer.ReplayElement('test_item', (), np.float32) ]) self.assertEqual(memory.cursor(), 0) zeros = np.zeros(SCREEN_SIZE) memory.add(zeros, 0, 0.0, 0, 0.0, priority=1.0) self.assertEqual(memory.cursor(), STACK_SIZE) self.assertEqual(memory.add_count, STACK_SIZE) # Check that the prioritized replay buffer expects an additional argument # for test_item. with self.assertRaisesRegexp(ValueError, 'Add expects'): memory.add(zeros, 0, 0, 0, priority=1.0)
def get_add_args_signature(self): """The signature of the add function. The signature is the same as the one for OutOfGraphReplayBuffer, with an added priority. Returns: list of ReplayElements defining the type of the argument signature needed by the add function. """ parent_add_signature = super(SAILOutOfGraphPrioritizedReplayBuffer, self).get_add_args_signature() add_signature = parent_add_signature + [ crb.ReplayElement('priority', (), np.float32) ] return add_signature
def get_transition_elements(self, batch_size=None): """Returns a 'type signature' for sample_transition_batch. Args: batch_size: int, number of transitions returned. If None, the default batch_size will be used. Returns: signature: A namedtuple describing the method's return type signature. """ parent_transition_type = (super( SAILOutOfGraphPrioritizedReplayBuffer, self).get_transition_elements(batch_size)) probablilities_type = [ crb.ReplayElement('sampling_probabilities', (batch_size, ), np.float32) ] return parent_transition_type + probablilities_type
def get_storage_signature(self): """Returns a default list of elements to be stored in this replay memory. Note - Derived classes may return a different signature. Returns: list of ReplayElements defining the type of the contents stored. """ storage_elements = [ crb.ReplayElement('observation', self._observation_shape, self._observation_dtype), crb.ReplayElement('action', self._action_shape, self._action_dtype), crb.ReplayElement('reward', self._reward_shape, self._reward_dtype), crb.ReplayElement('terminal', (), self._terminal_dtype), crb.ReplayElement('return', (), np.float32), crb.ReplayElement('episode_num', (), np.int32), ] for extra_replay_element in self._extra_storage_types: storage_elements.append(extra_replay_element) return storage_elements
def get_transition_elements(self, batch_size=None): """Returns a 'type signature' for sample_transition_batch. Args: batch_size: int, number of transitions returned. If None, the default batch_size will be used. Returns: signature: A namedtuple describing the method's return type signature. """ batch_size = self._batch_size if batch_size is None else batch_size transition_elements = [ crb.ReplayElement('state', (batch_size, ) + self._state_shape, self._observation_dtype), crb.ReplayElement('action', (batch_size, ) + self._action_shape, self._action_dtype), crb.ReplayElement('reward', (batch_size, ) + self._reward_shape, self._reward_dtype), crb.ReplayElement('next_state', (batch_size, ) + self._state_shape, self._observation_dtype), crb.ReplayElement('next_action', (batch_size, ) + self._action_shape, self._action_dtype), crb.ReplayElement('next_reward', (batch_size, ) + self._reward_shape, self._reward_dtype), crb.ReplayElement('terminal', (batch_size, ), self._terminal_dtype), crb.ReplayElement('indices', (batch_size, ), np.int32), crb.ReplayElement('return', (batch_size, ), np.float32), crb.ReplayElement('episode_num', (batch_size, ), np.int32), ] for element in self._extra_storage_types: transition_elements.append( crb.ReplayElement(element.name, (batch_size, ) + tuple(element.shape), element.type)) return transition_elements