Exemple #1
0
    def __init__(self,
                 device,
                 content_encoder: nn.Module,
                 decoder: nn.Module,
                 content_disc: nn.Module,
                 source_disc: nn.Module,
                 scaler: Scaler = None) -> None:
        super().__init__(device)
        self._content_encoder: nn.Module = content_encoder
        self._decoder: nn.Module = decoder

        self._content_disc: nn.Module = content_disc
        self._source_disc: nn.Module = source_disc

        if scaler is None:
            scaler = Scaler(1., 0.)
        self._scaler = scaler

        self._identity_criterion = nn.L1Loss()
        #self._content_criterion = nn.BCEWithLogitsLoss()
        self._content_criterion = nn.MSELoss()

        #self._source_criterion = nn.BCEWithLogitsLoss()
        self._source_criterion = nn.MSELoss()

        self.to(self._device)
Exemple #2
0
    def __init__(self,
                 device,
                 glow: Glow,
                 style_w: nn.Module,
                 content_disc: nn.Module,
                 style_disc: nn.Module,
                 scaler: Scaler = None) -> None:
        super().__init__(device)
        self._glow: Glow = glow
        #self._glow.requires_grad_(False)

        self._style_w: nn.Module = style_w
        self._content_disc: nn.Module = content_disc
        self._style_disc: nn.Module = style_disc

        if scaler is None:
            scaler = Scaler(1., 0.)
        self._scaler = scaler

        self._weight_cycle_criterion = ListMSELoss()

        #self._content_criterion = nn.BCEWithLogitsLoss()
        #self._style_criterion = nn.BCEWithLogitsLoss()
        self._content_criterion = nn.MSELoss()
        self._style_criterion = nn.MSELoss()

        self._norm_criterion = ListMSELoss(reduction='none')
        self._siamese_criterion1 = ListMSELoss()
        self._siamese_criterion2 = ListSiameseLoss()

        self.to(self._device)
Exemple #3
0
    def __init__(self,
                 device,
                 glow: Glow,
                 style_w: nn.Module,
                 content_disc: nn.Module,
                 style_disc: nn.Module,
                 scaler: Scaler = None) -> None:
        super().__init__(device)
        self._glow: Glow = glow
        self._style_w: nn.Module = style_w
        self._content_disc: nn.Module = content_disc
        self._style_disc: nn.Module = style_disc

        if scaler is None:
            scaler = Scaler(1., 0.)
        self._scaler = scaler

        self._identity_criterion = nn.L1Loss()
        self._weight_cycle_criterion = ListMSELoss()

        #self._content_criterion = nn.BCEWithLogitsLoss()
        #self._style_criterion = nn.BCEWithLogitsLoss()
        self._content_criterion = nn.MSELoss()
        self._style_criterion = nn.MSELoss()

        self._siamese_criterion = SiameseLoss()

        self.to(self._device)
    def __init__(self,
                 device,
                 G1: nn.Module,
                 G2: nn.Module,
                 D1: nn.Module,
                 D2: nn.Module,
                 vgg: nn.Module,
                 scaler: Scaler = None) -> None:
        super().__init__(device)
        self._G1 = G1
        self._G2 = G2
        self._D1 = D1
        self._D2 = D2
        if scaler is None:
            scaler = Scaler(1., 0.)
        self._scaler = scaler
        self._vgg = vgg
        self._vgg.requires_grad_(False)
        self._vgg.eval()

        self._l1_criterion = torch.nn.L1Loss()
        self._l2_criterion = torch.nn.MSELoss()
        self._gan_criterion = torch.nn.MSELoss()

        self.to(self._device)
Exemple #5
0
    def __init__(self,
                 device,
                 encoder: nn.Module,
                 decoder: nn.Module,
                 style_w: nn.Module,
                 content_disc: nn.Module,
                 style_disc: nn.Module,
                 scaler: Scaler = None) -> None:
        super().__init__(device)
        self._encoder = encoder
        self._decoder = decoder
        self._style_w = style_w
        self._content_disc = content_disc
        self._style_disc = style_disc
        if scaler is None:
            scaler = Scaler(1., 0.)
        self._scaler = scaler

        self._identity_criterion = nn.MSELoss()
        self._cycle_criterion = nn.MSELoss()
        self._weight_cycle_criterion = nn.MSELoss()
        self._content_criterion = nn.BCEWithLogitsLoss()
        self._style_criterion = nn.BCEWithLogitsLoss()
        self._siamese_criterion = SiameseLoss()

        self.to(self._device)
    def __init__(self, device) -> None:
        G1: nn.Module = GeneratorReferenceMakeup(64, 6, 3)
        G2: nn.Module = GeneratorDeMakeup(64, 6, 3)
        D1: nn.Module = Discriminator(256, 64, 3, 3, 'SN')
        D2: nn.Module = Discriminator(256, 64, 3, 3, 'SN')
        vgg: nn.Module = torchvision.models.vgg19(pretrained=True)
        scaler: Scaler = Scaler(2., 0.5)

        super().__init__(device, G1, G2, D1, D2, vgg, scaler)
        self.apply(weights_init_xavier)
