def batch_up_image( image: torch.Tensor, desired_batch_size: Optional[int] = None, loader: Optional[DataLoader] = None, ) -> torch.Tensor: def extract_batch_size_from_loader(loader: DataLoader) -> int: batch_size = loader.batch_size if batch_size is not None: return batch_size try: batch_size = loader.batch_sampler.batch_size # type: ignore[union-attr] assert isinstance(batch_size, int) return batch_size except (AttributeError, AssertionError): raise RuntimeError if desired_batch_size is None and loader is None: raise RuntimeError if desired_batch_size is None: desired_batch_size = extract_batch_size_from_loader( cast(DataLoader, loader)) if is_single_image(image): image = make_batched_image(image) elif extract_batch_size(image) > 1: raise RuntimeError return image.repeat(desired_batch_size, 1, 1, 1)
def get_single_and_batched_pystiche_images(pystiche_image): if is_single_image(pystiche_image): pystiche_single_image = pystiche_image pystiche_batched_image = make_batched_image( pystiche_single_image) else: pystiche_batched_image = pystiche_image pystiche_single_image = make_single_image( pystiche_batched_image) return pystiche_single_image, pystiche_batched_image
def batch_up_image( image: torch.Tensor, desired_batch_size: Optional[int] = None, loader: Optional[DataLoader] = None, ) -> torch.Tensor: if desired_batch_size is None and loader is None: raise RuntimeError if is_single_image(image): image = make_batched_image(image) elif extract_batch_size(image) > 1: raise RuntimeError if desired_batch_size is None: desired_batch_size = loader.batch_size if desired_batch_size is None: try: desired_batch_size = loader.batch_sampler.batch_size except AttributeError: raise RuntimeError return image.repeat(desired_batch_size, 1, 1, 1)
def test_make_batched_image(): single_image = torch.empty(1, 1, 1) batched_image = image_.make_batched_image(single_image) assert image_.is_batched_image(batched_image)