Beispiel #1
0
def generate_default_image_optim_loop_processing_asset(
    root, file="default_image_optim_loop_processing"
):
    torch.manual_seed(0)
    input_image = torch.rand(1, 3, 32, 32)
    criterion = TotalVariationOperator()

    def get_optimizer(image):
        from torch.optim import Adam

        return Adam([image.requires_grad_(True)], lr=1e-1)

    num_steps = 5
    preprocessor = CaffePreprocessing()
    postprocessor = CaffePostprocessing()

    _generate_default_image_optim_loop_asset(
        path.join(root, file),
        input_image,
        criterion,
        get_optimizer=get_optimizer,
        num_steps=num_steps,
        preprocessor=preprocessor,
        postprocessor=postprocessor,
    )
def generate_default_image_pyramid_optim_loop_asset(
        root, file="pyramid_image_optimization"):
    torch.manual_seed(0)
    input_image = torch.rand(1, 3, 32, 32)
    criterion = TotalVariationOperator()
    pyramid = ImagePyramid((16, 32), 3)

    def get_optimizer(image):
        from torch.optim import Adam

        return Adam([image.requires_grad_(True)], lr=1e-1)

    _generate_default_image_pyramid_optim_loop_asset(
        path.join(root, file),
        input_image,
        criterion,
        pyramid,
        get_optimizer=get_optimizer,
    )
def generate_default_image_optim_loop_asset(root, file="image_optimization"):
    torch.manual_seed(0)
    input_image = torch.rand(1, 3, 32, 32)
    criterion = TotalVariationOperator()

    def get_optimizer(image):
        from torch.optim import Adam

        return Adam([image.requires_grad_(True)], lr=1e-1)

    num_steps = 5

    _generate_default_image_optim_loop_asset(
        path.join(root, file),
        input_image,
        criterion,
        get_optimizer=get_optimizer,
        num_steps=num_steps,
    )
Beispiel #4
0
    def test_PerceptualLoss(self):
        op = TotalVariationOperator()
        required_components = {"content_loss", "style_loss"}
        all_components = {*required_components, "regularization"}
        for components in powerset(all_components):
            if not set(components).intersection(required_components):
                with self.assertRaises(RuntimeError):
                    loss.PerceptualLoss()
                continue

            perceptual_loss = loss.PerceptualLoss(
                **{component: op
                   for component in components})

            for component in components:
                self.assertTrue(getattr(perceptual_loss, f"has_{component}"))
                self.assertIs(getattr(perceptual_loss, component), op)

            for component in all_components - set(components):
                self.assertFalse(getattr(perceptual_loss, f"has_{component}"))