Example #1
0
def style_loss(
    impl_params: bool = True,
    multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None,
    hyper_parameters: Optional[HyperParameters] = None,
) -> loss.MultiLayerEncodingLoss:
    r"""Style loss from :cite:`JAL2016`.

    Args:
        impl_params: Switch the behavior and hyper-parameters between the reference
            implementation of the original authors and what is described in the paper.
            For details see :ref:`here <johnson_alahi_li_2016-impl_params>`.
        multi_layer_encoder: Pretrained :class:`~pystiche.enc.MultiLayerEncoder`. If
            omitted, the default
            :func:`~pystiche_papers.johnson_alahi_li_2016.multi_layer_encoder` is used.
        hyper_parameters: If omitted,
            :func:`~pystiche_papers.johnson_alahi_li_2016.hyper_parameters` is used.
    """
    if multi_layer_encoder is None:
        multi_layer_encoder = _multi_layer_encoder(impl_params=impl_params)

    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters()

    def get_encoding_op(encoder: enc.Encoder, layer_weight: float) -> GramLoss:
        return GramLoss(encoder,
                        impl_params=impl_params,
                        score_weight=layer_weight)

    return loss.MultiLayerEncodingLoss(
        multi_layer_encoder,
        hyper_parameters.style_loss.layers,
        get_encoding_op,
        layer_weights=hyper_parameters.style_loss.layer_weights,
        score_weight=hyper_parameters.style_loss.score_weight,
    )
Example #2
0
def make_loss(
    loss_str: str, layers_str: str, score_weight: float,
    mle: enc.MultiLayerEncoder
) -> Union[loss.ComparisonLoss, loss.MultiLayerEncodingLoss]:
    loss_str_normalized = loss_str.lower().replace("_", "").replace("-", "")
    if loss_str_normalized not in LOSSES.keys():
        raise ValueError(
            add_suggestion(
                f"Unknown loss '{loss_str}'.",
                word=loss_str_normalized,
                possibilities=tuple(zip(*LOSSES.values()))[0],
            ))

    _, loss_fn = LOSSES[loss_str_normalized]

    layers = [layer.strip() for layer in layers_str.split(",")]
    layers = sorted(layer for layer in layers if layer)
    for layer in layers:
        mle.verify(layer)

    if len(layers) == 1:
        return loss_fn(mle.extract_encoder(layers[0]), score_weight)
    else:
        return loss.MultiLayerEncodingLoss(mle,
                                           layers,
                                           loss_fn,
                                           score_weight=score_weight)
Example #3
0
def style_loss(
    impl_params: bool = True,
    multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None,
    hyper_parameters: Optional[HyperParameters] = None,
) -> loss.MultiLayerEncodingLoss:
    r"""Style loss from :cite:`LW2016`.

    Args:
        impl_params: Switch the behavior and hyper-parameters between the reference
            implementation of the original authors and what is described in the paper.
            For details see :ref:`here <li_wand_2016-impl_params>`.
        multi_layer_encoder: Pretrained multi-layer encoder. If
            omitted, :func:`~pystiche_papers.li_wand_2016.multi_layer_encoder` is used.
        hyper_parameters: Hyper parameters. If omitted,
            :func:`~pystiche_papers.li_wand_2016.hyper_parameters` is used.

    .. seealso::

        - :class:`pystiche_papers.li_wand_2016.MRFLoss`
    """
    if multi_layer_encoder is None:
        multi_layer_encoder = _multi_layer_encoder()

    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters(impl_params=impl_params)

    def encoding_loss_fn(encoder: enc.Encoder, layer_weight: float) -> MRFLoss:
        return MRFLoss(
            encoder,
            hyper_parameters.style_loss.patch_size,  # type: ignore[union-attr]
            impl_params=impl_params,
            stride=hyper_parameters.style_loss.
            stride,  # type: ignore[union-attr]
            target_transforms=_target_transforms(
                impl_params=impl_params, hyper_parameters=hyper_parameters),
            score_weight=layer_weight,
        )

    return loss.MultiLayerEncodingLoss(
        multi_layer_encoder,
        hyper_parameters.style_loss.layers,
        encoding_loss_fn,
        layer_weights=hyper_parameters.style_loss.layer_weights,
        score_weight=hyper_parameters.style_loss.score_weight,
    )
