def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        buffer_dim = self._replay_description.num_columns
        device = self._device
        self._sampler = offset_sampler.OffsetSampler(
            self._buffer,
            self._params.sequence_length,
            self._replay_description.get_index("weight_column"),
            self._replay_description.get_index("cum_weight_column"),
            self._replay_description.get_index("time_left"),
            self._params.discount_factor,
            soft_cutoff=False)

        self._density_buffer = ram_buffer.RamBuffer(
            self._params.density_replay_size, buffer_dim, device)
        self._density_sampler = offset_sampler.OffsetSampler(
            self._density_buffer,
            self._params.sequence_length,
            self._replay_description.get_index("weight_column"),
            self._replay_description.get_index("cum_weight_column"),
            self._replay_description.get_index("time_left"),
            self._params.discount_factor,
            soft_cutoff=False)

        self._valid_buffer = ram_buffer.RamBuffer(
            int(self._params.replay_size * self._params.validation_fraction),
            buffer_dim, device)
        self._valid_sampler = offset_sampler.OffsetSampler(
            self._valid_buffer,
            self._params.sequence_length,
            self._replay_description.get_index("weight_column"),
            self._replay_description.get_index("cum_weight_column"),
            self._replay_description.get_index("time_left"),
            self._params.discount_factor,
            soft_cutoff=False)
    def test_uniform_hard_cutoff(self):
        replay = ram_buffer.RamBuffer(14, 4, torch.device('cpu'))
        sampler = offset_sampler.OffsetSampler(replay, 5, 1, 2, 3, 0.5, False)

        sampler.add_episode(
            np.array([
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [2, 0, 0, 0],
                [3, 0, 0, 0],
                [4, 0, 0, 0],
                [5, 0, 0, 0],
                [6, 0, 0, 0],
            ]))
        sampler.add_episode(
            np.array([
                [10, 0, 0, 0],
                [11, 0, 0, 0],
                [12, 0, 0, 0],
                [13, 0, 0, 0],
                [14, 0, 0, 0],
                [15, 0, 0, 0],
                [16, 0, 0, 0],
                [17, 0, 0, 0],
            ]))

        samples, offsets = sampler.sample_uniform_offset(10000)
        valid_nums = [1, 2, 10, 11, 12, 13]
        counts = np.array([(samples[:, 0, 0] == i).sum() for i in valid_nums])
        self.assertTrue(
            (samples[:, 0, 0] // 10 == samples[:, 1, 0] // 10).all().item())
        self.assertTrue(np.all(counts > 0))
        self.assertTrue(counts.sum() == 10000)
        self.assertTrue(np.all(2300 > np.bincount(offsets)))
        self.assertTrue(np.all(np.bincount(offsets) > 1700))
    def test_uniform_soft_cutoff(self):
        replay = ram_buffer.RamBuffer(14, 4, torch.device('cpu'))
        sampler = offset_sampler.OffsetSampler(replay, 5, 1, 2, 3, 0.5, True)

        sampler.add_episode(
            np.array([
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [2, 0, 0, 0],
                [3, 0, 0, 0],
                [4, 0, 0, 0],
                [5, 0, 0, 0],
                [6, 0, 0, 0],
            ],
                     dtype=np.float32))
        sampler.add_episode(
            np.array([
                [10, 0, 0, 0],
                [11, 0, 0, 0],
                [12, 0, 0, 0],
                [13, 0, 0, 0],
                [14, 0, 0, 0],
                [15, 0, 0, 0],
                [16, 0, 0, 0],
                [17, 0, 0, 0],
            ],
                     dtype=np.float32))

        samples, offsets = sampler.sample_uniform_offset(10000)
        valid_nums = [1, 2, 3, 4, 5, 6, 10, 11, 12, 13, 14, 15, 16, 17]
        counts = np.array([(samples[:, 0, 0] == i).sum() for i in valid_nums])
        self.assertTrue(
            (samples[:, 0, 0] // 10 == samples[:, 1, 0] // 10).all().item())
        self.assertTrue(np.all(counts > 0))
        self.assertTrue(counts.sum() == 10000)
        offset_bincount = np.bincount(offsets)
        for i in range(4):
            self.assertTrue(offset_bincount[i + 1] < offset_bincount[i])
    def test_geometric_hard_cutoff(self):
        replay = ram_buffer.RamBuffer(14, 4, torch.device('cpu'))
        sampler = offset_sampler.OffsetSampler(replay, 5, 1, 2, 3, 0.5, False)

        sampler.add_episode(
            np.array([
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [2, 0, 0, 0],
                [3, 0, 0, 0],
                [4, 0, 0, 0],
                [5, 0, 0, 0],
                [6, 0, 0, 0],
            ]))
        sampler.add_episode(
            np.array([
                [10, 0, 0, 0],
                [11, 0, 0, 0],
                [12, 0, 0, 0],
                [13, 0, 0, 0],
                [14, 0, 0, 0],
                [15, 0, 0, 0],
                [16, 0, 0, 0],
                [17, 0, 0, 0],
            ]))

        samples, offsets = sampler.sample_discounted_offset(10000)
        valid_nums = [1, 2, 10, 11, 12, 13]
        counts = np.array([(samples[:, 0, 0] == i).sum() for i in valid_nums])
        self.assertTrue(
            (samples[:, 0, 0] // 10 == samples[:, 1, 0] // 10).all().item())
        self.assertTrue(np.all(counts > 0))
        self.assertTrue(counts.sum() == 10000)
        offset_bincount = np.bincount(offsets)
        for i in range(4):
            self.assertTrue(
                2.2 > offset_bincount[i] / offset_bincount[i + 1] > 1.8)