Exemple #1
0
def test_batch_from_to_numpy_without_copy():
    batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
    a_mem_addr_orig = batch.a.__array_interface__['data'][0]
    c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
    batch.to_torch()
    batch.to_numpy()
    a_mem_addr_new = batch.a.__array_interface__['data'][0]
    c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
    assert a_mem_addr_new == a_mem_addr_orig
    assert c_mem_addr_new == c_mem_addr_orig
Exemple #2
0
def test_batch_from_to_numpy_without_copy():
    batch = Batch(a=np.ones((1, )), b=Batch(c=np.ones((1, ))))
    a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
    c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0]
    batch.to_torch()
    assert isinstance(batch["a"], torch.Tensor)
    assert isinstance(batch["b"]["c"], torch.Tensor)
    batch.to_numpy()
    a_mem_addr_new = batch["a"].__array_interface__['data'][0]
    c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0]
    assert a_mem_addr_new == a_mem_addr_orig
    assert c_mem_addr_new == c_mem_addr_orig
Exemple #3
0
    def process_fn(self, batch: Batch, buffer=None,
                   indice=None) -> Batch:
        v_ = None
        rew,v_ = self.get_reward_by_discriminator(batch)

        batch.rew = rew 

        if self._rew_norm:
            mean, std = batch.rew.mean(), batch.rew.std()
            if not np.isclose(std.cpu().numpy(), 0):
                batch.rew = (batch.rew - mean) / std
        if self._lambda in [0, 1]:
            return self.compute_episodic_return(
                batch, None, gamma=self._gamma, gae_lambda=self._lambda)
        batch.to_numpy()
        batch = self.compute_episodic_return(
            batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
        batch.to_torch()
        return batch