コード例 #1
0
    def __init__(self,
                 out_size,
                 style_channels,
                 img_channels,
                 noise_size=512,
                 rgb2bgr=False,
                 pretrained=None,
                 synthesis_cfg=dict(type='SynthesisNetwork'),
                 mapping_cfg=dict(type='MappingNetwork')):
        super().__init__()
        self.noise_size = noise_size
        self.style_channels = style_channels
        self.out_size = out_size
        self.img_channels = img_channels
        self.rgb2bgr = rgb2bgr

        self._synthesis_cfg = deepcopy(synthesis_cfg)
        self._synthesis_cfg.setdefault('style_channels', style_channels)
        self._synthesis_cfg.setdefault('out_size', out_size)
        self._synthesis_cfg.setdefault('img_channels', img_channels)
        self.synthesis = build_module(self._synthesis_cfg)

        self.num_ws = self.synthesis.num_ws
        self._mapping_cfg = deepcopy(mapping_cfg)
        self._mapping_cfg.setdefault('noise_size', noise_size)
        self._mapping_cfg.setdefault('style_channels', style_channels)
        self._mapping_cfg.setdefault('num_ws', self.num_ws)
        self.style_mapping = build_module(self._mapping_cfg)

        if pretrained is not None:
            self._load_pretrained_model(**pretrained)
コード例 #2
0
    def __init__(self,
                 generator,
                 discriminator,
                 gan_loss,
                 disc_auxiliary_loss,
                 gen_auxiliary_loss=None,
                 train_cfg=None,
                 test_cfg=None):
        super().__init__()
        self._gen_cfg = deepcopy(generator)
        self.generator = build_module(generator)

        # support no discriminator in testing
        if discriminator is not None:
            self.discriminator = build_module(discriminator)
        else:
            self.discriminator = None

        # support no gan_loss in testing
        if gan_loss is not None:
            self.gan_loss = build_module(gan_loss)
        else:
            self.gan_loss = None

        if disc_auxiliary_loss:
            self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
            if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
                self.disc_auxiliary_losses = nn.ModuleList(
                    [self.disc_auxiliary_losses])
        else:
            self.disc_auxiliary_losses = None

        if gen_auxiliary_loss:
            self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
            if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
                self.gen_auxiliary_losses = nn.ModuleList(
                    [self.gen_auxiliary_losses])
        else:
            self.gen_auxiliary_losses = None

        # register necessary training status
        self.register_buffer('shown_nkimg', torch.tensor(0.))
        self.register_buffer('_curr_transition_weight', torch.tensor(1.))

        self.train_cfg = deepcopy(train_cfg) if train_cfg else None
        self.test_cfg = deepcopy(test_cfg) if test_cfg else None

        self._parse_train_cfg()
        # this buffer is used to resume model easily
        self.register_buffer(
            '_next_scale_int',
            torch.tensor(self.scales[0][0], dtype=torch.int32))
        # TODO: init it with the same value as `_next_scale_int`
        # a dirty workaround for testing
        self.register_buffer(
            '_curr_scale_int',
            torch.tensor(self.scales[-1][0], dtype=torch.int32))
        if test_cfg is not None:
            self._parse_test_cfg()
コード例 #3
0
ファイル: singan.py プロジェクト: youtang1993/mmgeneration
    def __init__(self,
                 generator,
                 discriminator,
                 gan_loss,
                 disc_auxiliary_loss,
                 gen_auxiliary_loss=None,
                 train_cfg=None,
                 test_cfg=None):
        super().__init__()
        self._gen_cfg = deepcopy(generator)
        self.generator = build_module(generator)

        # support no discriminator in testing
        if discriminator is not None:
            self.discriminator = build_module(discriminator)
        else:
            self.discriminator = None

        # support no gan_loss in testing
        if gan_loss is not None:
            self.gan_loss = build_module(gan_loss)
        else:
            self.gan_loss = None

        if disc_auxiliary_loss:
            self.disc_auxiliary_losses = build_module(disc_auxiliary_loss)
            if not isinstance(self.disc_auxiliary_losses, nn.ModuleList):
                self.disc_auxiliary_losses = nn.ModuleList(
                    [self.disc_auxiliary_losses])
        else:
            self.disc_auxiliary_losses = None

        if gen_auxiliary_loss:
            self.gen_auxiliary_losses = build_module(gen_auxiliary_loss)
            if not isinstance(self.gen_auxiliary_losses, nn.ModuleList):
                self.gen_auxiliary_losses = nn.ModuleList(
                    [self.gen_auxiliary_losses])
        else:
            self.gen_auxiliary_losses = None

        # register necessary training status
        self.curr_stage = -1
        self.noise_weights = [1]
        self.fixed_noises = []
        self.reals = []
        self.train_cfg = deepcopy(train_cfg) if train_cfg else None
        self.test_cfg = deepcopy(test_cfg) if test_cfg else None

        self._parse_train_cfg()
        if test_cfg is not None:
            self._parse_test_cfg()
コード例 #4
0
    def test_mse_loss(self):
        # test forward
        config = deepcopy(self.config)
        config['rescale_mode'] = 'timestep_weight'
        loss_fn = build_module(config, default_args=dict(weight=self.weight))
        loss = loss_fn(self.output_dict)
        np.allclose(loss, self.loss_manually)

        # test reduction raise error
        config = deepcopy(self.config)
        config['reduction'] = 'reduction'
        with pytest.raises(ValueError):
            loss_fn = build_module(config)

        # test return loss name
        config = deepcopy(self.config)
        config['loss_name'] = 'loss_name'
        loss_fn = build_module(config)
        assert loss_fn.loss_name() == 'loss_name'
