예제 #1
0
    def get_conv_layer(self, ni, upsample=False, blur_op=None, append_nl=True):
        upsampler = []
        if upsample:
            upsampler.append(self.upsampler)

        if self.use_noise or blur_op is not None:
            conv = Conv2dEx(ni=ni,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=self.equalized_lr,
                            include_bias=False)
            bias = [Conv2dBias(nf=self.fmap)]
        else:
            conv = Conv2dEx(ni=ni,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=self.equalized_lr,
                            include_bias=True)
            bias = []

        blur = []
        if blur_op is not None:
            assert isinstance(blur_op, nn.Module)
            blur.append(blur_op)

        noise = None
        if self.use_noise:
            noise = StyleAddNoise(nf=self.fmap)

        nl = []
        if append_nl:
            nl.append(self.nl)

        norms = []
        if self.use_pixelnorm:
            norms.append(NormalizeLayer('PixelNorm'))
        if self.use_instancenorm:
            norms.append(NormalizeLayer('InstanceNorm'))

        w_to_style = LinearEx(nin_feat=self.z_to_w.dims[-1],
                              nout_feat=2 * self.fmap,
                              init='He',
                              init_type='StyleGAN',
                              gain_sq_base=1.,
                              equalized_lr=self.equalized_lr)

        return nn.ModuleList([
            nn.Sequential(*upsampler, conv, *blur), noise,
            nn.Sequential(*(bias + nl + norms)), w_to_style
        ])
