Ejemplo n.º 1
0
    def __init__(self, n_envs, epslen, n_minibatches, state_shape, state_dtype,
                 action_shape, action_dtype):
        self.n_envs = n_envs
        self.epslen = epslen
        self.n_minibatches = n_minibatches

        self.indices = np.arange(self.n_envs)
        self.minibatch_size = (self.n_envs) // self.n_minibatches

        assert_colorize(
            n_envs // n_minibatches * n_minibatches == n_envs,
            f'#envs({n_envs}) is not divisible by #minibatches{n_minibatches}')

        self.basic_shape = (n_envs, epslen)
        super().__init__(
            state=np.zeros((*self.basic_shape, *state_shape),
                           dtype=state_dtype),
            action=np.zeros((*self.basic_shape, *action_shape),
                            dtype=action_dtype),
            reward=np.zeros((*self.basic_shape, 1), dtype=np.float32),
            nonterminal=np.zeros((*self.basic_shape, 1), dtype=np.float32),
            value=np.zeros((n_envs, epslen + 1, 1), dtype=np.float32),
            traj_ret=np.zeros((*self.basic_shape, 1), dtype=np.float32),
            advantage=np.zeros((*self.basic_shape, 1), dtype=np.float32),
            old_logpi=np.zeros((*self.basic_shape, 1), dtype=np.float32),
            mask=np.zeros((*self.basic_shape, 1), dtype=np.float32),
        )

        self.reset()
Ejemplo n.º 2
0
    def add(self, **data):
        assert_colorize(
            self.idx < self.epslen,
            f'Out-of-range idx {self.idx}. Call "self.reset" beforehand')
        idx = self.idx

        for k, v in data.items():
            if v is not None:
                self[k][:, idx] = v

        self.idx += 1
Ejemplo n.º 3
0
    def _log_tabular(self, key, val):
        """
        Log a value of some diagnostic.

        Call this only once for each diagnostic quantity, each iteration.
        After using ``log_tabular`` to store values for each diagnostic,
        make sure to call ``dump_tabular`` to write them out to file and
        stdout (otherwise they will not get saved anywhere).
        """
        if self.first_row:
            self.log_headers.append(key)
        else:
            assert_colorize(key in self.log_headers, f"Trying to introduce a new key {key} that you didn't include in the first iteration")
        assert_colorize(key not in self.log_current_row, f"You already set {key} this iteration. Maybe you forgot to call dump_tabular()")
        self.log_current_row[key] = val
Ejemplo n.º 4
0
    def get_batch(self):
        assert_colorize(
            self.ready,
            f'PPOBuffer is not ready to be read. Call "self.finish" first')
        start = self.batch_idx * self.minibatch_size
        end = (self.batch_idx + 1) * self.minibatch_size
        self.batch_idx = (self.batch_idx + 1) % self.n_minibatches

        keys = [
            'state', 'action', 'traj_ret', 'value', 'advantage', 'old_logpi',
            'mask'
        ]

        return {
            k: self[k][self.indices[start:end], :self.idx].reshape(
                (self.minibatch_size * self.idx, *self[k].shape[2:]))
            for k in keys
        }
Ejemplo n.º 5
0
def moments(x, mask=None):
    if mask is None:
        x_mean = np.mean(x)
        x_std = np.std(x)
    else:
        # expand mask to match the dimensionality of x
        while len(mask.shape) < len(x.shape):
            mask = mask[..., None]
        # compute valid entries in x corresponding to True in mask
        n = np.sum(mask)
        for i in range(len(mask.shape)):
            if mask.shape[i] != 1:
                assert_colorize(
                    mask.shape[i] == x.shape[i],
                    f'{i}th dimension of mask{mask.shape[i]} does not match'
                    f'that of x{x.shape[i]}')
            else:
                n *= x.shape[i]
        # compute x_mean and x_std from entries in x corresponding to True in mask
        x_mask = x * mask
        x_mean = np.sum(x_mask) / n
        x_std = np.sqrt(np.sum(mask * (x_mask - x_mean)**2) / n)

    return x_mean, x_std
Ejemplo n.º 6
0
 def get_count(self, name):
     assert_colorize(name in self.store_dict)
     return len(self.store_dict[name])