def __init__(self,
        in_channels,                    # Number of input channels.
        out_channels,                   # Number of output channels.
        w_dim,                          # Intermediate latent (W) dimensionality.
        resolution,                     # Resolution of this layer.
        kernel_size     = 3,            # Convolution kernel size.
        up              = 1,            # Integer upsampling factor.
        use_noise       = True,         # Enable noise input?
        activation      = 'lrelu',      # Activation function: 'relu', 'lrelu', etc.
        resample_filter = [1,3,3,1],    # Low-pass filter to apply when resampling activations.
        conv_clamp      = None,         # Clamp the output of convolution layers to +-X, None = disable clamping.
        channels_last   = False,        # Use channels_last format for the weights?
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.w_dim = w_dim
        self.resolution = resolution
        self.up = up
        self.use_noise = use_noise
        self.activation = activation
        self.conv_clamp = conv_clamp
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
        self.padding = kernel_size // 2
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
        if use_noise:
            self.register_buffer('noise_const', torch.randn([resolution, resolution]))
            self.noise_strength = torch.nn.Parameter(torch.zeros([]))
        self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
    def __init__(self, in_channels, out_channels, kernel_size, bias=True, activation='Linear', up=1, down=1, resample_filter=[1, 3, 3, 1], conv_clamp=None, channels_last=False, trainable=True):
        super().__init__()

        self.activation = activation
        self.up = up
        self.down = down
        self.conv_clamp = conv_clamp
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
        self.padding = kernel_size // 2
        self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
        bias = torch.zeros([out_channels]) if bias else None
        
        if trainable:
            self.weight = torch.nn.Parameter(weight)
            self.bias = torch.nn.Parameter(bias) if bias is not None else None
        else:
            self.register_buffer('weight', weight)

            if bias is not None:
                self.register_buffer('bias', bias)
            else:
                self.bias = None
Пример #3
0
    def __init__(self,
        in_channels,                    # Number of input channels.
        out_channels,                   # Number of output channels.
        kernel_size,                    # Width and height of the convolution kernel.
        bias            = True,         # Apply additive bias before the activation function?
        activation      = 'linear',     # Activation function: 'relu', 'lrelu', etc.
        up              = 1,            # Integer upsampling factor.
        down            = 1,            # Integer downsampling factor.
        resample_filter = [1,3,3,1],    # Low-pass filter to apply when resampling activations.
        conv_clamp      = None,         # Clamp the output to +-X, None = disable clamping.
        channels_last   = False,        # Expect the input to have memory_format=channels_last?
        trainable       = True,         # Update the weights of this layer during training?
    ):
        super().__init__()
        self.activation = activation
        self.up = up
        self.down = down
        self.conv_clamp = conv_clamp
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
        self.padding = kernel_size // 2
        self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
        bias = torch.zeros([out_channels]) if bias else None
        if trainable:
            self.weight = torch.nn.Parameter(weight)
            self.bias = torch.nn.Parameter(bias) if bias is not None else None
        else:
            self.register_buffer('weight', weight)
            if bias is not None:
                self.register_buffer('bias', bias)
            else:
                self.bias = None
Пример #4
0
    def __init__(
            self,
            in_channels,  # Number of input channels.
            out_channels,  # Number of output channels.
            w_dim,  # Intermediate latent (W) dimensionality.
            resolution,  # Resolution of this layer.
            # !!! custom
        countHW=[1, 1],  # frame split count by height,width
            splitfine=0.,  # frame split edge fineness (float from 0+)
            size=None,  # custom size
            scale_type=None,  # scaling way: fit, centr, side, pad, padside
            init_res=[
                4, 4
            ],  # Initial (minimal) resolution for progressive training
            kernel_size=3,  # Convolution kernel size.
            up=1,  # Integer upsampling factor.
            use_noise=True,  # Enable noise input?
            activation='lrelu',  # Activation function: 'relu', 'lrelu', etc.
            resample_filter=[
                1, 3, 3, 1
            ],  # Low-pass filter to apply when resampling activations.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
            channels_last=False,  # Use channels_last format for the weights?
    ):
        super().__init__()
        self.resolution = resolution
        self.countHW = countHW  # !!! custom
        self.splitfine = splitfine  # !!! custom
        self.size = size  # !!! custom
        self.scale_type = scale_type  # !!! custom
        self.init_res = init_res  # !!! custom
        self.up = up
        self.use_noise = use_noise
        self.activation = activation
        self.conv_clamp = conv_clamp
        self.register_buffer('resample_filter',
                             upfirdn2d.setup_filter(resample_filter))
        self.padding = kernel_size // 2
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        self.weight = torch.nn.Parameter(
            torch.randn([out_channels, in_channels, kernel_size,
                         kernel_size]).to(memory_format=memory_format))
        if use_noise:
            # !!! custom
            self.register_buffer(
                'noise_const',
                torch.randn([
                    resolution * init_res[0] // 4,
                    resolution * init_res[1] // 4
                ]))
            # self.register_buffer('noise_const', torch.randn([resolution, resolution]))
            self.noise_strength = torch.nn.Parameter(torch.zeros([]))
        self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
Пример #5
0
    def __init__(self,
        in_channels,                        # Number of input channels, 0 = first block.
        out_channels,                       # Number of output channels.
        w_dim,                              # Intermediate latent (W) dimensionality.
        resolution,                         # Resolution of this block.
        img_channels,                       # Number of output color channels.
        is_last,                            # Is this the last block?
# !!! custom
        init_res            = [4,4],      # Initial (minimal) resolution for progressive training
        architecture        = 'skip',       # Architecture: 'orig', 'skip', 'resnet'.
        resample_filter     = [1,3,3,1],    # Low-pass filter to apply when resampling activations.
        conv_clamp          = None,         # Clamp the output of convolution layers to +-X, None = disable clamping.
        use_fp16            = False,        # Use FP16 for this block?
        fp16_channels_last  = False,        # Use channels-last memory format with FP16?
        **layer_kwargs,                     # Arguments for SynthesisLayer.
    ):
        assert architecture in ['orig', 'skip', 'resnet']
        super().__init__()
        self.in_channels = in_channels
        self.w_dim = w_dim
        self.resolution = resolution
        self.init_res = init_res # !!! custom
        self.img_channels = img_channels
        self.is_last = is_last
        self.architecture = architecture
        self.use_fp16 = use_fp16
        self.channels_last = (use_fp16 and fp16_channels_last)
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
        self.num_conv = 0
        self.num_torgb = 0

        if in_channels == 0:
# !!! custom
            self.const = torch.nn.Parameter(torch.randn([out_channels, resolution * self.init_res[0]//4, resolution * self.init_res[1]//4]))
            # self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))

        if in_channels != 0:
            self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2, init_res=init_res, # !!! custom
                resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
            self.num_conv += 1

        self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution, init_res=init_res, # !!! custom
            conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
        self.num_conv += 1

        if is_last or architecture == 'skip':
            self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
                conv_clamp=conv_clamp, channels_last=self.channels_last)
            self.num_torgb += 1

        if in_channels != 0 and architecture == 'resnet':
            self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
                resample_filter=resample_filter, channels_last=self.channels_last)
Пример #6
0
    def __init__(self,
        in_channels,                        # Number of input channels, 0 = first block.
        tmp_channels,                       # Number of intermediate channels.
        out_channels,                       # Number of output channels.
        resolution,                         # Resolution of this block.
        img_channels,                       # Number of input color channels.
        first_layer_idx,                    # Index of the first layer.
# !!! custom
        init_res            = [4,4],      # Initial (minimal) resolution for progressive training
        architecture        = 'resnet',     # Architecture: 'orig', 'skip', 'resnet'.
        activation          = 'lrelu',      # Activation function: 'relu', 'lrelu', etc.
        resample_filter     = [1,3,3,1],    # Low-pass filter to apply when resampling activations.
        conv_clamp          = None,         # Clamp the output of convolution layers to +-X, None = disable clamping.
        use_fp16            = False,        # Use FP16 for this block?
        fp16_channels_last  = False,        # Use channels-last memory format with FP16?
        freeze_layers       = 0,            # Freeze-D: Number of layers to freeze.
    ):
        assert in_channels in [0, tmp_channels]
        assert architecture in ['orig', 'skip', 'resnet']
        super().__init__()
        self.in_channels = in_channels
        self.resolution = resolution
        self.init_res = init_res # !!! custom
        self.img_channels = img_channels
        self.first_layer_idx = first_layer_idx
        self.architecture = architecture
        self.use_fp16 = use_fp16
        self.channels_last = (use_fp16 and fp16_channels_last)
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))

        self.num_layers = 0
        def trainable_gen():
            while True:
                layer_idx = self.first_layer_idx + self.num_layers
                trainable = (layer_idx >= freeze_layers)
                self.num_layers += 1
                yield trainable
        trainable_iter = trainable_gen()

        if in_channels == 0 or architecture == 'skip':
            self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
                trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)

        self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
            trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)

        self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
            trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)

        if architecture == 'resnet':
            self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
                trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
