def __init__(self, norm_ref, ch, n_shot, n_downsample_G): super().__init__() # parameters for model for i in range(n_downsample_G): ch_in, ch_out = ch[i], ch[i + 1] setattr(self, 'ref_img_up_%d' % i, SPADEConv2d(ch_out, ch_in, norm=norm_ref)) setattr(self, 'ref_label_up_%d' % i, SPADEConv2d(ch_out, ch_in, norm=norm_ref)) # other parameter self.n_downsample_G = n_downsample_G self.n_shot = n_shot
def __init__(self, norm_ref, input_nc, nf, ch, n_shot, n_downsample_A): super().__init__() # parameters for model self.atn_query_first = SPADEConv2d(input_nc, nf, norm=norm_ref) self.atn_key_first = SPADEConv2d(input_nc, nf, norm=norm_ref) for i in range(n_downsample_A): f_in, f_out = ch[i], ch[i + 1] setattr(self, 'atn_key_%d' % i, SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref)) setattr(self, 'atn_query_%d' % i, SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref)) # other parameters self.n_shot = n_shot self.n_downsample_A = n_downsample_A
def __init__(self, norm_ref, nf, ch, n_shot, n_downsample_G, n_downsample_A, isTrain, ref_nc=3): super().__init__() # parameters for model self.conv1 = SPADEConv2d(ref_nc, nf, norm=norm_ref) for i in range(n_downsample_G): ch_in, ch_out = ch[i], ch[i + 1] setattr(self, 'ref_down_%d' % i, SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref)) # other parameters self.isTrain = isTrain self.n_shot = n_shot self.n_downsample_G = n_downsample_G self.n_downsample_A = n_downsample_A
def __init__(self, norm_ref, nf, ch, n_shot, n_downsample_G, n_downsample_A, isTrain, ref_nc=3, lmark_nc=1): super().__init__() # parameters for model self.conv1 = SPADEConv2d(ref_nc, nf, norm=norm_ref) self.conv2 = SPADEConv2d(lmark_nc, nf, norm=norm_ref) for i in range(n_downsample_G): ch_in, ch_out = ch[i], ch[i + 1] setattr(self, 'ref_down_img_%d' % i, SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref)) setattr(self, 'ref_down_lmark_%d' % i, SPADEConv2d(ch_in, ch_out, stride=2, norm=norm_ref)) if n_shot > 1 and i == n_downsample_A - 1: self.fusion1 = SPADEConv2d(ch_out * 2, ch_out, norm=norm_ref) self.fusion2 = SPADEConv2d(ch_out * 2, ch_out, norm=norm_ref) self.fusion = SPADEResnetBlock(ch_out * 2, ch_out, norm=norm_ref) self.atten1 = SPADEConv2d(ch_out * n_shot, ch_out, norm=norm_ref) self.atten2 = SPADEConv2d(ch_out * n_shot, ch_out, norm=norm_ref) # other parameters self.isTrain = isTrain self.n_shot = n_shot self.n_downsample_G = n_downsample_G self.n_downsample_A = n_downsample_A
def __init__(self, opt): super().__init__() ########################### define params ########################## self.opt = opt self.n_downsample_G = n_downsample_G = opt.n_downsample_G # number of downsamplings in generator self.n_downsample_A = n_downsample_A = opt.n_downsample_A # number of downsamplings in attention module self.nf = nf = opt.ngf # base channel size self.nf_max = nf_max = min(1024, nf * (2**n_downsample_G)) self.ch = ch = [ min(nf_max, nf * (2**i)) for i in range(n_downsample_G + 2) ] ### SPADE self.norm = norm = opt.norm_G self.conv_ks = conv_ks = opt.conv_ks # conv kernel size in main branch self.embed_ks = embed_ks = opt.embed_ks # conv kernel size in embedding network self.spade_ks = spade_ks = opt.spade_ks # conv kernel size in SPADE self.spade_combine = opt.spade_combine # combine ref/prev frames with current using SPADE self.n_sc_layers = opt.n_sc_layers # number of layers to perform spade combine ch_hidden = [] # hidden channel size in SPADE module for i in range(n_downsample_G + 1): ch_hidden += [[ ch[i] ]] if not self.spade_combine or i >= self.n_sc_layers else [ [ch[i]] * 3 ] self.ch_hidden = ch_hidden ### adaptive SPADE / Convolution self.adap_spade = opt.adaptive_spade # use adaptive weight generation for SPADE self.adap_embed = opt.adaptive_spade and not opt.no_adaptive_embed # use adaptive for the label embedding network self.adap_conv = opt.adaptive_conv # use adaptive for convolution layers in the main branch self.n_adaptive_layers = opt.n_adaptive_layers if opt.n_adaptive_layers != -1 else n_downsample_G # number of adaptive layers # for reference image encoding self.concat_label_ref = 'concat' in opt.use_label_ref # how to utilize the reference label map: concat | mul self.mul_label_ref = 'mul' in opt.use_label_ref self.sh_fix = self.sw_fix = 32 # output spatial size for adaptive pooling layer self.sw = opt.fineSize // ( 2**opt.n_downsample_G ) # output spatial size at the bottle neck of generator self.sh = int(self.sw / opt.aspect_ratio) # weight generation self.n_fc_layers = n_fc_layers = opt.n_fc_layers # number of fc layers in weight generation ########################### define network ########################## norm_ref = norm.replace('spade', '') input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ref_nc = opt.output_nc + (0 if not self.concat_label_ref else input_nc) self.ref_img_first = SPADEConv2d(ref_nc, nf, norm=norm_ref) if self.mul_label_ref: self.ref_label_first = SPADEConv2d(input_nc, nf, norm=norm_ref) ref_conv = SPADEConv2d if not opt.res_for_ref else SPADEResnetBlock ### reference image encoding for i in range(n_downsample_G): ch_in, ch_out = ch[i], ch[i + 1] setattr(self, 'ref_img_down_%d' % i, ref_conv(ch_in, ch_out, stride=2, norm=norm_ref)) setattr(self, 'ref_img_up_%d' % i, ref_conv(ch_out, ch_in, norm=norm_ref)) if self.mul_label_ref: setattr(self, 'ref_label_down_%d' % i, ref_conv(ch_in, ch_out, stride=2, norm=norm_ref)) setattr(self, 'ref_label_up_%d' % i, ref_conv(ch_out, ch_in, norm=norm_ref)) ### SPADE / main branch weight generation if self.adap_spade or self.adap_conv: for i in range(self.n_adaptive_layers): ch_in, ch_out = ch[i], ch[i + 1] conv_ks2 = conv_ks**2 embed_ks2 = embed_ks**2 spade_ks2 = spade_ks**2 ch_h = ch_hidden[i][0] fc_names, fc_outs = [], [] if self.adap_spade: fc0_out = fcs_out = (ch_h * spade_ks2 + 1) * 2 fc1_out = (ch_h * spade_ks2 + 1) * (1 if ch_in != ch_out else 2) fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s'] fc_outs += [fc0_out, fc1_out, fcs_out] if self.adap_embed: fc_names += ['fc_spade_e'] fc_outs += [ch_in * embed_ks2 + 1] if self.adap_conv: fc0_out = ch_out * conv_ks2 + 1 fc1_out = ch_in * conv_ks2 + 1 fcs_out = ch_out + 1 fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s'] fc_outs += [fc0_out, fc1_out, fcs_out] for n, l in enumerate(fc_names): fc_in = ch_out if self.mul_label_ref else self.sh_fix * self.sw_fix fc_layer = [sn(nn.Linear(fc_in, ch_out))] for k in range(1, n_fc_layers): fc_layer += [sn(nn.Linear(ch_out, ch_out))] fc_layer += [sn(nn.Linear(ch_out, fc_outs[n]))] setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer)) ### label embedding network self.label_embedding = LabelEmbedder( opt, input_nc, opt.netS, params_free_layers=(self.n_adaptive_layers if self.adap_embed else 0)) ### main branch layers for i in reversed(range(n_downsample_G + 1)): setattr( self, 'up_%d' % i, SPADEResnetBlock( ch[i + 1], ch[i], norm=norm, hidden_nc=ch_hidden[i], conv_ks=conv_ks, spade_ks=spade_ks, conv_params_free=(self.adap_conv and i < self.n_adaptive_layers), norm_params_free=(self.adap_spade and i < self.n_adaptive_layers))) self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1) self.up = functools.partial(F.interpolate, scale_factor=2) ### for multiple reference images if opt.n_shot > 1: self.atn_query_first = SPADEConv2d(input_nc, nf, norm=norm_ref) self.atn_key_first = SPADEConv2d(input_nc, nf, norm=norm_ref) for i in range(n_downsample_A): f_in, f_out = ch[i], ch[i + 1] setattr(self, 'atn_key_%d' % i, SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref)) setattr(self, 'atn_query_%d' % i, SPADEConv2d(f_in, f_out, stride=2, norm=norm_ref)) ### kld loss self.use_kld = opt.lambda_kld > 0 self.z_dim = 256 if self.use_kld: f_in = min(nf_max, nf * (2**n_downsample_G)) * self.sh * self.sw f_out = min(nf_max, nf * (2**n_downsample_G)) * self.sh * self.sw self.fc_mu_ref = nn.Linear(f_in, self.z_dim) self.fc_var_ref = nn.Linear(f_in, self.z_dim) self.fc = nn.Linear(self.z_dim, f_out) ### flow self.warp_prev = False # whether to warp prev image (set when starting training multiple frames) self.warp_ref = opt.warp_ref and not opt.for_face # whether to warp reference image and combine with the synthesized if self.warp_ref: self.flow_network_ref = FlowGenerator(opt, 2) if self.spade_combine: self.img_ref_embedding = LabelEmbedder(opt, opt.output_nc + 1, opt.sc_arch)