Exemple #7
0
    def __init__(self, device, glow_path: str = '') -> None:
        content_dim = 512
        style_dim = content_dim
        latent_dim = content_dim
        #style_dim = 64
        #latent_dim = content_dim + style_dim

        img_size = 32
        in_channel = 3
        n_flow = 32
        n_block = 4

        glow: Glow = Glow(img_size,
                          in_channel,
                          n_flow,
                          n_block,
                          affine=True,
                          conv_lu=True)
        if glow_path:
            glow.load_state_dict(torch.load(glow_path))

        style_w: nn.Module = BlockwiseWeight(img_size, in_channel, n_block)
        #content_disc: nn.Module = PyramidDiscriminator(img_size, 2 * in_channel, n_block, bias=-1)
        content_disc: nn.Module = PyramidDiscriminator(img_size,
                                                       2 * in_channel,
                                                       n_block,
                                                       bias=0)
        style_disc: nn.Module = PyramidDiscriminator(img_size,
                                                     2 * in_channel,
                                                     n_block,
                                                     bias=0)
        scaler: Scaler = Scaler(1., 0.5)

        style_w.apply(weights_init_xavier)
        content_disc.apply(weights_init_xavier)
        style_disc.apply(weights_init_xavier)
        #style_w.apply(weights_init_resnet)
        #content_disc.apply(weights_init_resnet)
        #style_disc.apply(weights_init_resnet)

        super().__init__(device, glow, style_w, content_disc, style_disc,
                         scaler)

        self._content_dim = content_dim
        self._style_dim = style_dim
Exemple #8
0
    def __init__(self, device) -> None:
        dimension = 2
        in_channels = 3
        out_channels = 3
        content_dim = 256
        latent_dim = content_dim
        #style_dim = 32
        #latent_dim = content_dim + style_dim
        num_blocks = [4, 4]
        planes = [64, 128, 256]

        content_encoder: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      planes[0],
                      kernel_size=7,
                      stride=1,
                      padding=3,
                      bias=False), nn.InstanceNorm2d(planes[0], affine=True),
            nn.LeakyReLU(0.01), nn.MaxPool2d(kernel_size=3,
                                             stride=2,
                                             padding=1),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=False,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False), Permute((0, 2, 3, 1)),
            nn.Linear(planes[-1], content_dim))
        decoder: nn.Module = nn.Sequential(
            nn.Linear(latent_dim, planes[-1]), Permute((0, 3, 1, 2)),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=True,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False),
            nn.ConvTranspose2d(planes[0],
                               planes[0],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1,
                               bias=False),
            nn.InstanceNorm2d(planes[0], affine=True), nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(planes[0],
                               out_channels,
                               kernel_size=7,
                               stride=1,
                               padding=3,
                               output_padding=0), nn.Tanh())
        content_disc: nn.Module = nn.Sequential(
            Permute((0, 3, 1, 2)),
            nn.Conv2d(content_dim,
                      256,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=True), nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(256, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(256, affine=True), nn.LeakyReLU(0.01),
            nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
            nn.Linear(256, 1, bias=True))
        source_disc: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      64,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False), nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1,
                      bias=False), nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1,
                      bias=False), nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1,
                      bias=False), nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.01), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
            nn.Linear(256, 1, bias=True))

        scaler: Scaler = Scaler(2., 0.5)

        super().__init__(device, content_encoder, decoder, content_disc,
                         source_disc, scaler)

        self._content_dim = content_dim
        self.apply(weights_init_resnet)