コード例 #5
0
    def __init__(self,
                 input_scale,
                 num_classes=0,
                 in_channels=3,
                 out_channels=1,
                 base_channels=96,
                 sn_eps=1e-6,
                 init_type='ortho',
                 act_cfg=dict(type='ReLU'),
                 with_spectral_norm=True,
                 blocks_cfg=dict(type='BigGANDiscResBlock'),
                 arch_cfg=None,
                 pretrained=None):
        super().__init__()
        self.num_classes = num_classes
        self.out_channels = out_channels
        self.input_scale = input_scale
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.arch = arch_cfg if arch_cfg else self._get_default_arch_cfg(
            self.input_scale, self.in_channels, self.base_channels)
        self.blocks_cfg = deepcopy(blocks_cfg)
        self.blocks_cfg.update(
            dict(act_cfg=act_cfg,
                 sn_eps=sn_eps,
                 with_spectral_norm=with_spectral_norm))

        self.conv_blocks = nn.ModuleList()
        for index, out_ch in enumerate(self.arch['out_channels']):
            # change args to adapt to current block
            self.blocks_cfg.update(
                dict(in_channels=self.arch['in_channels'][index],
                     out_channels=out_ch,
                     with_downsample=self.arch['downsample'][index],
                     is_head_block=(index == 0)))
            self.conv_blocks.append(build_module(self.blocks_cfg))
            if self.arch['attention'][index]:
                self.conv_blocks.append(
                    SelfAttentionBlock(out_ch,
                                       with_spectral_norm=with_spectral_norm,
                                       sn_eps=sn_eps))

        self.activate = build_activation_layer(act_cfg)

        self.decision = nn.Linear(self.arch['out_channels'][-1], out_channels)
        if with_spectral_norm:
            self.decision = spectral_norm(self.decision, eps=sn_eps)

        if self.num_classes > 0:
            self.proj_y = nn.Embedding(self.num_classes,
                                       self.arch['out_channels'][-1])
            if with_spectral_norm:
                self.proj_y = spectral_norm(self.proj_y, eps=sn_eps)

        self.init_weights(pretrained=pretrained, init_type=init_type)
コード例 #6
0
    def __init__(self,
                 in_channels,
                 embedding_channels,
                 use_scale_shift_norm,
                 dropout,
                 out_channels=None,
                 norm_cfg=dict(type='GN', num_groups=32),
                 act_cfg=dict(type='SiLU', inplace=False),
                 shortcut_kernel_size=1):
        super().__init__()
        out_channels = in_channels if out_channels is None else out_channels

        _norm_cfg = deepcopy(norm_cfg)

        _, norm_1 = build_norm_layer(_norm_cfg, in_channels)
        conv_1 = [
            norm_1,
            build_activation_layer(act_cfg),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        ]
        self.conv_1 = nn.Sequential(*conv_1)

        norm_with_embedding_cfg = dict(
            in_channels=out_channels,
            embedding_channels=embedding_channels,
            use_scale_shift=use_scale_shift_norm,
            norm_cfg=_norm_cfg)
        self.norm_with_embedding = build_module(
            dict(type='NormWithEmbedding'),
            default_args=norm_with_embedding_cfg)

        conv_2 = [
            build_activation_layer(act_cfg),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        ]
        self.conv_2 = nn.Sequential(*conv_2)

        assert shortcut_kernel_size in [
            1, 3
        ], ('Only support `1` and `3` for `shortcut_kernel_size`, but '
            f'receive {shortcut_kernel_size}.')

        self.learnable_shortcut = out_channels != in_channels

        if self.learnable_shortcut:
            shortcut_padding = 1 if shortcut_kernel_size == 3 else 0
            self.shortcut = nn.Conv2d(
                in_channels,
                out_channels,
                shortcut_kernel_size,
                padding=shortcut_padding)
        self.init_weights()
コード例 #7
0
    def __init__(self, in_size, *args, data_aug=None, **kwargs):
        """StyleGANv2 Discriminator with adaptive augmentation.

        Args:
            in_size (int): The input size of images.
            data_aug (dict, optional): Config for data
                augmentation. Defaults to None.
        """
        super().__init__(in_size, *args, **kwargs)
        self.with_ada = data_aug is not None
        if self.with_ada:
            self.ada_aug = build_module(data_aug)
            self.ada_aug.requires_grad = False
        self.log_size = int(np.log2(in_size))