Пример #7
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 w_dim,
                 resolution,
                 kernel_size=3,
                 up=1,
                 use_noise=True,
                 activation='lrelu',
                 resample_filter=[1, 3, 3, 1],
                 conv_clamp=None,
                 channels_last=False):
        super().__init__()

        self.resolution = resolution
        self.up = up
        self.use_noise = use_noise
        self.activation = activation
        self.conv_clamp = conv_clamp
        self.register_buffer('resample_filter',
                             upfirdn2d.setup_filter(resample_filter))
        self.padding = kernel_size // 2
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        self.affine = fcl(w_dim, in_channels, bias_init=1)
        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        self.weight = torch.nn.Parameter(
            torch.randn([out_channels, in_channels, kernel_size,
                         kernel_size]).to(memory_format=memory_format))

        if use_noise:
            self.register_buffer('noise_const',
                                 torch.randn([resolution, resolution]))
            self.noise_strength = torch.nn.Parameter(torch.zeros([]))

        self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
Пример #8
0
    def __init__(
            self,
            in_channels,  # Number of input channels, 0 = first block.
            out_channels,  # Number of output channels.
            w_dim,  # Intermediate latent (W) dimensionality.
            resolution,  # Resolution of this block.
            img_channels,  # Number of output color channels.
            is_last,  # Is this the last block?
            architecture='skip',  # Architecture: 'orig', 'skip', 'resnet'.
            resample_filter=[
                1, 3, 3, 1
            ],  # Low-pass filter to apply when resampling activations.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
            use_fp16=False,  # Use FP16 for this block?
            fp16_channels_last=False,  # Use channels-last memory format with FP16?
            cfg={},  # Additional config
            **layer_kwargs,  # Arguments for SynthesisLayer.
    ):
        assert architecture in ['orig', 'skip', 'resnet']
        super().__init__()

        self.cfg = OmegaConf.create(cfg)
        self.in_channels = in_channels
        self.w_dim = w_dim

        if resolution <= self.cfg.input.resolution:
            self.resolution = self.cfg.input.resolution
            self.up = 1
            self.input_resolution = self.cfg.input.resolution
        else:
            self.resolution = resolution
            self.up = 2
            self.input_resolution = resolution // 2

        self.img_channels = img_channels
        self.is_last = is_last
        self.architecture = architecture
        self.use_fp16 = use_fp16
        self.channels_last = (use_fp16 and fp16_channels_last)
        self.register_buffer('resample_filter',
                             upfirdn2d.setup_filter(resample_filter))
        self.num_conv = 0
        self.num_torgb = 0

        kernel_size = self.cfg.coords.kernel_size if self.cfg.coords.enabled else 3

        if in_channels == 0:
            self.input = GenInput(self.cfg.input, out_channels, w_dim)
            conv1_in_channels = self.input.total_dim
        else:
            if self.cfg.coords.enabled and (
                    not self.cfg.coords.per_resolution
                    or self.resolution > self.input_resolution):
                assert self.architecture != 'resnet'
                self.coord_fuser = CoordFuser(self.cfg.coords, self.w_dim,
                                              self.resolution)
                conv0_in_channels = in_channels + self.coord_fuser.total_dim
            else:
                self.coord_fuser = None
                conv0_in_channels = in_channels

            up_for_conv0 = self.up if self.cfg.upsampling_mode is None else 1
            self.conv0 = SynthesisLayer(conv0_in_channels,
                                        out_channels,
                                        w_dim=w_dim,
                                        resolution=self.resolution,
                                        up=up_for_conv0,
                                        resample_filter=resample_filter,
                                        conv_clamp=conv_clamp,
                                        channels_last=self.channels_last,
                                        kernel_size=kernel_size,
                                        cfg=cfg,
                                        **layer_kwargs)
            self.num_conv += 1
            conv1_in_channels = out_channels

        self.conv1 = SynthesisLayer(
            conv1_in_channels,
            out_channels,
            w_dim=w_dim,
            resolution=self.resolution,
            conv_clamp=conv_clamp,
            channels_last=self.channels_last,
            kernel_size=kernel_size,
            cfg=cfg,
            instance_norm=(in_channels > 0
                           and cfg.get('fmm', {}).get('instance_norm', False)),
            **layer_kwargs)
        self.num_conv += 1

        if self.cfg.get('num_extra_convs', {}).get(str(self.resolution),
                                                   0) > 0:
            assert self.architecture != 'resnet', "Not implemented for resnet"
            self.extra_convs = nn.ModuleList([
                SynthesisLayer(out_channels,
                               out_channels,
                               w_dim=w_dim,
                               resolution=self.resolution,
                               conv_clamp=conv_clamp,
                               channels_last=self.channels_last,
                               kernel_size=kernel_size,
                               instance_norm=cfg.get('fmm', {}).get(
                                   'instance_norm', False),
                               cfg=cfg,
                               **layer_kwargs)
                for _ in range(self.cfg.num_extra_convs[str(self.resolution)])
            ])
            self.num_conv += len(self.extra_convs)
        else:
            self.extra_convs = None

        if is_last or architecture == 'skip':
            self.torgb = ToRGBLayer(out_channels,
                                    img_channels,
                                    w_dim=w_dim,
                                    conv_clamp=conv_clamp,
                                    channels_last=self.channels_last)
            self.num_torgb += 1

        if in_channels != 0 and architecture == 'resnet':
            self.skip = Conv2dLayer(in_channels,
                                    out_channels,
                                    kernel_size=1,
                                    bias=False,
                                    up=self.up,
                                    resample_filter=resample_filter,
                                    channels_last=self.channels_last)
