예제 #1
0
파일: model.py 프로젝트: jiaoyiping630/DRIT
    def __init__(self, opts):
        super(DRIT, self).__init__()

        # parameters
        lr = 0.0001
        lr_dcontent = lr / 2.5
        self.nz = 8
        self.concat = opts.concat
        self.no_ms = opts.no_ms

        # discriminators
        if opts.dis_scale > 1:
            self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm,
                                               sn=opts.dis_spectral_norm)
            self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm,
                                               sn=opts.dis_spectral_norm)
            self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm,
                                                sn=opts.dis_spectral_norm)
            self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm,
                                                sn=opts.dis_spectral_norm)
        else:
            self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
            self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
            self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
            self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
        self.disContent = networks.Dis_content()

        # encoders
        self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b)
        if self.concat:
            self.enc_a = networks.E_attr_concat(opts.input_dim_a, opts.input_dim_b, self.nz, \
                                                norm_layer=None,
                                                nl_layer=networks.get_non_linearity(layer_type='lrelu'))
        else:
            self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz)

        # generator
        if self.concat:
            self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz)
        else:
            self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz)

        # optimizers
        self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
        self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
        self.disA2_opt = torch.optim.Adam(self.disA2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
        self.disB2_opt = torch.optim.Adam(self.disB2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
        self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=lr_dcontent, betas=(0.5, 0.999),
                                               weight_decay=0.0001)
        self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
        self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
        self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)

        # Setup the loss function for training
        self.criterionL1 = torch.nn.L1Loss()
예제 #2
0
  def __init__(self, opts):
    super(UID, self).__init__()
    with torch.autograd.set_detect_anomaly(True):
      # parameters
      lr = opts.lr
      self.nz = 8
      self.concat = opts.concat
      self.lambdaB = opts.lambdaB
      self.lambdaI = opts.lambdaI

      # discriminators
      if opts.dis_scale > 1:
        self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
        self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
        self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
        self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
      else:
        self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
        self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
        self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
        self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm)
          
      # encoders
      self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b)
      if self.concat:
        self.enc_a = networks.E_attr_concat(opts.input_dim_b, self.nz, \
            norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu'))
      else:
        self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz)

      # generator
      if self.concat:
        self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz)
      else:
        self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz)

      # optimizers
      self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
      self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
      self.disA2_opt = torch.optim.Adam(self.disA2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
      self.disB2_opt = torch.optim.Adam(self.disB2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)

      self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
      self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)
      self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001)

      # Setup the loss function for training
      self.criterionL1 = torch.nn.L1Loss()
      if opts.percep == 'default':
          self.perceptualLoss = networks.PerceptualLoss(nn.MSELoss(), opts.gpu, opts.percp_layer)
      elif opts.percep == 'face':
          self.perceptualLoss = networks.PerceptualLoss16(nn.MSELoss(), opts.gpu, opts.percp_layer)
      else:
          self.perceptualLoss = networks.MultiPerceptualLoss(nn.MSELoss(), opts.gpu)
