예제 #1
0
    def get_conv_layer(self, ni, upsample=False, blur_op=None, append_nl=True):
        upsampler = []
        if upsample:
            upsampler.append(self.upsampler)

        if self.use_noise or blur_op is not None:
            conv = Conv2dEx(ni=ni,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=self.equalized_lr,
                            include_bias=False)
            bias = [Conv2dBias(nf=self.fmap)]
        else:
            conv = Conv2dEx(ni=ni,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=self.equalized_lr,
                            include_bias=True)
            bias = []

        blur = []
        if blur_op is not None:
            assert isinstance(blur_op, nn.Module)
            blur.append(blur_op)

        noise = None
        if self.use_noise:
            noise = StyleAddNoise(nf=self.fmap)

        nl = []
        if append_nl:
            nl.append(self.nl)

        norms = []
        if self.use_pixelnorm:
            norms.append(NormalizeLayer('PixelNorm'))
        if self.use_instancenorm:
            norms.append(NormalizeLayer('InstanceNorm'))

        w_to_style = LinearEx(nin_feat=self.z_to_w.dims[-1],
                              nout_feat=2 * self.fmap,
                              init='He',
                              init_type='StyleGAN',
                              gain_sq_base=1.,
                              equalized_lr=self.equalized_lr)

        return nn.ModuleList([
            nn.Sequential(*upsampler, conv, *blur), noise,
            nn.Sequential(*(bias + nl + norms)), w_to_style
        ])
예제 #2
0
    def __init__(self,
                 len_latent=512,
                 len_dlatent=512,
                 num_fcs=8,
                 lrmul=.01,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 equalized_lr=True,
                 normalize_z=True):
        super(StyleMappingNetwork, self).__init__()

        if normalize_z:
            self.preprocess_z = nn.Sequential(
                Lambda(lambda x: x.view(-1, len_latent)),
                NormalizeLayer('PixelNorm'))
        else:
            self.preprocess_z = Lambda(lambda x: x.view(-1, len_latent))

        self.dims = np.linspace(len_latent, len_dlatent,
                                num_fcs + 1).astype(np.int64)
        self.fc_mapping_model = nn.Sequential()
        for seq_n in range(num_fcs):
            self.fc_mapping_model.add_module(
                'fc_' + str(seq_n),
                LinearEx(nin_feat=self.dims[seq_n],
                         nout_feat=self.dims[seq_n + 1],
                         init='He',
                         init_type='StyleGAN',
                         gain_sq_base=2.,
                         equalized_lr=equalized_lr,
                         lrmul=lrmul))
            self.fc_mapping_model.add_module('nl_' + str(seq_n), nl)
