class Tester(unittest.TestCase):
    root = get_file_path_2('test/assets/dataset/')
    classes = ['a', 'b']
    class_a_images = [get_file_path_2(os.path.join('test/assets/dataset/a/', path))
                      for path in ['a1.png', 'a2.png', 'a3.png']]
    class_b_images = [get_file_path_2(os.path.join('test/assets/dataset/b/', path))
                      for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]

    def test_image_folder(self):
        dataset = ImageFolder(Tester.root, loader=lambda x: x)
        self.assertEqual(sorted(Tester.classes), sorted(dataset.classes))
        for cls in Tester.classes:
            self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
        class_a_idx = dataset.class_to_idx['a']
        class_b_idx = dataset.class_to_idx['b']
        imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images]
        imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images]
        imgs = sorted(imgs_a + imgs_b)
        self.assertEqual(imgs, dataset.imgs)

        outputs = sorted([dataset[i] for i in range(len(dataset))])
        self.assertEqual(imgs, outputs)

    def test_transform(self):
        return_value = get_file_path_2('test/assets/dataset/a/a1.png')

        args = []
        transform = mock_transform(return_value, args)

        dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
        outputs = [dataset[i][0] for i in range(len(dataset))]
        self.assertEqual([return_value] * len(outputs), outputs)

        imgs = sorted(Tester.class_a_images + Tester.class_b_images)
        self.assertEqual(imgs, sorted(args))

    def test_target_transform(self):
        return_value = 1

        args = []
        target_transform = mock_transform(return_value, args)

        dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform)
        outputs = [dataset[i][1] for i in range(len(dataset))]
        self.assertEqual([return_value] * len(outputs), outputs)

        class_a_idx = dataset.class_to_idx['a']
        class_b_idx = dataset.class_to_idx['b']
        targets = sorted([class_a_idx] * len(Tester.class_a_images) +
                         [class_b_idx] * len(Tester.class_b_images))
        self.assertEqual(targets, sorted(args))
Exemplo n.º 2
0
    def test_imagefolder(self):
        # TODO: create the fake data on-the-fly
        FAKEDATA_DIR = get_file_path_2(
            os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

        with get_tmp_dir(
                src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
            classes = sorted(['a', 'b'])
            class_a_image_files = [
                os.path.join(root, 'a', file)
                for file in ('a1.png', 'a2.png', 'a3.png')
            ]
            class_b_image_files = [
                os.path.join(root, 'b', file)
                for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
            ]
            dataset = torchvision.datasets.ImageFolder(root,
                                                       loader=lambda x: x)

            # test if all classes are present
            self.assertEqual(classes, sorted(dataset.classes))

            # test if combination of classes and class_to_index functions correctly
            for cls in classes:
                self.assertEqual(cls,
                                 dataset.classes[dataset.class_to_idx[cls]])

            # test if all images were detected correctly
            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx)
                      for img_file in class_a_image_files]
            imgs_b = [(img_file, class_b_idx)
                      for img_file in class_b_image_files]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            # test if the datasets outputs all images correctly
            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

            # redo all tests with specified valid image files
            dataset = torchvision.datasets.ImageFolder(
                root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
            self.assertEqual(classes, sorted(dataset.classes))

            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx)
                      for img_file in class_a_image_files if '3' in img_file]
            imgs_b = [(img_file, class_b_idx)
                      for img_file in class_b_image_files if '3' in img_file]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)
