def add_experience(self, state, action, reward, next_state, done, priority=1): '''Implementation for update() to add experience to memory, expanding the memory size if necessary''' # Move head pointer. Wrap around if necessary self.head = (self.head + 1) % self.max_size self.states[self.head] = state # make action into one_hot if _.is_iterable(action): # non-singular action # self.actions[self.head] = one hot of multi-action (matrix) on a 3rd axis, to be implement raise NotImplementedError else: self.actions[self.head][action] = 1 self.rewards[self.head] = reward self.next_states[self.head] = next_state self.dones[self.head] = done self.priorities[self.head] = priority # Actually occupied size of memory if self.true_size < self.max_size: self.true_size += 1 self.total_experiences += 1
def calc_log_probs(algorithm, net, body, batch): ''' Method to calculate log_probs fresh from batch data Body already stores log_prob from self.net. This is used for PPO where log_probs needs to be recalculated. ''' states, actions = batch['states'], batch['actions'] action_dim = body.action_dim is_multi_action = ps.is_iterable(action_dim) # construct log_probs for each state-action pdparams = algorithm.calc_pdparam(states, net=net) pdparams = guard_multi_pdparams(pdparams, body) assert len(pdparams) == len( states ), f'batch_size of pdparams: {len(pdparams)} vs states: {len(states)}' pdtypes = ACTION_PDS[body.action_type] ActionPD = getattr(distributions, body.action_pdtype) log_probs = [] for idx, pdparam in enumerate(pdparams): if not is_multi_action: # already cloned for multi_action above pdparam = pdparam.clone() # clone for grad safety _action, action_pd = sample_action_pd(ActionPD, pdparam, body) log_probs.append(action_pd.log_prob(actions[idx].float()).sum(dim=0)) log_probs = torch.stack(log_probs) assert not torch.isnan(log_probs).any( ), f'log_probs: {log_probs}, \npdparams: {pdparams} \nactions: {actions}' logger.debug(f'log_probs: {log_probs}') return log_probs
def guard_multi_pdparams(pdparams, body): '''Guard pdparams for multi action''' action_dim = body.action_dim is_multi_action = ps.is_iterable(action_dim) if is_multi_action: assert ps.is_list(pdparams) pdparams = [t.clone() for t in pdparams] # clone for grad safety assert len(pdparams) == len(action_dim), pdparams # transpose into (batch_size, [action_dims]) pdparams = [list(torch.split(t, action_dim, dim=0)) for t in torch.cat(pdparams, dim=1)] return pdparams
def rounder(func, x, precision): precision = pow(10, precision) def rounder_func(item): return func(item * precision) / precision result = None if pyd.is_number(x): result = rounder_func(x) elif pyd.is_iterable(x): try: result = [rounder_func(item) for item in x] except TypeError: pass return result
def test_is_iterable(case, expected): assert _.is_iterable(case) == expected