Example #4
0
def style_loss(
    impl_params: bool = True,
    instance_norm: bool = True,
    multi_layer_encoder: Optional[enc.MultiLayerEncoder] = None,
    hyper_parameters: Optional[HyperParameters] = None,
) -> loss.MultiLayerEncodingLoss:
    r"""Style loss from :cite:`ULVL2016,UVL2017`.

    Args:
        impl_params: Switch the behavior and hyper-parameters between the reference
            implementation of the original authors and what is described in the paper.
            For details see :ref:`here <li_wand_2016-impl_params>`.
        instance_norm: Switch the behavior and hyper-parameters between both
            publications of the original authors. For details see
            :ref:`here <ulyanov_et_al_2016-instance_norm>`.
        multi_layer_encoder: Pretrained :class:`~pystiche.enc.MultiLayerEncoder`. If
            omitted, :func:`~pystiche_papers.ulyanov_et_al_2016.multi_layer_encoder`
            is used.
        hyper_parameters: Hyper parameters. If omitted,
            :func:`~pystiche_papers.ulyanov_et_al_2016.hyper_parameters` is used.

    .. seealso::

        - :class:`pystiche_papers.ulyanov_et_al_2016.GramLoss`
    """
    if multi_layer_encoder is None:
        multi_layer_encoder = _multi_layer_encoder()

    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters(impl_params=impl_params,
                                             instance_norm=instance_norm)

    def get_encoding_op(encoder: enc.Encoder, layer_weight: float) -> GramLoss:
        return GramLoss(encoder,
                        impl_params=impl_params,
                        score_weight=layer_weight)

    return loss.MultiLayerEncodingLoss(
        multi_layer_encoder,
        hyper_parameters.style_loss.layers,
        get_encoding_op,
        layer_weights=hyper_parameters.style_loss.layer_weights,
        score_weight=hyper_parameters.style_loss.score_weight,
    )
Example #5
0
 def _get_perceptual_loss(
     self,
     *,
     backbone: str,
     content_layer: str,
     content_weight: float,
     style_layers: Sequence[str],
     style_weight: float,
 ) -> loss.PerceptualLoss:
     mle, _ = cast(enc.MultiLayerEncoder, self.backbones.get(backbone)())
     content_loss = loss.FeatureReconstructionLoss(mle.extract_encoder(content_layer), score_weight=content_weight)
     style_loss = loss.MultiLayerEncodingLoss(
         mle,
         style_layers,
         lambda encoder, layer_weight: self._modified_gram_loss(encoder, score_weight=layer_weight),
         layer_weights="sum",
         score_weight=style_weight,
     )
     return loss.PerceptualLoss(content_loss, style_loss)
########################################################################################
# We use the :class:`~pystiche.loss.GramLoss` introduced by Gatys, Ecker, and Bethge
# :cite:`GEB2016` as ``style_loss``. Unlike before, we use multiple ``style_layers``.
# The individual losses can be conveniently bundled in a
# :class:`~pystiche.loss.MultiLayerEncodingLoss`.

style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")
style_weight = 1e3


def get_style_op(encoder, layer_weight):
    return loss.GramLoss(encoder, score_weight=layer_weight)


style_loss = loss.MultiLayerEncodingLoss(
    multi_layer_encoder, style_layers, get_style_op, score_weight=style_weight,
)
print(style_loss)


########################################################################################
# We combine the ``content_loss`` and ``style_loss`` into a joined
# :class:`~pystiche.loss.PerceptualLoss`, which will serve as optimization criterion.

perceptual_loss = loss.PerceptualLoss(content_loss, style_loss).to(device)
print(perceptual_loss)


########################################################################################
# Images
# ------
def get_region_op(region, region_weight):
    return loss.MultiLayerEncodingLoss(
        multi_layer_encoder, style_layers, get_style_op, score_weight=region_weight,
    )
)


class GramOperator(loss.GramLoss):
    def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
        repr = super().enc_to_repr(enc)
        num_channels = repr.size()[1]
        return repr / num_channels


style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
style_weight = 1e10
style_loss = loss.MultiLayerEncodingLoss(
    multi_layer_encoder,
    style_layers,
    lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
    layer_weights="sum",
    score_weight=style_weight,
)

perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
perceptual_loss = perceptual_loss.to(device)
print(perceptual_loss)


########################################################################################
# Training
# --------
#
# In a first step we load the style image that will be used to train the
# ``transformer``.