Пример #9
0
    def __init__(
        self,
        xflip=0,
        rotate90=0,
        xint=0,
        xint_max=0.125,
        scale=0,
        rotate=0,
        aniso=0,
        xfrac=0,
        scale_std=0.2,
        rotate_max=1,
        aniso_std=0.2,
        xfrac_std=0.125,
        brightness=0,
        contrast=0,
        lumaflip=0,
        hue=0,
        saturation=0,
        brightness_std=0.2,
        contrast_std=0.5,
        hue_max=1,
        saturation_std=1,
        imgfilter=0,
        imgfilter_bands=[1, 1, 1, 1],
        imgfilter_std=1,
        noise=0,
        cutout=0,
        noise_std=0.1,
        cutout_size=0.5,
    ):
        super().__init__()
        self.register_buffer('p', torch.ones(
            []))  # Overall multiplier for augmentation probability.

        # Pixel blitting.
        self.xflip = float(xflip)  # Probability multiplier for x-flip.
        self.rotate90 = float(
            rotate90)  # Probability multiplier for 90 degree rotations.
        self.xint = float(
            xint)  # Probability multiplier for integer translation.
        self.xint_max = float(
            xint_max
        )  # Range of integer translation, relative to image dimensions.

        # General geometric transformations.
        self.scale = float(
            scale)  # Probability multiplier for isotropic scaling.
        self.rotate = float(
            rotate)  # Probability multiplier for arbitrary rotation.
        self.aniso = float(
            aniso)  # Probability multiplier for anisotropic scaling.
        self.xfrac = float(
            xfrac)  # Probability multiplier for fractional translation.
        self.scale_std = float(
            scale_std)  # Log2 standard deviation of isotropic scaling.
        self.rotate_max = float(
            rotate_max)  # Range of arbitrary rotation, 1 = full circle.
        self.aniso_std = float(
            aniso_std)  # Log2 standard deviation of anisotropic scaling.
        self.xfrac_std = float(
            xfrac_std
        )  # Standard deviation of frational translation, relative to image dimensions.

        # Color transformations.
        self.brightness = float(
            brightness)  # Probability multiplier for brightness.
        self.contrast = float(contrast)  # Probability multiplier for contrast.
        self.lumaflip = float(
            lumaflip)  # Probability multiplier for luma flip.
        self.hue = float(hue)  # Probability multiplier for hue rotation.
        self.saturation = float(
            saturation)  # Probability multiplier for saturation.
        self.brightness_std = float(
            brightness_std)  # Standard deviation of brightness.
        self.contrast_std = float(
            contrast_std)  # Log2 standard deviation of contrast.
        self.hue_max = float(
            hue_max)  # Range of hue rotation, 1 = full circle.
        self.saturation_std = float(
            saturation_std)  # Log2 standard deviation of saturation.

        # Image-space filtering.
        self.imgfilter = float(
            imgfilter)  # Probability multiplier for image-space filtering.
        self.imgfilter_bands = list(
            imgfilter_bands
        )  # Probability multipliers for individual frequency bands.
        self.imgfilter_std = float(
            imgfilter_std
        )  # Log2 standard deviation of image-space filter amplification.

        # Image-space corruptions.
        self.noise = float(
            noise)  # Probability multiplier for additive RGB noise.
        self.cutout = float(cutout)  # Probability multiplier for cutout.
        self.noise_std = float(
            noise_std)  # Standard deviation of additive RGB noise.
        self.cutout_size = float(
            cutout_size
        )  # Size of the cutout rectangle, relative to image dimensions.

        # Setup orthogonal lowpass filter for geometric augmentations.
        self.register_buffer('Hz_geom',
                             upfirdn2d.setup_filter(wavelets['sym6']))

        # Construct filter bank for image-space filtering.
        Hz_lo = np.asarray(wavelets['sym2'])  # H(z)
        Hz_hi = Hz_lo * ((-1)**np.arange(Hz_lo.size))  # H(-z)
        Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2  # H(z) * H(z^-1) / 2
        Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2  # H(-z) * H(-z^-1) / 2
        Hz_fbank = np.eye(4, 1)  # Bandpass(H(z), b_i)
        for i in range(1, Hz_fbank.shape[0]):
            Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)
                                  ]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
            Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
            Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) //
                     2:(Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
        self.register_buffer('Hz_fbank',
                             torch.as_tensor(Hz_fbank, dtype=torch.float32))
Пример #10
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 w_dim,
                 resolution,
                 img_channels,
                 is_last,
                 architecture='skip',
                 resample_filter=[1, 3, 3, 1],
                 conv_clamp=None,
                 use_fp16=False,
                 fp16_channels_last=False,
                 **layer_kwargs):
        assert architecture in ['orig', 'skip', 'resnet']
        super().__init__()

        self.in_channels = in_channels
        self.w_dim = w_dim
        self.resolution = resolution
        self.img_channels = img_channels
        self.is_last = is_last
        self.architecture = architecture
        self.use_fp16 = use_fp16
        self.channels_last = (use_fp16 and fp16_channels_last)
        self.register_buffer('resample_filter',
                             upfirdn2d.setup_filter(resample_filter))
        self.num_conv = 0
        self.num_torgb = 0

        if in_channels == 0:
            self.const = torch.nn.Parameter(
                torch.randn([out_channels, resolution, resolution]))

        if in_channels != 0:
            self.conv0 = synthesis_layer(in_channels,
                                         out_channels,
                                         w_dim=w_dim,
                                         resolution=resolution,
                                         up=2,
                                         resample_filter=resample_filter,
                                         conv_clamp=conv_clamp,
                                         channels_last=self.channels_last,
                                         **layer_kwargs)
            self.num_conv += 1

        self.conv1 = synthesis_layer(out_channels,
                                     out_channels,
                                     w_dim=w_dim,
                                     resolution=resolution,
                                     conv_clamp=conv_clamp,
                                     channels_last=self.channels_last,
                                     **layer_kwargs)
        self.num_conv += 1

        if is_last or architecture == 'skip':
            self.torgb = ToRGBLayer(out_channels,
                                    img_channels,
                                    w_dim=w_dim,
                                    conv_clamp=conv_clamp,
                                    channels_last=self.channels_last)
            self.num_torgb += 1

        if in_channels != 0 and architecture == 'resnet':
            self.skip = Conv2dLayer(in_channels,
                                    out_channels,
                                    kernel_size=1,
                                    bias=False,
                                    up=2,
                                    resample_filter=resample_filter,
                                    channels_last=self.channels_last)
