コード例 #1
0
def test_pad(device, dt, pad, config):
    script_fn = torch.jit.script(F.pad)
    tensor, pil_img = _create_data(7, 8, device=device)
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        # This is a trivial cast to float of uint8 data to test all cases
        tensor = tensor.to(dt)
        batch_tensors = batch_tensors.to(dt)

    pad_tensor = F_t.pad(tensor, pad, **config)
    pad_pil_img = F_pil.pad(pil_img, pad, **config)

    pad_tensor_8b = pad_tensor
    # we need to cast to uint8 to compare with PIL image
    if pad_tensor_8b.dtype != torch.uint8:
        pad_tensor_8b = pad_tensor_8b.to(torch.uint8)

    _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config))

    if isinstance(pad, int):
        script_pad = [pad, ]
    else:
        script_pad = pad
    pad_tensor_script = script_fn(tensor, script_pad, **config)
    assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, config))

    _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
コード例 #2
0
def test_ten_crop(device):
    script_ten_crop = torch.jit.script(F.ten_crop)

    img_tensor, pil_img = _create_data(32, 34, device=device)

    cropped_pil_images = F.ten_crop(pil_img, [10, 11])

    cropped_tensors = F.ten_crop(img_tensor, [10, 11])
    for i in range(10):
        _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])

    cropped_tensors = script_ten_crop(img_tensor, [10, 11])
    for i in range(10):
        _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])

    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
    tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
        tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
        assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)

        for j in range(len(tuple_transformed_imgs)):
            true_transformed_img = tuple_transformed_imgs[j]
            transformed_img = tuple_transformed_batches[j][i, ...]
            assert_equal(true_transformed_img, transformed_img)

    # scriptable function test
    s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
    for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
        assert_equal(transformed_batch, s_transformed_batch)
コード例 #3
0
def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
    meth_kwargs = meth_kwargs or {}

    # test for class interface
    f = transform_cls(**meth_kwargs)
    scripted_fn = torch.jit.script(f)

    tensor, pil_img = _create_data(26, 34, channels, device=device)
    # set seed to reproduce the same transformation for tensor and PIL image
    torch.manual_seed(12)
    transformed_tensor = f(tensor)
    torch.manual_seed(12)
    transformed_pil_img = f(pil_img)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)

    torch.manual_seed(12)
    transformed_tensor_script = scripted_fn(tensor)
    assert_equal(transformed_tensor, transformed_tensor_script)

    batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device)
    _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

    with get_tmp_dir() as tmp_dir:
        scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
コード例 #4
0
def test_x_crop(fn, method, out_length, size, device):
    meth_kwargs = fn_kwargs = {"size": size}
    scripted_fn = torch.jit.script(fn)

    tensor, pil_img = _create_data(height=20, width=20, device=device)
    transformed_t_list = fn(tensor, **fn_kwargs)
    transformed_p_list = fn(pil_img, **fn_kwargs)
    assert len(transformed_t_list) == len(transformed_p_list)
    assert len(transformed_t_list) == out_length
    for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)

    transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
    assert len(transformed_t_list) == len(transformed_t_list_script)
    assert len(transformed_t_list_script) == out_length
    for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
        assert_equal(transformed_tensor, transformed_tensor_script)

    # test for class interface
    fn = method(**meth_kwargs)
    scripted_fn = torch.jit.script(fn)
    output = scripted_fn(tensor)
    assert len(output) == len(transformed_t_list_script)

    # test on batch of tensors
    batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
    torch.manual_seed(12)
    transformed_batch_list = fn(batch_tensors)

    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
        torch.manual_seed(12)
        transformed_img_list = fn(img_tensor)
        for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
            assert_equal(transformed_img, transformed_batch[i, ...])
コード例 #5
0
def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs):
    fn_kwargs = fn_kwargs or {}

    tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device)
    transformed_tensor = f(tensor, **fn_kwargs)
    transformed_pil_img = f(pil_img, **fn_kwargs)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