コード例 #8
0
    def test_vlb_loss(self):
        # test forward
        config = deepcopy(self.config)
        loss_fn = build_module(config)
        loss = loss_fn(self.output_dict)
        np.allclose(loss, self.loss_manually * 4)

        # test log_cfgs --> dict input
        config = deepcopy(self.config)
        config['log_cfgs'] = dict(type='name')
        loss_fn = build_module(config)
        assert isinstance(loss_fn.log_fn_list, list)

        # test log_cfgs --> no log_cfgs
        config = deepcopy(self.config)
        config['log_cfgs'] = None
        loss_fn = build_module(config)
        loss = loss_fn(self.output_dict)
        assert not loss_fn.log_vars

        # test rescale_cfg --> rescale is None
        config = deepcopy(self.config)
        config['rescale_mode'] = None
        loss_fn = build_module(config)
        loss = loss_fn(self.output_dict)
        np.allclose(loss, self.loss_manually)

        # TODO: test rescale_cfg --> test sampler

        # test rescale_cfg --> test weight
        config = deepcopy(self.config)
        config['rescale_mode'] = 'timestep_weight'
        weight = self.weight.clone()
        loss_fn = build_module(config, default_args=dict(weight=weight))
        loss = loss_fn(self.output_dict)
        loss_weighted_manually = (
            -(self.loss_disc_likelihood * weight)[0] +
            (self.loss_gaussian_kld * weight)[1:].sum()) / 4
        np.allclose(loss, loss_weighted_manually)

        # test rescale_cfg --> change weight
        weight[0] += 1
        loss = loss_fn(self.output_dict)
        loss_weighted_manually = (
            -(self.loss_disc_likelihood * weight)[0] +
            (self.loss_gaussian_kld * weight)[1:].sum()) / 4
        np.allclose(loss, loss_weighted_manually)

        # test t = 0
        config = deepcopy(self.config)
        output_dict = deepcopy(self.output_dict)
        output_dict['timesteps'][0] = 1
        loss_fn = build_module(config)
        loss = loss_fn(output_dict)
        assert loss_fn.log_vars['loss_vlb_quartile_0'] == 0
        assert loss_fn.log_vars['loss_DiscGaussianLogLikelihood'] == 0
コード例 #9
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 num_scales,
                 kernel_size=3,
                 padding=0,
                 num_layers=5,
                 base_channels=32,
                 min_feat_channels=32,
                 out_act_cfg=dict(type='Tanh'),
                 padding_mode='zero',
                 pad_at_head=True,
                 interp_pad=False,
                 noise_with_pad=False,
                 positional_encoding=None,
                 first_stage_in_channels=None,
                 **kwargs):
        super(SinGANMultiScaleGenerator, self).__init__()

        self.pad_at_head = pad_at_head
        self.interp_pad = interp_pad
        self.noise_with_pad = noise_with_pad

        self.with_positional_encode = positional_encoding is not None
        if self.with_positional_encode:
            self.head_position_encode = build_module(positional_encoding)

        self.pad_head = int((kernel_size - 1) / 2 * num_layers)
        self.blocks = nn.ModuleList()

        self.upsample = partial(F.interpolate,
                                mode='bicubic',
                                align_corners=True)

        for scale in range(num_scales + 1):
            base_ch = min(base_channels * pow(2, math.floor(scale / 4)), 128)
            min_feat_ch = min(
                min_feat_channels * pow(2, math.floor(scale / 4)), 128)

            if scale == 0:
                in_ch = (first_stage_in_channels
                         if first_stage_in_channels else in_channels)
            else:
                in_ch = in_channels

            self.blocks.append(
                GeneratorBlock(in_channels=in_ch,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               padding=padding,
                               num_layers=num_layers,
                               base_channels=base_ch,
                               min_feat_channels=min_feat_ch,
                               out_act_cfg=out_act_cfg,
                               padding_mode=padding_mode,
                               **kwargs))

        if padding_mode == 'zero':
            self.noise_padding_layer = nn.ZeroPad2d(self.pad_head)
            self.img_padding_layer = nn.ZeroPad2d(self.pad_head)
            self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head)
        elif padding_mode == 'reflect':
            self.noise_padding_layer = nn.ReflectionPad2d(self.pad_head)
            self.img_padding_layer = nn.ReflectionPad2d(self.pad_head)
            self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head)
            mmcv.print_log('Using Reflection padding', 'mmgen')
        else:
            raise NotImplementedError(
                f'Padding mode {padding_mode} is not supported')
