def __init__(self, opts): super(DRIT, self).__init__() # parameters lr = 0.0001 lr_dcontent = lr / 2.5 self.nz = 8 self.concat = opts.concat self.no_ms = opts.no_ms # discriminators if opts.dis_scale > 1: self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) else: self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disContent = networks.Dis_content() # encoders self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b) if self.concat: self.enc_a = networks.E_attr_concat(opts.input_dim_a, opts.input_dim_b, self.nz, \ norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu')) else: self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz) # generator if self.concat: self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz) else: self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz) # optimizers self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.disA2_opt = torch.optim.Adam(self.disA2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.disB2_opt = torch.optim.Adam(self.disB2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=lr_dcontent, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) # Setup the loss function for training self.criterionL1 = torch.nn.L1Loss()
def __init__(self, opts): super(UID, self).__init__() with torch.autograd.set_detect_anomaly(True): # parameters lr = opts.lr self.nz = 8 self.concat = opts.concat self.lambdaB = opts.lambdaB self.lambdaI = opts.lambdaI # discriminators if opts.dis_scale > 1: self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) else: self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) # encoders self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b) if self.concat: self.enc_a = networks.E_attr_concat(opts.input_dim_b, self.nz, \ norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu')) else: self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz) # generator if self.concat: self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz) else: self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz) # optimizers self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.disA2_opt = torch.optim.Adam(self.disA2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.disB2_opt = torch.optim.Adam(self.disB2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) # Setup the loss function for training self.criterionL1 = torch.nn.L1Loss() if opts.percep == 'default': self.perceptualLoss = networks.PerceptualLoss(nn.MSELoss(), opts.gpu, opts.percp_layer) elif opts.percep == 'face': self.perceptualLoss = networks.PerceptualLoss16(nn.MSELoss(), opts.gpu, opts.percp_layer) else: self.perceptualLoss = networks.MultiPerceptualLoss(nn.MSELoss(), opts.gpu)
def __init__(self, opts): super(SAVI2I, self).__init__() self.opts = opts if opts.gpu >= 0: self.device = torch.device('cuda:%d' % opts.gpu) else: self.device = torch.device('cpu') torch.cuda.set_device(opts.gpu) cudnn.benchmark = True self.phase = opts.phase self.type = opts.type self.nz = opts.input_nz self.style_dim = opts.style_dim self.num_domains = opts.num_domains self.enc_a = nn.DataParallel( networks.E_attr(img_size=opts.img_size, input_dim=opts.input_dim, nz=self.nz, n_domains=self.num_domains)) self.f = nn.DataParallel( networks.MappingNetwork(nz=self.nz, n_domains=self.num_domains, n_style=self.style_dim, hidden_dim=512, hidden_layer=1)) self.vgg = networks.VGG(self.device) if self.type==1: self.enc_c = nn.DataParallel(networks.E_content_style(img_size=opts.img_size, input_dim=opts.input_dim)) self.gen = nn.DataParallel(networks.Generator_style(img_size=opts.img_size, style_dim=self.style_dim)) elif self.type==0: self.enc_c = nn.DataParallel(networks.E_content_shape(img_size=opts.img_size, input_dim=opts.input_dim)) self.gen = nn.DataParallel(networks.Generator_shape(img_size=opts.img_size, style_dim=self.style_dim)) if self.phase == 'train': self.lr = opts.lr self.f_lr = opts.f_lr self.lr_dcontent = self.lr/2.5 self.dis = nn.DataParallel(networks.Discriminator(img_size=opts.img_size, num_domains=self.num_domains)) if self.type==1: self.disContent = nn.DataParallel(networks.Dis_content_style(c_dim=self.num_domains)) elif self.type==0: self.disContent = nn.DataParallel(networks.Dis_content_shape(c_dim=self.num_domains)) self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.f_opt = torch.optim.Adam(self.f.parameters(), lr=self.f_lr, betas=(0, 0.99), weight_decay=0.0001) self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=self.lr_dcontent, betas=(0, 0.99), weight_decay=0.0001) self.criterion_GAN = nn.BCEWithLogitsLoss() self.criterion_mmd = utils.get_mmd_loss()
def __init__(self, opts): super(DRIT, self).__init__() # parameters lr = 0.0001 self.nz = 8 self.concat = opts.concat self.lr_subspace = opts.lr_subspace #self.lr_MI=opts.lr_MI self.lr_dis = opts.lr_dis self.lr_enc = opts.lr_enc self.lr_gen = opts.lr_gen self.lr_gen_attr = opts.lr_gen_attr self.lr_pre_subspace = opts.lr_pre_subspace #self.lr_pre_MI = opts.lr_pre_MI self.lr_pre_enc = opts.lr_pre_enc self.lr_pre_gen = opts.lr_pre_gen self.margin = opts.margin self.semantic_w = opts.semantic_w self.recon_w = opts.recon_w #self.MI_w = opts.MI_w self.gan_w = opts.gan_w self.content_w = opts.content_w self.no_ms = opts.no_ms self.loss_1 = nn.BCEWithLogitsLoss() #映射到公共子空间的网络 self.subspace = model_1.IDCM_NN(img_input_dim=4096, text_input_dim=300) #self.MI=KLLoss() self.criterion = ContrastiveLoss(margin=opts.margin, measure=opts.measure, max_violation=opts.max_violation) one = torch.tensor(1, dtype=torch.float).cuda(0) self.mone = one * -1 # discriminators if opts.dis_scale > 1: #=3 self.disA = networks.MultiScaleDis(opts.input_dim_a, n_scale=opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB = networks.MultiScaleDis(opts.input_dim_b, n_scale=opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disA_attr = networks.MultiScaleDis(opts.input_dim_a, n_scale=opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB_attr = networks.MultiScaleDis(opts.input_dim_b, n_scale=opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) else: self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) # encoders self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b) if self.concat: self.enc_a = networks.E_attr_concat(opts.input_dim_a, opts.input_dim_b, self.nz, \ norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu')) else: self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz) # generator if self.concat: self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz) else: self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz) self.gen_attr = networks.G_a(opts, opts.input_dim_a, opts.input_dim_b) # optimizers self.subspace_opt = torch.optim.Adam(self.subspace.parameters(), lr=self.lr_subspace, betas=(0.5, 0.999), weight_decay=0.0001) #self.MI_opt = torch.optim.Adam(self.MI.parameters(), lr=self.lr_MI, betas=(0.5, 0.999), weight_decay=0.0001) self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=self.lr_dis, betas=(0.5, 0.999), weight_decay=0.0001) self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=self.lr_dis, betas=(0.5, 0.999), weight_decay=0.0001) self.disA_attr_opt = torch.optim.Adam(self.disA_attr.parameters(), lr=self.lr_dis, betas=(0.5, 0.999), weight_decay=0.0001) self.disB_attr_opt = torch.optim.Adam(self.disB_attr.parameters(), lr=self.lr_dis, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=self.lr_enc, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=self.lr_enc, betas=(0.5, 0.999), weight_decay=0.0001) self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=self.lr_gen, betas=(0.5, 0.999), weight_decay=0.0001) self.gen_attr_opt = torch.optim.Adam(self.gen_attr.parameters(), lr=self.lr_gen_attr, betas=(0.5, 0.999), weight_decay=0.0001) self.subspace_pre_opt = torch.optim.Adam(self.subspace.parameters(), lr=self.lr_pre_subspace, betas=(0.5, 0.999), weight_decay=0.0001) #self.MI_pre_opt = torch.optim.Adam(self.MI.parameters(), lr=self.lr_pre_MI, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_c_pre_opt = torch.optim.Adam(self.enc_c.parameters(), lr=self.lr_pre_enc, betas=(0.5, 0.999), weight_decay=0.0001) self.enc_a_pre_opt = torch.optim.Adam(self.enc_a.parameters(), lr=self.lr_pre_enc, betas=(0.5, 0.999), weight_decay=0.0001) self.gen_pre_opt = torch.optim.Adam(self.gen.parameters(), lr=self.lr_pre_gen, betas=(0.5, 0.999), weight_decay=0.0001) # Setup the loss function for training self.criterionL1 = torch.nn.L1Loss() self.MSE_loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)