def test_extract_batch_size(): batch_size = 3 batched_image = torch.zeros(batch_size, 1, 1, 1) actual = image_.extract_batch_size(batched_image) desired = batch_size assert actual == desired single_image = torch.zeros(1, 1, 1) with pytest.raises(TypeError): image_.extract_batch_size(single_image)
def batch_up_image( image: torch.Tensor, desired_batch_size: Optional[int] = None, loader: Optional[DataLoader] = None, ) -> torch.Tensor: def extract_batch_size_from_loader(loader: DataLoader) -> int: batch_size = loader.batch_size if batch_size is not None: return batch_size try: batch_size = loader.batch_sampler.batch_size # type: ignore[union-attr] assert isinstance(batch_size, int) return batch_size except (AttributeError, AssertionError): raise RuntimeError if desired_batch_size is None and loader is None: raise RuntimeError if desired_batch_size is None: desired_batch_size = extract_batch_size_from_loader( cast(DataLoader, loader)) if is_single_image(image): image = make_batched_image(image) elif extract_batch_size(image) > 1: raise RuntimeError return image.repeat(desired_batch_size, 1, 1, 1)
def test_batch_up_image_loader(image): batch_size = 3 dataset = () loader = DataLoader(dataset, batch_size=batch_size) batched_up_image = utils.batch_up_image(image, loader=loader) assert extract_batch_size(batched_up_image) == batch_size
def calculate_score(self, input_repr, target_repr, ctx): score = super().calculate_score(input_repr, target_repr, ctx) if not self.double_batch_size_mean: return score else: batch_size = extract_batch_size(input_repr) return score / batch_size
def test_batch_up_image_loader_with_batch_sampler(image): batch_size = 3 dataset = () batch_sampler = BatchSampler(SequentialSampler(dataset), batch_size, drop_last=False) loader = DataLoader(dataset, batch_sampler=batch_sampler) batched_up_image = utils.batch_up_image(image, loader=loader) assert extract_batch_size(batched_up_image) == batch_size
def batch_up_image( image: torch.Tensor, desired_batch_size: Optional[int] = None, loader: Optional[DataLoader] = None, ) -> torch.Tensor: if desired_batch_size is None and loader is None: raise RuntimeError if is_single_image(image): image = make_batched_image(image) elif extract_batch_size(image) > 1: raise RuntimeError if desired_batch_size is None: desired_batch_size = loader.batch_size if desired_batch_size is None: try: desired_batch_size = loader.batch_sampler.batch_size except AttributeError: raise RuntimeError return image.repeat(desired_batch_size, 1, 1, 1)
def test_batch_up_image_with_single_image(image): batch_size = 3 batched_up_image = utils.batch_up_image(make_single_image(image), batch_size) assert extract_batch_size(batched_up_image) == batch_size
def test_batch_up_image(image): batch_size = 3 batched_up_image = utils.batch_up_image(image, batch_size) assert extract_batch_size(batched_up_image) == batch_size