コード例 #10
0
ファイル: mspie.py プロジェクト: plyfager/mmgeneration
    def __init__(self,
                 out_size,
                 style_channels,
                 num_mlps=8,
                 channel_multiplier=2,
                 blur_kernel=[1, 3, 3, 1],
                 lr_mlp=0.01,
                 default_style_mode='mix',
                 eval_style_mode='single',
                 mix_prob=0.9,
                 no_pad=False,
                 deconv2conv=False,
                 interp_pad=None,
                 up_config=dict(scale_factor=2, mode='nearest'),
                 up_after_conv=False,
                 head_pos_encoding=None,
                 head_pos_size=(4, 4),
                 interp_head=False):
        super().__init__()
        self.out_size = out_size
        self.style_channels = style_channels
        self.num_mlps = num_mlps
        self.channel_multiplier = channel_multiplier
        self.lr_mlp = lr_mlp
        self._default_style_mode = default_style_mode
        self.default_style_mode = default_style_mode
        self.eval_style_mode = eval_style_mode
        self.mix_prob = mix_prob
        self.no_pad = no_pad
        self.deconv2conv = deconv2conv
        self.interp_pad = interp_pad
        self.with_interp_pad = interp_pad is not None
        self.up_config = deepcopy(up_config)
        self.up_after_conv = up_after_conv
        self.head_pos_encoding = head_pos_encoding
        self.head_pos_size = head_pos_size
        self.interp_head = interp_head

        # define style mapping layers
        mapping_layers = [PixelNorm()]

        for _ in range(num_mlps):
            mapping_layers.append(
                EqualLinearActModule(style_channels,
                                     style_channels,
                                     equalized_lr_cfg=dict(lr_mul=lr_mlp,
                                                           gain=1.),
                                     act_cfg=dict(type='fused_bias')))

        self.style_mapping = nn.Sequential(*mapping_layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        in_ch = self.channels[4]
        # constant input layer
        if self.head_pos_encoding:
            if self.head_pos_encoding['type'] in [
                    'CatersianGrid', 'CSG', 'CSG2d'
            ]:
                in_ch = 2
            self.head_pos_enc = build_module(self.head_pos_encoding)
        else:
            size_ = 4
            if self.no_pad:
                size_ += 2
            self.constant_input = ConstantInput(self.channels[4], size=size_)

        # 4x4 stage
        self.conv1 = ModulatedPEStyleConv(in_ch,
                                          self.channels[4],
                                          kernel_size=3,
                                          style_channels=style_channels,
                                          blur_kernel=blur_kernel,
                                          deconv2conv=self.deconv2conv,
                                          no_pad=self.no_pad,
                                          up_config=self.up_config,
                                          interp_pad=self.interp_pad)
        self.to_rgb1 = ModulatedToRGB(self.channels[4],
                                      style_channels,
                                      upsample=False)

        # generator backbone (8x8 --> higher resolutions)
        self.log_size = int(np.log2(self.out_size))

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()

        in_channels_ = self.channels[4]

        for i in range(3, self.log_size + 1):
            out_channels_ = self.channels[2**i]

            self.convs.append(
                ModulatedPEStyleConv(in_channels_,
                                     out_channels_,
                                     3,
                                     style_channels,
                                     upsample=True,
                                     blur_kernel=blur_kernel,
                                     deconv2conv=self.deconv2conv,
                                     no_pad=self.no_pad,
                                     up_config=self.up_config,
                                     interp_pad=self.interp_pad,
                                     up_after_conv=self.up_after_conv))
            self.convs.append(
                ModulatedPEStyleConv(out_channels_,
                                     out_channels_,
                                     3,
                                     style_channels,
                                     upsample=False,
                                     blur_kernel=blur_kernel,
                                     deconv2conv=self.deconv2conv,
                                     no_pad=self.no_pad,
                                     up_config=self.up_config,
                                     interp_pad=self.interp_pad,
                                     up_after_conv=self.up_after_conv))
            self.to_rgbs.append(
                ModulatedToRGB(out_channels_, style_channels, upsample=True))

            in_channels_ = out_channels_

        self.num_latents = self.log_size * 2 - 2
        self.num_injected_noises = self.num_latents - 1

        # register buffer for injected noises
        noises = self.make_injected_noise()
        for layer_idx in range(self.num_injected_noises):
            self.register_buffer(f'injected_noise_{layer_idx}',
                                 noises[layer_idx])
コード例 #11
0
    def __init__(self,
                 input_scale,
                 num_classes=0,
                 base_channels=128,
                 input_channels=3,
                 attention_cfg=dict(type='SelfAttentionBlock'),
                 attention_after_nth_block=-1,
                 channels_cfg=None,
                 downsample_cfg=None,
                 from_rgb_cfg=dict(type='SNGANDiscHeadResBlock'),
                 blocks_cfg=dict(type='SNGANDiscResBlock'),
                 act_cfg=dict(type='ReLU'),
                 with_spectral_norm=True,
                 sn_eps=1e-12,
                 init_cfg=dict(type='BigGAN'),
                 pretrained=None):

        super().__init__()

        self.init_type = init_cfg.get('type', None)

        # add SN options and activation function options to cfg
        self.from_rgb_cfg = deepcopy(from_rgb_cfg)
        self.from_rgb_cfg.setdefault('act_cfg', act_cfg)
        self.from_rgb_cfg.setdefault('with_spectral_norm', with_spectral_norm)
        self.from_rgb_cfg.setdefault('init_cfg', init_cfg)

        # add SN options and activation function options to cfg
        self.blocks_cfg = deepcopy(blocks_cfg)
        self.blocks_cfg.setdefault('act_cfg', act_cfg)
        self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm)
        self.blocks_cfg.setdefault('sn_eps', sn_eps)
        self.blocks_cfg.setdefault('init_cfg', init_cfg)

        channels_cfg = deepcopy(self._defualt_channels_cfg) \
            if channels_cfg is None else deepcopy(channels_cfg)
        if isinstance(channels_cfg, dict):
            if input_scale not in channels_cfg:
                raise KeyError(f'`input_scale={input_scale} is not found in '
                               '`channel_cfg`, only support configs for '
                               f'{[chn for chn in channels_cfg.keys()]}')
            self.channel_factor_list = channels_cfg[input_scale]
        elif isinstance(channels_cfg, list):
            self.channel_factor_list = channels_cfg
        else:
            raise ValueError('Only support list or dict for `channel_cfg`, '
                             f'receive {type(channels_cfg)}')

        downsample_cfg = deepcopy(self._defualt_downsample_cfg) \
            if downsample_cfg is None else deepcopy(downsample_cfg)
        if isinstance(downsample_cfg, dict):
            if input_scale not in downsample_cfg:
                raise KeyError(f'`output_scale={input_scale} is not found in '
                               '`downsample_cfg`, only support configs for '
                               f'{[chn for chn in downsample_cfg.keys()]}')
            self.downsample_list = downsample_cfg[input_scale]
        elif isinstance(downsample_cfg, list):
            self.downsample_list = downsample_cfg
        else:
            raise ValueError('Only support list or dict for `channel_cfg`, '
                             f'receive {type(downsample_cfg)}')

        if len(self.downsample_list) != len(self.channel_factor_list):
            raise ValueError('`downsample_cfg` should have same length with '
                             '`channels_cfg`, but receive '
                             f'{len(self.downsample_list)} and '
                             f'{len(self.channel_factor_list)}.')

        # check `attention_after_nth_block`
        if not isinstance(attention_after_nth_block, list):
            attention_after_nth_block = [attention_after_nth_block]
        if not all([isinstance(idx, int)
                    for idx in attention_after_nth_block]):
            raise ValueError('`attention_after_nth_block` only support int or '
                             'a list of int. Please check your input type.')

        self.from_rgb = build_module(
            self.from_rgb_cfg,
            dict(in_channels=input_channels, out_channels=base_channels))

        self.conv_blocks = nn.ModuleList()
        # add self-attention block after the first block
        if 1 in attention_after_nth_block:
            attn_cfg_ = deepcopy(attention_cfg)
            attn_cfg_['in_channels'] = base_channels
            self.conv_blocks.append(build_module(attn_cfg_))

        for idx in range(len(self.downsample_list)):
            factor_input = 1 if idx == 0 else self.channel_factor_list[idx - 1]
            factor_output = self.channel_factor_list[idx]

            # get block-specific config
            block_cfg_ = deepcopy(self.blocks_cfg)
            block_cfg_['downsample'] = self.downsample_list[idx]
            block_cfg_['in_channels'] = factor_input * base_channels
            block_cfg_['out_channels'] = factor_output * base_channels
            self.conv_blocks.append(build_module(block_cfg_))

            # build self-attention block
            # the first ConvBlock is `from_rgb` block,
            # add 2 to get the index of the ConvBlocks
            if idx + 2 in attention_after_nth_block:
                attn_cfg_ = deepcopy(attention_cfg)
                attn_cfg_['in_channels'] = factor_output * base_channels
                self.conv_blocks.append(build_module(attn_cfg_))

        self.decision = nn.Linear(factor_output * base_channels, 1)

        if with_spectral_norm:
            self.decision = spectral_norm(self.decision)

        self.num_classes = num_classes

        # In this case, discriminator is designed for conditional synthesis.
        if num_classes > 0:
            self.proj_y = nn.Embedding(num_classes,
                                       factor_output * base_channels)
            if with_spectral_norm:
                self.proj_y = spectral_norm(self.proj_y)

        self.activate = build_activation_layer(act_cfg)
        self.init_weights(pretrained)
