def init_loss(self, opt):
    Base_Model.init_loss(self, opt)

    # #####################
    # define loss functions
    # #####################
    # GAN loss
    if opt.ganloss == 'gan':
      self.criterionGAN = GANLoss(use_lsgan=False).to(self.device)
    elif opt.ganloss == 'lsgan':
      self.criterionGAN = GANLoss(use_lsgan=True).to(self.device)
    else:
      raise ValueError()

    # identity loss
    self.criterionIdt = RestructionLoss(opt.restruction_loss).to(self.device)

    # feature metric loss
    self.criterionFea = torch.nn.L1Loss()

    # map loss
    self.criterionMap = RestructionLoss(opt.map_m_type).to(self.device)

    # #####################
    # initialize optimizers
    # #####################
    self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                        lr=opt.lr, betas=(opt.beta1, opt.beta2))
    self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                        lr=opt.lr, betas=(opt.beta1, opt.beta2))
    self.optimizers = []
    self.optimizers.append(self.optimizer_G)
    self.optimizers.append(self.optimizer_D)
    def init_loss(self, opt):
        Base_Model.init_loss(self, opt)

        # #####################
        # define loss functions
        # #####################

        # identity loss
        self.criterionIdt = RestructionLoss(opt.idt_loss,
                                            opt.idt_reduction).to(self.device)

        # map loss
        self.criterionMap = RestructionLoss(opt.map_projection_loss).to(
            self.device)

        # #####################
        # initialize optimizers
        # #####################
        self.optimizers = []
        if self.opt.weight_decay_if:
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, opt.beta2),
                                                weight_decay=1e-4)
        else:
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr,
                                                betas=(opt.beta1, opt.beta2))
        self.optimizers.append(self.optimizer_G)
示例#3
0
  def init_loss(self, opt):
    Base_Model.init_loss(self, opt)

    # #####################
    # define loss functions
    # #####################
    # GAN loss
    self.criterionGAN = GANLoss(use_lsgan=True).to(self.device)

    # identity loss
    self.criterionIdt = RestructionLoss(opt.idt_loss, opt.idt_reduction).to(self.device)

    # feature metric loss
    self.criterionFea = torch.nn.L1Loss().to(self.device)

    # map loss
    self.criterionMap = RestructionLoss(opt.map_projection_loss).to(self.device)

    # #####################
    # initialize optimizers
    # #####################
    self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                        lr=opt.lr, betas=(opt.beta1, opt.beta2))
    self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                        lr=opt.lr, betas=(opt.beta1, opt.beta2))
    self.optimizers = []
    self.optimizers.append(self.optimizer_G)
    self.optimizers.append(self.optimizer_D)
  def init_network(self, opt):
    Base_Model.init_network(self, opt)

    self.if_pool = opt.if_pool
    self.multi_view = opt.multi_view
    self.conditional_D = opt.conditional_D
    assert len(self.multi_view) > 0

    self.loss_names = ['D', 'G']
    self.metrics_names = ['Mse', 'CosineSimilarity', 'PSNR']
    self.visual_names = ['G_real', 'G_fake', 'G_input', 'G_Map_fake_F', 'G_Map_real_F', 'G_Map_fake_S', 'G_Map_real_S']

    # identity loss
    if self.opt.idt_lambda > 0:
      self.loss_names += ['idt']

    # feature metric loss
    if self.opt.fea_m_lambda > 0:
      self.loss_names += ['fea_m']

    # map loss
    if self.opt.map_m_lambda > 0:
      self.loss_names += ['map_m']

    if self.training:
      self.model_names = ['G', 'D']
    else:  # during test time, only load Gs
      self.model_names = ['G']

    self.netG = factory.define_3DG(opt.noise_len, opt.input_shape, opt.output_shape,
                                   opt.input_nc_G, opt.output_nc_G, opt.ngf, opt.which_model_netG,
                                   opt.n_downsampling, opt.norm_G, not opt.no_dropout,
                                   opt.init_type, self.gpu_ids, opt.n_blocks,
                                   opt.encoder_input_shape, opt.encoder_input_nc, opt.encoder_norm,
                                   opt.encoder_blocks, opt.skip_number, opt.activation_type, opt=opt)

    if self.training:
      if opt.ganloss == 'gan':
        use_sigmoid = True
      elif opt.ganloss == 'lsgan':
        use_sigmoid = False
      else:
        raise ValueError()

      # conditional Discriminator
      if self.conditional_D:
        d_input_channels = opt.input_nc_D + 1
      else:
        d_input_channels = opt.input_nc_D
      self.netD = factory.define_D(d_input_channels, opt.ndf,
                                   opt.which_model_netD,
                                   opt.n_layers_D, opt.norm_D,
                                   use_sigmoid, opt.init_type, self.gpu_ids,
                                   opt.discriminator_feature, num_D=opt.num_D, n_out_channels=opt.n_out_ChannelsD)
      if self.if_pool:
        self.fake_pool = ImagePool(opt.pool_size)
