Пример #1
0
    def _build_replay_buffer(self, use_staging, use_cyclic_buffer):
        """Creates the replay buffer used by the agent.

        Args:
          use_staging: bool, if True, uses a staging area to prefetch data for
            faster training.

        Returns:
          A WrapperReplayBuffer object.
        """
        return circular_replay_buffer.WrappedReplayBuffer(
            observation_shape=self.observation_shape,
            stack_size=self.stack_size,
            num_actions=self.num_actions,
            use_cyclic_buffer=use_cyclic_buffer,
            use_staging=use_staging,
            update_horizon=self.update_horizon,
            gamma=self.gamma,
            observation_dtype=self.observation_dtype.as_numpy_dtype,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('features',
                                                     self._features_shape,
                                                     np.float32),
                circular_replay_buffer.ReplayElement(
                    'state', self.buffer_state_shape,
                    self.observation_dtype.as_numpy_dtype),
                circular_replay_buffer.ReplayElement(
                    'next_state', self.buffer_state_shape,
                    self.observation_dtype.as_numpy_dtype),
                circular_replay_buffer.ReplayElement('next_action', (),
                                                     np.int32),
                circular_replay_buffer.ReplayElement('next_reward', (),
                                                     np.float32)
            ])
    def testSampleTransitionBatchExtra(self):
        replay_capacity = 10
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=1,
            replay_capacity=replay_capacity,
            batch_size=2,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('extra1', [], np.float32),
                circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
            ])
        num_adds = 50  # The number of transitions to add to the memory.
        for i in range(num_adds):
            memory.add(
                np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE),
                        i,
                        dtype=OBS_DTYPE), 0, 0, i % 4, 0,
                [0, 0])  # Every 4 transitions is terminal.
        # Test sampling with default batch size.
        for i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)
        # Test changing batch sizes.
        for i in range(1000):
            batch = memory.sample_transition_batch(BATCH_SIZE)
            self.assertEqual(batch[0].shape[0], BATCH_SIZE)
        # Verify we revert to default batch size.
        for i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)

        # Verify we can specify what indices to sample.
        indices = [1, 2, 3, 5, 8]
        expected_states = np.array([
            np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE, 1),
                    i,
                    dtype=OBS_DTYPE) for i in indices
        ])
        expected_next_states = (expected_states + 1) % replay_capacity
        # Because the replay buffer is circular, we can exactly compute what the
        # states will be at the specified indices by doing a little mod math:
        expected_states += num_adds - replay_capacity
        expected_next_states += num_adds - replay_capacity
        # This is replicating the formula that was used above to determine what
        # transitions are terminal when adding observation (i % 4).
        expected_terminal = np.array(
            [min((x + num_adds - replay_capacity) % 4, 1) for x in indices])
        expected_extra2 = np.zeros([len(indices), 2])
        batch = memory.sample_transition_batch(batch_size=len(indices),
                                               indices=indices)
        (states, action, reward, next_states, terminal, indices_batch, extra1,
         extra2) = batch
        self.assertAllEqual(states, expected_states)
        self.assertAllEqual(action, np.zeros(len(indices)))
        self.assertAllEqual(reward, np.zeros(len(indices)))
        self.assertAllEqual(next_states, expected_next_states)
        self.assertAllEqual(terminal, expected_terminal)
        self.assertAllEqual(indices_batch, indices)
        self.assertAllEqual(extra1, np.zeros(len(indices)))
        self.assertAllEqual(extra2, expected_extra2)
 def testConstructorWithExtraStorageTypes(self):
     circular_replay_buffer.WrappedReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         extra_storage_types=[
             circular_replay_buffer.ReplayElement('extra1', [], np.float32),
             circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
         ])
