예제 #1
0
    def test_add_padding_last_3d(self):
        """Test pad_to_last function for 3d."""
        max_length = 10

        tensor_padding = torch_algo_utils.pad_to_last(nums_3d, total_length=10)
        expected = F.pad(torch.Tensor(nums_3d),
                         (0, max_length - nums_3d.shape[-1], 0, 0, 0, 0))
        assert expected.eq(tensor_padding).all()

        tensor_padding = torch_algo_utils.pad_to_last(nums_3d,
                                                      total_length=10,
                                                      axis=0)
        expected = F.pad(torch.Tensor(nums_3d),
                         (0, 0, 0, 0, 0, max_length - nums_3d.shape[0]))
        assert expected.eq(tensor_padding).all()

        tensor_padding = torch_algo_utils.pad_to_last(nums_3d,
                                                      total_length=10,
                                                      axis=1)
        expected = F.pad(torch.Tensor(nums_3d),
                         (0, 0, 0, max_length - nums_3d.shape[-1], 0, 0))
        assert expected.eq(tensor_padding).all()

        tensor_padding = torch_algo_utils.pad_to_last(nums_3d,
                                                      total_length=10,
                                                      axis=2)
        expected = F.pad(torch.Tensor(nums_3d),
                         (0, max_length - nums_3d.shape[-1], 0, 0, 0, 0))
        assert expected.eq(tensor_padding).all()
예제 #2
0
    def test_add_padding_last_1d(self):
        """Test pad_to_last function for 1d."""
        max_length = 10

        expected = F.pad(torch.Tensor(nums_1d),
                         (0, max_length - nums_1d.shape[-1]))

        tensor_padding = torch_algo_utils.pad_to_last(nums_1d,
                                                      total_length=max_length)
        assert expected.eq(tensor_padding).all()

        tensor_padding = torch_algo_utils.pad_to_last(nums_1d,
                                                      total_length=10,
                                                      axis=0)
        assert expected.eq(tensor_padding).all()
예제 #3
0
 def test_out_of_index_error(self, nums):
     """Test pad_to_last raises IndexError."""
     with pytest.raises(IndexError):
         torch_algo_utils.pad_to_last(nums,
                                      total_length=10,
                                      axis=len(nums.shape))