def test_batch_up_image_loader(image): batch_size = 3 dataset = () loader = DataLoader(dataset, batch_size=batch_size) batched_up_image = utils.batch_up_image(image, loader=loader) assert extract_batch_size(batched_up_image) == batch_size
def test_batch_up_image_loader_with_batch_sampler_no_batch_size( subtests, image): class NoBatchSizeBatchSampler(BatchSampler): def __init__(self): pass class WrongTypeBatchSizeBatchSampler(BatchSampler): def __init__(self): self.batch_size = None with subtests.test("no batch_size"): loader = DataLoader((), batch_sampler=NoBatchSizeBatchSampler()) with pytest.raises(RuntimeError): utils.batch_up_image(image, loader=loader) with subtests.test("wrong type batch_size"): loader = DataLoader((), batch_sampler=WrongTypeBatchSizeBatchSampler()) with pytest.raises(RuntimeError): utils.batch_up_image(image, loader=loader)
def test_batch_up_image_loader_with_batch_sampler(image): batch_size = 3 dataset = () batch_sampler = BatchSampler(SequentialSampler(dataset), batch_size, drop_last=False) loader = DataLoader(dataset, batch_sampler=batch_sampler) batched_up_image = utils.batch_up_image(image, loader=loader) assert extract_batch_size(batched_up_image) == batch_size
def test_training_criterion_style_image( preprocessor_mocks, optimizer_mocks, style_transforms_mocks, transformer_mocks, training, image_loader, style_image, ): _, preprocessor = preprocessor_mocks _, style_transform = style_transforms_mocks args, _, _ = training(image_loader, style_image) criterion = args[2] ptu.assert_allclose( criterion.style_loss.target_image, preprocessor( style_transform( batch_up_image(style_image, image_loader.batch_size))), )
def test_batch_up_image_missing_arg(image): with pytest.raises(RuntimeError): utils.batch_up_image(image)
def test_batch_up_image_with_batched_image(batch_image): with pytest.raises(RuntimeError): utils.batch_up_image(batch_image, 2)
def test_batch_up_image_with_single_image(image): batch_size = 3 batched_up_image = utils.batch_up_image(make_single_image(image), batch_size) assert extract_batch_size(batched_up_image) == batch_size
def test_batch_up_image(image): batch_size = 3 batched_up_image = utils.batch_up_image(image, batch_size) assert extract_batch_size(batched_up_image) == batch_size