Пример #4
0
    def get_storage_signature(self):
        storage_elements = [
            circular_replay_buffer.ReplayElement('observation',
                                                 self._observation_shape,
                                                 self._observation_dtype),
            circular_replay_buffer.ReplayElement('action', (), np.int32),
            circular_replay_buffer.ReplayElement('reward', (), np.float64),
            circular_replay_buffer.ReplayElement('terminal', (), np.uint8)
        ]

        for extra_replay_element in self._extra_storage_types:
            storage_elements.append(extra_replay_element)
        return storage_elements
    def testCheckAddTypes(self):
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('extra1', [], np.float32),
                circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
            ])
        zeros = np.zeros(OBSERVATION_SHAPE)

        memory._check_add_types(zeros, 0, 0, 0, 0, [0, 0])

        with self.assertRaisesRegexp(ValueError, 'Add expects'):
            memory._check_add_types(zeros, 0, 0, 0)
Пример #6
0
    def __init__(self, data_dir, replay_suffix, *args, **kwargs):  # pylint: disable=keyword-arg-before-vararg
        """Initialize the FixedReplayBuffer class.

    Args:
      data_dir: str, log Directory from which to load the replay buffer.
      replay_suffix: int, If not None, then only load the replay buffer
        corresponding to the specific suffix in data directory.
      *args: Arbitrary extra arguments.
      **kwargs: Arbitrary keyword arguments.
    """
        self._args = args
        self._kwargs = kwargs
        if use_off_policy_replay_buffer:
            if self._kwargs['extra_storage_types'] is None:
                self._kwargs['extra_storage_types'] = []
            self._kwargs['extra_storage_types'].append(
                circular_replay_buffer.ReplayElement('prob', [], np.float32))
        self._data_dir = data_dir
        self._loaded_buffers = False
        self.add_count = np.array(0)
        self._replay_suffix = replay_suffix
        self._maxbuffernum = kwargs.pop('maxbuffernum')
        self._stratified_sample = kwargs.pop('stratified_sample')
        self._inorder = kwargs.pop('inorder')
        self._prefer_early = kwargs.pop('prefer_early')
        self._loadcount = 0
        while not self._loaded_buffers:
            if replay_suffix:
                assert replay_suffix >= 0, 'Please pass a non-negative replay suffix'
                self.load_single_buffer(replay_suffix)
            else:
                self._load_replay_buffers(num_buffers=1)
        self._loadcount = 0
Пример #7
0
    def get_transition_elements(self, batch_size=None):
        batch_size = self._batch_size if batch_size is None else batch_size

        transition_elements = [
            circular_replay_buffer.ReplayElement('state', (batch_size, ) +
                                                 self._state_shape,
                                                 self._observation_dtype),
            circular_replay_buffer.ReplayElement('action', (batch_size, ),
                                                 np.int32),
            circular_replay_buffer.ReplayElement('reward', (batch_size, ),
                                                 np.float64),
            circular_replay_buffer.ReplayElement('next_state', (batch_size, ) +
                                                 self._state_shape,
                                                 self._observation_dtype),
            circular_replay_buffer.ReplayElement('terminal', (batch_size, ),
                                                 np.uint8),
            circular_replay_buffer.ReplayElement('indices', (batch_size, ),
                                                 np.int32)
        ]
        for element in self._extra_storage_types:
            transition_elements.append(
                circular_replay_buffer.ReplayElement(
                    element.name, (batch_size, ) + tuple(element.shape),
                    element.type))
        return transition_elements
    def testExtraAdd(self):
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('extra1', [], np.float32),
                circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
            ])
        self.assertEqual(memory.cursor(), 0)
        zeros = np.zeros(OBSERVATION_SHAPE)
        memory.add(zeros, 0, 0, 0, 0, [0, 0])

        with self.assertRaisesRegexp(ValueError, 'Add expects'):
            memory.add(zeros, 0, 0, 0)
        # Check if the cursor moved STACK_SIZE -1 zeros adds + 1, (the one above).
        self.assertEqual(memory.cursor(), STACK_SIZE)
    def testAddWithAdditionalArgsAndPriority(self):
        memory = self.create_default_memory(extra_storage_types=[
            circular_replay_buffer.ReplayElement('test_item', (), np.float32)
        ])
        self.assertEqual(memory.cursor(), 0)
        zeros = np.zeros(SCREEN_SIZE)

        memory.add(zeros, 0, 0.0, 0, 0.0, priority=1.0)
        self.assertEqual(memory.cursor(), STACK_SIZE)
        self.assertEqual(memory.add_count, STACK_SIZE)

        # Check that the prioritized replay buffer expects an additional argument
        # for test_item.
        with self.assertRaisesRegexp(ValueError, 'Add expects'):
            memory.add(zeros, 0, 0, 0, priority=1.0)
