def test_call(self): batch_collator = BatchCollator("vqa2", "train") sample_list = test_utils.build_random_sample_list() sample_list = batch_collator(sample_list) # Test already build sample list self.assertEqual(sample_list.dataset_name, "vqa2") self.assertEqual(sample_list.dataset_type, "train") sample = Sample() sample.a = torch.tensor([1, 2], dtype=torch.int) # Test list of samples sample_list = batch_collator([sample, sample]) self.assertTrue( test_utils.compare_tensors( sample_list.a, torch.tensor([[1, 2], [1, 2]], dtype=torch.int))) # Test IterableDataset case sample_list = test_utils.build_random_sample_list() new_sample_list = batch_collator([sample_list]) self.assertEqual(new_sample_list, sample_list)
def test_pin_memory(self): sample_list = test_utils.build_random_sample_list() sample_list.pin_memory() pin_list = [sample_list.y, sample_list.z.y] non_pin_list = [sample_list.x, sample_list.z.x] all_pinned = True for pin in pin_list: all_pinned = all_pinned and pin.is_pinned() self.assertTrue(all_pinned) any_pinned = False for pin in non_pin_list: any_pinned = any_pinned or (hasattr(pin, "is_pinned") and pin.is_pinned()) self.assertFalse(any_pinned)
def test_to_device(self): sample_list = test_utils.build_random_sample_list() modified = to_device(sample_list, "cpu") self.assertEqual(modified.get_device(), torch.device("cpu")) modified = to_device(sample_list, torch.device("cpu")) self.assertEqual(modified.get_device(), torch.device("cpu")) modified = to_device(sample_list, "cuda") if torch.cuda.is_available(): self.assertEqual(modified.get_device(), torch.device("cuda:0")) else: self.assertEqual(modified.get_device(), torch.device("cpu")) double_modified = to_device(modified, modified.get_device()) self.assertTrue(double_modified is modified) custom_batch = [{"a": 1}] self.assertEqual(to_device(custom_batch), custom_batch)
def test_to_dict(self): sample_list = test_utils.build_random_sample_list() sample_dict = sample_list.to_dict() self.assertTrue(isinstance(sample_dict, dict)) # hasattr won't work anymore self.assertFalse(hasattr(sample_dict, "x")) keys_to_assert = ["x", "y", "z", "z.x", "z.y"] all_keys = True for key in keys_to_assert: current = sample_dict if "." in key: sub_keys = key.split(".") for sub_key in sub_keys: all_keys = all_keys and sub_key in current current = current[sub_key] else: all_keys = all_keys and key in current self.assertTrue(all_keys) self.assertTrue(isinstance(sample_dict, dict))