def test_add_padding_last_3d(self): """Test pad_to_last function for 3d.""" max_length = 10 tensor_padding = torch_algo_utils.pad_to_last(nums_3d, total_length=10) expected = F.pad(torch.Tensor(nums_3d), (0, max_length - nums_3d.shape[-1], 0, 0, 0, 0)) assert expected.eq(tensor_padding).all() tensor_padding = torch_algo_utils.pad_to_last(nums_3d, total_length=10, axis=0) expected = F.pad(torch.Tensor(nums_3d), (0, 0, 0, 0, 0, max_length - nums_3d.shape[0])) assert expected.eq(tensor_padding).all() tensor_padding = torch_algo_utils.pad_to_last(nums_3d, total_length=10, axis=1) expected = F.pad(torch.Tensor(nums_3d), (0, 0, 0, max_length - nums_3d.shape[-1], 0, 0)) assert expected.eq(tensor_padding).all() tensor_padding = torch_algo_utils.pad_to_last(nums_3d, total_length=10, axis=2) expected = F.pad(torch.Tensor(nums_3d), (0, max_length - nums_3d.shape[-1], 0, 0, 0, 0)) assert expected.eq(tensor_padding).all()
def test_add_padding_last_1d(self): """Test pad_to_last function for 1d.""" max_length = 10 expected = F.pad(torch.Tensor(nums_1d), (0, max_length - nums_1d.shape[-1])) tensor_padding = torch_algo_utils.pad_to_last(nums_1d, total_length=max_length) assert expected.eq(tensor_padding).all() tensor_padding = torch_algo_utils.pad_to_last(nums_1d, total_length=10, axis=0) assert expected.eq(tensor_padding).all()
def test_out_of_index_error(self, nums): """Test pad_to_last raises IndexError.""" with pytest.raises(IndexError): torch_algo_utils.pad_to_last(nums, total_length=10, axis=len(nums.shape))