Пример #10
0
    def get_add_args_signature(self):
        """The signature of the add function.

    The signature is the same as the one for OutOfGraphReplayBuffer, with an
    added priority.

    Returns:
      list of ReplayElements defining the type of the argument signature needed
        by the add function.
    """
        parent_add_signature = super(SAILOutOfGraphPrioritizedReplayBuffer,
                                     self).get_add_args_signature()
        add_signature = parent_add_signature + [
            crb.ReplayElement('priority', (), np.float32)
        ]
        return add_signature
Пример #11
0
    def get_transition_elements(self, batch_size=None):
        """Returns a 'type signature' for sample_transition_batch.

    Args:
      batch_size: int, number of transitions returned. If None, the default
        batch_size will be used.
    Returns:
      signature: A namedtuple describing the method's return type signature.
    """
        parent_transition_type = (super(
            SAILOutOfGraphPrioritizedReplayBuffer,
            self).get_transition_elements(batch_size))
        probablilities_type = [
            crb.ReplayElement('sampling_probabilities', (batch_size, ),
                              np.float32)
        ]
        return parent_transition_type + probablilities_type
Пример #12
0
    def get_storage_signature(self):
        """Returns a default list of elements to be stored in this replay memory.

    Note - Derived classes may return a different signature.

    Returns:
      list of ReplayElements defining the type of the contents stored.
    """
        storage_elements = [
            crb.ReplayElement('observation', self._observation_shape,
                              self._observation_dtype),
            crb.ReplayElement('action', self._action_shape,
                              self._action_dtype),
            crb.ReplayElement('reward', self._reward_shape,
                              self._reward_dtype),
            crb.ReplayElement('terminal', (), self._terminal_dtype),
            crb.ReplayElement('return', (), np.float32),
            crb.ReplayElement('episode_num', (), np.int32),
        ]

        for extra_replay_element in self._extra_storage_types:
            storage_elements.append(extra_replay_element)
        return storage_elements
Пример #13
0
    def get_transition_elements(self, batch_size=None):
        """Returns a 'type signature' for sample_transition_batch.

    Args:
      batch_size: int, number of transitions returned. If None, the default
        batch_size will be used.
    Returns:
      signature: A namedtuple describing the method's return type signature.
    """
        batch_size = self._batch_size if batch_size is None else batch_size

        transition_elements = [
            crb.ReplayElement('state', (batch_size, ) + self._state_shape,
                              self._observation_dtype),
            crb.ReplayElement('action', (batch_size, ) + self._action_shape,
                              self._action_dtype),
            crb.ReplayElement('reward', (batch_size, ) + self._reward_shape,
                              self._reward_dtype),
            crb.ReplayElement('next_state', (batch_size, ) + self._state_shape,
                              self._observation_dtype),
            crb.ReplayElement('next_action',
                              (batch_size, ) + self._action_shape,
                              self._action_dtype),
            crb.ReplayElement('next_reward',
                              (batch_size, ) + self._reward_shape,
                              self._reward_dtype),
            crb.ReplayElement('terminal', (batch_size, ),
                              self._terminal_dtype),
            crb.ReplayElement('indices', (batch_size, ), np.int32),
            crb.ReplayElement('return', (batch_size, ), np.float32),
            crb.ReplayElement('episode_num', (batch_size, ), np.int32),
        ]
        for element in self._extra_storage_types:
            transition_elements.append(
                crb.ReplayElement(element.name,
                                  (batch_size, ) + tuple(element.shape),
                                  element.type))
        return transition_elements