def trj_sample(self, trj_num=None): if trj_num is None: indices = np.random.permutation(self.__len__()) else: assert trj_num <= self.__len__() indices = np.random.permutation(self.__len__())[:trj_num] batch_list = [self.__getitem__(i) for i in indices] batch = Batch() batch.cat_list(batch_list) #print(len(batch),batch.keys()) #print(aaa) return batch
def trj_by_policy(self,batch): #return batch assert len(batch) % self.trj_clip_steps == 0 infer_batch = None old_step_batch = None o_len = batch.act.o.shape[-1] policy_batch_list = [] for step in range(self.trj_clip_steps): indices = np.arange(step, len(batch), self.trj_clip_steps,) step_batch = batch[indices] if infer_batch is not None: #修正obs #print(step) #print(step_batch.obs.s[0]) s = torch.cat([old_step_batch.obs.s, infer_batch.act_m[:,o_len:].cpu().reshape(len(old_step_batch),1,-1)],1) s = s[:,1:,:] step_batch.obs.s = s #print(step_batch.obs.s[0]) del infer_batch infer_batch = self(step_batch) #修正act #print(step_batch.act) step_batch.act.a = infer_batch.act_d.cpu() step_batch.act.o = infer_batch.act_m[:,:o_len].contiguous().cpu() step_batch.act.s_next = infer_batch.act_m[:,o_len:].contiguous().cpu() #print(step_batch.act) old_step_batch = step_batch policy_batch_list.append(step_batch) policy_batch = Batch() policy_batch.cat_list(policy_batch_list) re_indices = [] for i in range(len(policy_batch_list[0])): for j in range(self.trj_clip_steps): re_indices.append(i + (j*len(policy_batch_list[0]))) re_indices = np.array(re_indices).astype(np.int32) policy_batch = policy_batch[re_indices] return policy_batch