Example #1
0
    def test_main(self):
        modules = (nn.Conv2d(3, 3, 3), nn.ReLU())
        model = pystiche.SequentialModule(*modules)

        for idx, module in enumerate(modules):
            actual = getattr(model, str(idx))
            desired = module
            assert actual is desired
Example #2
0
    def test_call(self):
        torch.manual_seed(0)
        modules = (nn.Conv2d(3, 3, 3), nn.ReLU())
        input = torch.rand(1, 3, 256, 256)

        pystiche_model = pystiche.SequentialModule(*modules)
        torch_model = nn.Sequential(*modules)

        actual = pystiche_model(input)
        desired = torch_model(input)
        ptu.assert_allclose(actual, desired)
Example #3
0
def johnson_alahi_li_2016_transformer_decoder(
    impl_params: bool = True,
    instance_norm: bool = True,
) -> pystiche.SequentialModule:
    def get_value_range_delimiter() -> nn.Module:
        if impl_params:

            def value_range_delimiter(x: torch.Tensor) -> torch.Tensor:
                return 150.0 * torch.tanh(x)

        else:

            def value_range_delimiter(x: torch.Tensor) -> torch.Tensor:
                # sgm(2*x) == (tanh(x) + 1) / 2
                return torch.sigmoid(2.0 * x)

        class ValueRangeDelimiter(nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return value_range_delimiter(x)

        return ValueRangeDelimiter()

    modules = (
        johnson_alahi_li_2016_conv_block(
            in_channels=64 if instance_norm else 128,
            out_channels=32 if instance_norm else 64,
            kernel_size=3,
            stride=2,
            upsample=True,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_conv_block(
            in_channels=32 if instance_norm else 64,
            out_channels=16 if instance_norm else 32,
            kernel_size=3,
            stride=2,
            upsample=True,
            instance_norm=instance_norm,
        ),
        nn.Conv2d(
            in_channels=16 if instance_norm else 32,
            out_channels=3,
            kernel_size=9,
            padding=same_size_padding(kernel_size=9),
        ),
        get_value_range_delimiter(),
    )

    return pystiche.SequentialModule(*modules)
Example #4
0
def johnson_alahi_li_2016_transformer_encoder(
    instance_norm: bool = True, ) -> pystiche.SequentialModule:
    modules = (
        nn.ReflectionPad2d(40),
        johnson_alahi_li_2016_conv_block(
            in_channels=3,
            out_channels=16 if instance_norm else 32,
            kernel_size=9,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_conv_block(
            in_channels=16 if instance_norm else 32,
            out_channels=32 if instance_norm else 64,
            kernel_size=3,
            stride=2,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_conv_block(
            in_channels=32 if instance_norm else 64,
            out_channels=64 if instance_norm else 128,
            kernel_size=3,
            stride=2,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_residual_block(
            channels=64 if instance_norm else 128,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_residual_block(
            channels=64 if instance_norm else 128,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_residual_block(
            channels=64 if instance_norm else 128,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_residual_block(
            channels=64 if instance_norm else 128,
            instance_norm=instance_norm,
        ),
        johnson_alahi_li_2016_residual_block(
            channels=64 if instance_norm else 128,
            instance_norm=instance_norm,
        ),
    )
    return pystiche.SequentialModule(*modules)
Example #5
0
def encoder(
    instance_norm: bool = True,
) -> pystiche.SequentialModule:
    r"""Encoder part of the :class:`Transformer` from :cite:`JAL2016` .

    Args:
        instance_norm: If ``True``, use :class:`~torch.nn.InstanceNorm2d` rather than
            :class:`~torch.nn.BatchNorm2d` as described in the paper. In addition, the
            number of channels of the convolution layers is reduced by half.
    """
    modules = (
        nn.ReflectionPad2d(40),
        conv_block(
            in_channels=3,
            out_channels=maybe_fix_num_channels(32, instance_norm),
            kernel_size=9,
            instance_norm=instance_norm,
        ),
        conv_block(
            in_channels=maybe_fix_num_channels(32, instance_norm),
            out_channels=maybe_fix_num_channels(64, instance_norm),
            kernel_size=3,
            stride=2,
            instance_norm=instance_norm,
        ),
        conv_block(
            in_channels=maybe_fix_num_channels(64, instance_norm),
            out_channels=maybe_fix_num_channels(128, instance_norm),
            kernel_size=3,
            stride=2,
            instance_norm=instance_norm,
        ),
        residual_block(channels=maybe_fix_num_channels(128, instance_norm)),
        residual_block(channels=maybe_fix_num_channels(128, instance_norm)),
        residual_block(channels=maybe_fix_num_channels(128, instance_norm)),
        residual_block(channels=maybe_fix_num_channels(128, instance_norm)),
        residual_block(channels=maybe_fix_num_channels(128, instance_norm)),
    )
    return pystiche.SequentialModule(*modules)
Example #6
0
def decoder(
    impl_params: bool = True,
    instance_norm: bool = True,
) -> pystiche.SequentialModule:
    r"""Decoder part of the :class:`Transformer` from :cite:`JAL2016`.

    Args:
        impl_params: If ``True``, the output of the is not externally pre-processed
            before being fed into the
            :func:`~pystiche_papers.johnson_alahi_li_2016.perceptual_loss`. Since this
            step is necessary to get meaningful encodings from the
            :func:`~pystiche_papers.johnson_alahi_li_2016.multi_layer_encoder`, the
            pre-processing transform has to be learned within the output layer of the
            decoder. To make this possible, ``150 * tanh(input)`` is used as activation
            in contrast to the ``(tanh(input) + 1) / 2`` given in the paper.
        instance_norm: If ``True``, use :class:`~torch.nn.InstanceNorm2d` rather than
            :class:`~torch.nn.BatchNorm2d` as described in the paper. In addition, the
            number of channels of the convolution layers is reduced by half.
    """

    def get_value_range_delimiter() -> nn.Module:
        if impl_params:
            # https://github.com/pmeier/fast-neural-style/blob/813c83441953ead2adb3f65f4cc2d5599d735fa7/train.lua#L25
            # https://github.com/pmeier/fast-neural-style/blob/813c83441953ead2adb3f65f4cc2d5599d735fa7/fast_neural_style/models.lua#L137-L138
            # A tanh with a constant factor of 150 is used instead of the
            # (tanh(x) + 1) / 2 in the paper.
            def value_range_delimiter(x: torch.Tensor) -> torch.Tensor:
                return 150.0 * torch.tanh(x)

        else:

            def value_range_delimiter(x: torch.Tensor) -> torch.Tensor:
                # (tanh(x) + 1) / 2 == sgm(2*x)
                return torch.sigmoid(2.0 * x)

        class ValueRangeDelimiter(nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return value_range_delimiter(x)

        return ValueRangeDelimiter()

    modules = (
        conv_block(
            in_channels=maybe_fix_num_channels(128, instance_norm),
            out_channels=maybe_fix_num_channels(64, instance_norm),
            kernel_size=3,
            stride=2,
            upsample=True,
            instance_norm=instance_norm,
        ),
        conv_block(
            in_channels=maybe_fix_num_channels(64, instance_norm),
            out_channels=maybe_fix_num_channels(32, instance_norm),
            kernel_size=3,
            stride=2,
            upsample=True,
            instance_norm=instance_norm,
        ),
        AutoPadConv2d(
            in_channels=maybe_fix_num_channels(32, instance_norm),
            out_channels=3,
            kernel_size=9,
        ),
        get_value_range_delimiter(),
    )

    return pystiche.SequentialModule(*modules)