예제 #3
0
  def __init__(self, opts):
    super(SAVI2I, self).__init__()
    self.opts = opts
    if opts.gpu >= 0:
        self.device = torch.device('cuda:%d' % opts.gpu)
    else:
        self.device = torch.device('cpu')
        torch.cuda.set_device(opts.gpu)
        cudnn.benchmark = True

    self.phase = opts.phase
    self.type = opts.type
    self.nz = opts.input_nz
    self.style_dim = opts.style_dim
    self.num_domains = opts.num_domains


    self.enc_a = nn.DataParallel(
        networks.E_attr(img_size=opts.img_size, input_dim=opts.input_dim, nz=self.nz, n_domains=self.num_domains))
    self.f = nn.DataParallel(
        networks.MappingNetwork(nz=self.nz, n_domains=self.num_domains, n_style=self.style_dim, hidden_dim=512,
                                hidden_layer=1))
    self.vgg = networks.VGG(self.device)
    if self.type==1:
      self.enc_c = nn.DataParallel(networks.E_content_style(img_size=opts.img_size, input_dim=opts.input_dim))
      self.gen = nn.DataParallel(networks.Generator_style(img_size=opts.img_size, style_dim=self.style_dim))
    elif self.type==0:
      self.enc_c = nn.DataParallel(networks.E_content_shape(img_size=opts.img_size, input_dim=opts.input_dim))
      self.gen = nn.DataParallel(networks.Generator_shape(img_size=opts.img_size, style_dim=self.style_dim))

    if self.phase == 'train':
      self.lr = opts.lr
      self.f_lr = opts.f_lr
      self.lr_dcontent = self.lr/2.5
      self.dis = nn.DataParallel(networks.Discriminator(img_size=opts.img_size, num_domains=self.num_domains))
      if self.type==1:
        self.disContent = nn.DataParallel(networks.Dis_content_style(c_dim=self.num_domains))
      elif self.type==0:
        self.disContent = nn.DataParallel(networks.Dis_content_shape(c_dim=self.num_domains))
      self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001)
      self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001)
      self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001)
      self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001)
      self.f_opt = torch.optim.Adam(self.f.parameters(), lr=self.f_lr, betas=(0, 0.99), weight_decay=0.0001)
      self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=self.lr_dcontent, betas=(0, 0.99), weight_decay=0.0001)

      self.criterion_GAN = nn.BCEWithLogitsLoss()
      self.criterion_mmd = utils.get_mmd_loss()