コード例 #12
0
    def __init__(self,
                 output_scale,
                 num_classes=0,
                 base_channels=64,
                 out_channels=3,
                 input_scale=4,
                 noise_size=128,
                 attention_cfg=dict(type='SelfAttentionBlock'),
                 attention_after_nth_block=0,
                 channels_cfg=None,
                 blocks_cfg=dict(type='SNGANGenResBlock'),
                 act_cfg=dict(type='ReLU'),
                 use_cbn=True,
                 auto_sync_bn=True,
                 with_spectral_norm=False,
                 with_embedding_spectral_norm=None,
                 norm_eps=1e-4,
                 sn_eps=1e-12,
                 init_cfg=dict(type='BigGAN'),
                 pretrained=None):

        super().__init__()

        self.input_scale = input_scale
        self.output_scale = output_scale
        self.noise_size = noise_size
        self.num_classes = num_classes
        self.init_type = init_cfg.get('type', None)

        self.blocks_cfg = deepcopy(blocks_cfg)

        self.blocks_cfg.setdefault('num_classes', num_classes)
        self.blocks_cfg.setdefault('act_cfg', act_cfg)
        self.blocks_cfg.setdefault('use_cbn', use_cbn)
        self.blocks_cfg.setdefault('auto_sync_bn', auto_sync_bn)
        self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm)

        # set `norm_spectral_norm` as `with_spectral_norm` if not defined
        with_embedding_spectral_norm = with_embedding_spectral_norm \
            if with_embedding_spectral_norm is not None else with_spectral_norm
        self.blocks_cfg.setdefault('with_embedding_spectral_norm',
                                   with_embedding_spectral_norm)
        self.blocks_cfg.setdefault('init_cfg', init_cfg)
        self.blocks_cfg.setdefault('norm_eps', norm_eps)
        self.blocks_cfg.setdefault('sn_eps', sn_eps)

        channels_cfg = deepcopy(self._default_channels_cfg) \
            if channels_cfg is None else deepcopy(channels_cfg)
        if isinstance(channels_cfg, dict):
            if output_scale not in channels_cfg:
                raise KeyError(f'`output_scale={output_scale} is not found in '
                               '`channel_cfg`, only support configs for '
                               f'{[chn for chn in channels_cfg.keys()]}')
            self.channel_factor_list = channels_cfg[output_scale]
        elif isinstance(channels_cfg, list):
            self.channel_factor_list = channels_cfg
        else:
            raise ValueError('Only support list or dict for `channel_cfg`, '
                             f'receive {type(channels_cfg)}')

        self.noise2feat = nn.Linear(
            noise_size,
            input_scale**2 * base_channels * self.channel_factor_list[0])
        if with_spectral_norm:
            self.noise2feat = spectral_norm(self.noise2feat)

        # check `attention_after_nth_block`
        if not isinstance(attention_after_nth_block, list):
            attention_after_nth_block = [attention_after_nth_block]
        if not is_list_of(attention_after_nth_block, int):
            raise ValueError('`attention_after_nth_block` only support int or '
                             'a list of int. Please check your input type.')

        self.conv_blocks = nn.ModuleList()
        self.attention_block_idx = []
        for idx in range(len(self.channel_factor_list)):
            factor_input = self.channel_factor_list[idx]
            factor_output = self.channel_factor_list[idx+1] \
                if idx < len(self.channel_factor_list)-1 else 1

            # get block-specific config
            block_cfg_ = deepcopy(self.blocks_cfg)
            block_cfg_['in_channels'] = factor_input * base_channels
            block_cfg_['out_channels'] = factor_output * base_channels
            self.conv_blocks.append(build_module(block_cfg_))

            # build self-attention block
            # `idx` is start from 0, add 1 to get the index
            if idx + 1 in attention_after_nth_block:
                self.attention_block_idx.append(len(self.conv_blocks))
                attn_cfg_ = deepcopy(attention_cfg)
                attn_cfg_['in_channels'] = factor_output * base_channels
                self.conv_blocks.append(build_module(attn_cfg_))

        to_rgb_norm_cfg = dict(type='BN', eps=norm_eps)
        if check_dist_init() and auto_sync_bn:
            to_rgb_norm_cfg['type'] = 'SyncBN'

        self.to_rgb = ConvModule(factor_output * base_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 bias=True,
                                 norm_cfg=to_rgb_norm_cfg,
                                 act_cfg=act_cfg,
                                 order=('norm', 'act', 'conv'),
                                 with_spectral_norm=with_spectral_norm)
        self.final_act = build_activation_layer(dict(type='Tanh'))

        self.init_weights(pretrained)
