Exemple #1
0
    def trj_sample(self, trj_num=None):
        if trj_num is None:
            indices = np.random.permutation(self.__len__())
        else:
            assert trj_num <= self.__len__()
            indices = np.random.permutation(self.__len__())[:trj_num]
        batch_list = [self.__getitem__(i) for i in indices]

        batch = Batch()
        batch.cat_list(batch_list)
        #print(len(batch),batch.keys())
        #print(aaa)
        return batch
Exemple #2
0
    def trj_by_policy(self,batch):
        #return batch
        assert len(batch) % self.trj_clip_steps == 0
        infer_batch = None
        old_step_batch = None
        o_len = batch.act.o.shape[-1]
        policy_batch_list = []
        for step in range(self.trj_clip_steps):
            indices = np.arange(step, len(batch), self.trj_clip_steps,)
            step_batch = batch[indices]
            if infer_batch is not None:
                #修正obs
                #print(step)
                #print(step_batch.obs.s[0])
                s = torch.cat([old_step_batch.obs.s, 
                                              infer_batch.act_m[:,o_len:].cpu().reshape(len(old_step_batch),1,-1)],1)
                s = s[:,1:,:]
                step_batch.obs.s = s
                #print(step_batch.obs.s[0])
                
                del infer_batch
                
            infer_batch = self(step_batch)

            #修正act
            #print(step_batch.act)
            step_batch.act.a = infer_batch.act_d.cpu()
            step_batch.act.o = infer_batch.act_m[:,:o_len].contiguous().cpu()
            step_batch.act.s_next = infer_batch.act_m[:,o_len:].contiguous().cpu()
            #print(step_batch.act)

            old_step_batch = step_batch
            policy_batch_list.append(step_batch)

        policy_batch = Batch()
        policy_batch.cat_list(policy_batch_list)


        re_indices = []
        for i in range(len(policy_batch_list[0])):
            for j in range(self.trj_clip_steps):
                re_indices.append(i + (j*len(policy_batch_list[0])))
        re_indices = np.array(re_indices).astype(np.int32)      
        policy_batch = policy_batch[re_indices]

        return policy_batch