Пример #1
0
    def test_call(self):
        batch_collator = BatchCollator("vqa2", "train")
        sample_list = test_utils.build_random_sample_list()
        sample_list = batch_collator(sample_list)

        # Test already build sample list
        self.assertEqual(sample_list.dataset_name, "vqa2")
        self.assertEqual(sample_list.dataset_type, "train")

        sample = Sample()
        sample.a = torch.tensor([1, 2], dtype=torch.int)

        # Test list of samples
        sample_list = batch_collator([sample, sample])
        self.assertTrue(
            test_utils.compare_tensors(
                sample_list.a, torch.tensor([[1, 2], [1, 2]],
                                            dtype=torch.int)))

        # Test IterableDataset case
        sample_list = test_utils.build_random_sample_list()
        new_sample_list = batch_collator([sample_list])
        self.assertEqual(new_sample_list, sample_list)
Пример #2
0
    def test_pin_memory(self):
        sample_list = test_utils.build_random_sample_list()
        sample_list.pin_memory()

        pin_list = [sample_list.y, sample_list.z.y]
        non_pin_list = [sample_list.x, sample_list.z.x]

        all_pinned = True

        for pin in pin_list:
            all_pinned = all_pinned and pin.is_pinned()

        self.assertTrue(all_pinned)

        any_pinned = False

        for pin in non_pin_list:
            any_pinned = any_pinned or (hasattr(pin, "is_pinned") and pin.is_pinned())

        self.assertFalse(any_pinned)
    def test_to_device(self):
        sample_list = test_utils.build_random_sample_list()

        modified = to_device(sample_list, "cpu")
        self.assertEqual(modified.get_device(), torch.device("cpu"))

        modified = to_device(sample_list, torch.device("cpu"))
        self.assertEqual(modified.get_device(), torch.device("cpu"))

        modified = to_device(sample_list, "cuda")

        if torch.cuda.is_available():
            self.assertEqual(modified.get_device(), torch.device("cuda:0"))
        else:
            self.assertEqual(modified.get_device(), torch.device("cpu"))

        double_modified = to_device(modified, modified.get_device())
        self.assertTrue(double_modified is modified)

        custom_batch = [{"a": 1}]
        self.assertEqual(to_device(custom_batch), custom_batch)
    def test_to_dict(self):
        sample_list = test_utils.build_random_sample_list()
        sample_dict = sample_list.to_dict()
        self.assertTrue(isinstance(sample_dict, dict))
        # hasattr won't work anymore
        self.assertFalse(hasattr(sample_dict, "x"))
        keys_to_assert = ["x", "y", "z", "z.x", "z.y"]

        all_keys = True
        for key in keys_to_assert:
            current = sample_dict
            if "." in key:
                sub_keys = key.split(".")
                for sub_key in sub_keys:
                    all_keys = all_keys and sub_key in current
                    current = current[sub_key]
            else:
                all_keys = all_keys and key in current

        self.assertTrue(all_keys)
        self.assertTrue(isinstance(sample_dict, dict))