def __init__(self, style_dim=32, resolution=16, max_dim=256, in_channel=1, init='N02', SN_param=False, norm='none', share_wid=True): super(StyleEncoder, self).__init__() self.reduce_len_scale = 16 self.share_wid = share_wid self.style_dim = style_dim ###################################### # Construct Backbone ###################################### nf = resolution cnn_f = [nn.ConstantPad2d(2, -1), Conv2dBlock(in_channel, nf, 5, 1, 0, norm='none', activation='none')] for i in range(2): nf_out = min([int(nf * 2), max_dim]) cnn_f += [ActFirstResBlock(nf, nf, None, 'lrelu', norm, sn=SN_param)] cnn_f += [nn.ReflectionPad2d((1, 1, 0, 0))] cnn_f += [ActFirstResBlock(nf, nf_out, None, 'lrelu', norm, sn=SN_param)] cnn_f += [nn.ReflectionPad2d(1)] cnn_f += [nn.MaxPool2d(kernel_size=3, stride=2)] nf = min([nf_out, max_dim]) df = nf for i in range(1): df_out = min([int(df * 2), max_dim]) cnn_f += [ActFirstResBlock(df, df, None, 'lrelu', norm, sn=SN_param)] cnn_f += [ActFirstResBlock(df, df_out, None, 'lrelu', norm, sn=SN_param)] cnn_f += [nn.MaxPool2d(kernel_size=3, stride=2)] df = min([df_out, max_dim]) df_out = min([int(df * 2), max_dim]) cnn_f += [ActFirstResBlock(df, df, None, 'lrelu', norm, sn=SN_param)] cnn_f += [ActFirstResBlock(df, df_out, None, 'lrelu', norm, sn=SN_param)] self.cnn_backbone = nn.Sequential(*cnn_f) # df_out = max_dim # df = max_dim // 2 ###################################### # Construct StyleEncoder ###################################### cnn_e = [nn.ReflectionPad2d((1, 1, 0, 0)), Conv2dBlock(df_out, df, 3, 2, 0, norm=norm, activation='lrelu', activation_first=True)] self.cnn_wid = nn.Sequential(*cnn_e) self.linear_style = nn.Sequential( nn.Linear(df, df), nn.LeakyReLU() ) self.mu = nn.Linear(df, style_dim) self.logvar = nn.Linear(df, style_dim) if init != 'none': init_weights(self, init) torch.nn.init.constant_(self.logvar.weight.data, 0.) torch.nn.init.constant_(self.logvar.bias.data, -10.)
def __init__(self, n_class, resolution=16, max_dim=256, in_channel=1, norm='none', init='none', rnn_depth=1, dropout=0.0, bidirectional=True): super(Recognizer, self).__init__() self.len_scale = 8 self.use_rnn = rnn_depth > 0 self.bidirectional = bidirectional ###################################### # Construct Backbone ###################################### nf = resolution cnn_f = [nn.ConstantPad2d(2, -1), Conv2dBlock(in_channel, nf, 5, 1, 0, norm='none', activation='none')] for i in range(2): nf_out = min([int(nf * 2), max_dim]) cnn_f += [ActFirstResBlock(nf, nf, None, 'relu', norm, 'zero', dropout=dropout / 2)] cnn_f += [nn.ZeroPad2d((1, 1, 0, 0))] cnn_f += [ActFirstResBlock(nf, nf_out, None, 'relu', norm, 'zero', dropout=dropout / 2)] cnn_f += [nn.ZeroPad2d(1)] cnn_f += [nn.MaxPool2d(kernel_size=3, stride=2)] nf = min([nf_out, max_dim]) df = nf for i in range(2): df_out = min([int(df * 2), max_dim]) cnn_f += [ActFirstResBlock(df, df, None, 'relu', norm, 'zero', dropout=dropout)] cnn_f += [ActFirstResBlock(df, df_out, None, 'relu', norm, 'zero', dropout=dropout)] if i < 1: cnn_f += [nn.MaxPool2d(kernel_size=3, stride=2)] else: cnn_f += [nn.ZeroPad2d((1, 1, 0, 0))] df = min([df_out, max_dim]) ###################################### # Construct Classifier ###################################### cnn_c = [nn.ReLU(), Conv2dBlock(df, df, 3, 1, 0, norm=norm, activation='relu')] self.cnn_backbone = nn.Sequential(*cnn_f) self.cnn_ctc = nn.Sequential(*cnn_c) if self.use_rnn: if bidirectional: self.rnn_ctc = DeepBLSTM(df, df, rnn_depth, bidirectional=True) else: self.rnn_ctc = DeepLSTM(df, df, rnn_depth) self.ctc_cls = nn.Linear(df, n_class) if init != 'none': init_weights(self, init)
def __init__(self, n_writer=284, resolution=16, max_dim=256, in_channel=1, init='N02', SN_param=False, dropout=0.0, norm='bn'): super(WriterIdentifier, self).__init__() self.reduce_len_scale = 16 ###################################### # Construct Backbone ###################################### nf = resolution cnn_f = [nn.ConstantPad2d(2, -1), Conv2dBlock(in_channel, nf, 5, 1, 0, norm='none', activation='none')] for i in range(2): nf_out = min([int(nf * 2), max_dim]) cnn_f += [ActFirstResBlock(nf, nf, None, 'lrelu', norm, sn=SN_param, dropout=dropout / 2)] cnn_f += [nn.ReflectionPad2d((1, 1, 0, 0))] cnn_f += [ActFirstResBlock(nf, nf_out, None, 'lrelu', norm, sn=SN_param, dropout=dropout / 2)] cnn_f += [nn.ReflectionPad2d(1)] cnn_f += [nn.MaxPool2d(kernel_size=3, stride=2)] nf = min([nf_out, max_dim]) df = nf for i in range(1): df_out = min([int(df * 2), max_dim]) cnn_f += [ActFirstResBlock(df, df, None, 'lrelu', norm, sn=SN_param, dropout=dropout)] cnn_f += [ActFirstResBlock(df, df_out, None, 'lrelu', norm, sn=SN_param, dropout=dropout)] cnn_f += [nn.MaxPool2d(kernel_size=3, stride=2)] df = min([df_out, max_dim]) df_out = min([int(df * 2), max_dim]) cnn_f += [ActFirstResBlock(df, df, None, 'lrelu', norm, sn=SN_param, dropout=dropout / 2)] cnn_f += [ActFirstResBlock(df, df_out, None, 'lrelu', norm, sn=SN_param, dropout=dropout / 2)] self.cnn_backbone = nn.Sequential(*cnn_f) ###################################### # Construct WriterIdentifier ###################################### cnn_w = [nn.ReflectionPad2d((1, 1, 0, 0)), Conv2dBlock(df_out, df, 3, 2, 0, norm=norm, activation='lrelu', activation_first=True)] self.cnn_wid = nn.Sequential(*cnn_w) self.linear_wid = nn.Sequential( nn.Linear(df, df), nn.LeakyReLU(), nn.Linear(df, n_writer), ) if init != 'none': init_weights(self, init)
def __init__(self, D_ch=64, D_wide=True, resolution=128, D_kernel_size=3, D_attn='64', n_class=1000, num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False), SN_eps=1e-12, output_dim=1, D_fp16=False, init='ortho', D_param='SN', bn_linear='embed', input_nc=3, one_hot=False): super(Discriminator, self).__init__() self.name = 'D' # one_hot representation self.one_hot = one_hot # Width multiplier self.ch = D_ch # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? self.D_wide = D_wide # Resolution self.resolution = resolution # Kernel size self.kernel_size = D_kernel_size # Attention? self.attention = D_attn # Number of classes self.n_classes = n_class # Activation self.activation = D_activation # Initialization style self.init = init # Parameterization style self.D_param = D_param # Epsilon for Spectral Norm? self.SN_eps = SN_eps # Fp16? self.fp16 = D_fp16 # Architecture self.arch = D_arch(self.ch, self.attention, input_nc)[resolution] # Which convs, batchnorms, and linear layers to use # No option to turn off SN in D right now if self.D_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) self.which_embedding = functools.partial(layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) if bn_linear=='SN': self.which_embedding = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) else: self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) self.which_linear = nn.Linear # We use a non-spectral-normed embedding here regardless; # For some reason applying SN to G's embedding seems to randomly cripple G self.which_embedding = nn.Embedding if one_hot: self.which_embedding = functools.partial(layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) # Prepare model # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) self.blocks = [] for index in range(len(self.arch['out_channels'])): self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv=self.which_conv, wide=self.D_wide, activation=self.activation, preactivation=(index > 0), downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]] # If attention on this block, attach it to the end if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) # Linear output layer. The output dimension is typically 1, but may be # larger if we're e.g. turning this into a VAE with an inference output self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) # Embedding for projection discrimination self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) # Initialize weights if self.init != 'none': self = init_weights(self, self.init)
def __init__(self, G_ch=64, style_dim=128, bottom_width=4, bottom_height=4, resolution=128, G_kernel_size=3, G_attn='64', n_class=1000, num_G_SVs=1, num_G_SV_itrs=1, G_shared=True, shared_dim=0, no_hier=False, cross_replica=False, mybn=False, G_activation=nn.ReLU(inplace=False), BN_eps=1e-5, SN_eps=1e-12, G_fp16=False, init='ortho', G_param='SN', norm_style='bn', bn_linear='embed', input_nc=3, one_hot=False, first_layer=False, one_hot_k=1): super(Generator, self).__init__() dim_z = style_dim self.name = 'G' # Use class only in first layer self.first_layer = first_layer # Use one hot vector representation for input class self.one_hot = one_hot # Use one hot k vector representation for input class if k is larger than 0. If it's 0, simly use the class number and not a k-hot encoding. self.one_hot_k = one_hot_k # Channel width mulitplier self.ch = G_ch # Dimensionality of the latent space self.dim_z = dim_z # The initial width dimensions self.bottom_width = bottom_width # The initial height dimension self.bottom_height = bottom_height # Resolution of the output self.resolution = resolution # Kernel size? self.kernel_size = G_kernel_size # Attention? self.attention = G_attn # number of classes, for use in categorical conditional generation self.n_classes = n_class # Use shared embeddings? self.G_shared = G_shared # Dimensionality of the shared embedding? Unused if not using G_shared self.shared_dim = shared_dim if shared_dim > 0 else dim_z # Hierarchical latent space? self.hier = not no_hier # Cross replica batchnorm? self.cross_replica = cross_replica # Use my batchnorm? self.mybn = mybn # nonlinearity for residual blocks self.activation = G_activation # Initialization style self.init = init # Parameterization style self.G_param = G_param # Normalization style self.norm_style = norm_style # Epsilon for BatchNorm? self.BN_eps = BN_eps # Epsilon for Spectral Norm? self.SN_eps = SN_eps # fp16? self.fp16 = G_fp16 # Architecture dict self.arch = G_arch(self.ch, self.attention)[resolution] self.bn_linear = bn_linear # If using hierarchical latents, adjust z if self.hier: # Number of places z slots into self.num_slots = len(self.arch['in_channels']) + 1 self.z_chunk_size = (self.dim_z // self.num_slots) # Recalculate latent dimensionality for even splitting into chunks self.dim_z = self.z_chunk_size * self.num_slots else: self.num_slots = 1 self.z_chunk_size = 0 # Which convs, batchnorms, and linear layers to use if self.G_param == 'SN': self.which_conv = functools.partial(layers.SNConv2d, kernel_size=3, padding=1, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) self.which_linear = functools.partial(layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) else: self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) self.which_linear = nn.Linear # We use a non-spectral-normed embedding here regardless; # For some reason applying SN to G's embedding seems to randomly cripple G if one_hot: self.which_embedding = functools.partial(layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) else: self.which_embedding = nn.Embedding bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding) if self.bn_linear=='SN': bn_linear = functools.partial(self.which_linear, bias=False) if self.G_shared: input_size = self.shared_dim + self.z_chunk_size elif self.hier: if self.first_layer: input_size = self.z_chunk_size else: input_size = self.n_classes + self.z_chunk_size self.which_bn = functools.partial(layers.ccbn, which_linear=bn_linear, cross_replica=self.cross_replica, mybn=self.mybn, input_size=input_size, norm_style=self.norm_style, eps=self.BN_eps) else: input_size = self.n_classes self.which_bn = functools.partial(layers.bn, cross_replica=self.cross_replica, mybn=self.mybn, eps=self.BN_eps) # Prepare model # If not using shared embeddings, self.shared is just a passthrough self.shared = (self.which_embedding(self.n_classes, self.shared_dim) if G_shared else layers.identity()) # First linear layer # The parameters for the first linear layer depend on the different input variations. if self.first_layer: # print('one_hot:{} one_hot_k:{}'.format(self.one_hot, self.one_hot_k) ) if self.one_hot: self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes, self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) else: self.linear = self.which_linear(self.dim_z // self.num_slots + 1, self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) if self.one_hot_k==1: self.linear = self.which_linear((self.dim_z // self.num_slots) * self.n_classes, self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) if self.one_hot_k>1: self.linear = self.which_linear(self.dim_z // self.num_slots + self.n_classes*self.one_hot_k, self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) if self.one_hot_k == 0: self.linear = self.which_linear(self.n_classes, self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) else: self.linear = self.which_linear(self.dim_z // self.num_slots, self.arch['in_channels'][0] * (self.bottom_width * self.bottom_height)) # self.blocks is a doubly-nested list of modules, the outer loop intended # to be over blocks at a given resolution (resblocks and/or self-attention) # while the inner loop is over a given block self.blocks = [] for index in range(len(self.arch['out_channels'])): if 'kernel1' in self.arch.keys(): padd1 = 1 if self.arch['kernel1'][index]>1 else 0 padd2 = 1 if self.arch['kernel2'][index]>1 else 0 conv1 = functools.partial(layers.SNConv2d, kernel_size=self.arch['kernel1'][index], padding=padd1, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) conv2 = functools.partial(layers.SNConv2d, kernel_size=self.arch['kernel2'][index], padding=padd2, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv1=conv1, which_conv2=conv2, which_bn=self.which_bn, activation=self.activation, upsample=(functools.partial(F.interpolate, scale_factor=self.arch['upsample'][index]) if index < len(self.arch['upsample']) else None))]] else: self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index], out_channels=self.arch['out_channels'][index], which_conv1=self.which_conv, which_conv2=self.which_conv, which_bn=self.which_bn, activation=self.activation, upsample=(functools.partial(F.interpolate, scale_factor=self.arch['upsample'][index]) if index < len(self.arch['upsample']) else None))]] # If attention on this block, attach it to the end # print('index ', index, self.arch['resolution'][index]) if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)] # Turn self.blocks into a ModuleList so that it's all properly registered. self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) # output layer: batchnorm-relu-conv. # Consider using a non-spectral conv here self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn), self.activation, self.which_conv(self.arch['out_channels'][-1], input_nc)) # Initialize weights. Optionally skip init for testing. if self.init != 'none': self = init_weights(self, self.init)