コード例 #13
0
    def __init__(self,
                 output_scale,
                 noise_size=120,
                 num_classes=0,
                 out_channels=3,
                 base_channels=96,
                 input_scale=4,
                 with_shared_embedding=True,
                 shared_dim=128,
                 sn_eps=1e-6,
                 init_type='ortho',
                 split_noise=True,
                 act_cfg=dict(type='ReLU'),
                 upsample_cfg=dict(type='nearest', scale_factor=2),
                 with_spectral_norm=True,
                 auto_sync_bn=True,
                 blocks_cfg=dict(type='BigGANGenResBlock'),
                 arch_cfg=None,
                 out_norm_cfg=dict(type='BN'),
                 pretrained=None,
                 rgb2bgr=False):
        super().__init__()
        self.noise_size = noise_size
        self.num_classes = num_classes
        self.shared_dim = shared_dim
        self.with_shared_embedding = with_shared_embedding
        self.output_scale = output_scale
        self.arch = arch_cfg if arch_cfg else self._get_default_arch_cfg(
            self.output_scale, base_channels)
        self.input_scale = input_scale
        self.split_noise = split_noise
        self.blocks_cfg = deepcopy(blocks_cfg)
        self.upsample_cfg = deepcopy(upsample_cfg)
        self.rgb2bgr = rgb2bgr

        # Validity Check
        # If 'num_classes' equals to zero, we shall set 'with_shared_embedding'
        # to False.
        if num_classes == 0:
            assert not self.with_shared_embedding
        else:
            if not self.with_shared_embedding:
                # If not `with_shared_embedding`, we will use `nn.Embedding` to
                # replace the original `Linear` layer in conditional BN.
                # Meanwhile, we do not adopt split noises.
                assert not self.split_noise

        # If using split latents, we may need to adjust noise_size
        if self.split_noise:
            # Number of places z slots into
            self.num_slots = len(self.arch['in_channels']) + 1
            self.noise_chunk_size = self.noise_size // self.num_slots
            # Recalculate latent dimensionality for even splitting into chunks
            self.noise_size = self.noise_chunk_size * self.num_slots
        else:
            self.num_slots = 1
            self.noise_chunk_size = 0

        # First linear layer
        self.noise2feat = nn.Linear(
            self.noise_size // self.num_slots,
            self.arch['in_channels'][0] * (self.input_scale**2))
        if with_spectral_norm:
            self.noise2feat = spectral_norm(self.noise2feat, eps=sn_eps)

        # If using 'shared_embedding', we will get an unified embedding of
        # label for all blocks. If not, we just pass the label to each
        # block.
        if with_shared_embedding:
            self.shared_embedding = nn.Embedding(num_classes, shared_dim)
        else:
            self.shared_embedding = nn.Identity()

        if num_classes > 0:
            self.dim_after_concat = (self.shared_dim + self.noise_chunk_size
                                     if self.with_shared_embedding else
                                     self.num_classes)
        else:
            self.dim_after_concat = self.noise_chunk_size

        self.blocks_cfg.update(
            dict(dim_after_concat=self.dim_after_concat,
                 act_cfg=act_cfg,
                 sn_eps=sn_eps,
                 input_is_label=(num_classes > 0)
                 and (not with_shared_embedding),
                 with_spectral_norm=with_spectral_norm,
                 auto_sync_bn=auto_sync_bn))

        self.conv_blocks = nn.ModuleList()
        for index, out_ch in enumerate(self.arch['out_channels']):
            # change args to adapt to current block
            self.blocks_cfg.update(
                dict(in_channels=self.arch['in_channels'][index],
                     out_channels=out_ch,
                     upsample_cfg=self.upsample_cfg
                     if self.arch['upsample'][index] else None))
            self.conv_blocks.append(build_module(self.blocks_cfg))
            if self.arch['attention'][index]:
                self.conv_blocks.append(
                    SelfAttentionBlock(out_ch,
                                       with_spectral_norm=with_spectral_norm,
                                       sn_eps=sn_eps))

        self.output_layer = SNConvModule(self.arch['out_channels'][-1],
                                         out_channels,
                                         kernel_size=3,
                                         padding=1,
                                         with_spectral_norm=with_spectral_norm,
                                         spectral_norm_cfg=dict(eps=sn_eps),
                                         act_cfg=act_cfg,
                                         norm_cfg=out_norm_cfg,
                                         bias=True,
                                         order=('norm', 'act', 'conv'))

        self.init_weights(pretrained=pretrained, init_type=init_type)