Пример #11
0
    def __init__(
            self,
            in_channels,  # Number of input channels, 0 = first block.
            out_channels,  # Number of output channels.
            w_dim,  # Intermediate latent (W) dimensionality.
            resolution,  # Resolution of this block.
            img_channels,  # Number of output color channels.
            is_last,  # Is this the last block?
            segmentation_channels,  # Number of segmentation channels (only used for try-on)
            architecture='skip',  # Architecture: 'orig', 'skip', 'resnet'.
            resample_filter=[
                1, 3, 3, 1
            ],  # Low-pass filter to apply when resampling activations.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
            use_fp16=False,  # Use FP16 for this block?
            fp16_channels_last=False,  # Use channels-last memory format with FP16?
            **layer_kwargs,  # Arguments for SynthesisLayer.
    ):
        assert architecture in ['orig', 'skip', 'resnet']
        super().__init__()
        self.in_channels = in_channels
        self.w_dim = w_dim
        self.resolution = resolution
        self.img_channels = img_channels
        self.segmentation_channels = segmentation_channels
        self.is_last = is_last
        self.architecture = architecture
        self.use_fp16 = use_fp16
        self.channels_last = (use_fp16 and fp16_channels_last)
        self.register_buffer('resample_filter',
                             upfirdn2d.setup_filter(resample_filter))
        self.num_conv = 0
        self.num_torgb = 0
        self.num_tosegmentation = 0

        if in_channels == 0:
            self.const = torch.nn.Parameter(
                torch.randn([out_channels, resolution, resolution]))

        if in_channels != 0:
            self.conv0 = SynthesisLayer(in_channels,
                                        out_channels,
                                        w_dim=w_dim,
                                        resolution=resolution,
                                        up=2,
                                        resample_filter=resample_filter,
                                        conv_clamp=conv_clamp,
                                        channels_last=self.channels_last,
                                        **layer_kwargs)
            self.num_conv += 1

        self.conv1 = SynthesisLayer(out_channels,
                                    out_channels,
                                    w_dim=w_dim,
                                    resolution=resolution,
                                    conv_clamp=conv_clamp,
                                    channels_last=self.channels_last,
                                    **layer_kwargs)
        self.num_conv += 1

        if is_last or architecture == 'skip':
            self.torgb = ToRGBLayer(out_channels,
                                    img_channels,
                                    w_dim=w_dim,
                                    conv_clamp=conv_clamp,
                                    channels_last=self.channels_last)
            self.tosegmentation = ToSegmentationLayer(
                out_channels,
                segmentation_channels,
                w_dim=w_dim,
                conv_clamp=conv_clamp,
                channels_last=self.channels_last)
            self.num_tosegmentation += 1
            self.num_torgb += 1

        if is_last:
            eps = torch.tensor(1e-8)
            zeroTensor = torch.zeros(1)
            self.register_buffer('eps', eps)
            self.register_buffer('zeroTensor', zeroTensor)

        if in_channels != 0 and architecture == 'resnet':
            self.skip = Conv2dLayer(in_channels,
                                    out_channels,
                                    kernel_size=1,
                                    bias=False,
                                    up=2,
                                    resample_filter=resample_filter,
                                    channels_last=self.channels_last)