示例#5
0
    def init_network(self, opt):
        Base_Model.init_network(self, opt)

        self.loss_names = ['D', 'G']
        self.metrics_names = ['Mse', 'CosineSimilarity']
        self.visual_names = [
            'G_real', 'G_fake', 'G_input', 'G_map_fake', 'G_map_real'
        ]

        # identity loss
        if self.opt.idt_lambda > 0:
            self.loss_names += ['idt']

        # feature metric loss
        if self.opt.fea_m_lambda > 0:
            self.loss_names += ['fea_m']

        # map loss
        if self.opt.map_m_lambda > 0:
            self.loss_names += ['map_m']

        if self.training:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']

        self.netG = factory.define_3DG(
            opt.noise_len, opt.input_shape, opt.output_shape, opt.input_nc_G,
            opt.output_nc_G, opt.ngf, opt.which_model_netG, opt.n_downsampling,
            opt.norm_G, not opt.no_dropout, opt.init_type, self.gpu_ids,
            opt.n_blocks, opt.encoder_input_shape, opt.encoder_input_nc,
            opt.encoder_norm)

        if self.training:
            if opt.ganloss == 'gan':
                use_sigmoid = True
            elif opt.ganloss == 'lsgan':
                use_sigmoid = False
            elif opt.ganloss == 'wgan':
                self.loss_names += ['wasserstein']
                use_sigmoid = False
            elif opt.ganloss == 'wgan_gp':
                self.loss_names += ['wasserstein', 'grad_penalty']
                use_sigmoid = False
            else:
                raise ValueError()

            self.netD = factory.define_D(opt.input_nc_D, opt.ndf,
                                         opt.which_model_netD, opt.n_layers_D,
                                         opt.norm_D, use_sigmoid,
                                         opt.init_type, self.gpu_ids,
                                         opt.discriminator_feature)

            self.fake_pool = ImagePool(opt.pool_size)
示例#6
0
    def init_loss(self, opt):
        Base_Model.init_loss(self, opt)

        # feature metric loss
        self.criterion = torch.nn.MSELoss()

        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2),
                                            weight_decay=1e-4)

        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
示例#7
0
    def init_loss(self, opt):
        Base_Model.init_loss(self, opt)

        # #####################
        # define loss functions
        # #####################
        # GAN loss
        if opt.ganloss == 'gan':
            self.criterionGAN = GANLoss(use_lsgan=False).to(self.device)
        elif opt.ganloss == 'lsgan':
            self.criterionGAN = GANLoss(use_lsgan=True).to(self.device)
        elif opt.ganloss == 'wgan':
            self.criterionGAN = WGANLoss(grad_penalty=False).to(self.device)
        elif opt.ganloss == 'wgan_gp':
            self.criterionGAN = WGANLoss(grad_penalty=True).to(self.device)
        else:
            raise ValueError()

        # identity loss
        if opt.restruction_loss == 'mse':
            print('Restruction loss: MSE')
            self.criterionIdt = torch.nn.MSELoss()
        elif opt.restruction_loss == 'l1':
            print('Restruction loss: l1')
            self.criterionIdt = torch.nn.L1Loss()
        else:
            raise ValueError()

        # feature metric loss
        self.criterionFea = torch.nn.L1Loss()

        # map loss
        self.criterionMap = Map_loss(direct_mean=opt.map_m_type,
                                     predict_transition=self.opt.CT_MIN_MAX,
                                     gt_transition=self.opt.XRAY1_MIN_MAX).to(
                                         self.device)

        # #####################
        # initialize optimizers
        # #####################
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
示例#8
0
    def init_network(self, opt):
        Base_Model.init_network(self, opt)

        self.if_pool = opt.if_pool
        self.multi_view = opt.multi_view
        assert len(self.multi_view) > 0

        self.loss_names = ['MSE']
        self.metrics_names = ['Mse', 'CosineSimilarity', 'PSNR']
        self.visual_names = [
            'G_real', 'G_fake', 'G_input', 'G_Map_fake_F', 'G_Map_real_F',
            'G_Map_fake_S', 'G_Map_real_S'
        ]

        self.model_names = ['G']

        self.netG = Xray3DVolumes_2DModel()

        self.netG = init_net(self.netG, opt.init_type, self.gpu_ids)
