def test_main(self): modules = (nn.Conv2d(3, 3, 3), nn.ReLU()) model = pystiche.SequentialModule(*modules) for idx, module in enumerate(modules): actual = getattr(model, str(idx)) desired = module assert actual is desired
def test_call(self): torch.manual_seed(0) modules = (nn.Conv2d(3, 3, 3), nn.ReLU()) input = torch.rand(1, 3, 256, 256) pystiche_model = pystiche.SequentialModule(*modules) torch_model = nn.Sequential(*modules) actual = pystiche_model(input) desired = torch_model(input) ptu.assert_allclose(actual, desired)
def johnson_alahi_li_2016_transformer_decoder( impl_params: bool = True, instance_norm: bool = True, ) -> pystiche.SequentialModule: def get_value_range_delimiter() -> nn.Module: if impl_params: def value_range_delimiter(x: torch.Tensor) -> torch.Tensor: return 150.0 * torch.tanh(x) else: def value_range_delimiter(x: torch.Tensor) -> torch.Tensor: # sgm(2*x) == (tanh(x) + 1) / 2 return torch.sigmoid(2.0 * x) class ValueRangeDelimiter(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return value_range_delimiter(x) return ValueRangeDelimiter() modules = ( johnson_alahi_li_2016_conv_block( in_channels=64 if instance_norm else 128, out_channels=32 if instance_norm else 64, kernel_size=3, stride=2, upsample=True, instance_norm=instance_norm, ), johnson_alahi_li_2016_conv_block( in_channels=32 if instance_norm else 64, out_channels=16 if instance_norm else 32, kernel_size=3, stride=2, upsample=True, instance_norm=instance_norm, ), nn.Conv2d( in_channels=16 if instance_norm else 32, out_channels=3, kernel_size=9, padding=same_size_padding(kernel_size=9), ), get_value_range_delimiter(), ) return pystiche.SequentialModule(*modules)
def johnson_alahi_li_2016_transformer_encoder( instance_norm: bool = True, ) -> pystiche.SequentialModule: modules = ( nn.ReflectionPad2d(40), johnson_alahi_li_2016_conv_block( in_channels=3, out_channels=16 if instance_norm else 32, kernel_size=9, instance_norm=instance_norm, ), johnson_alahi_li_2016_conv_block( in_channels=16 if instance_norm else 32, out_channels=32 if instance_norm else 64, kernel_size=3, stride=2, instance_norm=instance_norm, ), johnson_alahi_li_2016_conv_block( in_channels=32 if instance_norm else 64, out_channels=64 if instance_norm else 128, kernel_size=3, stride=2, instance_norm=instance_norm, ), johnson_alahi_li_2016_residual_block( channels=64 if instance_norm else 128, instance_norm=instance_norm, ), johnson_alahi_li_2016_residual_block( channels=64 if instance_norm else 128, instance_norm=instance_norm, ), johnson_alahi_li_2016_residual_block( channels=64 if instance_norm else 128, instance_norm=instance_norm, ), johnson_alahi_li_2016_residual_block( channels=64 if instance_norm else 128, instance_norm=instance_norm, ), johnson_alahi_li_2016_residual_block( channels=64 if instance_norm else 128, instance_norm=instance_norm, ), ) return pystiche.SequentialModule(*modules)
def encoder( instance_norm: bool = True, ) -> pystiche.SequentialModule: r"""Encoder part of the :class:`Transformer` from :cite:`JAL2016` . Args: instance_norm: If ``True``, use :class:`~torch.nn.InstanceNorm2d` rather than :class:`~torch.nn.BatchNorm2d` as described in the paper. In addition, the number of channels of the convolution layers is reduced by half. """ modules = ( nn.ReflectionPad2d(40), conv_block( in_channels=3, out_channels=maybe_fix_num_channels(32, instance_norm), kernel_size=9, instance_norm=instance_norm, ), conv_block( in_channels=maybe_fix_num_channels(32, instance_norm), out_channels=maybe_fix_num_channels(64, instance_norm), kernel_size=3, stride=2, instance_norm=instance_norm, ), conv_block( in_channels=maybe_fix_num_channels(64, instance_norm), out_channels=maybe_fix_num_channels(128, instance_norm), kernel_size=3, stride=2, instance_norm=instance_norm, ), residual_block(channels=maybe_fix_num_channels(128, instance_norm)), residual_block(channels=maybe_fix_num_channels(128, instance_norm)), residual_block(channels=maybe_fix_num_channels(128, instance_norm)), residual_block(channels=maybe_fix_num_channels(128, instance_norm)), residual_block(channels=maybe_fix_num_channels(128, instance_norm)), ) return pystiche.SequentialModule(*modules)
def decoder( impl_params: bool = True, instance_norm: bool = True, ) -> pystiche.SequentialModule: r"""Decoder part of the :class:`Transformer` from :cite:`JAL2016`. Args: impl_params: If ``True``, the output of the is not externally pre-processed before being fed into the :func:`~pystiche_papers.johnson_alahi_li_2016.perceptual_loss`. Since this step is necessary to get meaningful encodings from the :func:`~pystiche_papers.johnson_alahi_li_2016.multi_layer_encoder`, the pre-processing transform has to be learned within the output layer of the decoder. To make this possible, ``150 * tanh(input)`` is used as activation in contrast to the ``(tanh(input) + 1) / 2`` given in the paper. instance_norm: If ``True``, use :class:`~torch.nn.InstanceNorm2d` rather than :class:`~torch.nn.BatchNorm2d` as described in the paper. In addition, the number of channels of the convolution layers is reduced by half. """ def get_value_range_delimiter() -> nn.Module: if impl_params: # https://github.com/pmeier/fast-neural-style/blob/813c83441953ead2adb3f65f4cc2d5599d735fa7/train.lua#L25 # https://github.com/pmeier/fast-neural-style/blob/813c83441953ead2adb3f65f4cc2d5599d735fa7/fast_neural_style/models.lua#L137-L138 # A tanh with a constant factor of 150 is used instead of the # (tanh(x) + 1) / 2 in the paper. def value_range_delimiter(x: torch.Tensor) -> torch.Tensor: return 150.0 * torch.tanh(x) else: def value_range_delimiter(x: torch.Tensor) -> torch.Tensor: # (tanh(x) + 1) / 2 == sgm(2*x) return torch.sigmoid(2.0 * x) class ValueRangeDelimiter(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return value_range_delimiter(x) return ValueRangeDelimiter() modules = ( conv_block( in_channels=maybe_fix_num_channels(128, instance_norm), out_channels=maybe_fix_num_channels(64, instance_norm), kernel_size=3, stride=2, upsample=True, instance_norm=instance_norm, ), conv_block( in_channels=maybe_fix_num_channels(64, instance_norm), out_channels=maybe_fix_num_channels(32, instance_norm), kernel_size=3, stride=2, upsample=True, instance_norm=instance_norm, ), AutoPadConv2d( in_channels=maybe_fix_num_channels(32, instance_norm), out_channels=3, kernel_size=9, ), get_value_range_delimiter(), ) return pystiche.SequentialModule(*modules)