Ejemplo n.º 1
0
    def test_PerceptualLoss_set_style_image(self):
        torch.manual_seed(0)
        image = torch.rand(1, 1, 100, 100)
        content_loss = FeatureReconstructionOperator(
            SequentialEncoder((nn.Conv2d(1, 1, 1),))
        )
        style_loss = FeatureReconstructionOperator(
            SequentialEncoder((nn.Conv2d(1, 1, 1),))
        )

        perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
        perceptual_loss.set_style_image(image)

        actual = style_loss.target_image
        desired = image
        self.assertTensorAlmostEqual(actual, desired)
print(f"I'm working with {device}")

images = demo_images()
images.download()

########################################################################################
# At first we define a :class:`~pystiche.loss.perceptual.PerceptualLoss` that is used
# as optimization ``criterion``.

multi_layer_encoder = vgg19_multi_layer_encoder()

content_layer = "relu4_2"
content_encoder = multi_layer_encoder.extract_single_layer_encoder(
    content_layer)
content_weight = 1e0
content_loss = FeatureReconstructionOperator(content_encoder,
                                             score_weight=content_weight)

style_layers = ("relu3_1", "relu4_1")
style_weight = 2e0


def get_style_op(encoder, layer_weight):
    patch_size = 3
    return MRFOperator(encoder,
                       patch_size,
                       stride=2,
                       score_weight=layer_weight)


style_loss = MultiLayerEncodingOperator(
    multi_layer_encoder,
Ejemplo n.º 3
0
 def get_guided_perceptual_loss():
     content_loss = FeatureReconstructionOperator(
         SequentialEncoder((nn.Conv2d(1, 1, 1),))
     )
     style_loss = MultiRegionOperator(regions, get_op)
     return loss.GuidedPerceptualLoss(content_loss, style_loss)