コード例 #14
0
ファイル: cyclegan.py プロジェクト: youtang1993/mmgeneration
    def __init__(self,
                 generator,
                 discriminator,
                 gan_loss,
                 cycle_loss,
                 id_loss=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super().__init__()

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        # identity loss only works when input and output images have the same
        # number of channels
        if id_loss is not None and id_loss.get('loss_weight') > 0.0:
            assert generator.get('in_channels') == generator.get(
                'out_channels')

        # generators
        self.generators = nn.ModuleDict()
        self.generators['a'] = build_module(generator)
        self.generators['b'] = build_module(generator)

        # discriminators
        self.discriminators = nn.ModuleDict()
        self.discriminators['a'] = build_module(discriminator)
        self.discriminators['b'] = build_module(discriminator)

        # GAN image buffers
        self.image_buffers = dict()
        self.buffer_size = (50 if self.train_cfg is None else
                            self.train_cfg.get('buffer_size', 50))
        self.image_buffers['a'] = GANImageBuffer(self.buffer_size)
        self.image_buffers['b'] = GANImageBuffer(self.buffer_size)

        # losses
        assert gan_loss is not None  # gan loss cannot be None
        self.gan_loss = build_module(gan_loss)
        assert cycle_loss is not None  # cycle loss cannot be None
        self.cycle_loss = build_module(cycle_loss)
        self.id_loss = build_module(id_loss) if id_loss else None

        # others
        self.disc_steps = 1 if self.train_cfg is None else self.train_cfg.get(
            'disc_steps', 1)
        self.disc_init_steps = (0 if self.train_cfg is None else
                                self.train_cfg.get('disc_init_steps', 0))
        if self.train_cfg is None:
            self.direction = ('a2b' if self.test_cfg is None else
                              self.test_cfg.get('direction', 'a2b'))
        else:
            self.direction = self.train_cfg.get('direction', 'a2b')
        self.step_counter = 0  # counting training steps
        self.show_input = (False if self.test_cfg is None else
                           self.test_cfg.get('show_input', False))
        # In CycleGAN, if not showing input, we can decide the translation
        # direction in the test mode, i.e., whether to output fake_b or fake_a
        if not self.show_input:
            self.test_direction = ('a2b' if self.test_cfg is None else
                                   self.test_cfg.get('test_direction', 'a2b'))
            if self.direction == 'b2a':
                self.test_direction = ('b2a' if self.test_direction == 'a2b'
                                       else 'a2b')

        self.init_weights(pretrained)

        self.use_ema = False
コード例 #15
0
ファイル: denoising.py プロジェクト: nbei/mmgeneration
    def __init__(self,
                 image_size,
                 in_channels=3,
                 base_channels=128,
                 resblocks_per_downsample=3,
                 num_timesteps=1000,
                 use_rescale_timesteps=True,
                 dropout=0,
                 embedding_channels=-1,
                 num_classes=0,
                 channels_cfg=None,
                 output_cfg=dict(mean='eps', var='learned_range'),
                 norm_cfg=dict(type='GN', num_groups=32),
                 act_cfg=dict(type='SiLU', inplace=False),
                 shortcut_kernel_size=1,
                 use_scale_shift_norm=False,
                 num_heads=4,
                 time_embedding_mode='sin',
                 time_embedding_cfg=None,
                 resblock_cfg=dict(type='DenoisingResBlock'),
                 attention_cfg=dict(type='MultiHeadAttention'),
                 downsample_conv=True,
                 upsample_conv=True,
                 downsample_cfg=dict(type='DenoisingDownsample'),
                 upsample_cfg=dict(type='DenoisingUpsample'),
                 attention_res=[16, 8],
                 pretrained=None):

        super().__init__()

        self.num_classes = num_classes
        self.num_timesteps = num_timesteps
        self.use_rescale_timesteps = use_rescale_timesteps

        self.output_cfg = deepcopy(output_cfg)
        self.mean_mode = self.output_cfg.get('mean', 'eps')
        self.var_mode = self.output_cfg.get('var', 'learned_range')

        # double output_channels to output mean and var at same time
        out_channels = in_channels if self.var_mode is None \
            else 2 * in_channels
        self.out_channels = out_channels

        # check type of image_size
        if not isinstance(image_size, int) and not isinstance(
                image_size, list):
            raise TypeError(
                'Only support `int` and `list[int]` for `image_size`.')
        if isinstance(image_size, list):
            assert len(
                image_size) == 2, 'The length of `image_size` should be 2.'
            assert image_size[0] == image_size[
                1], 'Width and height of the image should be same.'
            image_size = image_size[0]
        self.image_size = image_size

        channels_cfg = deepcopy(self._default_channels_cfg) \
            if channels_cfg is None else deepcopy(channels_cfg)
        if isinstance(channels_cfg, dict):
            if image_size not in channels_cfg:
                raise KeyError(f'`image_size={image_size} is not found in '
                               '`channels_cfg`, only support configs for '
                               f'{[chn for chn in channels_cfg.keys()]}')
            self.channel_factor_list = channels_cfg[image_size]
        elif isinstance(channels_cfg, list):
            self.channel_factor_list = channels_cfg
        else:
            raise ValueError('Only support list or dict for `channels_cfg`, '
                             f'receive {type(channels_cfg)}')

        embedding_channels = base_channels * 4 \
            if embedding_channels == -1 else embedding_channels
        self.time_embedding = TimeEmbedding(
            base_channels,
            embedding_channels=embedding_channels,
            embedding_mode=time_embedding_mode,
            embedding_cfg=time_embedding_cfg,
            act_cfg=act_cfg)

        if self.num_classes != 0:
            self.label_embedding = nn.Embedding(self.num_classes,
                                                embedding_channels)

        self.resblock_cfg = deepcopy(resblock_cfg)
        self.resblock_cfg.setdefault('dropout', dropout)
        self.resblock_cfg.setdefault('norm_cfg', norm_cfg)
        self.resblock_cfg.setdefault('act_cfg', act_cfg)
        self.resblock_cfg.setdefault('embedding_channels', embedding_channels)
        self.resblock_cfg.setdefault('use_scale_shift_norm',
                                     use_scale_shift_norm)
        self.resblock_cfg.setdefault('shortcut_kernel_size',
                                     shortcut_kernel_size)

        # get scales of ResBlock to apply attention
        attention_scale = [image_size // int(res) for res in attention_res]
        self.attention_cfg = deepcopy(attention_cfg)
        self.attention_cfg.setdefault('num_heads', num_heads)
        self.attention_cfg.setdefault('norm_cfg', norm_cfg)

        self.downsample_cfg = deepcopy(downsample_cfg)
        self.downsample_cfg.setdefault('with_conv', downsample_conv)
        self.upsample_cfg = deepcopy(upsample_cfg)
        self.upsample_cfg.setdefault('with_conv', upsample_conv)

        # init the channel scale factor
        scale = 1
        self.in_blocks = nn.ModuleList([
            EmbedSequential(
                nn.Conv2d(in_channels, base_channels, 3, 1, padding=1))
        ])
        self.in_channels_list = [base_channels]

        # construct the encoder part of Unet
        for level, factor in enumerate(self.channel_factor_list):
            in_channels_ = base_channels if level == 0 \
                else base_channels * self.channel_factor_list[level - 1]
            out_channels_ = base_channels * factor

            for _ in range(resblocks_per_downsample):
                layers = [
                    build_module(self.resblock_cfg, {
                        'in_channels': in_channels_,
                        'out_channels': out_channels_
                    })
                ]
                in_channels_ = out_channels_

                if scale in attention_scale:
                    layers.append(
                        build_module(self.attention_cfg,
                                     {'in_channels': in_channels_}))

                self.in_channels_list.append(in_channels_)
                self.in_blocks.append(EmbedSequential(*layers))

            if level != len(self.channel_factor_list) - 1:
                self.in_blocks.append(
                    EmbedSequential(
                        build_module(self.downsample_cfg,
                                     {'in_channels': in_channels_})))
                self.in_channels_list.append(in_channels_)
                scale *= 2

        # construct the bottom part of Unet
        self.mid_blocks = EmbedSequential(
            build_module(self.resblock_cfg, {'in_channels': in_channels_}),
            build_module(self.attention_cfg, {'in_channels': in_channels_}),
            build_module(self.resblock_cfg, {'in_channels': in_channels_}),
        )

        # construct the decoder part of Unet
        in_channels_list = deepcopy(self.in_channels_list)
        self.out_blocks = nn.ModuleList()
        for level, factor in enumerate(self.channel_factor_list[::-1]):
            for idx in range(resblocks_per_downsample + 1):
                layers = [
                    build_module(
                        self.resblock_cfg, {
                            'in_channels':
                            in_channels_ + in_channels_list.pop(),
                            'out_channels': base_channels * factor
                        })
                ]
                in_channels_ = base_channels * factor
                if scale in attention_scale:
                    layers.append(
                        build_module(self.attention_cfg,
                                     {'in_channels': in_channels_}))
                if (level != len(self.channel_factor_list) - 1
                        and idx == resblocks_per_downsample):
                    layers.append(
                        build_module(self.upsample_cfg,
                                     {'in_channels': in_channels_}))
                    scale //= 2
                self.out_blocks.append(EmbedSequential(*layers))

        self.out = ConvModule(in_channels=in_channels_,
                              out_channels=out_channels,
                              kernel_size=3,
                              padding=1,
                              act_cfg=act_cfg,
                              norm_cfg=norm_cfg,
                              bias=True,
                              order=('norm', 'act', 'conv'))

        self.init_weights(pretrained)