Exemple #9
0
    def __init__(self, device) -> None:
        dimension = 2
        in_channels = 3
        content_dim = 512
        latent_dim = content_dim
        #style_dim = 64
        #latent_dim = content_dim + style_dim
        num_blocks = [4]
        planes = [64, 64]

        #content_encoder: nn.Module = nn.Sequential(
        #    nn.Conv2d(in_channels, planes[0], kernel_size=5, stride=1, padding=2, bias=False),
        #    nn.InstanceNorm2d(planes[0], affine=True), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        #    simple_resnet(dimension, num_blocks, planes,
        #                  transpose=False, norm='InstanceNorm', activation='ReLU', pool=False),
        #    nn.Flatten(start_dim=1), nn.Linear(planes[-1]*7*7, content_dim))
        #style_encoder: nn.Module = nn.Sequential(
        #    nn.Conv2d(in_channels, planes[0], kernel_size=5, stride=1, padding=2, bias=False),
        #    nn.InstanceNorm2d(planes[0], affine=True), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        #    simple_resnet(dimension, num_blocks, planes,
        #                  transpose=False, norm='InstanceNorm', activation='ReLU', pool=False),
        #    nn.Flatten(start_dim=1), nn.Linear(planes[-1]*7*7, style_dim))
        #decoder: nn.Module = nn.Sequential(
        #    nn.Linear(latent_dim, planes[-1]*7*7), View((-1, planes[-1], 7, 7)),
        #    simple_resnet(dimension, num_blocks, planes,
        #                  transpose=True, norm='InstanceNorm', activation='ReLU', pool=False),
        #    nn.ConvTranspose2d(planes[0], planes[0], kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
        #    nn.InstanceNorm2d(planes[0], affine=True), nn.ReLU(),
        #    nn.ConvTranspose2d(planes[0], in_channels, kernel_size=5, stride=1, padding=2, output_padding=0),
        #    nn.Tanh())
        content_encoder: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      planes[0],
                      kernel_size=5,
                      stride=1,
                      padding=2,
                      bias=False), nn.InstanceNorm2d(planes[0], affine=True),
            nn.LeakyReLU(0.01), nn.MaxPool2d(kernel_size=3,
                                             stride=2,
                                             padding=1),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=False,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False), nn.Flatten(start_dim=1),
            nn.Linear(planes[-1] * 7 * 7, content_dim))
        decoder: nn.Module = nn.Sequential(
            nn.Linear(latent_dim, planes[-1] * 7 * 7),
            View((-1, planes[-1], 7, 7)),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=True,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False),
            nn.ConvTranspose2d(planes[0],
                               planes[0],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1,
                               bias=False),
            nn.InstanceNorm2d(planes[0], affine=True), nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(planes[0],
                               in_channels,
                               kernel_size=5,
                               stride=1,
                               padding=2,
                               output_padding=0), nn.Tanh())

        content_disc: nn.Module = nn.Sequential(
            nn.Linear(content_dim, 256, bias=False), nn.BatchNorm1d(256),
            nn.LeakyReLU(0.01), nn.Linear(256, 64, bias=False),
            nn.BatchNorm1d(64), nn.LeakyReLU(0.01), nn.Linear(64, 1,
                                                              bias=True))
        scaler: Scaler = Scaler(2., 0.5)
        #source_disc: nn.Module = nn.Sequential(
        #    nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1, bias=False), nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
        #    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), nn.InstanceNorm2d(128, affine=True), nn.LeakyReLU(0.01),
        #    nn.Flatten(), nn.Linear(7*7*128, 1, bias=True))
        source_disc: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      64,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False), nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(128, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1,
                      bias=False), nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.01), nn.Flatten(),
            nn.Linear(7 * 7 * 128, 1, bias=True))

        super().__init__(device, content_encoder, decoder, content_disc,
                         source_disc, scaler)

        self._content_dim = content_dim
        self.apply(weights_init_resnet)
Exemple #10
0
    def __init__(self, device) -> None:
        dimension = 2
        in_channels = 3
        out_channels = 3
        latent_dim = 256
        num_blocks = [4, 4]
        planes = [64, 128, 256]

        encoder: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      planes[0],
                      kernel_size=7,
                      stride=1,
                      padding=3,
                      bias=False), nn.InstanceNorm2d(planes[0], affine=True),
            nn.LeakyReLU(0.01), nn.MaxPool2d(kernel_size=3,
                                             stride=2,
                                             padding=1),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=False,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False), Permute((0, 2, 3, 1)),
            nn.Linear(planes[-1], latent_dim))
        decoder: nn.Module = nn.Sequential(
            nn.Linear(latent_dim, planes[-1]), Permute((0, 3, 1, 2)),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=True,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False),
            nn.ConvTranspose2d(planes[0],
                               planes[0],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1,
                               bias=False),
            nn.InstanceNorm2d(planes[0], affine=True), nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(planes[0],
                               out_channels,
                               kernel_size=7,
                               stride=1,
                               padding=3,
                               output_padding=0), nn.Tanh())
        style_w: nn.Module = nn.Linear(latent_dim, latent_dim, bias=False)
        content_disc: nn.Module = nn.Sequential(
            Permute((0, 3, 1, 2)),
            nn.Conv2d(latent_dim,
                      64,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=True), nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Flatten(), spectral_norm(nn.Linear(4 * 4 * 64, 1, bias=False)))
        style_disc: nn.Module = nn.Sequential(
            Permute((0, 3, 1, 2)),
            nn.Conv2d(latent_dim,
                      64,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=True), nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Flatten(), spectral_norm(nn.Linear(4 * 4 * 64, 1, bias=False)))
        scaler: Scaler = Scaler(2., 0.5)

        super().__init__(device, encoder, decoder, style_w, content_disc,
                         style_disc, scaler)

        self._source_disc: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      32,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False), nn.InstanceNorm2d(32, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(32, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1,
                      bias=False), nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.01), nn.Flatten(),
            spectral_norm(nn.Linear(4 * 4 * 128, 1, bias=False)))
        self._reference_disc: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      32,
                      kernel_size=4,
                      stride=2,
                      padding=1,
                      bias=False), nn.InstanceNorm2d(32, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(32, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1,
                      bias=False), nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.01), nn.Flatten(),
            spectral_norm(nn.Linear(4 * 4 * 128, 1, bias=False)))

        self._content_seg_disc: nn.Module = nn.Sequential(
            Permute((0, 3, 1, 2)),
            nn.Conv2d(latent_dim,
                      128,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True), nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(128, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(64,
                               64,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1,
                               bias=True), nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.01),
            spectral_norm(
                nn.Conv2d(64,
                          15,
                          kernel_size=7,
                          stride=1,
                          padding=3,
                          bias=False)))
        self._style_seg_disc: nn.Module = nn.Sequential(
            Permute((0, 3, 1, 2)),
            nn.Conv2d(latent_dim,
                      128,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=True), nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.01),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(128, affine=True), nn.LeakyReLU(0.01),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(64, affine=True), nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(64,
                               64,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1,
                               bias=True), nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.01),
            spectral_norm(
                nn.Conv2d(64,
                          15,
                          kernel_size=7,
                          stride=1,
                          padding=3,
                          bias=False)))

        self._source_criterion = nn.BCEWithLogitsLoss()
        self._reference_criterion = nn.BCEWithLogitsLoss()

        self._content_seg_criterion = nn.CrossEntropyLoss()
        self._style_seg_criterion = nn.CrossEntropyLoss()

        self.to(self._device)
        self.apply(weights_init_resnet)