예제 #3
0
    def __init__(self,
                 len_latent=128,
                 fmap=FMAP_G,
                 upsampler=nn.Upsample(scale_factor=2, mode='nearest'),
                 blur_type=None,
                 nl=nn.ReLU(),
                 num_classes=0,
                 equalized_lr=False):
        super(Generator64PixResnet, self).__init__(64)

        self.len_latent = len_latent
        self.num_classes = num_classes

        self.equalized_lr = equalized_lr

        _fmap_init_64 = len_latent * FMAP_G_INIT_64_FCTR
        self.generator_model = nn.Sequential(
          Lambda( lambda x: x.view( -1, len_latent + num_classes ) ),
          LinearEx( nin_feat = len_latent + num_classes, nout_feat = _fmap_init_64 * RES_INIT**2,
                    init = 'Xavier', equalized_lr = equalized_lr ),
          Lambda( lambda x: x.view( -1, _fmap_init_64, RES_INIT, RES_INIT ) ),
          ResBlock2d( ni = _fmap_init_64, nf = 8*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 8*fmap, nf = 4*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 4*fmap, nf = 2*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          ResBlock2d( ni = 2*fmap, nf = 1*fmap, ks = 3, norm_type = 'BatchNorm', upsampler = upsampler, \
            init = 'He', nl = nl, equalized_lr = equalized_lr, blur_type = blur_type ),
          NormalizeLayer( 'BatchNorm', ni = 1*fmap ),
          nl,
          Conv2dEx( ni = 1*fmap, nf = FMAP_SAMPLES, ks = 3, stride = 1, padding = 1, init = 'He', equalized_lr = equalized_lr ),
          nn.Tanh()
        )
예제 #4
0
    def get_conv_layer(self, ni, upsample=False, blur_op=None, append_nl=True):
        upsampler = []
        if upsample:
            upsampler.append(self.upsampler)

        if blur_op is not None:
            conv = Conv2dEx(ni=ni,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='ProGAN',
                            gain_sq_base=2.,
                            equalized_lr=self.equalized_lr,
                            include_bias=False)

            assert isinstance(blur_op, nn.Module)
            blur = [blur_op]

            bias = [Conv2dBias(nf=self.fmap)]
        else:
            conv = Conv2dEx(ni=ni,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='ProGAN',
                            gain_sq_base=2.,
                            equalized_lr=self.equalized_lr,
                            include_bias=True)
            blur = []
            bias = []

        nl = []
        if append_nl:
            nl.append(self.nl)

        norm = []
        if self.use_pixelnorm:
            norm.append(NormalizeLayer('PixelNorm'))

        return nn.Sequential(*upsampler, conv, *(blur + bias + nl + norm))
예제 #5
0
    def __init__(self,
                 num_classes,
                 len_latent=512,
                 len_dlatent=512,
                 num_fcs=8,
                 lrmul=.01,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 equalized_lr=True,
                 normalize_z=True,
                 embed_cond_vars=True):
        super(StyleConditionedMappingNetwork, self).__init__()

        self.len_latent = len_latent
        self.num_classes = num_classes

        self.embed_cond_vars = embed_cond_vars
        if embed_cond_vars:
            self.class_embedding = LinearEx(nin_feat=num_classes,
                                            nout_feat=len_latent,
                                            init=None,
                                            init_type='Standard Normal',
                                            include_bias=False)

        self.dims = np.linspace(len_latent, len_dlatent,
                                num_fcs).astype(np.int64)
        self.dims = np.insert(
            self.dims, 0, 2 *
            len_latent if self.embed_cond_vars else len_latent + num_classes)
        self.fc_mapping_model = nn.Sequential()
        if normalize_z:
            self.fc_mapping_model.add_module('pixelnorm',
                                             NormalizeLayer('PixelNorm'))
        for seq_n in range(num_fcs):
            self.fc_mapping_model.add_module(
                'fc_' + str(seq_n),
                LinearEx(nin_feat=self.dims[seq_n],
                         nout_feat=self.dims[seq_n + 1],
                         init='He',
                         init_type='StyleGAN',
                         gain_sq_base=2.,
                         equalized_lr=equalized_lr,
                         lrmul=lrmul))
            self.fc_mapping_model.add_module('nl_' + str(seq_n), nl)
예제 #6
0
    def __init__(self,
                 final_res,
                 len_latent=512,
                 upsampler=nn.Upsample(scale_factor=2, mode='nearest'),
                 blur_type=None,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 num_classes=0,
                 equalized_lr=True,
                 normalize_z=True,
                 use_pixelnorm=True):

        super(self.__class__, self).__init__(final_res)

        self.gen_blocks = nn.ModuleList()

        self.upsampler = upsampler
        self.upsampler_skip_connection = \
          lambda xb: F.interpolate( xb, scale_factor = 2, mode = 'nearest' )  # keep fading-in layers simple

        self.gen_blur_type = blur_type

        self.nl = nl

        self.len_latent = len_latent
        self.num_classes = num_classes

        self.equalized_lr = equalized_lr

        norms = []
        self.use_pixelnorm = use_pixelnorm
        if use_pixelnorm:
            norms.append(NormalizeLayer('PixelNorm'))

        if normalize_z:
            self.preprocess_z = nn.Sequential(
                Lambda(lambda x: x.view(-1, len_latent + num_classes)),
                NormalizeLayer('PixelNorm'))
        else:
            self.preprocess_z = Lambda(
                lambda x: x.view(-1, len_latent + num_classes))

        _fmap_init = len_latent * FMAP_G_INIT_FCTR
        self.gen_blocks.append(
            nn.Sequential(
                LinearEx(
                    nin_feat=len_latent + num_classes,
                    nout_feat=_fmap_init * RES_INIT**2,
                    init='He',
                    init_type='ProGAN',
                    gain_sq_base=2. / 16,
                    equalized_lr=equalized_lr
                ),  # this can be done with a tranpose conv layer as well (efficiency)
                Lambda(lambda x: x.view(-1, _fmap_init, RES_INIT, RES_INIT)),
                nl,
                *norms,
                Conv2dEx(ni=_fmap_init,
                         nf=self.fmap,
                         ks=3,
                         stride=1,
                         padding=1,
                         init='He',
                         init_type='ProGAN',
                         gain_sq_base=2.,
                         equalized_lr=equalized_lr),
                nl,
                *norms))

        self.prev_torgb = None
        self._update_torgb(ni=self.fmap)
예제 #7
0
    def __init__(self,
                 final_res,
                 latent_distribution='normal',
                 len_latent=512,
                 len_dlatent=512,
                 mapping_num_fcs=8,
                 mapping_lrmul=.01,
                 use_instancenorm=True,
                 use_noise=True,
                 upsampler=nn.Upsample(scale_factor=2, mode='nearest'),
                 blur_type=None,
                 nl=nn.LeakyReLU(negative_slope=.2),
                 num_classes=0,
                 equalized_lr=True,
                 normalize_z=True,
                 use_pixelnorm=False,
                 pct_mixing_reg=.9,
                 truncation_trick_params={
                     'beta': .995,
                     'psi': .7,
                     'cutoff_stage': 4
                 }):

        super(self.__class__, self).__init__(final_res)

        self.gen_layers = nn.ModuleList()

        self.upsampler = upsampler
        self.upsampler_skip_connection = \
          lambda xb: F.interpolate( xb, scale_factor = 2, mode = 'nearest' )  # keep fading-in layers simple

        self.gen_blur_type = blur_type

        self.nl = nl

        self.equalized_lr = equalized_lr

        self.pct_mixing_reg = pct_mixing_reg
        self._use_mixing_reg = True if pct_mixing_reg else False

        self.latent_distribution = latent_distribution
        self.len_latent = len_latent
        self.len_dlatent = len_dlatent
        assert isinstance(num_classes, int)
        self.num_classes = num_classes

        # Mapping Network initialization:
        if not num_classes:
            self.z_to_w = StyleMappingNetwork(len_latent=len_latent,
                                              len_dlatent=len_dlatent,
                                              num_fcs=mapping_num_fcs,
                                              lrmul=mapping_lrmul,
                                              nl=nl,
                                              equalized_lr=equalized_lr,
                                              normalize_z=normalize_z)
        else:
            self.z_to_w = StyleConditionedMappingNetwork(
                num_classes,
                len_latent=len_latent,
                len_dlatent=len_dlatent,
                num_fcs=mapping_num_fcs,
                lrmul=mapping_lrmul,
                nl=nl,
                equalized_lr=equalized_lr,
                normalize_z=normalize_z)

        _fmap_init = len_latent * FMAP_G_INIT_FCTR

        # initializing the input to 1 has about the same effect as applyng PixelNorm to the input
        self.const_input = nn.Parameter(
            torch.FloatTensor(1, _fmap_init, RES_INIT, RES_INIT).fill_(1))

        self._use_noise = use_noise
        self._trained_with_noise = use_noise
        if use_noise:
            conv = Conv2dEx(ni=_fmap_init,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=equalized_lr,
                            include_bias=False)
            noise = [
                StyleAddNoise(nf=_fmap_init),
                StyleAddNoise(nf=self.fmap),
            ]
            bias = (
                [Conv2dBias(nf=_fmap_init)],
                [Conv2dBias(nf=self.fmap)],
            )
        else:
            conv = Conv2dEx(ni=_fmap_init,
                            nf=self.fmap,
                            ks=3,
                            stride=1,
                            padding=1,
                            init='He',
                            init_type='StyleGAN',
                            gain_sq_base=2.,
                            equalized_lr=equalized_lr,
                            include_bias=True)
            # noise = ( [], [], )
            noise = [None, None]
            bias = (
                [],
                [],
            )  # NOTE: without noise, the bias would get directly added to the constant input, so the constant input can just learn this bias,
            #       so theoretically, there shouldn't be a need to include the bias either. There may be numerical approximation problems from backprop, however.

        norms = []
        self.use_pixelnorm = use_pixelnorm
        if use_pixelnorm:
            norms.append(NormalizeLayer('PixelNorm'))
        self.use_instancenorm = use_instancenorm
        if use_instancenorm:
            norms.append(NormalizeLayer('InstanceNorm'))

        w_to_styles = (
            LinearEx(nin_feat=self.z_to_w.dims[-1],
                     nout_feat=2 * _fmap_init,
                     init='He',
                     init_type='StyleGAN',
                     gain_sq_base=1.,
                     equalized_lr=equalized_lr),
            LinearEx(nin_feat=self.z_to_w.dims[-1],
                     nout_feat=2 * self.fmap,
                     init='He',
                     init_type='StyleGAN',
                     gain_sq_base=1.,
                     equalized_lr=equalized_lr),
        )
        assert 0. <= truncation_trick_params['beta'] <= 1.
        self.w_ewma_beta = truncation_trick_params['beta']
        self._w_eval_psi = truncation_trick_params[
            'psi']  # allow psi to be any number you want, perhaps worthy of experimentation
        assert ( ( isinstance( truncation_trick_params[ 'cutoff_stage' ], int ) and \
                    0 < truncation_trick_params[ 'cutoff_stage' ] <= int( np.log2( self.final_res ) ) - 2 ) or \
                    truncation_trick_params[ 'cutoff_stage' ] is None )
        self._trunc_cutoff_stage = truncation_trick_params['cutoff_stage']
        # set the below to `False` if you want to turn off during evaluation mode
        self.use_truncation_trick = True if self._trunc_cutoff_stage else False
        self.w_ewma = None

        self.gen_layers.append(
            nn.ModuleList([
                None, noise[0],
                nn.Sequential(*bias[0], nl, *norms), w_to_styles[0]
            ]))
        self.gen_layers.append(
            nn.ModuleList([
                conv, noise[1],
                nn.Sequential(*bias[1], nl, *norms), w_to_styles[1]
            ]))

        self.prev_torgb = None
        self._update_torgb(ni=self.fmap)
예제 #8
0
    def __init__(self,
                 ni,
                 nf,
                 ks,
                 norm_type,
                 upsampler=None,
                 pooler=None,
                 init='He',
                 nl=nn.ReLU(),
                 res=None,
                 flip_sampling=False,
                 equalized_lr=False,
                 blur_type=None):
        super(ResBlock2d, self).__init__()

        assert not (upsampler is not None and pooler is not None)

        padding = (ks - 1) // 2  # 'SAME' padding for stride 1 conv

        if not flip_sampling:
            self.nif = nf if (upsampler is not None and pooler is None) else ni
        else:
            self.nif = ni if (upsampler is None and pooler is not None) else nf
        self.convs = (
            Conv2dEx(ni,
                     self.nif,
                     ks=ks,
                     stride=1,
                     padding=padding,
                     init=init,
                     equalized_lr=equalized_lr),
            Conv2dEx(self.nif,
                     nf,
                     ks=ks,
                     stride=1,
                     padding=padding,
                     init=init,
                     equalized_lr=equalized_lr),
            Conv2dEx(ni,
                     nf,
                     ks=1,
                     stride=1,
                     padding=0,
                     init='Xavier',
                     equalized_lr=equalized_lr),  # this is same as a FC layer
        )

        blur_op = get_blur_op(
            blur_type=blur_type,
            num_channels=self.convs[0].nf) if blur_type is not None else None

        _norm_nls = (
            [NormalizeLayer(norm_type, ni=ni, res=res), nl],
            [NormalizeLayer(norm_type, ni=self.convs[0].nf, res=res), nl],
        )

        if upsampler is not None:
            _mostly_linear_op_1 = [
                upsampler, self.convs[0], blur_op
            ] if blur_type is not None else [upsampler, self.convs[0]]
            _mostly_linear_op_2 = [
                upsampler, self.convs[2], blur_op
            ] if blur_type is not None else [upsampler, self.convs[2]]
            _ops = (
                _mostly_linear_op_1,
                [self.convs[1]],
                _mostly_linear_op_2,
            )
        elif pooler is not None:
            _mostly_linear_op_1 = [
                blur_op, self.convs[1], pooler
            ] if blur_type is not None else [self.convs[1], pooler]
            _mostly_linear_op_2 = [
                blur_op, pooler, self.convs[2]
            ] if blur_type is not None else [pooler, self.convs[2]]
            _ops = (
                [self.convs[0]],
                _mostly_linear_op_1,
                _mostly_linear_op_2,
            )
        else:
            _ops = (
                [self.convs[0]],
                [self.convs[1]],
                [self.convs[2]],
            )

        self.conv_layer_1 = nn.Sequential(*(_norm_nls[0] + _ops[0]))
        self.conv_layer_2 = nn.Sequential(*(_norm_nls[1] + _ops[1]))

        if (upsampler is not None or pooler is not None) or ni != nf:
            self.skip_connection = nn.Sequential(*(_ops[2]))
        else:
            self.skip_connection = Lambda(lambda x: x)