Beispiel #1
0
    def __init__(self,
                 module,
                 weight_init,
                 optimizer,
                 scheduler,
                 loss,
                 ensemble=False):
        self.G = getGANModule(**module)
        self.D = unet.SimpleClassify(1, 1, module["param"]["unit"])
        if init:
            init_func = Init(init)
            self.G.apply(init_func)
            self.D.apply(init_func)

        self.optimizer_G = torch.optim.Adam(params=self.G.parameters(),
                                            lr=0.001)
        self.optimizer_D = torch.optim.SGD(params=self.G.parameters(), lr=0.01)

        self.scheduler_G = getScheduler(**scheduler,
                                        optimizer=self.optimizer_G)
        self.scheduler_D = getScheduler(**scheduler,
                                        optimizer=self.optimizer_D)

        self.criterion = getLoss(**loss)
        self.ganLoss = L.GANLoss(lsgan=True)

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
Beispiel #2
0
    def __init__(self,args):
        self.args=args
        
        generator=Generator(args,1,1,7,64,use_dropout=False)
        if args.cgan==True:
            discriminator=Discriminator(2,64,n_layers=3)
        else:
            discriminator=Discriminator(1,64,n_layers=3)

        self.device = torch.device("cuda:"+args.gpu_ids if torch.cuda.is_available() else "cpu")
        cudnn.benchmark = True  # There is BN issue for early version of PyTorch
                        # see https://github.com/bearpaw/pytorch-pose/issues/33

        self.is_train=args.is_train
        
        #G&D
        if 'hig' in args.use_net:
            self.netG=generator.net.to(self.device)

        if 'hid' in args.use_net:
            self.netD=discriminator.net.to(self.device)

        if self.is_train:
            self.criterionGAN = losses.GANLoss(args.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
    
            if 'hig' in args.train_net:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                        lr=args.lr_hig,
                                        weight_decay=args.weight_decay,betas=(args.beta1,0.999))
        
            if 'hid' in args.train_net:
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                        lr=args.lr_hid,
                                        weight_decay=args.weight_decay,betas=(args.beta1,0.999))
