def testCheckAddTypes(self):
        memory = circular_replay_buffer.ReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement("extra1", [], np.float32),
                circular_replay_buffer.ReplayElement("extra2", [2], np.int8),
            ],
        )
        zeros = np.zeros(OBSERVATION_SHAPE)

        memory._check_add_types(zeros, 0, 0, 0, 0, [0, 0])

        with self.assertRaisesRegexp(ValueError, "Add expects"):
            memory._check_add_types(zeros, 0, 0, 0)
    def testExtraAdd(self):
        memory = circular_replay_buffer.ReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement("extra1", [], np.float32),
                circular_replay_buffer.ReplayElement("extra2", [2], np.int8),
            ],
        )
        self.assertEqual(memory.cursor(), 0)
        zeros = np.zeros(OBSERVATION_SHAPE)
        memory.add(zeros, 0, 0, 0, 0, [0, 0])

        with self.assertRaisesRegexp(ValueError, "Add expects"):
            memory.add(zeros, 0, 0, 0)
        # Check if the cursor moved STACK_SIZE -1 zeros adds + 1, (the one above).
        self.assertEqual(memory.cursor(), STACK_SIZE)
    def testSampleTransitionBatchExtra(self):
        replay_capacity = 10
        memory = circular_replay_buffer.ReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=1,
            replay_capacity=replay_capacity,
            batch_size=2,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement("extra1", [], np.float32),
                circular_replay_buffer.ReplayElement("extra2", [2], np.int8),
            ],
        )
        num_adds = 50  # The number of transitions to add to the memory.
        for i in range(num_adds):
            memory.add(
                np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 0, i % 4, 0, [0, 0]
            )  # Every 4 transitions is terminal.
        # Test sampling with default batch size.
        for _i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)
        # Test changing batch sizes.
        for _i in range(1000):
            batch = memory.sample_transition_batch(BATCH_SIZE)
            self.assertEqual(batch[0].shape[0], BATCH_SIZE)
        # Verify we revert to default batch size.
        for _i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)

        # Verify we can specify what indices to sample.
        indices = [1, 2, 3, 5, 8]
        expected_states = np.array(
            [np.full(OBSERVATION_SHAPE + (1,), i, dtype=OBS_DTYPE) for i in indices]
        )
        expected_next_states = (expected_states + 1) % replay_capacity
        # Because the replay buffer is circular, we can exactly compute what the
        # states will be at the specified indices by doing a little mod math:
        expected_states += num_adds - replay_capacity
        expected_next_states += num_adds - replay_capacity
        # This is replicating the formula that was used above to determine what
        # transitions are terminal when adding observation (i % 4).
        expected_terminal = np.array(
            [min((x + num_adds - replay_capacity) % 4, 1) for x in indices]
        )
        expected_extra2 = np.zeros([len(indices), 2])
        batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices)
        (
            states,
            action,
            reward,
            next_states,
            next_action,
            next_reward,
            terminal,
            indices_batch,
            extra1,
            extra2,
        ) = batch
        npt.assert_array_equal(states, expected_states)
        npt.assert_array_equal(action, np.zeros(len(indices)))
        npt.assert_array_equal(reward, np.zeros(len(indices)))
        npt.assert_array_equal(next_action, np.zeros(len(indices)))
        npt.assert_array_equal(next_reward, np.zeros(len(indices)))
        npt.assert_array_equal(next_states, expected_next_states)
        npt.assert_array_equal(terminal, expected_terminal)
        npt.assert_array_equal(indices_batch, indices)
        npt.assert_array_equal(extra1, np.zeros(len(indices)))
        npt.assert_array_equal(extra2, expected_extra2)