Exemple #11
0
    def __init__(self, device) -> None:
        dimension = 2
        in_channels = 3
        latent_dim = 1024
        num_blocks = [4]
        planes = [64, 64]

        encoder: nn.Module = nn.Sequential(
            nn.Conv2d(in_channels,
                      planes[0],
                      kernel_size=5,
                      stride=1,
                      padding=2,
                      bias=False), nn.InstanceNorm2d(planes[0], affine=True),
            nn.LeakyReLU(0.01), nn.MaxPool2d(kernel_size=3,
                                             stride=2,
                                             padding=1),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=False,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False), nn.Flatten(start_dim=1),
            nn.Linear(planes[-1] * 7 * 7, latent_dim))
        decoder: nn.Module = nn.Sequential(
            nn.Linear(latent_dim, planes[-1] * 7 * 7),
            View((-1, planes[-1], 7, 7)),
            simple_resnet(dimension,
                          num_blocks,
                          planes,
                          transpose=True,
                          norm='InstanceNorm',
                          activation='LeakyReLU',
                          pool=False),
            nn.ConvTranspose2d(planes[0],
                               planes[0],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1,
                               bias=False),
            nn.InstanceNorm2d(planes[0], affine=True), nn.LeakyReLU(0.01),
            nn.ConvTranspose2d(planes[0],
                               in_channels,
                               kernel_size=5,
                               stride=1,
                               padding=2,
                               output_padding=0), nn.Tanh())
        style_w: nn.Module = nn.Linear(latent_dim, latent_dim, bias=False)
        #content_disc: nn.Module = nn.Sequential(
        #    spectral_norm(nn.Linear(latent_dim, 256, bias=False)), nn.LeakyReLU(0.01),
        #    spectral_norm(nn.Linear(256, 64, bias=False)), nn.LeakyReLU(0.01),
        #    nn.Linear(64, 1))
        #style_disc: nn.Module = nn.Sequential(
        #    #nn.Dropout(p=0.5, inplace=False),
        #    spectral_norm(nn.Linear(latent_dim, 256, bias=False)), nn.LeakyReLU(0.01),
        #    spectral_norm(nn.Linear(256, 64, bias=False)), nn.LeakyReLU(0.01),
        #    nn.Linear(64, 1))
        content_disc: nn.Module = nn.Sequential(
            nn.Linear(latent_dim, 256, bias=False), nn.BatchNorm1d(256),
            nn.LeakyReLU(0.01), nn.Linear(256, 64, bias=False),
            nn.BatchNorm1d(64), nn.LeakyReLU(0.01), nn.Linear(64, 1,
                                                              bias=True))
        style_disc: nn.Module = nn.Sequential(
            #nn.Dropout(p=0.5, inplace=False),
            nn.Linear(latent_dim, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 64, bias=False),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.01),
            nn.Linear(64, 1, bias=True))
        scaler: Scaler = Scaler(2., 0.5)

        super().__init__(device, encoder, decoder, style_w, content_disc,
                         style_disc, scaler)
        self.apply(weights_init_resnet)