Beispiel #3
0
    def __init__(self, conf):
        os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'

        # Acquire configuration
        self.conf = conf
        self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Define the GAN
        self.G = networks.Generator(conf)
        self.D = networks.Discriminator(conf)

        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            print("gpu num : ", torch.cuda.device_count())
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)
        print("haha, gpu num : ", torch.cuda.device_count())
        self.G.to(self._device)
        self.D.to(self._device)

        # Calculate D's input & output shape according to the shaving done by the networks
        if torch.cuda.device_count() > 1:
            self.d_input_shape = self.G.module.output_size
            self.d_output_shape = self.d_input_shape - self.D.module.forward_shave
        else:
            self.d_input_shape = self.G.output_size
            self.d_output_shape = self.d_input_shape - self.D.forward_shave

        # Input tensors
        self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda()
        self.d_input = torch.FloatTensor(1, 3, self.d_input_shape, self.d_input_shape).cuda()

        # The kernel G is imitating
        self.curr_k = torch.FloatTensor(conf.G_kernel_size, conf.G_kernel_size).cuda()

        # Losses
        self.GAN_loss_layer = loss.GANLoss(d_last_layer_size=self.d_output_shape).cuda()
        self.bicubic_loss = loss.DownScaleLoss(scale_factor=conf.scale_factor).cuda()
        self.sum2one_loss = loss.SumOfWeightsLoss().cuda()
        self.boundaries_loss = loss.BoundariesLoss(k_size=conf.G_kernel_size).cuda()
        self.centralized_loss = loss.CentralizedLoss(k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda()
        self.sparse_loss = loss.SparsityLoss().cuda()
        self.loss_bicubic = 0

        # Define loss function
        self.criterionGAN = self.GAN_loss_layer.forward

        # Initialize networks weights
        self.G.apply(networks.weights_init_G)
        self.D.apply(networks.weights_init_D)

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999))

        print('*' * 60 + '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path)
Beispiel #4
0
    def __init__(self, conf):
        # Acquire configuration
        self.conf = conf

        # Define the GAN
        self.G = networks.Generator(conf).cuda()
        self.D = networks.Discriminator(conf).cuda()

        # Calculate D's input & output shape according to the shaving done by the networks
        self.d_input_shape = self.G.output_size
        self.d_output_shape = self.d_input_shape - self.D.forward_shave

        # Input tensors
        self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size,
                                         conf.input_crop_size).cuda()
        self.d_input = torch.FloatTensor(1, 3, self.d_input_shape,
                                         self.d_input_shape).cuda()

        # The kernel G is imitating
        self.curr_k = torch.FloatTensor(conf.G_kernel_size,
                                        conf.G_kernel_size).cuda()

        # Losses
        self.GAN_loss_layer = loss.GANLoss(
            d_last_layer_size=self.d_output_shape).cuda()
        self.bicubic_loss = loss.DownScaleLoss(
            scale_factor=conf.scale_factor).cuda()
        self.sum2one_loss = loss.SumOfWeightsLoss().cuda()
        self.boundaries_loss = loss.BoundariesLoss(
            k_size=conf.G_kernel_size).cuda()
        self.centralized_loss = loss.CentralizedLoss(
            k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda()
        self.sparse_loss = loss.SparsityLoss().cuda()
        self.loss_bicubic = 0

        # Define loss function
        self.criterionGAN = self.GAN_loss_layer.forward

        # Initialize networks weights
        self.G.apply(networks.weights_init_G)
        self.D.apply(networks.weights_init_D)

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                            lr=conf.g_lr,
                                            betas=(conf.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                            lr=conf.d_lr,
                                            betas=(conf.beta1, 0.999))

        self.iteration = 0  # for tensorboard
        # self.ground_truth_kernel = np.loadtxt(conf.ground_truth_kernel_path)
        # writer.add_image("ground_truth_kernel", (self.ground_truth_kernel - np.min(self.ground_truth_kernel)) / (np.max(self.ground_truth_kernel - np.min(self.ground_truth_kernel))), 0, dataformats="HW")

        print('*' * 60 +
              '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path)
Beispiel #5
0
 def __init__(self,pix2pix,hpe1,hpe2,args):
     super(FusionNet, self).__init__()
     self.net_hig=pix2pix.netG.to(args.device)
     self.net_hpe1=hpe1.to(args.device)
     self.net_hpe2=hpe2.to(args.device)
     
     if args.discriminator_reconstruction==True:
         net_hpd=SkeletonDiscriminator(args.skeleton_orig_dim)
     else:
         net_hpd=SkeletonDiscriminator(args.skeleton_pca_dim)
     self.net_hpd=net_hpd.to(args.device)
     self.net_recon=ReconstructionLayer(args.skeleton_pca_dim,args.skeleton_orig_dim)
     
     
     
     self.device=args.device
     self.is_train=args.is_train
     self.args=args
     
     self.out={}
     self.out['hig']=[]
     self.out['hpe1']=[]
     self.out['hig_hpe1']=[]
     self.out['hpe2']=[]
     self.out['hpe1_hpd']=[]
     self.out['hig_hpe1_hpd']=[]
     self.out['hpe2_hpd']=[]
     
     #to print loss
     self.loss_D_real={}
     self.loss_D_fake={}
     self.loss_D_real['hpe2']=None
     self.loss_D_fake['hpe2']=None
     self.loss_D_real['hig_hpe1']=None
     self.loss_D_fake['hig_hpe1']=None
     
     
     if self.is_train:
         #self.criterionL1 = torch.nn.L1Loss()
         
         self.optimizer_hig=torch.optim.Adam(self.net_hig.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,betas=(args.beta1,0.999))
     
         self.optimizer_hpe2=torch.optim.Adam(self.net_hpe2.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,betas=(args.beta1,0.999))
     
         self.optimizer_hpd=torch.optim.Adam(self.net_hpd.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,betas=(args.beta1,0.999))
     
     
         self.JointsMSELoss= losses.JointsMSELoss().to(self.device)
         self.criterionGAN = losses.GANLoss(args.gan_mode).to(self.device)
    def __init__(self, args, dataloaders):

        self.dataloaders = dataloaders
        self.net_D1 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=2).to(device)
        self.net_D2 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=2).to(device)
        self.net_D3 = cycnet.define_D(input_nc=6,
                                      ndf=64,
                                      netD='n_layers',
                                      n_layers_D=3).to(device)
        self.net_G = cycnet.define_G(input_nc=3,
                                     output_nc=6,
                                     ngf=args.ngf,
                                     netG=args.net_G,
                                     use_dropout=False,
                                     norm='none').to(device)
        # M.Amintoosi norm='instance'
        # self.net_G = cycnet.define_G(
        #     input_nc=3, output_nc=6, ngf=args.ngf, netG=args.net_G, use_dropout=False, norm='instance').to(device)

        # Learning rate and Beta1 for Adam optimizers
        self.lr = args.lr

        # define optimizers
        self.optimizer_G = optim.Adam(self.net_G.parameters(),
                                      lr=self.lr,
                                      betas=(0.5, 0.999))
        self.optimizer_D1 = optim.Adam(self.net_D1.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))
        self.optimizer_D2 = optim.Adam(self.net_D2.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))
        self.optimizer_D3 = optim.Adam(self.net_D3.parameters(),
                                       lr=self.lr,
                                       betas=(0.5, 0.999))

        # define lr schedulers
        self.exp_lr_scheduler_G = lr_scheduler.StepLR(
            self.optimizer_G,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D1 = lr_scheduler.StepLR(
            self.optimizer_D1,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D2 = lr_scheduler.StepLR(
            self.optimizer_D2,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)
        self.exp_lr_scheduler_D3 = lr_scheduler.StepLR(
            self.optimizer_D3,
            step_size=args.exp_lr_scheduler_stepsize,
            gamma=0.1)

        # coefficient to balance loss functions
        self.lambda_L1 = args.lambda_L1
        self.lambda_adv = args.lambda_adv

        # based on which metric to update the "best" ckpt
        self.metric = args.metric

        # define some other vars to record the training states
        self.running_acc = []
        self.epoch_acc = 0
        if 'mse' in self.metric:
            self.best_val_acc = 1e9  # for mse, rmse, a lower score is better
        else:
            self.best_val_acc = 0.0  # for others (ssim, psnr), a higher score is better
        self.best_epoch_id = 0
        self.epoch_to_start = 0
        self.max_num_epochs = args.max_num_epochs
        self.G_pred1 = None
        self.G_pred2 = None
        self.batch = None
        self.G_loss = None
        self.D_loss = None
        self.is_training = False
        self.batch_id = 0
        self.epoch_id = 0
        self.checkpoint_dir = args.checkpoint_dir
        self.vis_dir = args.vis_dir
        self.D1_fake_pool = utils.ImagePool(pool_size=50)
        self.D2_fake_pool = utils.ImagePool(pool_size=50)
        self.D3_fake_pool = utils.ImagePool(pool_size=50)

        # define the loss functions
        if args.pixel_loss == 'minimum_pixel_loss':
            self._pxl_loss = loss.MinimumPixelLoss(
                opt=1)  # 1 for L1 and 2 for L2
        elif args.pixel_loss == 'pixel_loss':
            self._pxl_loss = loss.PixelLoss(opt=1)  # 1 for L1 and 2 for L2
        else:
            raise NotImplementedError(
                'pixel loss function [%s] is not implemented', args.pixel_loss)
        self._gan_loss = loss.GANLoss(gan_mode='vanilla').to(device)
        self._exclusion_loss = loss.ExclusionLoss()
        self._kurtosis_loss = loss.KurtosisLoss()
        # enable some losses?
        self.with_d1d2 = args.enable_d1d2
        self.with_d3 = args.enable_d3
        self.with_exclusion_loss = args.enable_exclusion_loss
        self.with_kurtosis_loss = args.enable_kurtosis_loss

        # m-th epoch to activate adversarial training
        self.m_epoch_activate_adv = int(self.max_num_epochs / 20) + 1

        # output auto-enhancement?
        self.output_auto_enhance = args.output_auto_enhance

        # use synfake to train D?
        self.synfake = args.enable_synfake

        # check and create model dir
        if os.path.exists(self.checkpoint_dir) is False:
            os.mkdir(self.checkpoint_dir)
        if os.path.exists(self.vis_dir) is False:
            os.mkdir(self.vis_dir)

        # visualize model
        if args.print_models:
            self._visualize_models()
Beispiel #7
0
    def __init__(self, conf):
        # Fix random seed
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True  # slightly reduces throughput

        # Acquire configuration
        self.conf = conf

        # Define the networks
        self.G_DN = networks.Generator_DN().cuda()
        self.D_DN = networks.Discriminator_DN().cuda()
        self.G_UP = networks.Generator_UP().cuda()

        # Losses
        self.criterion_gan = loss.GANLoss().cuda()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_interp = torch.nn.L1Loss()
        self.regularization = loss.DownsamplerRegularization(
            conf.scale_factor_downsampler, self.G_DN.G_kernel_size)

        # Initialize networks weights
        self.G_DN.apply(networks.weights_init_G_DN)
        self.D_DN.apply(networks.weights_init_D_DN)
        self.G_UP.apply(networks.weights_init_G_UP)

        # Optimizers
        self.optimizer_G_DN = torch.optim.Adam(self.G_DN.parameters(),
                                               lr=conf.lr_G_DN,
                                               betas=(conf.beta1, 0.999))
        self.optimizer_D_DN = torch.optim.Adam(self.D_DN.parameters(),
                                               lr=conf.lr_D_DN,
                                               betas=(conf.beta1, 0.999))
        self.optimizer_G_UP = torch.optim.Adam(self.G_UP.parameters(),
                                               lr=conf.lr_G_UP,
                                               betas=(conf.beta1, 0.999))

        # Read input image
        self.in_img = util.read_image(conf.input_image_path)
        self.in_img_t = util.im2tensor(self.in_img)
        b_x = self.in_img_t.shape[2] % conf.scale_factor
        b_y = self.in_img_t.shape[3] % conf.scale_factor
        self.in_img_cropped_t = self.in_img_t[..., b_x:, b_y:]

        self.gt_img = util.read_image(
            conf.gt_path) if conf.gt_path is not None else None
        self.gt_kernel = loadmat(
            conf.kernel_path
        )['Kernel'] if conf.kernel_path is not None else None

        if self.gt_kernel is not None:
            self.gt_kernel = np.pad(self.gt_kernel, 1, 'constant')
            self.gt_kernel = util.kernel_shift(self.gt_kernel,
                                               sf=conf.scale_factor)
            self.gt_kernel_t = torch.FloatTensor(self.gt_kernel).cuda()

            self.gt_downsampled_img_t = util.downscale_with_kernel(
                self.in_img_cropped_t, self.gt_kernel_t)
            self.gt_downsampled_img = util.tensor2im(self.gt_downsampled_img_t)

        # Debug variables
        self.debug_steps = []
        self.UP_psnrs = [] if self.gt_img is not None else None
        self.DN_psnrs = [] if self.gt_kernel is not None else None

        if self.conf.debug:
            self.loss_GANs = []
            self.loss_cycle_forwards = []
            self.loss_cycle_backwards = []
            self.loss_interps = []
            self.loss_Discriminators = []

        self.iter = 0
def train_rnn(config):
    num_gpu = len(config.gpu)
    if config.use_npy:
        dataset_train = NpySeqDataset(train_file=config.filename_list,
                                      config=config,
                                      transform=transforms.Compose([
                                          Resizer(config.size_image),
                                          Normalizer(),
                                          ToTensor()
                                      ]))
    elif config.use_seq:
        dataset_train = SeqDataset(train_file=config.filename_list,
                                   use_mask=None,
                                   config=config,
                                   transform=transforms.Compose([
                                       Resizer(config.size_image),
                                       Normalizer(),
                                       ToTensor()
                                   ]))
    dataloader = DataLoader(dataset_train,
                            num_workers=8,
                            collate_fn=collater,
                            batch_size=config.batch_size,
                            shuffle=True)

    G_model = model_G.LipGeneratorRNN(config.audio_encoder,
                                      config.img_encoder,
                                      config.img_decoder,
                                      config.rnn_type,
                                      config.size_image,
                                      config.num_output_length,
                                      if_tanh=config.if_tanh)
    D_model = model_D.Discriminator(config)
    D_v_model = model_D.DiscriminatorVideo(config)
    D_model_lip = model_D.DiscriminatorLip(config)

    adversarial_lip_loss = loss.GAN_LR_Loss()
    adversarial_loss = loss.GANLoss()
    recon_loss = loss.ReconLoss()

    if config.ckpt is not None:
        load_ckpt(G_model, config.ckpt)
    if config.discriminator_lip == 'lip_read':
        load_ckpt(D_model_lip, config.ckpt_lipmodel, prefix='discriminator.')

    # optimizer = optim.Adam(filter(lambda p: p.requires_grad, G_model.parameters()), lr=0.002, betas=(0.5, 0.999))
    optimizer_G = optim.Adam(G_model.parameters(),
                             lr=config.lr,
                             betas=(0.5, 0.999))
    if config.discriminator is not None:
        optimizer_D = optim.Adam(D_model.parameters(),
                                 lr=0.001,
                                 betas=(0.5, 0.999))
    if config.discriminator_v is not None:
        optimizer_D_v = optim.Adam(D_v_model.parameters(),
                                   lr=0.001,
                                   betas=(0.5, 0.999))
    if config.discriminator_lip is not None:
        optimizer_D_lip = optim.Adam(D_model_lip.parameters(),
                                     lr=0.001,
                                     betas=(0.5, 0.999))

    if num_gpu > 1:
        G_model = torch.nn.DataParallel(G_model,
                                        device_ids=list(
                                            range(num_gpu))).cuda()
        D_model = torch.nn.DataParallel(D_model,
                                        device_ids=list(
                                            range(num_gpu))).cuda()
        D_v_model = torch.nn.DataParallel(D_v_model,
                                          device_ids=list(
                                              range(num_gpu))).cuda()
        D_model_lip = torch.nn.DataParallel(D_model_lip,
                                            device_ids=list(
                                                range(num_gpu))).cuda()
    else:
        G_model = G_model.cuda()
        D_model = D_model.cuda()
        D_model_lip = D_model_lip.cuda()
        D_v_model = D_v_model.cuda()

    adversarial_lip_loss = adversarial_lip_loss.cuda()
    adversarial_loss = adversarial_loss.cuda()
    recon_loss = recon_loss.cuda()

    writer = SummaryWriter(log_dir=config.save_dir)

    sample_inputs = None
    for epoch_num in range(config.epochs):
        G_model.train()
        D_model.train()
        D_v_model.train()
        D_model_lip.train()
        for iter_num, data in enumerate(dataloader):
            n_iter = len(dataloader) * epoch_num + iter_num
            if sample_inputs == None:
                sample_inputs = (data['img'].cuda(), data['audio'].cuda(),
                                 data['gt'].cuda(), data['len'].cuda())
            try:
                input_images = data['img'].cuda()
                input_audios = data['audio'].cuda()
                gts_orignal = data['gt'].cuda()

                G_images_orignal = G_model(
                    input_images,
                    input_audios,
                    valid_len=data['len'].cuda(),
                    teacher_forcing_ratio=config.teacher_force_ratio)
                loss_EG = recon_loss(G_images_orignal,
                                     gts_orignal,
                                     valid_len=data['len'].cuda())
                G_loss = loss_EG

                if config.discriminator is not None:
                    loss_G_GAN = adversarial_loss(D_model(G_images_orignal),
                                                  is_real=True)
                    loss_D_real = adversarial_loss(D_model(gts_orignal),
                                                   is_real=True)
                    loss_D_fake = adversarial_loss(D_model(
                        G_images_orignal.detach()),
                                                   is_real=False)
                    G_loss = G_loss + 0.002 * loss_G_GAN
                    D_loss = loss_D_real + loss_D_fake

                if config.discriminator_v is not None:
                    clip_range = get_clip_range(data['len'],
                                                config.num_frames_D)
                    loss_G_GAN_v = adversarial_loss(D_v_model(
                        G_images_orignal, config.num_frames_D,
                        clip_range.cuda()),
                                                    is_real=True)
                    loss_D_real_v = adversarial_loss(D_v_model(
                        gts_orignal, config.num_frames_D, clip_range.cuda()),
                                                     is_real=True)
                    loss_D_fake_v = adversarial_loss(D_v_model(
                        G_images_orignal.detach(), config.num_frames_D,
                        clip_range.cuda()),
                                                     is_real=False)
                    G_loss = G_loss + config.D_v_weight * loss_G_GAN_v
                    D_loss_v = loss_D_real_v + loss_D_fake_v

                if config.discriminator_lip is not None:
                    clip_range = get_clip_range(data['len'],
                                                config.num_frames_lipNet)

                    loss_G_GAN_lip = adversarial_lip_loss(
                        D_model_lip(G_images_orignal, clip_range.cuda(),
                                    data['lip'].cuda()),
                        is_real=True,
                        targets=data['label'].cuda())
                    loss_D_real_lip = adversarial_lip_loss(
                        D_model_lip(gts_orignal, clip_range.cuda(),
                                    data['lip'].cuda()),
                        is_real=True,
                        targets=data['label'].cuda())
                    loss_D_fake_lip = adversarial_lip_loss(
                        D_model_lip(G_images_orignal.detach(),
                                    clip_range.cuda(), data['lip'].cuda()),
                        is_real=False,
                        targets=data['label'].cuda())
                    G_loss = G_loss + config.D_lip_weight * loss_G_GAN_lip
                    D_loss_lip = loss_D_real_lip + loss_D_fake_lip

                # for generator
                optimizer_G.zero_grad()
                G_loss.backward()
                optimizer_G.step()

                # for discriminator
                if config.discriminator is not None:
                    optimizer_D.zero_grad()
                    D_loss.backward()
                    optimizer_D.step()

                if config.discriminator_lip is not None:
                    optimizer_D_lip.zero_grad()
                    D_loss_lip.backward()
                    optimizer_D_lip.step()

                if config.discriminator_v is not None:
                    optimizer_D_v.zero_grad()
                    D_loss_v.backward()
                    optimizer_D_v.step()

                if iter_num % 20 == 0:
                    print('Epoch: {} | Iteration: {} | EG loss: {:1.5f}: '.
                          format(epoch_num, iter_num, float(loss_EG)))
                    if config.discriminator is not None:
                        print('D loss {:1.5f} | G_GAN loss {:1.5f} : '.format(
                            float(D_loss), float(loss_G_GAN)))
                        writer.add_scalar('D_loss', D_loss, n_iter)
                    if config.discriminator_v is not None:
                        print('D_v loss: {:1.5f} | G_GAN_v loss {:1.5f} : '.
                              format(float(D_loss_v), float(loss_G_GAN_v)))
                        writer.add_scalar('D_loss_v', D_loss_v, n_iter)
                    if config.discriminator_lip is not None:
                        print('D_lip loss {:1.5f} | G_GAN_lip loss {:1.5f} : '.
                              format(float(D_loss_lip), float(loss_G_GAN_lip)))
                        writer.add_scalar('D_loss_lip', D_loss_lip, n_iter)

                    writer.add_scalar('loss_EG', loss_EG, n_iter)

            except Exception as e:
                print(e)
                traceback.print_exc()

        # visualize some results
        sample(sample_inputs, G_model, epoch_num, config.save_dir,
               config.teacher_force_ratio)
        test(G_model, config.test_dir, config.save_dir, config.size_image)

        if isinstance(G_model, torch.nn.DataParallel):
            torch.save(
                G_model.module.state_dict(),
                os.path.join(config.save_dir,
                             'model_G{}.pt'.format(epoch_num)))
            if config.discriminator_lip is not None:
                torch.save(
                    D_model_lip.module.state_dict(),
                    os.path.join(config.save_dir,
                                 'model_D{}.pt'.format(epoch_num)))
        else:
            torch.save(
                G_model.state_dict(),
                os.path.join(config.save_dir,
                             'model_G{}.pt'.format(epoch_num)))
            if config.discriminator_lip is not None:
                torch.save(
                    D_model_lip.state_dict(),
                    os.path.join(config.save_dir,
                                 'model_D{}.pt'.format(epoch_num)))

    if isinstance(G_model, torch.nn.DataParallel):
        torch.save(G_model.module.state_dict(),
                   os.path.join(config.save_dir, 'model_G_final.pt'))
    else:
        torch.save(G_model.state_dict(),
                   os.path.join(config.save_dir, 'model_G_final.pt'))
Beispiel #9
0
    def __init__(self, pix2pix, hpe1, hpe2, args):
        super(FusionNet, self).__init__()
        if args.gpu_ids == 'all':
            self.net_hig = nn.DataParallel(pix2pix.netG, device_ids=[0, 1])
            self.net_hpe1 = nn.DataParallel(hpe1, device_ids=[0, 1])
            self.net_hpe2 = nn.DataParallel(hpe2, device_ids=[0, 1])

            self.net_hig.to(args.device)
            self.net_hpe1.to(args.device)
            self.net_hpe2.to(args.device)
        else:
            self.net_hig = pix2pix.netG.to(args.device)
            self.net_hpe1 = hpe1.to(args.device)
            self.net_hpe2 = hpe2.to(args.device)

        if args.discriminator_reconstruction == True:
            net_hpd = SkeletonDiscriminator(args.skeleton_orig_dim)
        else:
            net_hpd = SkeletonDiscriminator(args.skeleton_pca_dim)

        if args.gpu_ids == 'all':
            self.net_hpd = nn.DataParallel(net_hpd, device_ids=[0, 1])
            self.net_hpd.to(args.device)
        else:
            self.net_hpd = net_hpd.to(args.device)

        if args.gpu_ids == 'all':
            self.net_hid = nn.DataParallel(pix2pix.netD, device_ids=[0, 1])
            self.net_hid.to(args.device)
        else:
            #self.net_hid=pix2pix.netD.to(args.device)
            self.net_hid = net_hpd.to(args.device)

        self.net_recon = ReconstructionLayer(args.skeleton_pca_dim,
                                             args.skeleton_orig_dim)

        self.device = args.device
        self.is_train = args.is_train
        self.args = args

        self.out = {}
        self.out['hig'] = []
        self.out['hpe1'] = []
        self.out['hig_hpe1'] = []
        self.out['hpe2'] = []
        self.out['hpe1_hpd'] = []
        self.out['hig_hpe1_hpd'] = []
        self.out['hpe2_hpd'] = []

        #to print loss
        self.loss_hpd_real = {}
        self.loss_hpd_fake = {}
        self.loss_hpd_real['hig_hpe1'] = None
        self.loss_hpd_real['hpe2'] = None
        self.loss_hpd_fake['hig_hpe1'] = None
        self.loss_hpd_fake['hpe2'] = None

        #
        self.use_blurdata = False
        self.selection_idx = 0

        #
        if self.is_train:
            #self.criterionL1 = torch.nn.L1Loss()

            self.optimizer_hig = torch.optim.Adam(
                self.net_hig.parameters(),
                lr=args.lr_hig,
                weight_decay=args.weight_decay,
                betas=(args.beta1, 0.999))

            self.optimizer_hpe2 = torch.optim.Adam(
                self.net_hpe2.parameters(),
                lr=args.lr_hpe2,
                weight_decay=args.weight_decay,
                betas=(args.beta1, 0.999))

            self.optimizer_hpd = torch.optim.Adam(
                self.net_hpd.parameters(),
                lr=args.lr_hpd,
                weight_decay=args.weight_decay,
                betas=(args.beta1, 0.999))

            self.optimizer_hid = torch.optim.Adam(
                self.net_hid.parameters(),
                lr=args.lr_hid,
                weight_decay=args.weight_decay,
                betas=(args.beta1, 0.999))

            self.JointsMSELoss = losses.JointsMSELoss().to(self.device)
            self.criterionGAN = losses.GANLoss(args.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
Beispiel #10
0
    def __init__(self, hpe1_orig, pix2pix, hpe1, hpe2, hpe1_refine, args):
        super(FusionNet, self).__init__()
        self.args = args

        if pix2pix != None: self.net_hig = pix2pix.netG.to(args.device)
        if hpe1 != None: self.net_hpe1 = hpe1.to(args.device)
        if hpe2 != None: self.net_hpe2 = hpe2.to(args.device)
        if hpe1_orig != None: self.net_hpe1_orig = hpe1_orig.to(args.device)
        if hpe1_refine != None:
            self.net_hpe1_refine = hpe1_refine.to(args.device)

        if 'hpd' in args.use_net:
            if 'recon' in args.use_net:
                net_hpd = SkeletonDiscriminator(args.skeleton_orig_dim)
            else:
                net_hpd = SkeletonDiscriminator(args.skeleton_pca_dim)
            self.net_hpd = net_hpd.to(args.device)

        if 'hid' in args.use_net:
            self.net_hid = pix2pix.netD.to(args.device)
        if 'recon' in args.use_net:
            self.net_recon = ReconstructionLayer(args.skeleton_pca_dim,
                                                 args.skeleton_orig_dim)

        self.device = args.device
        self.is_train = args.is_train

        self.out = {}
        self.out['hpe1_orig'] = []
        self.out['hig'] = []
        self.out['hig_hpe1'] = []
        self.out['hpe2'] = []
        self.out['hpe1_refine'] = []

        self.out['hpe1_orig_hpd'] = []
        self.out['hig_hpe1_hpd'] = []
        self.out['hpe2_hpd'] = []
        self.out['hpe1_refine_hpd'] = []

        #to print loss
        self.loss_hpd_real = {}
        self.loss_hpd_fake = {}
        self.loss_hpd_real['hpe1_orig'] = None
        self.loss_hpd_real['hig_hpe1'] = None
        self.loss_hpd_real['hpe2'] = None
        self.loss_hpd_real['hpe1_refine'] = None
        self.loss_hpd_fake['hpe1_orig'] = None
        self.loss_hpd_fake['hig_hpe1'] = None
        self.loss_hpd_fake['hpe2'] = None
        self.loss_hpd_fake['hpe1_refine'] = None
        #
        self.use_blurdata = False
        self.selection_idx = 0

        #
        if self.is_train:
            if 'hpe1' in args.train_net:
                self.optimizer_hpe1 = torch.optim.Adam(
                    self.net_hpe1.parameters(),
                    lr=args.lr_hpe1,
                    weight_decay=args.weight_decay,
                    betas=(args.beta1, 0.999))
            if 'hpe1_refine' in args.train_net:
                self.optimizer_hpe1_refine = torch.optim.Adam(
                    self.net_hpe1_refine.parameters(),
                    lr=args.lr_hpe1_refine,
                    weight_decay=args.weight_decay,
                    betas=(args.beta1, 0.999))
            if 'hig' in args.train_net:
                self.optimizer_hig = torch.optim.Adam(
                    self.net_hig.parameters(),
                    lr=args.lr_hig,
                    weight_decay=args.weight_decay,
                    betas=(args.beta1, 0.999))
            if 'hpe2' in args.train_net:
                self.optimizer_hpe2 = torch.optim.Adam(
                    self.net_hpe2.parameters(),
                    lr=args.lr_hpe2,
                    weight_decay=args.weight_decay,
                    betas=(args.beta1, 0.999))
            if 'hpd' in args.train_net:
                self.optimizer_hpd = torch.optim.Adam(
                    self.net_hpd.parameters(),
                    lr=args.lr_hpd,
                    weight_decay=args.weight_decay,
                    betas=(args.beta1, 0.999))
            if 'hid' in args.train_net:
                self.optimizer_hid = torch.optim.Adam(
                    self.net_hid.parameters(),
                    lr=args.lr_hid,
                    weight_decay=args.weight_decay,
                    betas=(args.beta1, 0.999))

            self.JointsMSELoss = losses.JointsMSELoss().to(self.device)
            self.criterionGAN = losses.GANLoss(args.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()