def test_add_padding_last_1d(self): """Test pad_to_last function for 1d.""" max_length = 10 expected = F.pad(torch.Tensor(nums_1d), (0, max_length - nums_1d.shape[-1])) tensor_padding = pad_to_last(nums_1d, total_length=max_length) assert expected.eq(tensor_padding).all() tensor_padding = pad_to_last(nums_1d, total_length=10, axis=0) assert expected.eq(tensor_padding).all()
def test_add_padding_last_2d(self): """Test pad_to_last function for 2d.""" max_length = 10 tensor_padding = pad_to_last(nums_2d, total_length=10) expected = F.pad(torch.Tensor(nums_2d), (0, max_length - nums_2d.shape[-1])) assert expected.eq(tensor_padding).all() tensor_padding = pad_to_last(nums_2d, total_length=10, axis=0) expected = F.pad(torch.Tensor(nums_2d), (0, 0, 0, max_length - nums_2d.shape[0])) assert expected.eq(tensor_padding).all() tensor_padding = pad_to_last(nums_2d, total_length=10, axis=1) expected = F.pad(torch.Tensor(nums_2d), (0, max_length - nums_2d.shape[-1], 0, 0)) assert expected.eq(tensor_padding).all()
def process_samples(self, paths): r"""Process sample data based on the collected paths. Notes: P is the maximum episode length (self.max_episode_length) Args: paths (list[dict]): A list of collected paths Returns: torch.Tensor: The observations of the environment with shape :math:`(N, P, O*)`. torch.Tensor: The actions fed to the environment with shape :math:`(N, P, A*)`. torch.Tensor: The acquired rewards with shape :math:`(N, P)`. list[int]: Numbers of valid steps in each paths. torch.Tensor: Value function estimation at each step with shape :math:`(N, P)`. """ valids = torch.Tensor([len(path['actions']) for path in paths]).int() obs = torch.stack([ pad_to_last(path['observations'], total_length=self.max_episode_length, axis=0) for path in paths ]) actions = torch.stack([ pad_to_last(path['actions'], total_length=self.max_episode_length, axis=0) for path in paths ]) rewards = torch.stack([ pad_to_last(path['rewards'], total_length=self.max_episode_length) for path in paths ]) returns = torch.stack([ pad_to_last(tu.discount_cumsum(path['rewards'], self.discount).copy(), total_length=self.max_episode_length) for path in paths ]) with torch.no_grad(): baselines = self._value_function(obs) return obs, actions, rewards, returns, valids, baselines
def test_out_of_index_error(self, nums): """Test pad_to_last raises IndexError.""" with pytest.raises(IndexError): pad_to_last(nums, total_length=10, axis=len(nums.shape))