def test_smoke_no_transform(self, device): x_data = torch.rand(1, 2, 3, 4).to(device) batch_prob = torch.rand(1) < 0.5 start_points = torch.rand(1, 4, 2).to(device) end_points = torch.rand(1, 4, 2).to(device) params = dict(batch_prob=batch_prob, start_points=start_points, end_points=end_points) out_data = F.apply_perspective(x_data, params, return_transform=False) assert out_data.shape == x_data.shape
def test_smoke(self, device): x_data = torch.rand(1, 2, 3, 4).to(device) batch_prob = torch.rand(1) < 0.5 start_points = torch.rand(1, 4, 2).to(device) end_points = torch.rand(1, 4, 2).to(device) params = dict(batch_prob=batch_prob, start_points=start_points, end_points=end_points, interpolation=torch.tensor(1)) out_data = F.apply_perspective(x_data, params) assert out_data.shape == x_data.shape
def test_smoke_transform(self, device): x_data = torch.rand(1, 2, 3, 4).to(device) batch_prob = torch.rand(1) < 0.5 start_points = torch.rand(1, 4, 2).to(device) end_points = torch.rand(1, 4, 2).to(device) params = dict(batch_prob=batch_prob, start_points=start_points, end_points=end_points) out_data = F.apply_perspective(x_data, params, return_transform=True) assert isinstance(out_data, tuple) assert len(out_data) == 2 assert out_data[0].shape == x_data.shape assert out_data[1].shape == (1, 3, 3)