예제 #2
0
    def __init__(self,
                 len_latent=128,
                 fmap=FMAP_G,
                 upsampler=nn.Upsample(scale_factor=2, mode='nearest'),
                 blur_type=None,
                 nl=nn.ReLU(),
                 num_classes=0,
                 equalized_lr=False):
        super(Generator64PixResnet, self).__init__(64)

        self.len_latent = len_latent
        self.num_classes = num_classes

        self.equalized_lr = equalized_lr

        _fmap_init_64 = len_latent * FMAP_G_INIT_64_FCTR
        self.generator_model = nn.Sequential(
          Lambda( lambda x: x.view( -1, len_latent + num_classes ) ),
          LinearEx( nin_feat = len_latent + num_classes, nout_feat = _fmap_init_64 * RES_INIT**2,
                    init = 'Xavier', equalized_lr = equalized_lr ),
          Lambda( lambda x: x.view( -1, _fmap_init_64, RES_INIT, RES_INIT ) ),
          ResBlock2d( ni = _fmap_init_64, nf = 8*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 8*fmap, nf = 4*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 4*fmap, nf = 2*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 2*fmap, nf = 1*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          NormalizeLayer( 'BatchNorm', ni = 1*fmap ),
          nl,
          Conv2dEx( ni = 1*fmap, nf = FMAP_SAMPLES, ks = 3, stride = 1, padding = 1, init = 'He', equalized_lr = equalized_lr ),
          nn.Tanh()
        )
예제 #3
0
    def __init__(self,
                 len_latent=512,
                 len_dlatent=512,
                 num_fcs=8,
                 lrmul=.01,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 equalized_lr=True,
                 normalize_z=True):
        super(StyleMappingNetwork, self).__init__()

        if normalize_z:
            self.preprocess_z = nn.Sequential(
                Lambda(lambda x: x.view(-1, len_latent)),
                NormalizeLayer('PixelNorm'))
        else:
            self.preprocess_z = Lambda(lambda x: x.view(-1, len_latent))

        self.dims = np.linspace(len_latent, len_dlatent,
                                num_fcs + 1).astype(np.int64)
        self.fc_mapping_model = nn.Sequential()
        for seq_n in range(num_fcs):
            self.fc_mapping_model.add_module(
                'fc_' + str(seq_n),
                LinearEx(nin_feat=self.dims[seq_n],
                         nout_feat=self.dims[seq_n + 1],
                         init='He',
                         init_type='StyleGAN',
                         gain_sq_base=2.,
                         equalized_lr=equalized_lr,
                         lrmul=lrmul))
            self.fc_mapping_model.add_module('nl_' + str(seq_n), nl)
예제 #4
0
    def __init__(self,
                 fmap=FMAP_D * 2,
                 pooler=nn.AvgPool2d(kernel_size=2, stride=2),
                 blur_type=None,
                 nl=nn.ReLU(),
                 num_classes=0,
                 equalized_lr=False):
        super(DiscriminatorAC32PixResnet,
              self).__init__(fmap, pooler, nl, num_classes, equalized_lr,
                             blur_type)

        self.view1 = Lambda(
            lambda x: x.view(-1, FMAP_SAMPLES, self.res, self.res))
        self.conv1 = FastResBlock2dDownsample(ni=FMAP_SAMPLES,
                                              nf=fmap,
                                              ks=3,
                                              pooler=pooler,
                                              init='Xavier',
                                              nl=nl,
                                              equalized_lr=equalized_lr,
                                              blur_type=blur_type)
        self.linear_aux = LinearEx(nin_feat=fmap,
                                   nout_feat=num_classes,
                                   init='Xavier',
                                   equalized_lr=equalized_lr)
예제 #5
0
    def __init__(self,
                 num_classes,
                 len_latent=512,
                 len_dlatent=512,
                 num_fcs=8,
                 lrmul=.01,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 equalized_lr=True,
                 normalize_z=True,
                 embed_cond_vars=True):
        super(StyleConditionedMappingNetwork, self).__init__()

        self.len_latent = len_latent
        self.num_classes = num_classes

        self.embed_cond_vars = embed_cond_vars
        if embed_cond_vars:
            self.class_embedding = LinearEx(nin_feat=num_classes,
                                            nout_feat=len_latent,
                                            init=None,
                                            init_type='Standard Normal',
                                            include_bias=False)

        self.dims = np.linspace(len_latent, len_dlatent,
                                num_fcs).astype(np.int64)
        self.dims = np.insert(
            self.dims, 0, 2 *
            len_latent if self.embed_cond_vars else len_latent + num_classes)
        self.fc_mapping_model = nn.Sequential()
        if normalize_z:
            self.fc_mapping_model.add_module('pixelnorm',
                                             NormalizeLayer('PixelNorm'))
        for seq_n in range(num_fcs):
            self.fc_mapping_model.add_module(
                'fc_' + str(seq_n),
                LinearEx(nin_feat=self.dims[seq_n],
                         nout_feat=self.dims[seq_n + 1],
                         init='He',
                         init_type='StyleGAN',
                         gain_sq_base=2.,
                         equalized_lr=equalized_lr,
                         lrmul=lrmul))
            self.fc_mapping_model.add_module('nl_' + str(seq_n), nl)
예제 #6
0
    def __init__(self,
                 fmap=FMAP_D * 2,
                 pooler=nn.AvgPool2d(kernel_size=2, stride=2),
                 blur_type=None,
                 nl=nn.ReLU(),
                 num_classes=0,
                 equalized_lr=False):
        super(Discriminator32PixResnet, self).__init__(32)

        self.num_classes = num_classes
        self.equalized_lr = equalized_lr

        self.view1 = Lambda(lambda x: x.view(-1, FMAP_SAMPLES + num_classes,
                                             self.res, self.res))
        self.conv1 = FastResBlock2dDownsample(ni=FMAP_SAMPLES + num_classes,
                                              nf=fmap,
                                              ks=3,
                                              pooler=pooler,
                                              init='Xavier',
                                              nl=nl,
                                              equalized_lr=equalized_lr,
                                              blur_type=blur_type)
        self.resblocks = nn.Sequential(
          ResBlock2d32Pix( ni = fmap, nf = fmap, ks = 3, norm_type = 'LayerNorm', pooler = pooler, \
            init = 'He', nl = nl, res = self.res//2, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d32Pix( ni = fmap, nf = fmap, ks = 3, norm_type = 'LayerNorm', \
            init = 'He', nl = nl, res = self.res//4, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d32Pix( ni = fmap, nf = fmap, ks = 3, norm_type = 'LayerNorm', \
            init = 'He', nl = nl, res = self.res//4, equalized_lr = equalized_lr, blur_type = blur_type ),
          nl,
          nn.AvgPool2d( kernel_size = self.res//4, stride = self.res//4 ),
          Lambda( lambda x: x.view( -1, fmap ) )  # final feature space
          # NormalizeLayer( 'LayerNorm', ni = fmap, res = 1 )
        )
        self.linear1 = LinearEx(nin_feat=fmap,
                                nout_feat=1,
                                init='Xavier',
                                equalized_lr=equalized_lr)
예제 #7
0
    def __init__(self,
                 fmap=FMAP_D,
                 pooler=nn.AvgPool2d(kernel_size=2, stride=2),
                 blur_type=None,
                 nl=nn.ReLU(),
                 num_classes=0,
                 equalized_lr=False):
        super(Discriminator64PixResnet, self).__init__(64)

        self.num_classes = num_classes
        self.equalized_lr = equalized_lr

        self.view1 = Lambda(lambda x: x.view(-1, FMAP_SAMPLES + num_classes,
                                             self.res, self.res))
        self.conv1 = Conv2dEx(ni=FMAP_SAMPLES + num_classes,
                              nf=1 * fmap,
                              ks=3,
                              stride=1,
                              padding=1,
                              init='Xavier',
                              equalized_lr=equalized_lr)
        self.resblocks = nn.Sequential(
          ResBlock2d( ni = 1*fmap, nf = 2*fmap, ks = 3, norm_type = 'LayerNorm', pooler = pooler, \
            init = 'He', nl = nl, res = self.res//1, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 2*fmap, nf = 4*fmap, ks = 3, norm_type = 'LayerNorm', pooler = pooler, \
            init = 'He', nl = nl, res = self.res//2, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 4*fmap, nf = 8*fmap, ks = 3, norm_type = 'LayerNorm', pooler = pooler, \
            init = 'He', nl = nl, res = self.res//4, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 8*fmap, nf = 8*fmap, ks = 3, norm_type = 'LayerNorm', pooler = pooler, \
            init = 'He', nl = nl, res = self.res//8, equalized_lr = equalized_lr, blur_type = blur_type ),
          Lambda( lambda x: x.view( -1, RES_FEATURE_SPACE**2 * 8*fmap ) )  # final feature space
          # NormalizeLayer( 'LayerNorm', ni = fmap, res = 1 )
        )
        self.linear1 = LinearEx(nin_feat=RES_FEATURE_SPACE**2 * 8 * fmap,
                                nout_feat=1,
                                init='Xavier',
                                equalized_lr=equalized_lr)
예제 #8
0
    def __init__(self,
                 fmap=FMAP_D,
                 pooler=nn.AvgPool2d(kernel_size=2, stride=2),
                 blur_type=None,
                 nl=nn.ReLU(),
                 num_classes=0,
                 equalized_lr=False):
        super(DiscriminatorAC64PixResnet,
              self).__init__(fmap, pooler, nl, num_classes, equalized_lr,
                             blur_type)

        self.view1 = Lambda(
            lambda x: x.view(-1, FMAP_SAMPLES, self.res, self.res))
        self.conv1 = Conv2dEx(ni=FMAP_SAMPLES,
                              nf=1 * fmap,
                              ks=3,
                              stride=1,
                              padding=1,
                              init='Xavier',
                              equalized_lr=equalized_lr)
        self.linear_aux = LinearEx(nin_feat=RES_FEATURE_SPACE**2 * 8 * fmap,
                                   nout_feat=num_classes,
                                   init='Xavier',
                                   equalized_lr=equalized_lr)
예제 #9
0
    def __init__(self,
                 final_res,
                 len_latent=512,
                 upsampler=nn.Upsample(scale_factor=2, mode='nearest'),
                 blur_type=None,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 num_classes=0,
                 equalized_lr=True,
                 normalize_z=True,
                 use_pixelnorm=True):

        super(self.__class__, self).__init__(final_res)

        self.gen_blocks = nn.ModuleList()

        self.upsampler = upsampler
        self.upsampler_skip_connection = \
          lambda xb: F.interpolate( xb, scale_factor = 2, mode = 'nearest' )  # keep fading-in layers simple

        self.gen_blur_type = blur_type

        self.nl = nl

        self.len_latent = len_latent
        self.num_classes = num_classes

        self.equalized_lr = equalized_lr

        norms = []
        self.use_pixelnorm = use_pixelnorm
        if use_pixelnorm:
            norms.append(NormalizeLayer('PixelNorm'))

        if normalize_z:
            self.preprocess_z = nn.Sequential(
                Lambda(lambda x: x.view(-1, len_latent + num_classes)),
                NormalizeLayer('PixelNorm'))
        else:
            self.preprocess_z = Lambda(
                lambda x: x.view(-1, len_latent + num_classes))

        _fmap_init = len_latent * FMAP_G_INIT_FCTR
        self.gen_blocks.append(
            nn.Sequential(
                LinearEx(
                    nin_feat=len_latent + num_classes,
                    nout_feat=_fmap_init * RES_INIT**2,
                    init='He',
                    init_type='ProGAN',
                    gain_sq_base=2. / 16,
                    equalized_lr=equalized_lr
                ),  # this can be done with a tranpose conv layer as well (efficiency)
                Lambda(lambda x: x.view(-1, _fmap_init, RES_INIT, RES_INIT)),
                nl,
                *norms,
                Conv2dEx(ni=_fmap_init,
                         nf=self.fmap,
                         ks=3,
                         stride=1,
                         padding=1,
                         init='He',
                         init_type='ProGAN',
                         gain_sq_base=2.,
                         equalized_lr=equalized_lr),
                nl,
                *norms))

        self.prev_torgb = None
        self._update_torgb(ni=self.fmap)
예제 #10
0
    def __init__(self,
                 final_res,
                 pooler=nn.AvgPool2d(kernel_size=2, stride=2),
                 blur_type=None,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 num_classes=0,
                 equalized_lr=True,
                 mbstd_group_size=4):

        super(self.__class__, self).__init__(final_res)
        self.init_type = self.cls_base.__name__
        if self.init_type not in (
                'ProGAN',
                'StyleGAN',
        ):
            raise RuntimeError(
                'This class can only inherit from either `ProGAN` or `StyleGAN` base classes currently.'
            )

        self.disc_blocks = nn.ModuleList()

        self.num_classes = num_classes

        self.preprocess_x = Lambda(lambda x: x.view(
            -1, FMAP_SAMPLES + num_classes, self.curr_res, self.curr_res))

        self.pooler = pooler
        self.pooler_skip_connection = \
          lambda xb: F.avg_pool2d( xb, kernel_size = 2, stride = 2 )  # keep fading-in layers simple

        self.disc_blur_type = blur_type

        self.nl = nl

        self.equalized_lr = equalized_lr

        self.mbstd_group_size = mbstd_group_size
        mbstd_layer = self.get_mbstd_layer()

        self.prev_fromrgb = None
        self._update_fromrgb(nf=self.fmap)

        _fmap_end = self.fmap * FMAP_D_END_FCTR
        self.disc_blocks.insert(
            0,
            nn.Sequential(
                *mbstd_layer,
                Conv2dEx(
                    ni=self.fmap + (1 if mbstd_layer else 0),
                    nf=self.fmap,
                    ks=3,
                    stride=1,
                    padding=1,
                    init='He',
                    init_type=self.init_type,
                    gain_sq_base=2.,
                    equalized_lr=equalized_lr
                ),  # this can be done with a linear layer as well (efficiency)
                nl,
                Conv2dEx(ni=self.fmap,
                         nf=_fmap_end,
                         ks=4,
                         stride=1,
                         padding=0,
                         init='He',
                         init_type=self.init_type,
                         gain_sq_base=2.,
                         equalized_lr=equalized_lr),
                nl,
                Lambda(lambda x: x.view(-1, _fmap_end)),
                LinearEx(nin_feat=_fmap_end,
                         nout_feat=1,
                         init='He',
                         init_type=self.init_type,
                         gain_sq_base=1.,
                         equalized_lr=equalized_lr)))
예제 #11
0
    def __init__(self,
                 final_res,
                 latent_distribution='normal',
                 len_latent=512,
                 len_dlatent=512,
                 mapping_num_fcs=8,
                 mapping_lrmul=.01,
                 use_instancenorm=True,
                 use_noise=True,
                 upsampler=nn.Upsample(scale_factor=2, mode='nearest'),
                 blur_type=None,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 num_classes=0,
                 equalized_lr=True,
                 normalize_z=True,
                 use_pixelnorm=False,
                 pct_mixing_reg=.9,
                 truncation_trick_params={
                     'beta': .995,
                     'psi': .7,
                     'cutoff_stage': 4
                 }):

        super(self.__class__, self).__init__(final_res)

        self.gen_layers = nn.ModuleList()

        self.upsampler = upsampler
        self.upsampler_skip_connection = \
          lambda xb: F.interpolate( xb, scale_factor = 2, mode = 'nearest' )  # keep fading-in layers simple

        self.gen_blur_type = blur_type

        self.nl = nl

        self.equalized_lr = equalized_lr

        self.pct_mixing_reg = pct_mixing_reg
        self._use_mixing_reg = True if pct_mixing_reg else False

        self.latent_distribution = latent_distribution
        self.len_latent = len_latent
        self.len_dlatent = len_dlatent
        assert isinstance(num_classes, int)
        self.num_classes = num_classes

        # Mapping Network initialization:
        if not num_classes:
            self.z_to_w = StyleMappingNetwork(len_latent=len_latent,
                                              len_dlatent=len_dlatent,
                                              num_fcs=mapping_num_fcs,
                                              lrmul=mapping_lrmul,
                                              nl=nl,
                                              equalized_lr=equalized_lr,
                                              normalize_z=normalize_z)
        else:
            self.z_to_w = StyleConditionedMappingNetwork(
                num_classes,
                len_latent=len_latent,
                len_dlatent=len_dlatent,
                num_fcs=mapping_num_fcs,
                lrmul=mapping_lrmul,
                nl=nl,
                equalized_lr=equalized_lr,
                normalize_z=normalize_z)

        _fmap_init = len_latent * FMAP_G_INIT_FCTR

        # initializing the input to 1 has about the same effect as applyng PixelNorm to the input
        self.const_input = nn.Parameter(
            torch.FloatTensor(1, _fmap_init, RES_INIT, RES_INIT).fill_(1))

        self._use_noise = use_noise
        self._trained_with_noise = use_noise
        if use_noise:
            conv = Conv2dEx(ni=_fmap_init,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=equalized_lr,
                            include_bias=False)
            noise = [
                StyleAddNoise(nf=_fmap_init),
                StyleAddNoise(nf=self.fmap),
            ]
            bias = (
                [Conv2dBias(nf=_fmap_init)],
                [Conv2dBias(nf=self.fmap)],
            )
        else:
            conv = Conv2dEx(ni=_fmap_init,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=equalized_lr,
                            include_bias=True)
            # noise = ( [], [], )
            noise = [None, None]
            bias = (
                [],
                [],
            )  # NOTE: without noise, the bias would get directly added to the constant input, so the constant input can just learn this bias,
            #       so theoretically, there shouldn't be a need to include the bias either. There may be numerical approximation problems from backprop, however.

        norms = []
        self.use_pixelnorm = use_pixelnorm
        if use_pixelnorm:
            norms.append(NormalizeLayer('PixelNorm'))
        self.use_instancenorm = use_instancenorm
        if use_instancenorm:
            norms.append(NormalizeLayer('InstanceNorm'))

        w_to_styles = (
            LinearEx(nin_feat=self.z_to_w.dims[-1],
                     nout_feat=2 * _fmap_init,
                     init='He',
                     init_type='StyleGAN',
                     gain_sq_base=1.,
                     equalized_lr=equalized_lr),
            LinearEx(nin_feat=self.z_to_w.dims[-1],
                     nout_feat=2 * self.fmap,
                     init='He',
                     init_type='StyleGAN',
                     gain_sq_base=1.,
                     equalized_lr=equalized_lr),
        )
        assert 0. <= truncation_trick_params['beta'] <= 1.
        self.w_ewma_beta = truncation_trick_params['beta']
        self._w_eval_psi = truncation_trick_params[
            'psi']  # allow psi to be any number you want, perhaps worthy of experimentation
        assert ( ( isinstance( truncation_trick_params[ 'cutoff_stage' ], int ) and \
                    0 < truncation_trick_params[ 'cutoff_stage' ] <= int( np.log2( self.final_res ) ) - 2 ) or \
                    truncation_trick_params[ 'cutoff_stage' ] is None )
        self._trunc_cutoff_stage = truncation_trick_params['cutoff_stage']
        # set the below to `False` if you want to turn off during evaluation mode
        self.use_truncation_trick = True if self._trunc_cutoff_stage else False
        self.w_ewma = None

        self.gen_layers.append(
            nn.ModuleList([
                None, noise[0],
                nn.Sequential(*bias[0], nl, *norms), w_to_styles[0]
            ]))
        self.gen_layers.append(
            nn.ModuleList([
                conv, noise[1],
                nn.Sequential(*bias[1], nl, *norms), w_to_styles[1]
            ]))

        self.prev_torgb = None
        self._update_torgb(ni=self.fmap)