コード例 #6
0
def test_hflip(device):
    script_hflip = torch.jit.script(F.hflip)

    img_tensor, pil_img = _create_data(16, 18, device=device)
    hflipped_img = F.hflip(img_tensor)
    hflipped_pil_img = F.hflip(pil_img)
    _assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img)

    # scriptable function test
    hflipped_img_script = script_hflip(img_tensor)
    assert_equal(hflipped_img, hflipped_img_script)

    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
    _test_fn_on_batch(batch_tensors, F.hflip)
コード例 #7
0
def test_crop(device, top, left, height, width):
    script_crop = torch.jit.script(F.crop)

    img_tensor, pil_img = _create_data(16, 18, device=device)

    pil_img_cropped = F.crop(pil_img, top, left, height, width)

    img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
    _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)

    img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
    _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)

    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
    _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
コード例 #8
0
def test_center_crop(device):
    script_center_crop = torch.jit.script(F.center_crop)

    img_tensor, pil_img = _create_data(32, 34, device=device)

    cropped_pil_image = F.center_crop(pil_img, [10, 11])

    cropped_tensor = F.center_crop(img_tensor, [10, 11])
    _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)

    cropped_tensor = script_center_crop(img_tensor, [10, 11])
    _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)

    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
    _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
コード例 #9
0
    def test_translations(self, device, height, width, dt, t, fn):
        # 3) Test translation
        tensor, pil_img = _create_data(height, width, device=device)

        if dt == torch.float16 and device == "cpu":
            # skip float16 on CPU case
            return

        if dt is not None:
            tensor = tensor.to(dtype=dt)

        out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)

        out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)

        if out_tensor.dtype != torch.uint8:
            out_tensor = out_tensor.to(torch.uint8)

        _assert_equal_tensor_to_pil(out_tensor, out_pil_img)
コード例 #10
0
    def _test_op_list_output(self,
                             func,
                             method,
                             out_length,
                             fn_kwargs=None,
                             meth_kwargs=None):
        if fn_kwargs is None:
            fn_kwargs = {}
        if meth_kwargs is None:
            meth_kwargs = {}

        fn = getattr(F, func)
        scripted_fn = torch.jit.script(fn)

        tensor, pil_img = _create_data(height=20, width=20, device=self.device)
        transformed_t_list = fn(tensor, **fn_kwargs)
        transformed_p_list = fn(pil_img, **fn_kwargs)
        self.assertEqual(len(transformed_t_list), len(transformed_p_list))
        self.assertEqual(len(transformed_t_list), out_length)
        for transformed_tensor, transformed_pil_img in zip(
                transformed_t_list, transformed_p_list):
            _assert_equal_tensor_to_pil(transformed_tensor,
                                        transformed_pil_img)

        transformed_t_list_script = scripted_fn(tensor.detach().clone(),
                                                **fn_kwargs)
        self.assertEqual(len(transformed_t_list),
                         len(transformed_t_list_script))
        self.assertEqual(len(transformed_t_list_script), out_length)
        for transformed_tensor, transformed_tensor_script in zip(
                transformed_t_list, transformed_t_list_script):
            assert_equal(
                transformed_tensor,
                transformed_tensor_script,
                msg="{} vs {}".format(transformed_tensor,
                                      transformed_tensor_script),
            )

        # test for class interface
        fn = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(fn)
        output = scripted_fn(tensor)
        self.assertEqual(len(output), len(transformed_t_list_script))

        # test on batch of tensors
        batch_tensors = _create_data_batch(height=23,
                                           width=34,
                                           channels=3,
                                           num_samples=4,
                                           device=self.device)
        torch.manual_seed(12)
        transformed_batch_list = fn(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img_list = fn(img_tensor)
            for transformed_img, transformed_batch in zip(
                    transformed_img_list, transformed_batch_list):
                assert_equal(
                    transformed_img,
                    transformed_batch[i, ...],
                    msg="{} vs {}".format(transformed_img,
                                          transformed_batch[i, ...]),
                )

        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(
                os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))