Example #1
0
    def __init__(
            self,
            w_dim,  # Intermediate latent (W) dimensionality.
            img_resolution,  # Output image resolution.
            img_channels,  # Number of color channels.
            # !!! custom
        init_res=[4,
                  4],  # Initial (minimal) resolution for progressive training
            size=None,  # Output size
            scale_type=None,  # scaling way: fit, centr, side, pad, padside
            channel_base=32768,  # Overall multiplier for the number of channels.
            channel_max=512,  # Maximum number of channels in any layer.
            num_fp16_res=0,  # Use FP16 for the N highest resolutions.
            verbose=False,  #
            **block_kwargs,  # Arguments for SynthesisBlock.
    ):
        assert img_resolution >= 4 and img_resolution & (img_resolution -
                                                         1) == 0
        super().__init__()
        self.w_dim = w_dim
        self.img_resolution = img_resolution
        self.res_log2 = int(np.log2(img_resolution))
        self.img_channels = img_channels
        self.fmap_base = channel_base
        self.block_resolutions = [2**i for i in range(2, self.res_log2 + 1)]
        channels_dict = {
            res: min(channel_base // res, channel_max)
            for res in self.block_resolutions
        }
        fp16_resolution = max(2**(self.res_log2 + 1 - num_fp16_res), 8)

        # calculate intermediate layers sizes for arbitrary output resolution
        custom_res = (img_resolution * init_res[0] // 4,
                      img_resolution * init_res[1] // 4)
        if size is None: size = custom_res
        if init_res != [4, 4] and verbose:
            print(' .. init res', init_res, size)
        keep_first_layers = 2 if scale_type == 'fit' else None
        hws = hw_scales(size, custom_res, self.res_log2 - 2, keep_first_layers,
                        verbose)
        if verbose: print(hws, '..', custom_res, self.res_log2 - 1)

        self.num_ws = 0
        for i, res in enumerate(self.block_resolutions):
            in_channels = channels_dict[res // 2] if res > 4 else 0
            out_channels = channels_dict[res]
            use_fp16 = (res >= fp16_resolution)
            is_last = (res == self.img_resolution)
            block = SynthesisBlock(
                in_channels,
                out_channels,
                w_dim=w_dim,
                resolution=res,
                init_res=init_res,
                scale_type=scale_type,
                size=hws[i],  # !!! custom
                img_channels=img_channels,
                is_last=is_last,
                use_fp16=use_fp16,
                **block_kwargs)
            self.num_ws += block.num_conv
            if is_last:
                self.num_ws += block.num_torgb
            setattr(self, f'b{res}', block)
def G_synthesis_stylegan2(
    dlatents_in,  # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
    latmask,  # mask for split-frame latents blending
    dconst,  # initial (const) layer displacement
    latmask_res=[1, 1],  # resolution of external mask for blending
    countW=1,  # frame split count by width
    countH=1,  # frame split count by height
    splitfine=0.,  # frame split edge sharpness (float from 0)
    size=None,  # Output size
    scale_type=None,  # scaling way: fit, centr, side, pad, padside
    init_res=[4, 4],  # Initial (minimum) resolution for progressive training
    dlatent_size=512,  # Disentangled latent (W) dimensionality.
    num_channels=3,  # Number of output color channels.
    resolution=1024,  # Base model resolution (corresponding to the layer count)
    fmap_base=16 << 10,  # Overall multiplier for the number of feature maps.
    fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
    fmap_min=1,  # Minimum number of feature maps in any layer.
    fmap_max=512,  # Maximum number of feature maps in any layer.
    randomize_noise=True,  # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
    architecture='skip',  # Architecture: 'orig', 'skip', 'resnet'.
    nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
    dtype='float32',  # Data type to use for activations and outputs.
    resample_kernel=[
        1, 3, 3, 1
    ],  # Low-pass filter to apply when resampling activations. None = no filtering.
    fused_modconv=True,  # Implement modulated_conv2d_layer() as a single fused op?
    verbose=False,  #
    impl='cuda',  # Custom ops implementation - cuda (original) or ref (no compiling)
    **_kwargs):  # Ignore unrecognized keyword args.

    res_log2 = int(np.log2(resolution))
    assert resolution == 2**res_log2 and resolution >= 4

    # calculate intermediate layers sizes for arbitrary output resolution
    custom_res = (resolution * init_res[0] // 4, resolution * init_res[1] // 4)
    if size is None: size = custom_res
    if init_res != [4, 4] and verbose:
        print(' .. init res', init_res, size)
    keep_first_layers = 2 if scale_type == 'fit' else None
    hws = hw_scales(size, custom_res, res_log2 - 2, keep_first_layers, verbose)
    if verbose: print(hws, '..', custom_res, res_log2 - 1)

    # multi latent blending
    latmask.set_shape([None, *latmask_res])
    dconst.set_shape([None, dlatent_size, *init_res])
    splitfine = tf.cast(splitfine, tf.float32)

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity
    num_layers = res_log2 * 2 - 2
    images_out = None

    # Primary inputs.
    dlatents_in.set_shape([None, num_layers, dlatent_size])
    dlatents_in = tf.cast(dlatents_in, dtype)

    # Noise inputs.
    noise_inputs = []
    for layer_idx in range(num_layers - 1):
        res = (layer_idx + 5) // 2
        shape = [1, 1, 2**(res - 2) * init_res[0], 2**(res - 2) * init_res[1]]
        noise_inputs.append(
            tf.get_variable('noise%d' % layer_idx,
                            shape=shape,
                            initializer=tf.initializers.random_normal(),
                            trainable=False))

    # Single convolution layer with all the bells and whistles.
    def layer(x, layer_idx, size, fmaps, kernel, up=False):
        x = modulated_conv2d_layer(x,
                                   dlatents_in[:, layer_idx],
                                   fmaps=fmaps,
                                   kernel=kernel,
                                   up=up,
                                   resample_kernel=resample_kernel,
                                   fused_modconv=fused_modconv,
                                   impl=impl)
        if size is not None and up is True:
            x = fix_size(x, size, scale_type)
            # multi latent blending
            x = multimask(x, size, latmask, countH, countW, splitfine)
        if randomize_noise:
            noise = tf.random_normal(
                [tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)
        else:
            noise = tf.cast(noise_inputs[layer_idx], x.dtype)
            noise = fix_size(noise, (x.shape[2], x.shape[3]),
                             scale_type=scale_type)
        noise_strength = tf.get_variable('noise_strength',
                                         shape=[],
                                         initializer=tf.initializers.zeros())
        x += noise * tf.cast(noise_strength, x.dtype)
        return apply_bias_act(x, act=act, impl=impl)

    # Building blocks for main layers.
    def block(x, res, size):  # res = 3..res_log2
        t = x
        with tf.variable_scope('Conv0_up'):
            x = layer(x,
                      layer_idx=res * 2 - 5,
                      size=size,
                      fmaps=nf(res - 1),
                      kernel=3,
                      up=True)
        with tf.variable_scope('Conv1'):
            x = layer(x,
                      layer_idx=res * 2 - 4,
                      size=size,
                      fmaps=nf(res - 1),
                      kernel=3)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 1),
                                 kernel=1,
                                 up=True,
                                 resample_kernel=resample_kernel,
                                 impl=impl)
                if size is not None:
                    t = fix_size(t, (x.shape[2], x.shape[3]),
                                 scale_type=scale_type)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def upsample(y):
        with tf.variable_scope('Upsample'):
            return upsample_2d(y, k=resample_kernel, impl=impl)

    def torgb(x, y, res):  # res = 2..res_log2
        with tf.variable_scope('ToRGB'):
            t = apply_bias_act(modulated_conv2d_layer(
                x,
                dlatents_in[:, res * 2 - 3],
                fmaps=num_channels,
                kernel=1,
                demodulate=False,
                fused_modconv=fused_modconv,
                impl=impl),
                               impl=impl)
            return t if y is None else y + t

    # Early layers.
    y = None
    with tf.variable_scope('4x4'):
        with tf.variable_scope('Const'):
            x = tf.get_variable('const',
                                shape=[1, nf(1), *init_res],
                                initializer=tf.initializers.random_normal())
            x = tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1])
            # distortion technique from Aydao
            x += dconst
        with tf.variable_scope('Conv'):
            x = layer(x, layer_idx=0, size=None, fmaps=nf(1), kernel=3)
        if architecture == 'skip':
            y = torgb(x, y, 2)

    # Main layers.
    for res in range(3, res_log2 + 1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            x = block(x, res, hws[res - 2])
            if architecture == 'skip':
                y = upsample(y)
                if size is not None:
                    y = fix_size(y, hws[res - 2], scale_type=scale_type)
            if architecture == 'skip' or res == res_log2:
                y = torgb(x, y, res)
    images_out = y

    assert images_out.dtype == tf.as_dtype(dtype)
    return tf.identity(images_out, name='images_out')