예제 #4
0
    def __init__(self, opts):
        super(DRIT, self).__init__()

        # parameters
        lr = 0.0001
        self.nz = 8
        self.concat = opts.concat
        self.lr_subspace = opts.lr_subspace
        #self.lr_MI=opts.lr_MI
        self.lr_dis = opts.lr_dis
        self.lr_enc = opts.lr_enc
        self.lr_gen = opts.lr_gen
        self.lr_gen_attr = opts.lr_gen_attr

        self.lr_pre_subspace = opts.lr_pre_subspace
        #self.lr_pre_MI = opts.lr_pre_MI
        self.lr_pre_enc = opts.lr_pre_enc
        self.lr_pre_gen = opts.lr_pre_gen

        self.margin = opts.margin
        self.semantic_w = opts.semantic_w
        self.recon_w = opts.recon_w
        #self.MI_w = opts.MI_w
        self.gan_w = opts.gan_w
        self.content_w = opts.content_w
        self.no_ms = opts.no_ms
        self.loss_1 = nn.BCEWithLogitsLoss()
        #映射到公共子空间的网络
        self.subspace = model_1.IDCM_NN(img_input_dim=4096, text_input_dim=300)
        #self.MI=KLLoss()
        self.criterion = ContrastiveLoss(margin=opts.margin,
                                         measure=opts.measure,
                                         max_violation=opts.max_violation)
        one = torch.tensor(1, dtype=torch.float).cuda(0)
        self.mone = one * -1

        # discriminators
        if opts.dis_scale > 1:  #=3
            self.disA = networks.MultiScaleDis(opts.input_dim_a,
                                               n_scale=opts.dis_scale,
                                               norm=opts.dis_norm,
                                               sn=opts.dis_spectral_norm)
            self.disB = networks.MultiScaleDis(opts.input_dim_b,
                                               n_scale=opts.dis_scale,
                                               norm=opts.dis_norm,
                                               sn=opts.dis_spectral_norm)
            self.disA_attr = networks.MultiScaleDis(opts.input_dim_a,
                                                    n_scale=opts.dis_scale,
                                                    norm=opts.dis_norm,
                                                    sn=opts.dis_spectral_norm)
            self.disB_attr = networks.MultiScaleDis(opts.input_dim_b,
                                                    n_scale=opts.dis_scale,
                                                    norm=opts.dis_norm,
                                                    sn=opts.dis_spectral_norm)

        else:
            self.disA = networks.Dis(opts.input_dim_a,
                                     norm=opts.dis_norm,
                                     sn=opts.dis_spectral_norm)
            self.disB = networks.Dis(opts.input_dim_b,
                                     norm=opts.dis_norm,
                                     sn=opts.dis_spectral_norm)

        # encoders
        self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b)
        if self.concat:
            self.enc_a = networks.E_attr_concat(opts.input_dim_a, opts.input_dim_b, self.nz, \
                norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu'))
        else:
            self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b,
                                         self.nz)

        # generator
        if self.concat:
            self.gen = networks.G_concat(opts.input_dim_a,
                                         opts.input_dim_b,
                                         nz=self.nz)
        else:
            self.gen = networks.G(opts.input_dim_a,
                                  opts.input_dim_b,
                                  nz=self.nz)

        self.gen_attr = networks.G_a(opts, opts.input_dim_a, opts.input_dim_b)

        # optimizers
        self.subspace_opt = torch.optim.Adam(self.subspace.parameters(),
                                             lr=self.lr_subspace,
                                             betas=(0.5, 0.999),
                                             weight_decay=0.0001)
        #self.MI_opt = torch.optim.Adam(self.MI.parameters(), lr=self.lr_MI, betas=(0.5, 0.999), weight_decay=0.0001)
        self.disA_opt = torch.optim.Adam(self.disA.parameters(),
                                         lr=self.lr_dis,
                                         betas=(0.5, 0.999),
                                         weight_decay=0.0001)
        self.disB_opt = torch.optim.Adam(self.disB.parameters(),
                                         lr=self.lr_dis,
                                         betas=(0.5, 0.999),
                                         weight_decay=0.0001)
        self.disA_attr_opt = torch.optim.Adam(self.disA_attr.parameters(),
                                              lr=self.lr_dis,
                                              betas=(0.5, 0.999),
                                              weight_decay=0.0001)
        self.disB_attr_opt = torch.optim.Adam(self.disB_attr.parameters(),
                                              lr=self.lr_dis,
                                              betas=(0.5, 0.999),
                                              weight_decay=0.0001)
        self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(),
                                          lr=self.lr_enc,
                                          betas=(0.5, 0.999),
                                          weight_decay=0.0001)
        self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(),
                                          lr=self.lr_enc,
                                          betas=(0.5, 0.999),
                                          weight_decay=0.0001)
        self.gen_opt = torch.optim.Adam(self.gen.parameters(),
                                        lr=self.lr_gen,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0001)
        self.gen_attr_opt = torch.optim.Adam(self.gen_attr.parameters(),
                                             lr=self.lr_gen_attr,
                                             betas=(0.5, 0.999),
                                             weight_decay=0.0001)

        self.subspace_pre_opt = torch.optim.Adam(self.subspace.parameters(),
                                                 lr=self.lr_pre_subspace,
                                                 betas=(0.5, 0.999),
                                                 weight_decay=0.0001)
        #self.MI_pre_opt = torch.optim.Adam(self.MI.parameters(), lr=self.lr_pre_MI, betas=(0.5, 0.999), weight_decay=0.0001)
        self.enc_c_pre_opt = torch.optim.Adam(self.enc_c.parameters(),
                                              lr=self.lr_pre_enc,
                                              betas=(0.5, 0.999),
                                              weight_decay=0.0001)
        self.enc_a_pre_opt = torch.optim.Adam(self.enc_a.parameters(),
                                              lr=self.lr_pre_enc,
                                              betas=(0.5, 0.999),
                                              weight_decay=0.0001)
        self.gen_pre_opt = torch.optim.Adam(self.gen.parameters(),
                                            lr=self.lr_pre_gen,
                                            betas=(0.5, 0.999),
                                            weight_decay=0.0001)

        # Setup the loss function for training
        self.criterionL1 = torch.nn.L1Loss()
        self.MSE_loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)