示例#9
0
    def init_network(self, opt):
        Base_Model.init_network(self, opt)

        self.if_pool = opt.if_pool
        self.multi_view = opt.multi_view

        assert len(self.multi_view) > 0

        self.metrics_names = ['Mse', 'CosineSimilarity', 'PSNR']
        self.visual_names = [
            'G_real', 'G_fake', 'G_input1', 'G_input2', 'G_Map_fake_F',
            'G_Map_real_F', 'G_Map_fake_S', 'G_Map_real_S'
        ]

        self.netG = factory.define_3DG(opt.noise_len,
                                       opt.input_shape,
                                       opt.output_shape,
                                       opt.input_nc_G,
                                       opt.output_nc_G,
                                       opt.ngf,
                                       opt.which_model_netG,
                                       opt.n_downsampling,
                                       opt.norm_G,
                                       not opt.no_dropout,
                                       opt.init_type,
                                       self.gpu_ids,
                                       opt.n_blocks,
                                       opt.encoder_input_shape,
                                       opt.encoder_input_nc,
                                       opt.encoder_norm,
                                       opt.encoder_blocks,
                                       opt.skip_number,
                                       opt.activation_type,
                                       opt=opt)

        self.loss_names = ['idt']
        self.model_names = ['G']

        # map loss
        if self.opt.map_projection_lambda > 0:
            self.loss_names += ['map_m']
示例#10
0
    def init_network(self, opt):
        Base_Model.init_network(self, opt)

        self.if_pool = opt.if_pool
        self.multi_view = opt.multi_view
        self.conditional_D = opt.conditional_D
        self.order_map_list = [(0, 1, 2, 3, 4), (0, 1, 3, 2, 4),
                               (0, 1, 4, 2, 3)]
        assert len(self.multi_view) >= 0

        self.loss_names = ['D', 'G']
        self.metrics_names = ['Mse', 'CosineSimilarity', 'PSNR']
        self.visual_names = [
            'G_real', 'G_fake', 'G_input', 'G_Map_fake_F', 'G_Map_real_F',
            'G_Map_fake_S', 'G_Map_real_S'
        ]
        self.model_names = ['G']

        if self.training:
            self.model_names += ['D']

            # identity loss
            if self.opt.idt_lambda > 0:
                self.loss_names += ['idt']

            # feature metric loss
            if self.opt.feature_D_lambda > 0:
                self.loss_names += ['fea_m']

            self.loss_names += ['D_Map', 'G_Map']

            # feature metric loss
            if self.opt.feature_D_map_lambda > 0:
                self.loss_names += ['fea_m_Map']

            # map loss
            if self.opt.map_projection_lambda > 0:
                self.loss_names += ['idt_Map']

            # models
            self.model_names += ['D_Map']

        self.netG = factory.define_3DG(opt.noise_len,
                                       opt.input_shape,
                                       opt.output_shape,
                                       opt.input_nc_G,
                                       opt.output_nc_G,
                                       opt.ngf,
                                       opt.which_model_netG,
                                       opt.n_downsampling,
                                       opt.norm_G,
                                       not opt.no_dropout,
                                       opt.init_type,
                                       self.gpu_ids,
                                       opt.n_blocks,
                                       opt.encoder_input_shape,
                                       opt.encoder_input_nc,
                                       opt.encoder_norm,
                                       opt.encoder_blocks,
                                       opt.skip_number,
                                       opt.activation_type,
                                       opt=opt)

        if self.training:
            # out of discriminator is not probability when
            # GAN loss is LSGAN
            use_sigmoid = False

            # conditional Discriminator
            if self.conditional_D:
                d_input_channels = opt.input_nc_D + 1
            else:
                d_input_channels = opt.input_nc_D
            self.netD = factory.define_D(d_input_channels,
                                         opt.ndf,
                                         opt.which_model_netD,
                                         opt.n_layers_D,
                                         opt.norm_D,
                                         use_sigmoid,
                                         opt.init_type,
                                         self.gpu_ids,
                                         opt.discriminator_feature,
                                         num_D=opt.num_D,
                                         n_out_channels=opt.n_out_ChannelsD)

            self.netD_Map = factory.define_D(
                len(self.multi_view),
                opt.map_ndf,
                opt.map_which_model_netD,
                opt.map_n_layers_D,
                opt.map_norm_D,
                use_sigmoid,
                opt.init_type,
                self.gpu_ids,
                opt.discriminator_feature,
                opt.map_num_D,
                n_out_channels=opt.map_n_out_ChannelsD)
            if self.if_pool:
                self.fake_pool = ImagePool(opt.pool_size)
                self.fake_pool_Map = ImagePool(opt.map_pool_size)