def __init__(self, D=None, is_enc_shared=True, dim_z=128, D_ch=64, D_wide=True, resolution=128, D_kernel_size=3, D_attn='64', n_classes=1000, num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, D_init='ortho', skip_init=False, D_param='SN', optimizer_type='adam', weight_decay=5e-4, dataset_channel=3, **kwargs): super(Encoder, self).__init__() # Width multiplier self.ch = D_ch # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? self.D_wide = D_wide # Resolution self.resolution = resolution # Kernel size self.kernel_size = D_kernel_size # Attention? self.attention = D_attn # Number of classes self.n_classes = n_classes # Activation self.activation = D_activation # Initialization style self.init = D_init # Parameterization style self.D_param = D_param # weight decay self.weight_decay = weight_decay # Epsilon for Spectral Norm? self.SN_eps = SN_eps # Fp16? self.fp16 = D_fp16 # Architecture self.arch = D_arch(self.ch, self.attention)[resolution] # Which convs, batchnorms, and linear layers to use # No option to turn off SN in D right now if self.D_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_embedding = functools.partial(layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) # Prepare model # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) if not is_enc_shared: self.arch['in_channels'][0] = dataset_channel self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[ layers.DBlock( in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(index > 0), downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None)) ]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [ layers.Attention(self.arch['out_channels'][index], self.which_conv) ] self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) else: self.blocks = D.blocks self.blocks = self.blocks[:-1] self.blocks2 = [] self.blocks2 += [[ layers.DBlock(in_channels=self.arch['in_channels'][-1], out_channels=self.arch['out_channels'][-1], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(True), downsample=(None)) ]] dim_reduction_enc = 2 for index in range(dim_reduction_enc): self.blocks2 += [[ layers.DBlock(in_channels=self.arch['in_channels'][-1], out_channels=self.arch['out_channels'][-1], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(True), downsample=(nn.AvgPool2d(2))) ]] self.blocks2 = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks2]) #output linear VAE reduction = 2**(dim_reduction_enc + np.sum(np.array(self.arch['downsample']))) o_res_dim = resolution // reduction self.input_linear_dim = self.arch['out_channels'][-1] * o_res_dim**2 print(reduction) print(o_res_dim) self.linear_mu1 = self.which_linear(self.input_linear_dim, self.input_linear_dim // 2) self.linear_lv1 = self.which_linear(self.input_linear_dim, self.input_linear_dim // 2) self.linear_mu2 = self.which_linear(self.input_linear_dim // 2, dim_z) self.linear_lv2 = self.which_linear(self.input_linear_dim // 2, dim_z) # Initialize weights if not skip_init: self.init_weights() # Set up optimizer self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps if D_mixed_precision: print('Using fp16 adam in D...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) else: if optimizer_type == 'adam': self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) elif optimizer_type == 'radam': self.optim = optimizers.RAdam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=self.weight_decay, eps=self.adam_eps) elif optimizer_type == 'ranger': self.optim = optimizers.Ranger(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=self.weight_decay, eps=self.adam_eps)
def __init__(self, D_ch=64, D_wide=True, resolution=128, D_kernel_size=3, D_attn='64', n_classes=1000, num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, D_init='ortho', skip_init=False, D_param='SN', **kwargs): super(Discriminator, self).__init__() # Width multiplier self.ch = D_ch # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? self.D_wide = D_wide # Resolution self.resolution = resolution # Kernel size self.kernel_size = D_kernel_size # Attention? self.attention = D_attn # Number of classes self.n_classes = n_classes # Activation self.activation = D_activation # Initialization style self.init = D_init # Parameterization style self.D_param = D_param # Epsilon for Spectral Norm? self.SN_eps = SN_eps # Fp16? self.fp16 = D_fp16 # Architecture self.arch = D_arch(self.ch, self.attention)[resolution] # Which convs, batchnorms, and linear layers to use # No option to turn off SN in D right now if self.D_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_embedding = functools.partial(layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) # Prepare model # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(index > 0), downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) # Linear output layer. The output dimension is typically 1, but may be # larger if we're e.g. turning this into a VAE with an inference output self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) # Embedding for projection discrimination self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) # Initialize weights if not skip_init: self.init_weights() # Set up optimizer self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps if D_mixed_precision: print('Using fp16 adam in D...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) else: self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
def __init__(self, D_ch=64, D_wide=True, resolution=128, D_kernel_size=3, D_attn='64', n_classes=1000, num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8, SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False, D_init='ortho', skip_init=False, D_param='SN', decoder_skip_connection=True, **kwargs): super(Unet_Discriminator, self).__init__() # Width multiplier self.ch = D_ch # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? self.D_wide = D_wide # Resolution self.resolution = resolution # Kernel size self.kernel_size = D_kernel_size # Attention? self.attention = D_attn # Number of classes self.n_classes = n_classes # Activation self.activation = D_activation # Initialization style self.init = D_init # Parameterization style self.D_param = D_param # Epsilon for Spectral Norm? self.SN_eps = SN_eps # Fp16? self.fp16 = D_fp16 if self.resolution == 128: self.save_features = [0, 1, 2, 3, 4] elif self.resolution == 256: self.save_features = [0, 1, 2, 3, 4, 5] self.out_channel_multiplier = 1 #4 # Architecture self.arch = D_unet_arch( self.ch, self.attention, out_channel_multiplier=self.out_channel_multiplier)[resolution] self.unconditional = kwargs["unconditional"] # Which convs, batchnorms, and linear layers to use # No option to turn off SN in D right now if self.D_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_embedding = functools.partial(layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) # Prepare model # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) self.blocks = [] for index in range(len(self.arch['out_channels'])): if self.arch["downsample"][index]: self.blocks += [[ layers.DBlock( in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(index > 0), downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None)) ]] elif self.arch["upsample"][index]: upsample_function = ( functools.partial(F.interpolate, scale_factor=2, mode="nearest") #mode=nearest is default if self.arch['upsample'][index] else None) self.blocks += [[ layers.GBlock2( in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, #which_bn=self.which_bn, activation=self.activation, upsample=upsample_function, skip_connection=True) ]] # If attention on this block, attach it to the end attention_condition = index < 5 if self.arch['attention'][self.arch['resolution'][ index]] and attention_condition: #index < 5 print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) print("index = ", index) self.blocks[-1] += [ layers.Attention(self.arch['out_channels'][index], self.which_conv) ] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList( [nn.ModuleList(block) for block in self.blocks]) last_layer = nn.Conv2d(self.ch * self.out_channel_multiplier, 1, kernel_size=1) self.blocks.append(last_layer) # # Linear output layer. The output dimension is typically 1, but may be # larger if we're e.g. turning this into a VAE with an inference output self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) self.linear_middle = self.which_linear(16 * self.ch, output_dim) # Embedding for projection discrimination #if not kwargs["agnostic_unet"] and not kwargs["unconditional"]: # self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]+extra) if not kwargs["unconditional"]: self.embed_middle = self.which_embedding(self.n_classes, 16 * self.ch) self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) # Initialize weights if not skip_init: self.init_weights() ### print("_____params______") for name, param in self.named_parameters(): print(name, param.size()) # Set up optimizer self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps if D_mixed_precision: print('Using fp16 adam in D...') import utils self.optim = utils.Adam16(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps) else: self.optim = optim.Adam(params=self.parameters(), lr=self.lr, betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)