Exemplo n.º 1
0
    def test_forward(self, device, dtype, keepdim):
        torch.manual_seed(42)
        input = torch.rand((12, 3, 4, 5), device=device, dtype=dtype)
        expected_output = input[..., :2, :2] if keepdim else input.unsqueeze(
            dim=0)[..., :2, :2]
        augmentation = _BasicAugmentationBase(p=0.3,
                                              p_batch=1.0,
                                              keepdim=keepdim)
        with patch.object(
                augmentation, "apply_transform",
                autospec=True) as apply_transform, patch.object(
                    augmentation, "generate_parameters",
                    autospec=True) as generate_parameters, patch.object(
                        augmentation, "transform_tensor",
                        autospec=True) as transform_tensor, patch.object(
                            augmentation, "__check_batching__",
                            autospec=True) as check_batching:

            generate_parameters.side_effect = lambda shape: {
                'degrees': torch.arange(
                    0, shape[0], device=device, dtype=dtype)
            }
            transform_tensor.side_effect = lambda x: x.unsqueeze(dim=0)
            apply_transform.side_effect = lambda input, params: input[..., :
                                                                      2, :2]
            check_batching.side_effect = lambda input: None
            output = augmentation(input)
            assert output.shape == expected_output.shape
            assert_close(output, expected_output)
Exemplo n.º 2
0
 def test_infer_input(self, device, dtype):
     input = torch.rand((2, 3, 4, 5), device=device, dtype=dtype)
     augmentation = _BasicAugmentationBase(p=1., p_batch=1)
     with patch.object(augmentation, "transform_tensor", autospec=True) as transform_tensor:
         transform_tensor.side_effect = lambda x: x.unsqueeze(dim=2)
         output = augmentation.__infer_input__(input)
         assert output.shape == torch.Size([2, 3, 1, 4, 5])
         assert_allclose(input, output[:, :, 0, :, :])
Exemplo n.º 3
0
 def test_forward_params(self, p, p_batch, same_on_batch, num, seed, device, dtype):
     input_shape = (12,)
     torch.manual_seed(seed)
     augmentation = _BasicAugmentationBase(p, p_batch, same_on_batch)
     with patch.object(augmentation, "generate_parameters", autospec=True) as generate_parameters:
         generate_parameters.side_effect = lambda shape: {
             'degrees': torch.arange(0, shape[0], device=device, dtype=dtype)
         }
         output = augmentation.__forward_parameters__(input_shape, p, p_batch, same_on_batch)
         assert "batch_prob" in output
         assert len(output['degrees']) == output['batch_prob'].sum().item() == num
Exemplo n.º 4
0
    def test_forward(self, device, dtype):
        torch.manual_seed(42)
        input = torch.rand((12, 3, 4, 5), device=device, dtype=dtype)
        expected_output = input.unsqueeze(dim=1)[..., :2, :2]
        augmentation = _BasicAugmentationBase(p=.3, p_batch=1.)
        with patch.object(augmentation, "apply_transform", autospec=True) as apply_transform, \
                patch.object(augmentation, "generate_parameters", autospec=True) as generate_parameters, \
                patch.object(augmentation, "transform_tensor", autospec=True) as transform_tensor:

            generate_parameters.side_effect = lambda shape: {
                'degrees': torch.arange(0, shape[0], device=device, dtype=dtype)
            }
            transform_tensor.side_effect = lambda x: x.unsqueeze(dim=1)
            apply_transform.side_effect = lambda input, params: input[..., :2, :2]
            output = augmentation(input)
            assert output.shape == expected_output.shape
            assert_allclose(output, expected_output)
Exemplo n.º 5
0
 def test_smoke(self):
     base = _BasicAugmentationBase(p=0.5, p_batch=1.0, same_on_batch=True)
     __repr__ = "_BasicAugmentationBase(p=0.5, p_batch=1.0, same_on_batch=True)"
     assert str(base) == __repr__
Exemplo n.º 6
0
 def test_smoke(self, device, dtype):
     base = _BasicAugmentationBase(p=0.5, p_batch=1., same_on_batch=True)
     __repr__ = "p=0.5, p_batch=1.0, same_on_batch=True"
     assert str(base) == __repr__