Exemple #1
0
 def add_experience(self,
                    state,
                    action,
                    reward,
                    next_state,
                    done,
                    priority=1):
     '''Implementation for update() to add experience to memory, expanding the memory size if necessary'''
     # Move head pointer. Wrap around if necessary
     self.head = (self.head + 1) % self.max_size
     self.states[self.head] = state
     # make action into one_hot
     if _.is_iterable(action):
         # non-singular action
         # self.actions[self.head] = one hot of multi-action (matrix) on a 3rd axis, to be implement
         raise NotImplementedError
     else:
         self.actions[self.head][action] = 1
     self.rewards[self.head] = reward
     self.next_states[self.head] = next_state
     self.dones[self.head] = done
     self.priorities[self.head] = priority
     # Actually occupied size of memory
     if self.true_size < self.max_size:
         self.true_size += 1
     self.total_experiences += 1
Exemple #2
0
def calc_log_probs(algorithm, net, body, batch):
    '''
    Method to calculate log_probs fresh from batch data
    Body already stores log_prob from self.net. This is used for PPO where log_probs needs to be recalculated.
    '''
    states, actions = batch['states'], batch['actions']
    action_dim = body.action_dim
    is_multi_action = ps.is_iterable(action_dim)
    # construct log_probs for each state-action
    pdparams = algorithm.calc_pdparam(states, net=net)
    pdparams = guard_multi_pdparams(pdparams, body)
    assert len(pdparams) == len(
        states
    ), f'batch_size of pdparams: {len(pdparams)} vs states: {len(states)}'

    pdtypes = ACTION_PDS[body.action_type]
    ActionPD = getattr(distributions, body.action_pdtype)

    log_probs = []
    for idx, pdparam in enumerate(pdparams):
        if not is_multi_action:  # already cloned  for multi_action above
            pdparam = pdparam.clone()  # clone for grad safety
        _action, action_pd = sample_action_pd(ActionPD, pdparam, body)
        log_probs.append(action_pd.log_prob(actions[idx].float()).sum(dim=0))
    log_probs = torch.stack(log_probs)
    assert not torch.isnan(log_probs).any(
    ), f'log_probs: {log_probs}, \npdparams: {pdparams} \nactions: {actions}'
    logger.debug(f'log_probs: {log_probs}')
    return log_probs
Exemple #3
0
def guard_multi_pdparams(pdparams, body):
    '''Guard pdparams for multi action'''
    action_dim = body.action_dim
    is_multi_action = ps.is_iterable(action_dim)
    if is_multi_action:
        assert ps.is_list(pdparams)
        pdparams = [t.clone() for t in pdparams]  # clone for grad safety
        assert len(pdparams) == len(action_dim), pdparams
        # transpose into (batch_size, [action_dims])
        pdparams = [list(torch.split(t, action_dim, dim=0)) for t in torch.cat(pdparams, dim=1)]
    return pdparams
Exemple #4
0
def rounder(func, x, precision):
    precision = pow(10, precision)

    def rounder_func(item):
        return func(item * precision) / precision

    result = None

    if pyd.is_number(x):
        result = rounder_func(x)
    elif pyd.is_iterable(x):
        try:
            result = [rounder_func(item) for item in x]
        except TypeError:
            pass

    return result
Exemple #5
0
def test_is_iterable(case, expected):
    assert _.is_iterable(case) == expected