def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        buffer_dim = self._replay_description.num_columns
        device = self._device
        self._sampler = offset_sampler.OffsetSampler(
            self._buffer,
            self._params.sequence_length,
            self._replay_description.get_index("weight_column"),
            self._replay_description.get_index("cum_weight_column"),
            self._replay_description.get_index("time_left"),
            self._params.discount_factor,
            soft_cutoff=False)

        self._density_buffer = ram_buffer.RamBuffer(
            self._params.density_replay_size, buffer_dim, device)
        self._density_sampler = offset_sampler.OffsetSampler(
            self._density_buffer,
            self._params.sequence_length,
            self._replay_description.get_index("weight_column"),
            self._replay_description.get_index("cum_weight_column"),
            self._replay_description.get_index("time_left"),
            self._params.discount_factor,
            soft_cutoff=False)

        self._valid_buffer = ram_buffer.RamBuffer(
            int(self._params.replay_size * self._params.validation_fraction),
            buffer_dim, device)
        self._valid_sampler = offset_sampler.OffsetSampler(
            self._valid_buffer,
            self._params.sequence_length,
            self._replay_description.get_index("weight_column"),
            self._replay_description.get_index("cum_weight_column"),
            self._replay_description.get_index("time_left"),
            self._params.discount_factor,
            soft_cutoff=False)
    def __init__(self,
                 make_env: Callable[[],
                                    environment.Environment[Union[np.ndarray,
                                                                  int]]],
                 device: torch.device, params: AsyncReplayAgentParams):
        self._params = params
        self._device = device
        self._actor_device = torch.device(self._params.actor_device)
        self._executor = futures.ProcessPoolExecutor(
            max_workers=self._params.num_envs,
            mp_context=torch.multiprocessing.get_context('spawn'),
            initializer=self._initialize_process_env,
            initargs=(make_env, ))
        self._futures = None

        self._make_env = make_env
        self._env = make_env()
        self._state_dim = self._env.state_dim
        self._action_dim = self._env.action_dim

        self._replay_description = self.get_description(self._env)
        self._buffer = ram_buffer.RamBuffer(
            self._params.replay_size, self._replay_description.num_columns,
            device)

        reporting.register_field("return")
        self._env_steps = 0
        reporting.register_field("env_steps")
        self.__update_count = 0
    def test_uniform_hard_cutoff(self):
        replay = ram_buffer.RamBuffer(14, 4, torch.device('cpu'))
        sampler = offset_sampler.OffsetSampler(replay, 5, 1, 2, 3, 0.5, False)

        sampler.add_episode(
            np.array([
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [2, 0, 0, 0],
                [3, 0, 0, 0],
                [4, 0, 0, 0],
                [5, 0, 0, 0],
                [6, 0, 0, 0],
            ]))
        sampler.add_episode(
            np.array([
                [10, 0, 0, 0],
                [11, 0, 0, 0],
                [12, 0, 0, 0],
                [13, 0, 0, 0],
                [14, 0, 0, 0],
                [15, 0, 0, 0],
                [16, 0, 0, 0],
                [17, 0, 0, 0],
            ]))

        samples, offsets = sampler.sample_uniform_offset(10000)
        valid_nums = [1, 2, 10, 11, 12, 13]
        counts = np.array([(samples[:, 0, 0] == i).sum() for i in valid_nums])
        self.assertTrue(
            (samples[:, 0, 0] // 10 == samples[:, 1, 0] // 10).all().item())
        self.assertTrue(np.all(counts > 0))
        self.assertTrue(counts.sum() == 10000)
        self.assertTrue(np.all(2300 > np.bincount(offsets)))
        self.assertTrue(np.all(np.bincount(offsets) > 1700))
Example #4
0
    def __init__(self, make_env: Callable[[], environment.Environment[Union[np.ndarray, int]]], device: torch.device,
                 params: ReplayAgentParams):
        self._params = params
        self._env = make_env()
        self._env.reset()
        self._state_dim = self._env.state_dim
        self._action_dim = self._env.action_dim
        self._device = device

        # state, action, reward, next_state, timeout, terminal, action_logprob
        buffer_dim = 2 * self._state_dim + self._action_dim + 4
        if params.buffer_type == 'ram':
            self._buffer = ram_buffer.RamBuffer(self._params.replay_size, buffer_dim, device)
        elif params.buffer_type == 'vram':
            self._buffer = vram_buffer.VramBuffer(self._params.replay_size, buffer_dim, device)
        else:
            assert False

        reporting.register_field("return")
    def test_uniform_soft_cutoff(self):
        replay = ram_buffer.RamBuffer(14, 4, torch.device('cpu'))
        sampler = offset_sampler.OffsetSampler(replay, 5, 1, 2, 3, 0.5, True)

        sampler.add_episode(
            np.array([
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [2, 0, 0, 0],
                [3, 0, 0, 0],
                [4, 0, 0, 0],
                [5, 0, 0, 0],
                [6, 0, 0, 0],
            ],
                     dtype=np.float32))
        sampler.add_episode(
            np.array([
                [10, 0, 0, 0],
                [11, 0, 0, 0],
                [12, 0, 0, 0],
                [13, 0, 0, 0],
                [14, 0, 0, 0],
                [15, 0, 0, 0],
                [16, 0, 0, 0],
                [17, 0, 0, 0],
            ],
                     dtype=np.float32))

        samples, offsets = sampler.sample_uniform_offset(10000)
        valid_nums = [1, 2, 3, 4, 5, 6, 10, 11, 12, 13, 14, 15, 16, 17]
        counts = np.array([(samples[:, 0, 0] == i).sum() for i in valid_nums])
        self.assertTrue(
            (samples[:, 0, 0] // 10 == samples[:, 1, 0] // 10).all().item())
        self.assertTrue(np.all(counts > 0))
        self.assertTrue(counts.sum() == 10000)
        offset_bincount = np.bincount(offsets)
        for i in range(4):
            self.assertTrue(offset_bincount[i + 1] < offset_bincount[i])
    def test_geometric_hard_cutoff(self):
        replay = ram_buffer.RamBuffer(14, 4, torch.device('cpu'))
        sampler = offset_sampler.OffsetSampler(replay, 5, 1, 2, 3, 0.5, False)

        sampler.add_episode(
            np.array([
                [0, 0, 0, 0],
                [1, 0, 0, 0],
                [2, 0, 0, 0],
                [3, 0, 0, 0],
                [4, 0, 0, 0],
                [5, 0, 0, 0],
                [6, 0, 0, 0],
            ]))
        sampler.add_episode(
            np.array([
                [10, 0, 0, 0],
                [11, 0, 0, 0],
                [12, 0, 0, 0],
                [13, 0, 0, 0],
                [14, 0, 0, 0],
                [15, 0, 0, 0],
                [16, 0, 0, 0],
                [17, 0, 0, 0],
            ]))

        samples, offsets = sampler.sample_discounted_offset(10000)
        valid_nums = [1, 2, 10, 11, 12, 13]
        counts = np.array([(samples[:, 0, 0] == i).sum() for i in valid_nums])
        self.assertTrue(
            (samples[:, 0, 0] // 10 == samples[:, 1, 0] // 10).all().item())
        self.assertTrue(np.all(counts > 0))
        self.assertTrue(counts.sum() == 10000)
        offset_bincount = np.bincount(offsets)
        for i in range(4):
            self.assertTrue(
                2.2 > offset_bincount[i] / offset_bincount[i + 1] > 1.8)
    def test_sample_weights(self):
        replay = ram_buffer.RamBuffer(5, 4, torch.device('cpu'))
        sampler = weighted_sampler.WeightedSampler(replay, 2, 3)
        replay.add_samples(np.array([[0, 0, 1, 0], [1, 1, 1, 0], [2, 2, 2,
                                                                  0]]))
        replay.add_samples(np.array([
            [3, 3, 0, 0],
            [4, 4, 3, 0],
        ]))

        samples = sampler.sample_weighted(100000).detach().cpu().numpy()
        sample_bincount = np.bincount(samples[:, 0].astype(np.int64))
        self.assertTrue(2.2 > sample_bincount[2] / sample_bincount[0] > 1.8)
        self.assertTrue(3.3 > sample_bincount[4] / sample_bincount[1] > 2.7)
        self.assertEqual(sample_bincount[3], 0)
        replay.add_samples(np.array([
            [5, 0, 1, 0],
            [6, 1, 3, 0],
        ]))
        samples = sampler.sample_weighted(100000).detach().cpu().numpy()
        sample_bincount = np.bincount(samples[:, 1].astype(np.int64))
        self.assertTrue(2.2 > sample_bincount[2] / sample_bincount[0] > 1.8)
        self.assertTrue(1.2 > sample_bincount[4] / sample_bincount[1] > 0.8)
        self.assertEqual(sample_bincount[3], 0)
 def build_buffer(self):
     self._device = torch.device('cpu')
     return ram_buffer.RamBuffer(5, 2, device=self._device)
        else:
            before_head_indices = np.searchsorted(before_head[:, self._cum_weight_column], weight_samples)
            after_head_indices = np.searchsorted(after_head[:, self._cum_weight_column], weight_samples)
            head = before_head.shape[0]
            indices = np.where(after_head_indices == self._ram_buffer.size - head,
                               before_head_indices + self._ram_buffer.size - head, after_head_indices)

        return indices

    def sample_weighted(self, batch_size: int) -> torch.Tensor:
        indices = self.sample_weighted_indices(batch_size)
        return self._ram_buffer.load(indices)


if __name__ == '__main__':
    replay = ram_buffer.RamBuffer(5, 4, torch.device('cpu'))
    sampler = WeightedSampler(replay, 2, 3)
    replay.add_samples(np.array([
        [0, 0, 1, 0],
        [1, 1, 1, 0],
        [2, 2, 2, 0]
    ]))
    replay.add_samples(np.array([
        [3, 3, 0, 0],
        [4, 4, 3, 0],
    ]))

    samples = sampler.sample_weighted(10000).detach().cpu().numpy()
    print(np.bincount(samples[:, 0].astype(np.int64)))
    replay.add_samples(np.array([
        [5, 0, 1, 0],