def test_content_transform_grayscale_image(subtests, content_image, impl_params, instance_norm): content_image = F.rgb_to_grayscale(content_image) edge_size = 16 hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) hyper_parameters.content_transform.edge_size = edge_size content_transform = paper.content_transform( impl_params=impl_params, instance_norm=instance_norm, hyper_parameters=hyper_parameters, ) if instance_norm: utils.make_reproducible() actual = content_transform(content_image) if impl_params: if instance_norm: # Since the transform involves an uncontrollable random component, we can't # do a reference test here return transform_image = F.resize(content_image, edge_size) else: transform = transforms.CenterCrop(edge_size) transform_image = transform(content_image) desired = transform_image.repeat(1, 3, 1, 1) ptu.assert_allclose(actual, desired)
def test_style_loss(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters( impl_params=impl_params, instance_norm=instance_norm).style_loss style_loss = paper.style_loss( impl_params=impl_params, instance_norm=instance_norm, ) assert isinstance(style_loss, pystiche.loss.MultiLayerEncodingLoss) with subtests.test("losses"): assert all( isinstance(loss, paper.GramLoss) for loss in style_loss.children()) layers, layer_weights = zip(*[(loss.encoder.layer, loss.score_weight) for loss in style_loss.children()]) with subtests.test("layers"): assert layers == hyper_parameters.layers with subtests.test("layer_weights"): assert layer_weights == pytest.approx(hyper_parameters.layer_weights) with subtests.test("score_weight"): assert style_loss.score_weight == pytest.approx( hyper_parameters.score_weight)
def test_dataset(subtests, mocker, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) mocker.patch( "pystiche_papers.ulyanov_et_al_2016._data.ImageFolderDataset._collect_image_files", return_value=[], ) dataset = paper.dataset("root", impl_params=impl_params, instance_norm=instance_norm) assert isinstance(dataset, torch.utils.data.IterableDataset) with subtests.test("dataset"): assert isinstance(dataset.dataset, ImageFolderDataset) with subtests.test("content_transform"): assert isinstance(dataset.transform, type(paper.content_transform())) with subtests.test("min_size"): assert dataset.min_size == hyper_parameters.content_transform.edge_size with subtests.test("length"): assert (len(dataset) == hyper_parameters.num_batches * hyper_parameters.batch_size)
def test_stylization_smoke(stylization, postprocessor_mocks, input_image): hyper_parameters = paper.hyper_parameters() edge_size = hyper_parameters.content_transform.edge_size _, _, output_image = stylization(input_image) ptu.assert_allclose(output_image, F.resize(input_image, (edge_size, edge_size)) + 0.5, rtol=1e-6)
def test_hyper_parameters_lr_scheduler(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) sub_params = "lr_scheduler" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("lr_decay"): assert hyper_parameters.lr_decay == 0.8 if impl_params else 0.7
def test_hyper_parameters_optimizer(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) sub_params = "optimizer" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("lr"): assert hyper_parameters.lr == 1e-3 if impl_params and instance_norm else 1e-1
def test_hyper_parameters_content_transform(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) sub_params = "content_transform" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("edge_size"): assert hyper_parameters.edge_size == 256
def test_hyper_parameters_num_images(impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) num_batches_per_epoch = hyper_parameters.num_batches num_epochs = hyper_parameters.num_epochs num_images = num_batches_per_epoch * num_epochs assert num_images == ( 50_000 if instance_norm else 3_000) if impl_params else 2_000
def test_style_transform(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters( impl_params=impl_params, instance_norm=instance_norm).style_transform style_transform = paper.style_transform(impl_params=impl_params, instance_norm=instance_norm) with subtests.test("edge_size"): assert style_transform.edge_size == hyper_parameters.edge_size with subtests.test("interpolation"): assert style_transform.interpolation == hyper_parameters.interpolation
def test_image_loader(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) dataset = () image_loader = paper.image_loader(dataset, impl_params=impl_params, instance_norm=instance_norm) assert isinstance(image_loader, DataLoader) with subtests.test("batch_size"): assert image_loader.batch_size == hyper_parameters.batch_size
def test_hyper_parameters_content_loss(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) sub_params = "content_loss" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("layer"): assert hyper_parameters.layer == "relu4_2" with subtests.test("score_weight"): assert hyper_parameters.score_weight == pytest.approx( 6e-1 if impl_params and not instance_norm else 1e0)
def test_hyper_parameters_lr_decay_delay(impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) num_batches = hyper_parameters.num_batches delay = hyper_parameters.lr_scheduler.delay num_batches_before_first_decay = num_batches * (delay + 1) num_batches_between_decays = num_batches assert (num_batches_before_first_decay == (2_000 if instance_norm else 300) if impl_params else 1000) assert (num_batches_between_decays == (2_000 if instance_norm else 300) if impl_params else 200)
def test_hyper_parameters_style_transform(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) sub_params = "style_transform" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("edge_size"): assert hyper_parameters.edge_size == 256 with subtests.test("interpolation"): assert hyper_parameters.interpolation == ( InterpolationMode.BICUBIC if impl_params and instance_norm else InterpolationMode.BILINEAR)
def test_content_loss(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters( impl_params=impl_params, instance_norm=instance_norm).content_loss content_loss = paper.content_loss( impl_params=impl_params, instance_norm=instance_norm, ) assert isinstance(content_loss, pystiche.loss.FeatureReconstructionLoss) with subtests.test("layer"): assert content_loss.encoder.layer == hyper_parameters.layer with subtests.test("score_weight"): assert content_loss.score_weight == pytest.approx( hyper_parameters.score_weight)
def test_ulyanov_et_al_2016_lr_scheduler(subtests, impl_params, instance_norm): transformer = nn.Conv2d(3, 3, 1) optimizer = paper.optimizer(transformer) hyper_parameters = paper.hyper_parameters( impl_params=impl_params, instance_norm=instance_norm).lr_scheduler lr_scheduler = paper.lr_scheduler(optimizer, impl_params=impl_params, instance_norm=instance_norm) assert isinstance(lr_scheduler, paper.DelayedExponentialLR) with subtests.test("lr_decay"): assert lr_scheduler.gamma == hyper_parameters.lr_decay with subtests.test("delay"): assert lr_scheduler.delay == hyper_parameters.delay
def test_hyper_parameters_style_loss(subtests, impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) sub_params = "style_loss" assert sub_params in hyper_parameters hyper_parameters = getattr(hyper_parameters, sub_params) with subtests.test("layer"): assert hyper_parameters.layers == (("relu1_1", "relu2_1", "relu3_1", "relu4_1") if impl_params and instance_norm else ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")) with subtests.test("layer_weights"): assert hyper_parameters.layer_weights == pytest.approx( [1e3 if impl_params and not instance_norm else 1e0] * len(hyper_parameters.layers)) with subtests.test("score_weight"): assert hyper_parameters.score_weight == pytest.approx(1e0)
def test_ulyanov_et_al_2016_optimizer(subtests, impl_params, instance_norm): transformer = nn.Conv2d(3, 3, 1) params = tuple(transformer.parameters()) hyper_parameters = paper.hyper_parameters( impl_params=impl_params, instance_norm=instance_norm).optimizer optimizer = paper.optimizer(transformer, impl_params=impl_params, instance_norm=instance_norm) assert isinstance(optimizer, optim.Adam) assert len(optimizer.param_groups) == 1 param_group = optimizer.param_groups[0] with subtests.test(msg="optimization params"): assert len(param_group["params"]) == len(params) for actual, desired in zip(param_group["params"], params): assert actual is desired with subtests.test(msg="optimizer properties"): assert param_group["lr"] == ptu.approx(hyper_parameters.lr)
def test_hyper_parameters_num_batches(impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) assert hyper_parameters.num_batches == ((2000 if instance_norm else 300) if impl_params else 200)
def test_hyper_parameters_batch_size(impl_params, instance_norm): hyper_parameters = paper.hyper_parameters(impl_params=impl_params, instance_norm=instance_norm) assert hyper_parameters.batch_size == ((1 if instance_norm else 4) if impl_params else 16)
def test_hyper_parameters_smoke(): hyper_parameters = paper.hyper_parameters() assert isinstance(hyper_parameters, HyperParameters)