def __init__(self, out_size, style_channels, img_channels, noise_size=512, rgb2bgr=False, pretrained=None, synthesis_cfg=dict(type='SynthesisNetwork'), mapping_cfg=dict(type='MappingNetwork')): super().__init__() self.noise_size = noise_size self.style_channels = style_channels self.out_size = out_size self.img_channels = img_channels self.rgb2bgr = rgb2bgr self._synthesis_cfg = deepcopy(synthesis_cfg) self._synthesis_cfg.setdefault('style_channels', style_channels) self._synthesis_cfg.setdefault('out_size', out_size) self._synthesis_cfg.setdefault('img_channels', img_channels) self.synthesis = build_module(self._synthesis_cfg) self.num_ws = self.synthesis.num_ws self._mapping_cfg = deepcopy(mapping_cfg) self._mapping_cfg.setdefault('noise_size', noise_size) self._mapping_cfg.setdefault('style_channels', style_channels) self._mapping_cfg.setdefault('num_ws', self.num_ws) self.style_mapping = build_module(self._mapping_cfg) if pretrained is not None: self._load_pretrained_model(**pretrained)
def __init__(self, generator, discriminator, gan_loss, disc_auxiliary_loss, gen_auxiliary_loss=None, train_cfg=None, test_cfg=None): super().__init__() self._gen_cfg = deepcopy(generator) self.generator = build_module(generator) # support no discriminator in testing if discriminator is not None: self.discriminator = build_module(discriminator) else: self.discriminator = None # support no gan_loss in testing if gan_loss is not None: self.gan_loss = build_module(gan_loss) else: self.gan_loss = None if disc_auxiliary_loss: self.disc_auxiliary_losses = build_module(disc_auxiliary_loss) if not isinstance(self.disc_auxiliary_losses, nn.ModuleList): self.disc_auxiliary_losses = nn.ModuleList( [self.disc_auxiliary_losses]) else: self.disc_auxiliary_losses = None if gen_auxiliary_loss: self.gen_auxiliary_losses = build_module(gen_auxiliary_loss) if not isinstance(self.gen_auxiliary_losses, nn.ModuleList): self.gen_auxiliary_losses = nn.ModuleList( [self.gen_auxiliary_losses]) else: self.gen_auxiliary_losses = None # register necessary training status self.register_buffer('shown_nkimg', torch.tensor(0.)) self.register_buffer('_curr_transition_weight', torch.tensor(1.)) self.train_cfg = deepcopy(train_cfg) if train_cfg else None self.test_cfg = deepcopy(test_cfg) if test_cfg else None self._parse_train_cfg() # this buffer is used to resume model easily self.register_buffer( '_next_scale_int', torch.tensor(self.scales[0][0], dtype=torch.int32)) # TODO: init it with the same value as `_next_scale_int` # a dirty workaround for testing self.register_buffer( '_curr_scale_int', torch.tensor(self.scales[-1][0], dtype=torch.int32)) if test_cfg is not None: self._parse_test_cfg()
def __init__(self, generator, discriminator, gan_loss, disc_auxiliary_loss, gen_auxiliary_loss=None, train_cfg=None, test_cfg=None): super().__init__() self._gen_cfg = deepcopy(generator) self.generator = build_module(generator) # support no discriminator in testing if discriminator is not None: self.discriminator = build_module(discriminator) else: self.discriminator = None # support no gan_loss in testing if gan_loss is not None: self.gan_loss = build_module(gan_loss) else: self.gan_loss = None if disc_auxiliary_loss: self.disc_auxiliary_losses = build_module(disc_auxiliary_loss) if not isinstance(self.disc_auxiliary_losses, nn.ModuleList): self.disc_auxiliary_losses = nn.ModuleList( [self.disc_auxiliary_losses]) else: self.disc_auxiliary_losses = None if gen_auxiliary_loss: self.gen_auxiliary_losses = build_module(gen_auxiliary_loss) if not isinstance(self.gen_auxiliary_losses, nn.ModuleList): self.gen_auxiliary_losses = nn.ModuleList( [self.gen_auxiliary_losses]) else: self.gen_auxiliary_losses = None # register necessary training status self.curr_stage = -1 self.noise_weights = [1] self.fixed_noises = [] self.reals = [] self.train_cfg = deepcopy(train_cfg) if train_cfg else None self.test_cfg = deepcopy(test_cfg) if test_cfg else None self._parse_train_cfg() if test_cfg is not None: self._parse_test_cfg()
def test_mse_loss(self): # test forward config = deepcopy(self.config) config['rescale_mode'] = 'timestep_weight' loss_fn = build_module(config, default_args=dict(weight=self.weight)) loss = loss_fn(self.output_dict) np.allclose(loss, self.loss_manually) # test reduction raise error config = deepcopy(self.config) config['reduction'] = 'reduction' with pytest.raises(ValueError): loss_fn = build_module(config) # test return loss name config = deepcopy(self.config) config['loss_name'] = 'loss_name' loss_fn = build_module(config) assert loss_fn.loss_name() == 'loss_name'
def __init__(self, input_scale, num_classes=0, in_channels=3, out_channels=1, base_channels=96, sn_eps=1e-6, init_type='ortho', act_cfg=dict(type='ReLU'), with_spectral_norm=True, blocks_cfg=dict(type='BigGANDiscResBlock'), arch_cfg=None, pretrained=None): super().__init__() self.num_classes = num_classes self.out_channels = out_channels self.input_scale = input_scale self.in_channels = in_channels self.base_channels = base_channels self.arch = arch_cfg if arch_cfg else self._get_default_arch_cfg( self.input_scale, self.in_channels, self.base_channels) self.blocks_cfg = deepcopy(blocks_cfg) self.blocks_cfg.update( dict(act_cfg=act_cfg, sn_eps=sn_eps, with_spectral_norm=with_spectral_norm)) self.conv_blocks = nn.ModuleList() for index, out_ch in enumerate(self.arch['out_channels']): # change args to adapt to current block self.blocks_cfg.update( dict(in_channels=self.arch['in_channels'][index], out_channels=out_ch, with_downsample=self.arch['downsample'][index], is_head_block=(index == 0))) self.conv_blocks.append(build_module(self.blocks_cfg)) if self.arch['attention'][index]: self.conv_blocks.append( SelfAttentionBlock(out_ch, with_spectral_norm=with_spectral_norm, sn_eps=sn_eps)) self.activate = build_activation_layer(act_cfg) self.decision = nn.Linear(self.arch['out_channels'][-1], out_channels) if with_spectral_norm: self.decision = spectral_norm(self.decision, eps=sn_eps) if self.num_classes > 0: self.proj_y = nn.Embedding(self.num_classes, self.arch['out_channels'][-1]) if with_spectral_norm: self.proj_y = spectral_norm(self.proj_y, eps=sn_eps) self.init_weights(pretrained=pretrained, init_type=init_type)
def __init__(self, in_channels, embedding_channels, use_scale_shift_norm, dropout, out_channels=None, norm_cfg=dict(type='GN', num_groups=32), act_cfg=dict(type='SiLU', inplace=False), shortcut_kernel_size=1): super().__init__() out_channels = in_channels if out_channels is None else out_channels _norm_cfg = deepcopy(norm_cfg) _, norm_1 = build_norm_layer(_norm_cfg, in_channels) conv_1 = [ norm_1, build_activation_layer(act_cfg), nn.Conv2d(in_channels, out_channels, 3, padding=1) ] self.conv_1 = nn.Sequential(*conv_1) norm_with_embedding_cfg = dict( in_channels=out_channels, embedding_channels=embedding_channels, use_scale_shift=use_scale_shift_norm, norm_cfg=_norm_cfg) self.norm_with_embedding = build_module( dict(type='NormWithEmbedding'), default_args=norm_with_embedding_cfg) conv_2 = [ build_activation_layer(act_cfg), nn.Dropout(dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1) ] self.conv_2 = nn.Sequential(*conv_2) assert shortcut_kernel_size in [ 1, 3 ], ('Only support `1` and `3` for `shortcut_kernel_size`, but ' f'receive {shortcut_kernel_size}.') self.learnable_shortcut = out_channels != in_channels if self.learnable_shortcut: shortcut_padding = 1 if shortcut_kernel_size == 3 else 0 self.shortcut = nn.Conv2d( in_channels, out_channels, shortcut_kernel_size, padding=shortcut_padding) self.init_weights()
def __init__(self, in_size, *args, data_aug=None, **kwargs): """StyleGANv2 Discriminator with adaptive augmentation. Args: in_size (int): The input size of images. data_aug (dict, optional): Config for data augmentation. Defaults to None. """ super().__init__(in_size, *args, **kwargs) self.with_ada = data_aug is not None if self.with_ada: self.ada_aug = build_module(data_aug) self.ada_aug.requires_grad = False self.log_size = int(np.log2(in_size))
def test_vlb_loss(self): # test forward config = deepcopy(self.config) loss_fn = build_module(config) loss = loss_fn(self.output_dict) np.allclose(loss, self.loss_manually * 4) # test log_cfgs --> dict input config = deepcopy(self.config) config['log_cfgs'] = dict(type='name') loss_fn = build_module(config) assert isinstance(loss_fn.log_fn_list, list) # test log_cfgs --> no log_cfgs config = deepcopy(self.config) config['log_cfgs'] = None loss_fn = build_module(config) loss = loss_fn(self.output_dict) assert not loss_fn.log_vars # test rescale_cfg --> rescale is None config = deepcopy(self.config) config['rescale_mode'] = None loss_fn = build_module(config) loss = loss_fn(self.output_dict) np.allclose(loss, self.loss_manually) # TODO: test rescale_cfg --> test sampler # test rescale_cfg --> test weight config = deepcopy(self.config) config['rescale_mode'] = 'timestep_weight' weight = self.weight.clone() loss_fn = build_module(config, default_args=dict(weight=weight)) loss = loss_fn(self.output_dict) loss_weighted_manually = ( -(self.loss_disc_likelihood * weight)[0] + (self.loss_gaussian_kld * weight)[1:].sum()) / 4 np.allclose(loss, loss_weighted_manually) # test rescale_cfg --> change weight weight[0] += 1 loss = loss_fn(self.output_dict) loss_weighted_manually = ( -(self.loss_disc_likelihood * weight)[0] + (self.loss_gaussian_kld * weight)[1:].sum()) / 4 np.allclose(loss, loss_weighted_manually) # test t = 0 config = deepcopy(self.config) output_dict = deepcopy(self.output_dict) output_dict['timesteps'][0] = 1 loss_fn = build_module(config) loss = loss_fn(output_dict) assert loss_fn.log_vars['loss_vlb_quartile_0'] == 0 assert loss_fn.log_vars['loss_DiscGaussianLogLikelihood'] == 0
def __init__(self, in_channels, out_channels, num_scales, kernel_size=3, padding=0, num_layers=5, base_channels=32, min_feat_channels=32, out_act_cfg=dict(type='Tanh'), padding_mode='zero', pad_at_head=True, interp_pad=False, noise_with_pad=False, positional_encoding=None, first_stage_in_channels=None, **kwargs): super(SinGANMultiScaleGenerator, self).__init__() self.pad_at_head = pad_at_head self.interp_pad = interp_pad self.noise_with_pad = noise_with_pad self.with_positional_encode = positional_encoding is not None if self.with_positional_encode: self.head_position_encode = build_module(positional_encoding) self.pad_head = int((kernel_size - 1) / 2 * num_layers) self.blocks = nn.ModuleList() self.upsample = partial(F.interpolate, mode='bicubic', align_corners=True) for scale in range(num_scales + 1): base_ch = min(base_channels * pow(2, math.floor(scale / 4)), 128) min_feat_ch = min( min_feat_channels * pow(2, math.floor(scale / 4)), 128) if scale == 0: in_ch = (first_stage_in_channels if first_stage_in_channels else in_channels) else: in_ch = in_channels self.blocks.append( GeneratorBlock(in_channels=in_ch, out_channels=out_channels, kernel_size=kernel_size, padding=padding, num_layers=num_layers, base_channels=base_ch, min_feat_channels=min_feat_ch, out_act_cfg=out_act_cfg, padding_mode=padding_mode, **kwargs)) if padding_mode == 'zero': self.noise_padding_layer = nn.ZeroPad2d(self.pad_head) self.img_padding_layer = nn.ZeroPad2d(self.pad_head) self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head) elif padding_mode == 'reflect': self.noise_padding_layer = nn.ReflectionPad2d(self.pad_head) self.img_padding_layer = nn.ReflectionPad2d(self.pad_head) self.mask_padding_layer = nn.ReflectionPad2d(self.pad_head) mmcv.print_log('Using Reflection padding', 'mmgen') else: raise NotImplementedError( f'Padding mode {padding_mode} is not supported')
def __init__(self, out_size, style_channels, num_mlps=8, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, default_style_mode='mix', eval_style_mode='single', mix_prob=0.9, no_pad=False, deconv2conv=False, interp_pad=None, up_config=dict(scale_factor=2, mode='nearest'), up_after_conv=False, head_pos_encoding=None, head_pos_size=(4, 4), interp_head=False): super().__init__() self.out_size = out_size self.style_channels = style_channels self.num_mlps = num_mlps self.channel_multiplier = channel_multiplier self.lr_mlp = lr_mlp self._default_style_mode = default_style_mode self.default_style_mode = default_style_mode self.eval_style_mode = eval_style_mode self.mix_prob = mix_prob self.no_pad = no_pad self.deconv2conv = deconv2conv self.interp_pad = interp_pad self.with_interp_pad = interp_pad is not None self.up_config = deepcopy(up_config) self.up_after_conv = up_after_conv self.head_pos_encoding = head_pos_encoding self.head_pos_size = head_pos_size self.interp_head = interp_head # define style mapping layers mapping_layers = [PixelNorm()] for _ in range(num_mlps): mapping_layers.append( EqualLinearActModule(style_channels, style_channels, equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.), act_cfg=dict(type='fused_bias'))) self.style_mapping = nn.Sequential(*mapping_layers) self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } in_ch = self.channels[4] # constant input layer if self.head_pos_encoding: if self.head_pos_encoding['type'] in [ 'CatersianGrid', 'CSG', 'CSG2d' ]: in_ch = 2 self.head_pos_enc = build_module(self.head_pos_encoding) else: size_ = 4 if self.no_pad: size_ += 2 self.constant_input = ConstantInput(self.channels[4], size=size_) # 4x4 stage self.conv1 = ModulatedPEStyleConv(in_ch, self.channels[4], kernel_size=3, style_channels=style_channels, blur_kernel=blur_kernel, deconv2conv=self.deconv2conv, no_pad=self.no_pad, up_config=self.up_config, interp_pad=self.interp_pad) self.to_rgb1 = ModulatedToRGB(self.channels[4], style_channels, upsample=False) # generator backbone (8x8 --> higher resolutions) self.log_size = int(np.log2(self.out_size)) self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() in_channels_ = self.channels[4] for i in range(3, self.log_size + 1): out_channels_ = self.channels[2**i] self.convs.append( ModulatedPEStyleConv(in_channels_, out_channels_, 3, style_channels, upsample=True, blur_kernel=blur_kernel, deconv2conv=self.deconv2conv, no_pad=self.no_pad, up_config=self.up_config, interp_pad=self.interp_pad, up_after_conv=self.up_after_conv)) self.convs.append( ModulatedPEStyleConv(out_channels_, out_channels_, 3, style_channels, upsample=False, blur_kernel=blur_kernel, deconv2conv=self.deconv2conv, no_pad=self.no_pad, up_config=self.up_config, interp_pad=self.interp_pad, up_after_conv=self.up_after_conv)) self.to_rgbs.append( ModulatedToRGB(out_channels_, style_channels, upsample=True)) in_channels_ = out_channels_ self.num_latents = self.log_size * 2 - 2 self.num_injected_noises = self.num_latents - 1 # register buffer for injected noises noises = self.make_injected_noise() for layer_idx in range(self.num_injected_noises): self.register_buffer(f'injected_noise_{layer_idx}', noises[layer_idx])
def __init__(self, input_scale, num_classes=0, base_channels=128, input_channels=3, attention_cfg=dict(type='SelfAttentionBlock'), attention_after_nth_block=-1, channels_cfg=None, downsample_cfg=None, from_rgb_cfg=dict(type='SNGANDiscHeadResBlock'), blocks_cfg=dict(type='SNGANDiscResBlock'), act_cfg=dict(type='ReLU'), with_spectral_norm=True, sn_eps=1e-12, init_cfg=dict(type='BigGAN'), pretrained=None): super().__init__() self.init_type = init_cfg.get('type', None) # add SN options and activation function options to cfg self.from_rgb_cfg = deepcopy(from_rgb_cfg) self.from_rgb_cfg.setdefault('act_cfg', act_cfg) self.from_rgb_cfg.setdefault('with_spectral_norm', with_spectral_norm) self.from_rgb_cfg.setdefault('init_cfg', init_cfg) # add SN options and activation function options to cfg self.blocks_cfg = deepcopy(blocks_cfg) self.blocks_cfg.setdefault('act_cfg', act_cfg) self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm) self.blocks_cfg.setdefault('sn_eps', sn_eps) self.blocks_cfg.setdefault('init_cfg', init_cfg) channels_cfg = deepcopy(self._defualt_channels_cfg) \ if channels_cfg is None else deepcopy(channels_cfg) if isinstance(channels_cfg, dict): if input_scale not in channels_cfg: raise KeyError(f'`input_scale={input_scale} is not found in ' '`channel_cfg`, only support configs for ' f'{[chn for chn in channels_cfg.keys()]}') self.channel_factor_list = channels_cfg[input_scale] elif isinstance(channels_cfg, list): self.channel_factor_list = channels_cfg else: raise ValueError('Only support list or dict for `channel_cfg`, ' f'receive {type(channels_cfg)}') downsample_cfg = deepcopy(self._defualt_downsample_cfg) \ if downsample_cfg is None else deepcopy(downsample_cfg) if isinstance(downsample_cfg, dict): if input_scale not in downsample_cfg: raise KeyError(f'`output_scale={input_scale} is not found in ' '`downsample_cfg`, only support configs for ' f'{[chn for chn in downsample_cfg.keys()]}') self.downsample_list = downsample_cfg[input_scale] elif isinstance(downsample_cfg, list): self.downsample_list = downsample_cfg else: raise ValueError('Only support list or dict for `channel_cfg`, ' f'receive {type(downsample_cfg)}') if len(self.downsample_list) != len(self.channel_factor_list): raise ValueError('`downsample_cfg` should have same length with ' '`channels_cfg`, but receive ' f'{len(self.downsample_list)} and ' f'{len(self.channel_factor_list)}.') # check `attention_after_nth_block` if not isinstance(attention_after_nth_block, list): attention_after_nth_block = [attention_after_nth_block] if not all([isinstance(idx, int) for idx in attention_after_nth_block]): raise ValueError('`attention_after_nth_block` only support int or ' 'a list of int. Please check your input type.') self.from_rgb = build_module( self.from_rgb_cfg, dict(in_channels=input_channels, out_channels=base_channels)) self.conv_blocks = nn.ModuleList() # add self-attention block after the first block if 1 in attention_after_nth_block: attn_cfg_ = deepcopy(attention_cfg) attn_cfg_['in_channels'] = base_channels self.conv_blocks.append(build_module(attn_cfg_)) for idx in range(len(self.downsample_list)): factor_input = 1 if idx == 0 else self.channel_factor_list[idx - 1] factor_output = self.channel_factor_list[idx] # get block-specific config block_cfg_ = deepcopy(self.blocks_cfg) block_cfg_['downsample'] = self.downsample_list[idx] block_cfg_['in_channels'] = factor_input * base_channels block_cfg_['out_channels'] = factor_output * base_channels self.conv_blocks.append(build_module(block_cfg_)) # build self-attention block # the first ConvBlock is `from_rgb` block, # add 2 to get the index of the ConvBlocks if idx + 2 in attention_after_nth_block: attn_cfg_ = deepcopy(attention_cfg) attn_cfg_['in_channels'] = factor_output * base_channels self.conv_blocks.append(build_module(attn_cfg_)) self.decision = nn.Linear(factor_output * base_channels, 1) if with_spectral_norm: self.decision = spectral_norm(self.decision) self.num_classes = num_classes # In this case, discriminator is designed for conditional synthesis. if num_classes > 0: self.proj_y = nn.Embedding(num_classes, factor_output * base_channels) if with_spectral_norm: self.proj_y = spectral_norm(self.proj_y) self.activate = build_activation_layer(act_cfg) self.init_weights(pretrained)
def __init__(self, output_scale, num_classes=0, base_channels=64, out_channels=3, input_scale=4, noise_size=128, attention_cfg=dict(type='SelfAttentionBlock'), attention_after_nth_block=0, channels_cfg=None, blocks_cfg=dict(type='SNGANGenResBlock'), act_cfg=dict(type='ReLU'), use_cbn=True, auto_sync_bn=True, with_spectral_norm=False, with_embedding_spectral_norm=None, norm_eps=1e-4, sn_eps=1e-12, init_cfg=dict(type='BigGAN'), pretrained=None): super().__init__() self.input_scale = input_scale self.output_scale = output_scale self.noise_size = noise_size self.num_classes = num_classes self.init_type = init_cfg.get('type', None) self.blocks_cfg = deepcopy(blocks_cfg) self.blocks_cfg.setdefault('num_classes', num_classes) self.blocks_cfg.setdefault('act_cfg', act_cfg) self.blocks_cfg.setdefault('use_cbn', use_cbn) self.blocks_cfg.setdefault('auto_sync_bn', auto_sync_bn) self.blocks_cfg.setdefault('with_spectral_norm', with_spectral_norm) # set `norm_spectral_norm` as `with_spectral_norm` if not defined with_embedding_spectral_norm = with_embedding_spectral_norm \ if with_embedding_spectral_norm is not None else with_spectral_norm self.blocks_cfg.setdefault('with_embedding_spectral_norm', with_embedding_spectral_norm) self.blocks_cfg.setdefault('init_cfg', init_cfg) self.blocks_cfg.setdefault('norm_eps', norm_eps) self.blocks_cfg.setdefault('sn_eps', sn_eps) channels_cfg = deepcopy(self._default_channels_cfg) \ if channels_cfg is None else deepcopy(channels_cfg) if isinstance(channels_cfg, dict): if output_scale not in channels_cfg: raise KeyError(f'`output_scale={output_scale} is not found in ' '`channel_cfg`, only support configs for ' f'{[chn for chn in channels_cfg.keys()]}') self.channel_factor_list = channels_cfg[output_scale] elif isinstance(channels_cfg, list): self.channel_factor_list = channels_cfg else: raise ValueError('Only support list or dict for `channel_cfg`, ' f'receive {type(channels_cfg)}') self.noise2feat = nn.Linear( noise_size, input_scale**2 * base_channels * self.channel_factor_list[0]) if with_spectral_norm: self.noise2feat = spectral_norm(self.noise2feat) # check `attention_after_nth_block` if not isinstance(attention_after_nth_block, list): attention_after_nth_block = [attention_after_nth_block] if not is_list_of(attention_after_nth_block, int): raise ValueError('`attention_after_nth_block` only support int or ' 'a list of int. Please check your input type.') self.conv_blocks = nn.ModuleList() self.attention_block_idx = [] for idx in range(len(self.channel_factor_list)): factor_input = self.channel_factor_list[idx] factor_output = self.channel_factor_list[idx+1] \ if idx < len(self.channel_factor_list)-1 else 1 # get block-specific config block_cfg_ = deepcopy(self.blocks_cfg) block_cfg_['in_channels'] = factor_input * base_channels block_cfg_['out_channels'] = factor_output * base_channels self.conv_blocks.append(build_module(block_cfg_)) # build self-attention block # `idx` is start from 0, add 1 to get the index if idx + 1 in attention_after_nth_block: self.attention_block_idx.append(len(self.conv_blocks)) attn_cfg_ = deepcopy(attention_cfg) attn_cfg_['in_channels'] = factor_output * base_channels self.conv_blocks.append(build_module(attn_cfg_)) to_rgb_norm_cfg = dict(type='BN', eps=norm_eps) if check_dist_init() and auto_sync_bn: to_rgb_norm_cfg['type'] = 'SyncBN' self.to_rgb = ConvModule(factor_output * base_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, norm_cfg=to_rgb_norm_cfg, act_cfg=act_cfg, order=('norm', 'act', 'conv'), with_spectral_norm=with_spectral_norm) self.final_act = build_activation_layer(dict(type='Tanh')) self.init_weights(pretrained)
def __init__(self, output_scale, noise_size=120, num_classes=0, out_channels=3, base_channels=96, input_scale=4, with_shared_embedding=True, shared_dim=128, sn_eps=1e-6, init_type='ortho', split_noise=True, act_cfg=dict(type='ReLU'), upsample_cfg=dict(type='nearest', scale_factor=2), with_spectral_norm=True, auto_sync_bn=True, blocks_cfg=dict(type='BigGANGenResBlock'), arch_cfg=None, out_norm_cfg=dict(type='BN'), pretrained=None, rgb2bgr=False): super().__init__() self.noise_size = noise_size self.num_classes = num_classes self.shared_dim = shared_dim self.with_shared_embedding = with_shared_embedding self.output_scale = output_scale self.arch = arch_cfg if arch_cfg else self._get_default_arch_cfg( self.output_scale, base_channels) self.input_scale = input_scale self.split_noise = split_noise self.blocks_cfg = deepcopy(blocks_cfg) self.upsample_cfg = deepcopy(upsample_cfg) self.rgb2bgr = rgb2bgr # Validity Check # If 'num_classes' equals to zero, we shall set 'with_shared_embedding' # to False. if num_classes == 0: assert not self.with_shared_embedding else: if not self.with_shared_embedding: # If not `with_shared_embedding`, we will use `nn.Embedding` to # replace the original `Linear` layer in conditional BN. # Meanwhile, we do not adopt split noises. assert not self.split_noise # If using split latents, we may need to adjust noise_size if self.split_noise: # Number of places z slots into self.num_slots = len(self.arch['in_channels']) + 1 self.noise_chunk_size = self.noise_size // self.num_slots # Recalculate latent dimensionality for even splitting into chunks self.noise_size = self.noise_chunk_size * self.num_slots else: self.num_slots = 1 self.noise_chunk_size = 0 # First linear layer self.noise2feat = nn.Linear( self.noise_size // self.num_slots, self.arch['in_channels'][0] * (self.input_scale**2)) if with_spectral_norm: self.noise2feat = spectral_norm(self.noise2feat, eps=sn_eps) # If using 'shared_embedding', we will get an unified embedding of # label for all blocks. If not, we just pass the label to each # block. if with_shared_embedding: self.shared_embedding = nn.Embedding(num_classes, shared_dim) else: self.shared_embedding = nn.Identity() if num_classes > 0: self.dim_after_concat = (self.shared_dim + self.noise_chunk_size if self.with_shared_embedding else self.num_classes) else: self.dim_after_concat = self.noise_chunk_size self.blocks_cfg.update( dict(dim_after_concat=self.dim_after_concat, act_cfg=act_cfg, sn_eps=sn_eps, input_is_label=(num_classes > 0) and (not with_shared_embedding), with_spectral_norm=with_spectral_norm, auto_sync_bn=auto_sync_bn)) self.conv_blocks = nn.ModuleList() for index, out_ch in enumerate(self.arch['out_channels']): # change args to adapt to current block self.blocks_cfg.update( dict(in_channels=self.arch['in_channels'][index], out_channels=out_ch, upsample_cfg=self.upsample_cfg if self.arch['upsample'][index] else None)) self.conv_blocks.append(build_module(self.blocks_cfg)) if self.arch['attention'][index]: self.conv_blocks.append( SelfAttentionBlock(out_ch, with_spectral_norm=with_spectral_norm, sn_eps=sn_eps)) self.output_layer = SNConvModule(self.arch['out_channels'][-1], out_channels, kernel_size=3, padding=1, with_spectral_norm=with_spectral_norm, spectral_norm_cfg=dict(eps=sn_eps), act_cfg=act_cfg, norm_cfg=out_norm_cfg, bias=True, order=('norm', 'act', 'conv')) self.init_weights(pretrained=pretrained, init_type=init_type)
def __init__(self, generator, discriminator, gan_loss, cycle_loss, id_loss=None, train_cfg=None, test_cfg=None, pretrained=None): super().__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg # identity loss only works when input and output images have the same # number of channels if id_loss is not None and id_loss.get('loss_weight') > 0.0: assert generator.get('in_channels') == generator.get( 'out_channels') # generators self.generators = nn.ModuleDict() self.generators['a'] = build_module(generator) self.generators['b'] = build_module(generator) # discriminators self.discriminators = nn.ModuleDict() self.discriminators['a'] = build_module(discriminator) self.discriminators['b'] = build_module(discriminator) # GAN image buffers self.image_buffers = dict() self.buffer_size = (50 if self.train_cfg is None else self.train_cfg.get('buffer_size', 50)) self.image_buffers['a'] = GANImageBuffer(self.buffer_size) self.image_buffers['b'] = GANImageBuffer(self.buffer_size) # losses assert gan_loss is not None # gan loss cannot be None self.gan_loss = build_module(gan_loss) assert cycle_loss is not None # cycle loss cannot be None self.cycle_loss = build_module(cycle_loss) self.id_loss = build_module(id_loss) if id_loss else None # others self.disc_steps = 1 if self.train_cfg is None else self.train_cfg.get( 'disc_steps', 1) self.disc_init_steps = (0 if self.train_cfg is None else self.train_cfg.get('disc_init_steps', 0)) if self.train_cfg is None: self.direction = ('a2b' if self.test_cfg is None else self.test_cfg.get('direction', 'a2b')) else: self.direction = self.train_cfg.get('direction', 'a2b') self.step_counter = 0 # counting training steps self.show_input = (False if self.test_cfg is None else self.test_cfg.get('show_input', False)) # In CycleGAN, if not showing input, we can decide the translation # direction in the test mode, i.e., whether to output fake_b or fake_a if not self.show_input: self.test_direction = ('a2b' if self.test_cfg is None else self.test_cfg.get('test_direction', 'a2b')) if self.direction == 'b2a': self.test_direction = ('b2a' if self.test_direction == 'a2b' else 'a2b') self.init_weights(pretrained) self.use_ema = False
def __init__(self, image_size, in_channels=3, base_channels=128, resblocks_per_downsample=3, num_timesteps=1000, use_rescale_timesteps=True, dropout=0, embedding_channels=-1, num_classes=0, channels_cfg=None, output_cfg=dict(mean='eps', var='learned_range'), norm_cfg=dict(type='GN', num_groups=32), act_cfg=dict(type='SiLU', inplace=False), shortcut_kernel_size=1, use_scale_shift_norm=False, num_heads=4, time_embedding_mode='sin', time_embedding_cfg=None, resblock_cfg=dict(type='DenoisingResBlock'), attention_cfg=dict(type='MultiHeadAttention'), downsample_conv=True, upsample_conv=True, downsample_cfg=dict(type='DenoisingDownsample'), upsample_cfg=dict(type='DenoisingUpsample'), attention_res=[16, 8], pretrained=None): super().__init__() self.num_classes = num_classes self.num_timesteps = num_timesteps self.use_rescale_timesteps = use_rescale_timesteps self.output_cfg = deepcopy(output_cfg) self.mean_mode = self.output_cfg.get('mean', 'eps') self.var_mode = self.output_cfg.get('var', 'learned_range') # double output_channels to output mean and var at same time out_channels = in_channels if self.var_mode is None \ else 2 * in_channels self.out_channels = out_channels # check type of image_size if not isinstance(image_size, int) and not isinstance( image_size, list): raise TypeError( 'Only support `int` and `list[int]` for `image_size`.') if isinstance(image_size, list): assert len( image_size) == 2, 'The length of `image_size` should be 2.' assert image_size[0] == image_size[ 1], 'Width and height of the image should be same.' image_size = image_size[0] self.image_size = image_size channels_cfg = deepcopy(self._default_channels_cfg) \ if channels_cfg is None else deepcopy(channels_cfg) if isinstance(channels_cfg, dict): if image_size not in channels_cfg: raise KeyError(f'`image_size={image_size} is not found in ' '`channels_cfg`, only support configs for ' f'{[chn for chn in channels_cfg.keys()]}') self.channel_factor_list = channels_cfg[image_size] elif isinstance(channels_cfg, list): self.channel_factor_list = channels_cfg else: raise ValueError('Only support list or dict for `channels_cfg`, ' f'receive {type(channels_cfg)}') embedding_channels = base_channels * 4 \ if embedding_channels == -1 else embedding_channels self.time_embedding = TimeEmbedding( base_channels, embedding_channels=embedding_channels, embedding_mode=time_embedding_mode, embedding_cfg=time_embedding_cfg, act_cfg=act_cfg) if self.num_classes != 0: self.label_embedding = nn.Embedding(self.num_classes, embedding_channels) self.resblock_cfg = deepcopy(resblock_cfg) self.resblock_cfg.setdefault('dropout', dropout) self.resblock_cfg.setdefault('norm_cfg', norm_cfg) self.resblock_cfg.setdefault('act_cfg', act_cfg) self.resblock_cfg.setdefault('embedding_channels', embedding_channels) self.resblock_cfg.setdefault('use_scale_shift_norm', use_scale_shift_norm) self.resblock_cfg.setdefault('shortcut_kernel_size', shortcut_kernel_size) # get scales of ResBlock to apply attention attention_scale = [image_size // int(res) for res in attention_res] self.attention_cfg = deepcopy(attention_cfg) self.attention_cfg.setdefault('num_heads', num_heads) self.attention_cfg.setdefault('norm_cfg', norm_cfg) self.downsample_cfg = deepcopy(downsample_cfg) self.downsample_cfg.setdefault('with_conv', downsample_conv) self.upsample_cfg = deepcopy(upsample_cfg) self.upsample_cfg.setdefault('with_conv', upsample_conv) # init the channel scale factor scale = 1 self.in_blocks = nn.ModuleList([ EmbedSequential( nn.Conv2d(in_channels, base_channels, 3, 1, padding=1)) ]) self.in_channels_list = [base_channels] # construct the encoder part of Unet for level, factor in enumerate(self.channel_factor_list): in_channels_ = base_channels if level == 0 \ else base_channels * self.channel_factor_list[level - 1] out_channels_ = base_channels * factor for _ in range(resblocks_per_downsample): layers = [ build_module(self.resblock_cfg, { 'in_channels': in_channels_, 'out_channels': out_channels_ }) ] in_channels_ = out_channels_ if scale in attention_scale: layers.append( build_module(self.attention_cfg, {'in_channels': in_channels_})) self.in_channels_list.append(in_channels_) self.in_blocks.append(EmbedSequential(*layers)) if level != len(self.channel_factor_list) - 1: self.in_blocks.append( EmbedSequential( build_module(self.downsample_cfg, {'in_channels': in_channels_}))) self.in_channels_list.append(in_channels_) scale *= 2 # construct the bottom part of Unet self.mid_blocks = EmbedSequential( build_module(self.resblock_cfg, {'in_channels': in_channels_}), build_module(self.attention_cfg, {'in_channels': in_channels_}), build_module(self.resblock_cfg, {'in_channels': in_channels_}), ) # construct the decoder part of Unet in_channels_list = deepcopy(self.in_channels_list) self.out_blocks = nn.ModuleList() for level, factor in enumerate(self.channel_factor_list[::-1]): for idx in range(resblocks_per_downsample + 1): layers = [ build_module( self.resblock_cfg, { 'in_channels': in_channels_ + in_channels_list.pop(), 'out_channels': base_channels * factor }) ] in_channels_ = base_channels * factor if scale in attention_scale: layers.append( build_module(self.attention_cfg, {'in_channels': in_channels_})) if (level != len(self.channel_factor_list) - 1 and idx == resblocks_per_downsample): layers.append( build_module(self.upsample_cfg, {'in_channels': in_channels_})) scale //= 2 self.out_blocks.append(EmbedSequential(*layers)) self.out = ConvModule(in_channels=in_channels_, out_channels=out_channels, kernel_size=3, padding=1, act_cfg=act_cfg, norm_cfg=norm_cfg, bias=True, order=('norm', 'act', 'conv')) self.init_weights(pretrained)