def layer(x, layer_idx, size, fmaps, kernel, up=False):
     x = modulated_conv2d_layer(x,
                                dlatents_in[:, layer_idx],
                                fmaps=fmaps,
                                kernel=kernel,
                                up=up,
                                resample_kernel=resample_kernel,
                                fused_modconv=fused_modconv,
                                impl=impl)
     if size is not None and up is True:
         x = fix_size(x, size, scale_type)
         # multi latent blending
         x = multimask(x, size, latmask, countH, countW, splitfine)
     if randomize_noise:
         noise = tf.random_normal(
             [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
     else:
         noise = tf.cast(noise_inputs[layer_idx], x.dtype)
         noise = fix_size(noise, (x.shape[2], x.shape[3]),
                          scale_type=scale_type)
     noise_strength = tf.get_variable('noise_strength',
                                      shape=[],
                                      initializer=tf.initializers.zeros())
     x += noise * tf.cast(noise_strength, x.dtype)
     return apply_bias_act(x, act=act, impl=impl)
 def block(x, res, size):  # res = 3..res_log2
     t = x
     with tf.variable_scope('Conv0_up'):
         x = layer(x,
                   layer_idx=res * 2 - 5,
                   size=size,
                   fmaps=nf(res - 1),
                   kernel=3,
                   up=True)
     with tf.variable_scope('Conv1'):
         x = layer(x,
                   layer_idx=res * 2 - 4,
                   size=size,
                   fmaps=nf(res - 1),
                   kernel=3)
     if architecture == 'resnet':
         with tf.variable_scope('Skip'):
             t = conv2d_layer(t,
                              fmaps=nf(res - 1),
                              kernel=1,
                              up=True,
                              resample_kernel=resample_kernel,
                              impl=impl)
             if size is not None:
                 t = fix_size(t, (x.shape[2], x.shape[3]),
                              scale_type=scale_type)
             x = (x + t) * (1 / np.sqrt(2))
     return x
Example #3
0
    def forward(self,
                x,
                latmask,
                w,
                noise_mode='random',
                fused_modconv=True,
                gain=1):
        # def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
        assert noise_mode in ['random', 'const', 'none']
        in_resolution = self.resolution // self.up
        # misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
        styles = self.affine(w)

        noise = None
        if self.use_noise and noise_mode == 'random':
            # !!! custom
            sz = self.size if self.up == 2 and self.size is not None else x.shape[
                2:]
            noise = torch.randn([x.shape[0], 1, *sz],
                                device=x.device) * self.noise_strength
            # noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
        if self.use_noise and noise_mode == 'const':
            noise = self.noise_const * self.noise_strength
            # !!! custom noise size
            noise_size = self.size if self.up == 2 and self.size is not None and self.resolution > 4 else x.shape[
                2:]
            noise = fix_size(noise.unsqueeze(0).unsqueeze(0),
                             noise_size,
                             scale_type=self.scale_type)[0][0]

        # print(x.shape, noise.shape, self.size, self.up)

        flip_weight = (self.up == 1)  # slightly faster
        x = modulated_conv2d(
            x=x,
            weight=self.weight,
            styles=styles,
            noise=noise,
            up=self.up,
            latmask=latmask,
            countHW=self.countHW,
            splitfine=self.splitfine,
            size=self.size,
            scale_type=self.scale_type,  # !!! custom
            padding=self.padding,
            resample_filter=self.resample_filter,
            flip_weight=flip_weight,
            fused_modconv=fused_modconv)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x,
                              self.bias.to(x.dtype),
                              act=self.activation,
                              gain=act_gain,
                              clamp=act_clamp)
        return x
Example #4
0
    def forward(self,
                x,
                img,
                ws,
                latmask,
                dconst,
                force_fp32=False,
                fused_modconv=None,
                **layer_kwargs):
        # def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
        misc.assert_shape(ws,
                          [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            with misc.suppress_tracer_warnings(
            ):  # this value will be treated as a constant
                fused_modconv = (not self.training) and (
                    dtype == torch.float32 or int(x.shape[0]) == 1)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
            # !!! custom const size
            if 'side' in self.scale_type and 'symm' in self.scale_type:  # looks better
                const_size = self.init_res if self.size is None else self.size
                x = fix_size(x, const_size, self.scale_type)
# distortion technique from Aydao
            x += dconst
        else:
            # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            # !!! custom latmask
            x = self.conv1(x,
                           None,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            # x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == 'resnet':
            y = self.skip(x, gain=np.sqrt(0.5))
            # !!! custom latmask
            x = self.conv0(x,
                           latmask,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            x = self.conv1(x,
                           None,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           gain=np.sqrt(0.5),
                           **layer_kwargs)
            # x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            # x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            # !!! custom latmask
            x = self.conv0(x,
                           latmask,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            x = self.conv1(x,
                           None,
                           next(w_iter),
                           fused_modconv=fused_modconv,
                           **layer_kwargs)
            # x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            # x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
            # !!! custom img size
            # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
            img = upfirdn2d.upsample2d(img, self.resample_filter)
            img = fix_size(img, self.size, scale_type=self.scale_type)

        if self.is_last or self.architecture == 'skip':
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32,
                     memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img
Example #5
0
def modulated_conv2d(
    x,  # Input tensor of shape [batch_size, in_channels, in_height, in_width].
    weight,  # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
    styles,  # Modulation coefficients of shape [batch_size, in_channels].
    # !!! custom
    latmask,  # mask for split-frame latents blending
    countHW=[1, 1],  # frame split count by height,width
    splitfine=0.,  # frame split edge fineness (float from 0+)
    size=None,  # custom size
    scale_type=None,  # scaling way: fit, centr, side, pad, padside
    noise=None,  # Optional noise tensor to add to the output activations.
    up=1,  # Integer upsampling factor.
    down=1,  # Integer downsampling factor.
    padding=0,  # Padding with respect to the upsampled image.
    resample_filter=None,  # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
    demodulate=True,  # Apply weight demodulation?
    flip_weight=True,  # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
    fused_modconv=True,  # Perform modulation, convolution, and demodulation as a single fused operation?
):
    batch_size = x.shape[0]
    out_channels, in_channels, kh, kw = weight.shape
    misc.assert_shape(weight, [out_channels, in_channels, kh, kw])  # [OIkk]
    misc.assert_shape(x, [batch_size, in_channels, None, None])  # [NIHW]
    misc.assert_shape(styles, [batch_size, in_channels])  # [NI]

    # Pre-normalize inputs to avoid FP16 overflow.
    if x.dtype == torch.float16 and demodulate:
        weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(
            float('inf'), dim=[1, 2, 3], keepdim=True))  # max_Ikk
        styles = styles / styles.norm(float('inf'), dim=1,
                                      keepdim=True)  # max_I

    # Calculate per-sample weights and demodulation coefficients.
    w = None
    dcoefs = None
    if demodulate or fused_modconv:
        w = weight.unsqueeze(0)  # [NOIkk]
        w = w * styles.reshape(batch_size, 1, -1, 1, 1)  # [NOIkk]
    if demodulate:
        dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt()  # [NO]
    if demodulate and fused_modconv:
        w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1)  # [NOIkk]

    # Execute by scaling the activations before and after the convolution.
    if not fused_modconv:
        x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
        x = conv2d_resample.conv2d_resample(x=x,
                                            w=weight.to(x.dtype),
                                            f=resample_filter,
                                            up=up,
                                            down=down,
                                            padding=padding,
                                            flip_weight=flip_weight)
        # !!! custom size & multi latent blending
        if size is not None and up == 2:
            x = fix_size(x, size, scale_type)
            x = multimask(x, size, latmask, countHW, splitfine)
        if demodulate and noise is not None:
            x = fma.fma(x,
                        dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1),
                        noise.to(x.dtype))
        elif demodulate:
            x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
        elif noise is not None:
            x = x.add_(noise.to(x.dtype))
        return x

    # Execute as one fused op using grouped convolution.
    with misc.suppress_tracer_warnings(
    ):  # this value will be treated as a constant
        batch_size = int(batch_size)
    misc.assert_shape(x, [batch_size, in_channels, None, None])
    x = x.reshape(1, -1, *x.shape[2:])
    w = w.reshape(-1, in_channels, kh, kw)
    x = conv2d_resample.conv2d_resample(x=x,
                                        w=w.to(x.dtype),
                                        f=resample_filter,
                                        up=up,
                                        down=down,
                                        padding=padding,
                                        groups=batch_size,
                                        flip_weight=flip_weight)
    x = x.reshape(batch_size, -1, *x.shape[2:])
    # !!! custom size & multi latent blending
    if size is not None and up == 2:
        x = fix_size(x, size, scale_type)
        x = multimask(x, size, latmask, countHW, splitfine)
    if noise is not None:
        x = x.add_(noise)
    return x
def G_synthesis_stylegan2(
    dlatents_in,  # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
    latmask,  # mask for split-frame latents blending
    dconst,  # initial (const) layer displacement
    latmask_res=[1, 1],  # resolution of external mask for blending
    countW=1,  # frame split count by width
    countH=1,  # frame split count by height
    splitfine=0.,  # frame split edge sharpness (float from 0)
    size=None,  # Output size
    scale_type=None,  # scaling way: fit, centr, side, pad, padside
    init_res=[4, 4],  # Initial (minimum) resolution for progressive training
    dlatent_size=512,  # Disentangled latent (W) dimensionality.
    num_channels=3,  # Number of output color channels.
    resolution=1024,  # Base model resolution (corresponding to the layer count)
    fmap_base=16 << 10,  # Overall multiplier for the number of feature maps.
    fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
    fmap_min=1,  # Minimum number of feature maps in any layer.
    fmap_max=512,  # Maximum number of feature maps in any layer.
    randomize_noise=True,  # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
    architecture='skip',  # Architecture: 'orig', 'skip', 'resnet'.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    dtype='float32',  # Data type to use for activations and outputs.
    resample_kernel=[
        1, 3, 3, 1
    ],  # Low-pass filter to apply when resampling activations. None = no filtering.
    fused_modconv=True,  # Implement modulated_conv2d_layer() as a single fused op?
    verbose=False,  #
    impl='cuda',  # Custom ops implementation - cuda (original) or ref (no compiling)
    **_kwargs):  # Ignore unrecognized keyword args.

    res_log2 = int(np.log2(resolution))
    assert resolution == 2**res_log2 and resolution >= 4

    # calculate intermediate layers sizes for arbitrary output resolution
    custom_res = (resolution * init_res[0] // 4, resolution * init_res[1] // 4)
    if size is None: size = custom_res
    if init_res != [4, 4] and verbose:
        print(' .. init res', init_res, size)
    keep_first_layers = 2 if scale_type == 'fit' else None
    hws = hw_scales(size, custom_res, res_log2 - 2, keep_first_layers, verbose)
    if verbose: print(hws, '..', custom_res, res_log2 - 1)

    # multi latent blending
    latmask.set_shape([None, *latmask_res])
    dconst.set_shape([None, dlatent_size, *init_res])
    splitfine = tf.cast(splitfine, tf.float32)

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity
    num_layers = res_log2 * 2 - 2
    images_out = None

    # Primary inputs.
    dlatents_in.set_shape([None, num_layers, dlatent_size])
    dlatents_in = tf.cast(dlatents_in, dtype)

    # Noise inputs.
    noise_inputs = []
    for layer_idx in range(num_layers - 1):
        res = (layer_idx + 5) // 2
        shape = [1, 1, 2**(res - 2) * init_res[0], 2**(res - 2) * init_res[1]]
        noise_inputs.append(
            tf.get_variable('noise%d' % layer_idx,
                            shape=shape,
                            initializer=tf.initializers.random_normal(),
                            trainable=False))

    # Single convolution layer with all the bells and whistles.
    def layer(x, layer_idx, size, fmaps, kernel, up=False):
        x = modulated_conv2d_layer(x,
                                   dlatents_in[:, layer_idx],
                                   fmaps=fmaps,
                                   kernel=kernel,
                                   up=up,
                                   resample_kernel=resample_kernel,
                                   fused_modconv=fused_modconv,
                                   impl=impl)
        if size is not None and up is True:
            x = fix_size(x, size, scale_type)
            # multi latent blending
            x = multimask(x, size, latmask, countH, countW, splitfine)
        if randomize_noise:
            noise = tf.random_normal(
                [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
        else:
            noise = tf.cast(noise_inputs[layer_idx], x.dtype)
            noise = fix_size(noise, (x.shape[2], x.shape[3]),
                             scale_type=scale_type)
        noise_strength = tf.get_variable('noise_strength',
                                         shape=[],
                                         initializer=tf.initializers.zeros())
        x += noise * tf.cast(noise_strength, x.dtype)
        return apply_bias_act(x, act=act, impl=impl)

    # Building blocks for main layers.
    def block(x, res, size):  # res = 3..res_log2
        t = x
        with tf.variable_scope('Conv0_up'):
            x = layer(x,
                      layer_idx=res * 2 - 5,
                      size=size,
                      fmaps=nf(res - 1),
                      kernel=3,
                      up=True)
        with tf.variable_scope('Conv1'):
            x = layer(x,
                      layer_idx=res * 2 - 4,
                      size=size,
                      fmaps=nf(res - 1),
                      kernel=3)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 1),
                                 kernel=1,
                                 up=True,
                                 resample_kernel=resample_kernel,
                                 impl=impl)
                if size is not None:
                    t = fix_size(t, (x.shape[2], x.shape[3]),
                                 scale_type=scale_type)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def upsample(y):
        with tf.variable_scope('Upsample'):
            return upsample_2d(y, k=resample_kernel, impl=impl)

    def torgb(x, y, res):  # res = 2..res_log2
        with tf.variable_scope('ToRGB'):
            t = apply_bias_act(modulated_conv2d_layer(
                x,
                dlatents_in[:, res * 2 - 3],
                fmaps=num_channels,
                kernel=1,
                demodulate=False,
                fused_modconv=fused_modconv,
                impl=impl),
                               impl=impl)
            return t if y is None else y + t

    # Early layers.
    y = None
    with tf.variable_scope('4x4'):
        with tf.variable_scope('Const'):
            x = tf.get_variable('const',
                                shape=[1, nf(1), *init_res],
                                initializer=tf.initializers.random_normal())
            x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
            # distortion technique from Aydao
            x += dconst
        with tf.variable_scope('Conv'):
            x = layer(x, layer_idx=0, size=None, fmaps=nf(1), kernel=3)
        if architecture == 'skip':
            y = torgb(x, y, 2)

    # Main layers.
    for res in range(3, res_log2 + 1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            x = block(x, res, hws[res - 2])
            if architecture == 'skip':
                y = upsample(y)
                if size is not None:
                    y = fix_size(y, hws[res - 2], scale_type=scale_type)
            if architecture == 'skip' or res == res_log2:
                y = torgb(x, y, res)
    images_out = y

    assert images_out.dtype == tf.as_dtype(dtype)
    return tf.identity(images_out, name='images_out')