Exemplo n.º 3
0
    def _test_serialization_container(self, unique_key, filecontext_lambda):

        tmpmodule_name = 'tmpmodule{}'.format(unique_key)

        def import_module(name, filename):
            if sys.version_info >= (3, 5):
                import importlib.util
                spec = importlib.util.spec_from_file_location(name, filename)
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
            else:
                import imp
                module = imp.load_source(name, filename)
            sys.modules[module.__name__] = module
            return module

        with filecontext_lambda() as checkpoint:
            fname = get_file_path_2(
                os.path.dirname(os.path.dirname(torch.__file__)), 'torch',
                'testing', '_internal', 'data', 'network1.py')
            module = import_module(tmpmodule_name, fname)
            torch.save(module.Net(), checkpoint)

            # First check that the checkpoint can be loaded without warnings
            checkpoint.seek(0)
            with warnings.catch_warnings(record=True) as w:
                loaded = torch.load(checkpoint)
                self.assertTrue(isinstance(loaded, module.Net))
                if can_retrieve_source:
                    self.assertEquals(len(w), 0)

            # Replace the module with different source
            fname = get_file_path_2(
                os.path.dirname(os.path.dirname(torch.__file__)), 'torch',
                'testing', '_internal', 'data', 'network2.py')
            module = import_module(tmpmodule_name, fname)
            checkpoint.seek(0)
            with warnings.catch_warnings(record=True) as w:
                loaded = torch.load(checkpoint)
                self.assertTrue(isinstance(loaded, module.Net))
                if can_retrieve_source:
                    self.assertEquals(len(w), 1)
                    self.assertTrue(w[0].category, 'SourceChangeWarning')
Exemplo n.º 4
0
    def test_transform(self):
        return_value = get_file_path_2('test/assets/dataset/a/a1.png')

        args = []
        transform = mock_transform(return_value, args)

        dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
        outputs = [dataset[i][0] for i in range(len(dataset))]
        self.assertEqual([return_value] * len(outputs), outputs)

        imgs = sorted(Tester.class_a_images + Tester.class_b_images)
        self.assertEqual(imgs, sorted(args))
Exemplo n.º 5
0
    def test_transform(self):
        return_value = get_file_path_2('test/assets/dataset/a/a1.png')

        args = []
        transform = mock_transform(return_value, args)

        dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform)
        outputs = [dataset[i][0] for i in range(len(dataset))]
        self.assertEqual([return_value] * len(outputs), outputs)

        imgs = sorted(Tester.class_a_images + Tester.class_b_images)
        self.assertEqual(imgs, sorted(args))
Exemplo n.º 6
0
import torchvision.datasets.utils as utils
import unittest
import unittest.mock
import zipfile
import tarfile
import gzip
import warnings
from torch._utils_internal import get_file_path_2
from urllib.error import URLError
import itertools
import lzma

from common_utils import get_tmp_dir, call_args_to_kwargs_only

TEST_FILE = get_file_path_2(os.path.dirname(os.path.abspath(__file__)),
                            'assets', 'encode_jpeg',
                            'grace_hopper_517x606.jpg')


class Tester(unittest.TestCase):
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
        self.assertTrue(utils.check_md5(fpath, correct_md5))
        self.assertFalse(utils.check_md5(fpath, false_md5))

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
Exemplo n.º 7
0
import contextlib
import gzip
import os
import tarfile
import zipfile

import pytest
import torchvision.datasets.utils as utils
from torch._utils_internal import get_file_path_2
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS

TEST_FILE = get_file_path_2(os.path.dirname(os.path.abspath(__file__)),
                            "assets", "encode_jpeg",
                            "grace_hopper_517x606.jpg")


def patch_url_redirection(mocker, redirect_url):
    class Response:
        def __init__(self, url):
            self.url = url

    @contextlib.contextmanager
    def patched_opener(*args, **kwargs):
        yield Response(redirect_url)

    return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen",
                        side_effect=patched_opener)


class TestDatasetsUtils:
    def test_get_redirect_url(self, mocker):
Exemplo n.º 8
0
import unittest
import math
import random
import numpy as np
from PIL import Image
try:
    import accimage
except ImportError:
    accimage = None

try:
    from scipy import stats
except ImportError:
    stats = None

GRACE_HOPPER = get_file_path_2('assets/grace_hopper_517x606.jpg')


class Tester(unittest.TestCase):

    def test_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
        owidth = random.randint(5, (width - 2) / 2) * 2

        img = torch.ones(3, height, width)
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
        imgnarrow.fill_(0)