def __init__(self, device='cpu', last=nn.Sigmoid): super(SGAN, self).__init__() self.device = device self.net_g = G() self.net_d = D(last=last) self.criterion = GANLoss(relativistic=False) self.optim_G = Adam(self.net_g.parameters()) self.optim_D = Adam(self.net_d.parameters())
# content loss if opt.content_loss_type == 'L1_Charbonnier': content_loss = L1_Charbonnier_loss() elif opt.content_loss_type == 'L1': content_loss = torch.nn.L1Loss() elif opt.content_loss_type == 'L2': content_loss = torch.nn.MSELoss() # pixel loss if opt.pixel_loss_type == 'L1': pixel_loss = torch.nn.L1Loss() elif opt.pixel_loss_type == 'L2': pixel_loss = torch.nn.MSELoss() # gan loss GAN_loss = GANLoss(opt.gan_type, real_label_val=1.0, fake_label_val=0.0) edge_loss = edgeV_loss() tv_loss = TV_loss() # GPU if opt.cuda and not torch.cuda.is_available(): # 检查是否有GPU raise Exception('No GPU found, please run without --cuda') print("===> Setting GPU") if opt.cuda: print('cuda_mode:', opt.cuda) generator = generator.cuda() discriminator = discriminator.cuda() feature_extractor = feature_extractor.cuda() content_loss = content_loss.cuda() pixel_loss = pixel_loss.cuda() GAN_loss = GAN_loss.cuda() edge_loss = edge_loss.cuda()
vec.append(word2vec[term]) del word2vec # BERT Model model = modeling.BertNoEmbed(vocab=vocab, hidden_size=1024, enc_num_layer=3) model.load_state_dict(torch.load('checkpoint/bert-LanGen-last.pt')['state']) model.cuda() d_net = modeling.TextCNNClassify(vocab, vec, num_labels=2) d_net.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer_d = torch.optim.SGD(d_net.parameters(), lr=0.01) label_smoothing = modeling.LabelSmoothing(len(vocab), 0, 0.1) label_smoothing.cuda() gan_loss = GANLoss() gan_loss.cuda() G_STEP = 1 D_STEP = 3 D_PRE = 5 SAVE_EVERY = 50 PENALTY_EPOCH = -1 DRAW_LEARNING_CURVE = False data = [] # Tokenized input print('Tokenization...') with open('pair.csv') as PAIR: for line in tqdm(PAIR): [text, summary, _] = line.split(',') texts = []
cprint('==> Preparing Data Set: Complete\n', 'green') ################################################################################ cprint('==> Building Models', 'yellow') netG = define_G(opt.input_nc, opt.output_nc, opt.ngf, norm='batch', use_dropout=False, gpu_ids=gpu_ids) netD = define_D(opt.input_nc + opt.output_nc, opt.ndf, norm='batch', use_sigmoid=False, gpu_ids=gpu_ids) print('---------- Networks initialized -------------') print_network(netG) print_network(netD) print('-----------------------------------------------\n') cprint('==> Building Models: Complete\n', 'green') ################################################################################ criterionGAN = GANLoss() criterionL1 = nn.L1Loss() criterionMSE = nn.MSELoss() # setup optimizer optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) real_a = torch.FloatTensor(opt.batchSize, opt.input_nc, 256, 256) real_b = torch.FloatTensor(opt.batchSize, opt.output_nc, 256, 256) if opt.cuda: netD = netD.cuda() netG = netG.cuda() criterionGAN = criterionGAN.cuda() criterionL1 = criterionL1.cuda()
def train(opt): #### device device = torch.device('cuda:{}'.format(opt.gpu_id) if opt.gpu_id >= 0 else torch.device('cpu')) #### dataset data_loader = UnAlignedDataLoader() data_loader.initialize(opt) data_set = data_loader.load_data() print("The number of training images = %d." % len(data_set)) #### initialize models ## declaration E_a2Zb = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9) G_Zb2b = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type) T_Zb2Za = LatentTranslator(n_channels=256, norm_type=opt.norm_type, use_dropout=not opt.no_dropout) D_b = Discriminator(input_nc=opt.input_nc, ndf=opt.ndf, n_layers=opt.n_layers, norm_type=opt.norm_type) E_b2Za = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9) G_Za2a = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type) T_Za2Zb = LatentTranslator(n_channels=256, norm_type=opt.norm_type, use_dropout=not opt.no_dropout) D_a = Discriminator(input_nc=opt.input_nc, ndf=opt.ndf, n_layers=opt.n_layers, norm_type=opt.norm_type) ## initialization E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device) G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device) T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device) D_b = init_net(D_b, init_type=opt.init_type).to(device) E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device) G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device) T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device) D_a = init_net(D_a, init_type=opt.init_type).to(device) print( "+------------------------------------------------------+\nFinish initializing networks." ) #### optimizer and criterion ## criterion criterionGAN = GANLoss(opt.gan_mode).to(device) criterionZId = nn.L1Loss() criterionIdt = nn.L1Loss() criterionCTC = nn.L1Loss() criterionZCyc = nn.L1Loss() ## optimizer optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(), G_Zb2b.parameters(), T_Zb2Za.parameters(), E_b2Za.parameters(), G_Za2a.parameters(), T_Za2Zb.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(), D_b.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) ## scheduler scheduler = [ get_scheduler(optimizer_G, opt), get_scheduler(optimizer_D, opt) ] print( "+------------------------------------------------------+\nFinish initializing the optimizers and criterions." ) #### global variables checkpoints_pth = os.path.join(opt.checkpoints, opt.name) if os.path.exists(checkpoints_pth) is not True: os.mkdir(checkpoints_pth) os.mkdir(os.path.join(checkpoints_pth, 'images')) record_fh = open(os.path.join(checkpoints_pth, 'records.txt'), 'w', encoding='utf-8') loss_names = [ 'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B' ] fake_A_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images fake_B_pool = ImagePool( opt.pool_size ) # create image buffer to store previously generated images print( "+------------------------------------------------------+\nFinish preparing the other works." ) print( "+------------------------------------------------------+\nNow training is beginning .." ) #### training cur_iter = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() # timer for entire epoch for i, data in enumerate(data_set): ## setup inputs real_A = data['A'].to(device) real_B = data['B'].to(device) ## forward # image cycle / GAN latent_B = E_a2Zb(real_A) #-> a -> Zb : E_a2b(a) fake_B = G_Zb2b(latent_B) #-> Zb -> b' : G_b(E_a2b(a)) latent_A = E_b2Za(real_B) #-> b -> Za : E_b2a(b) fake_A = G_Za2a(latent_A) #-> Za -> a' : G_a(E_b2a(b)) # Idt ''' rec_A = G_Za2a(E_b2Za(fake_B)) #-> b' -> Za' -> rec_a : G_a(E_b2a(fake_b)) rec_B = G_Zb2b(E_a2Zb(fake_A)) #-> a' -> Zb' -> rec_b : G_b(E_a2b(fake_a)) ''' idt_latent_A = E_b2Za(real_A) #-> a -> Za : E_b2a(a) idt_A = G_Za2a(idt_latent_A) #-> Za -> idt_a : G_a(E_b2a(a)) idt_latent_B = E_a2Zb(real_B) #-> b -> Zb : E_a2b(b) idt_B = G_Zb2b(idt_latent_B) #-> Zb -> idt_b : G_b(E_a2b(b)) # ZIdt T_latent_A = T_Zb2Za(latent_B) #-> Zb -> Za'' : T_b2a(E_a2b(a)) T_rec_A = G_Za2a( T_latent_A) #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a))) T_latent_B = T_Za2Zb(latent_A) #-> Za -> Zb'' : T_a2b(E_b2a(b)) T_rec_B = G_Zb2b( T_latent_B) #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b))) # CTC T_idt_latent_B = T_Za2Zb(idt_latent_A) #-> a -> T_a2b(E_b2a(a)) T_idt_latent_A = T_Zb2Za(idt_latent_B) #-> b -> T_b2a(E_a2b(b)) # ZCyc TT_latent_B = T_Za2Zb(T_latent_A) #-> T_a2b(T_b2a(E_a2b(a))) TT_latent_A = T_Zb2Za(T_latent_B) #-> T_b2a(T_a2b(E_b2a(b))) ### optimize parameters ## Generator updating set_requires_grad( [D_b, D_a], False) #-> set Discriminator to require no gradient optimizer_G.zero_grad() # GAN loss loss_G_A = criterionGAN(D_b(fake_B), True) loss_G_B = criterionGAN(D_a(fake_A), True) loss_GAN = loss_G_A + loss_G_B # Idt loss loss_idt_A = criterionIdt(idt_A, real_A) loss_idt_B = criterionIdt(idt_B, real_B) loss_Idt = loss_idt_A + loss_idt_B # Latent cross-identity loss loss_Zid_A = criterionZId(T_rec_A, real_A) loss_Zid_B = criterionZId(T_rec_B, real_B) loss_Zid = loss_Zid_A + loss_Zid_B # Latent cross-translation consistency loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A) loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B) loss_CTC = loss_CTC_B + loss_CTC_A # Latent cycle consistency loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A) loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B) loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc # backward and gradient updating loss_G.backward() optimizer_G.step() ## Discriminator updating set_requires_grad([D_b, D_a], True) # -> set Discriminator to require gradient optimizer_D.zero_grad() # backward D_b fake_B_ = fake_B_pool.query(fake_B) #-> real_B, fake_B pred_real_B = D_b(real_B) loss_D_real_B = criterionGAN(pred_real_B, True) pred_fake_B = D_b(fake_B_) loss_D_fake_B = criterionGAN(pred_fake_B, False) loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5 loss_D_B.backward() # backward D_a fake_A_ = fake_A_pool.query(fake_A) #-> real_A, fake_A pred_real_A = D_a(real_A) loss_D_real_A = criterionGAN(pred_real_A, True) pred_fake_A = D_a(fake_A_) loss_D_fake_A = criterionGAN(pred_fake_A, False) loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5 loss_D_A.backward() # update the gradients optimizer_D.step() ### validate here, both qualitively and quantitatively ## record the losses if cur_iter % opt.log_freq == 0: # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'] losses = [ loss_G_A.item(), loss_D_A.item(), loss_idt_A.item(), loss_CTC_A.item(), loss_Zid_A.item(), loss_ZCyc_A.item(), loss_G_B.item(), loss_D_B.item(), loss_idt_B.item(), loss_CTC_B.item(), loss_Zid_B.item(), loss_ZCyc_B.item() ] # record line = '' for loss in losses: line += '{} '.format(loss) record_fh.write(line[:-1] + '\n') # print out print('Epoch: %3d/%3dIter: %9d--------------------------+' % (epoch, opt.epoch, i)) field_names = loss_names[:len(loss_names) // 2] table = PrettyTable(field_names=field_names) for l_n in field_names: table.align[l_n] = 'm' table.add_row(losses[:len(field_names)]) print(table.get_string(reversesort=True)) field_names = loss_names[len(loss_names) // 2:] table = PrettyTable(field_names=field_names) for l_n in field_names: table.align[l_n] = 'm' table.add_row(losses[-len(field_names):]) print(table.get_string(reversesort=True)) ## visualize if cur_iter % opt.vis_freq == 0: if opt.gpu_id >= 0: real_A = real_A.cpu().data real_B = real_B.cpu().data fake_A = fake_A.cpu().data fake_B = fake_B.cpu().data idt_A = idt_A.cpu().data idt_B = idt_B.cpu().data T_rec_A = T_rec_A.cpu().data T_rec_B = T_rec_B.cpu().data plt.subplot(241), plt.title('real_A'), plt.imshow( tensor2image_RGB(real_A[0, ...])) plt.subplot(242), plt.title('fake_B'), plt.imshow( tensor2image_RGB(fake_B[0, ...])) plt.subplot(243), plt.title('idt_A'), plt.imshow( tensor2image_RGB(idt_A[0, ...])) plt.subplot(244), plt.title('L_idt_A'), plt.imshow( tensor2image_RGB(T_rec_A[0, ...])) plt.subplot(245), plt.title('real_B'), plt.imshow( tensor2image_RGB(real_B[0, ...])) plt.subplot(246), plt.title('fake_A'), plt.imshow( tensor2image_RGB(fake_A[0, ...])) plt.subplot(247), plt.title('idt_B'), plt.imshow( tensor2image_RGB(idt_B[0, ...])) plt.subplot(248), plt.title('L_idt_B'), plt.imshow( tensor2image_RGB(T_rec_B[0, ...])) plt.savefig( os.path.join(checkpoints_pth, 'images', '%03d_%09d.jpg' % (epoch, i))) cur_iter += 1 #break #-> debug ## till now, we finish one epoch, try to update the learning rate update_learning_rate(schedulers=scheduler, opt=opt, optimizer=optimizer_D) ## save the model if epoch % opt.ckp_freq == 0: #-> save models # torch.save(model.state_dict(), PATH) #-> load in models # model.load_state_dict(torch.load(PATH)) # model.eval() if opt.gpu_id >= 0: E_a2Zb = E_a2Zb.cpu() G_Zb2b = G_Zb2b.cpu() T_Zb2Za = T_Zb2Za.cpu() D_b = D_b.cpu() E_b2Za = E_b2Za.cpu() G_Za2a = G_Za2a.cpu() T_Za2Zb = T_Za2Zb.cpu() D_a = D_a.cpu() ''' torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch)) torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch)) torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch)) torch.save( D_b.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch)) torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch)) torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch)) torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch)) torch.save( D_a.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch)) ''' torch.save( E_a2Zb.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch)) torch.save( G_Zb2b.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch)) torch.save( T_Zb2Za.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch)) torch.save( D_b.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch)) torch.save( E_b2Za.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch)) torch.save( G_Za2a.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch)) torch.save( T_Za2Zb.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch)) torch.save( D_a.state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch)) if opt.gpu_id >= 0: E_a2Zb = E_a2Zb.to(device) G_Zb2b = G_Zb2b.to(device) T_Zb2Za = T_Zb2Za.to(device) D_b = D_b.to(device) E_b2Za = E_b2Za.to(device) G_Za2a = G_Za2a.to(device) T_Za2Zb = T_Za2Zb.to(device) D_a = D_a.to(device) print("+Successfully saving models in epoch: %3d.-------------+" % epoch) #break #-> debug record_fh.close() print("≧◔◡◔≦ Congratulation! Finishing the training!")
def __init__(self, opt): """Pix2PIxHD model Parameters ---------- opt : ArgumentParsee option of this Model. e.g.) gain, isAffine """ super(Pix2PixHDModel, self).__init__() self.opt = opt if opt.gpu_ids == 0: self.device = torch.device("cuda:0") elif opt.gpu_ids == 1: self.device = torch.device("cuda:1") else: self.device = torch.device("cpu") # define networks respectively input_nc = opt.label_num if not opt.no_use_feature: input_nc += opt.feature_nc if not opt.no_use_edge: input_nc += 1 self.netG = define_G( input_nc=input_nc, output_nc=opt.output_nc, ngf=opt.ngf, g_type=opt.g_type, device=self.device, isAffine=opt.isAffine, use_relu=opt.use_relu, ) input_nc = opt.output_nc if not opt.no_use_edge: input_nc += opt.label_num + 1 else: input_nc += opt.label_num self.netD = define_D( input_nc=input_nc, ndf=opt.ndf, n_layers_D=opt.n_layers_D, device=self.device, isAffine=opt.isAffine, num_D=opt.num_D, ) self.netE = define_E( input_nc=opt.output_nc, feat_num=opt.feature_nc, nef=opt.nef, device=self.device, isAffine=opt.isAffine, ) # define optimizer respectively # initialize optimizer G&E # if opt.niter_fix_global is True, fix parameters in Global Generator if opt.niter_fix_global > 0: finetune_list = set() params = [] for key, value in self.netG.named_parameters(): if key.startswith("model" + str(opt.n_local_enhancers)): params += [value] finetune_list.add(key.split(".")[0]) print( "------------- Only training the local enhancer network (for %d epochs) ------------" % opt.niter_fix_global) print("The layers that are finetuned are ", sorted(finetune_list)) else: params = list(self.netG.parameters()) if not self.opt.no_use_feature: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) self.scheduler_G = LinearDecayLR(self.optimizer_G, niter_decay=opt.niter_decay) # initialize optimizer D # optimizer D self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.scheduler_D = LinearDecayLR(self.optimizer_D, niter_decay=opt.niter_decay) # defin loss functions if opt.gpu_ids == 0 or opt.gpu_ids == 1: self.Tensor = torch.cuda.FloatTensor else: self.Tensor = torch.FloatTensor self.criterionGAN = GANLoss(self.device, use_lsgan=not opt.no_lsgan, tensor=self.Tensor) if not self.opt.no_fmLoss: self.criterionFM = FMLoss(num_D=opt.num_D, n_layers=opt.n_layers_D, lambda_feat=opt.lambda_feat) if not self.opt.no_pLoss: self.criterionP = PerceptualLoss( self.device, lambda_perceptual=opt.lambda_perceptual)
def train(learning_rate=0.0002, beta1=0.5, epochs=1): # parse data from args passed data_dir = args.data batch_size = args.batch_size num_workers = args.num_workers #check if data dir exists assert os.path.isdir(data_dir), "{} is not a valid directory".format( data_dir) ''' # create dataset (transforms are also included in this only) print('Loading dataset...') dataset = DehazeDataset(data_dir) print('Dataset loaded successfully...') print('Dataset contains {} distinct datapoints in X(source) & Y(target) domain\n\n'.format(len(dataset))) # create custom DataLoader dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) ''' # create G, F print('Loading Generators(G & F)...') G = Generator() F = Generator() print('Generators(G & F) loaded successfully...') # create Dx, Dy print('Loading Discriminators(Dx, Dy)...') Dx = Discriminator() Dy = Discriminator() print('Discriminators(Dx, Dy) loaded successfully...') # check generator summary #summary(G,(3,256,256)) # OR print(G) # print Generator # check discriminator summary #summary(Dx,(3,256,256)) # OR print(Dx) # print Discriminator # create 3-loss_functions - Adv_loss, Cycle_consistent_loss, perceptual_loss criterionGAN = GANLoss().to(device) ############## change device criterionCycle = nn.L1Loss() criterionIdt = nn.L1Loss() # create optimizers optimizers = [] optimizer_G = optim.Adam(itertools.chain(G.parameters(), F.parameters()), lr=learning_rate, betas=(beta1, 0.999)) optimizer_D = optim.Adam(itertools.chain(Dx.parameters(), Dy.parameters()), lr=learning_rate, betas=(beta1, 0.999)) optimizers.append(optimizer_G) optimizers.append(optimizer_D) # make dataset ready for training data_loader = CustomDatasetLoader() dataset = data_loader.load_data() print('Number of training images = %d' % len(dataset)) # iterate over dataset for training for epoch in range( epochs ): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq> #epoch_start_time = time.time() # timer for entire epoch #iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch for i, batch in enumerate(dataset): # inner loop within one epoch pass
def run(): # Dataset transform = transforms.Compose([ transforms.Resize(64), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]) dataset = datasets.MNIST('.', transform=transform, download=True) dataloader = data.DataLoader(dataset, batch_size=4) print("[INFO] Define DataLoader") # Define Model g = Generator() d = Discriminator() print("[INFO] Define Model") # optimizer, loss gan_loss = GANLoss() optim_G = optim.Adam(g.parameters(), lr=0.0002, betas=(0.5, 0.999)) optim_D = optim.Adam(d.parameters(), lr=0.0002, betas=(0.5, 0.999)) print('[INFO] Define optimizer and loss') # train num_epoch = 2 print('[INFO] Start Training!!') for epoch in range(num_epoch): total_batch = len(dataloader) for idx, (image, _) in enumerate(dataloader): d.train() g.train() # fake image 생성 noise = torch.randn(4, 100, 1, 1) output_fake = g(noise) # Loss d_loss_fake = gan_loss(d(output_fake.detach()), False) d_loss_real = gan_loss(d(image), True) d_loss = (d_loss_fake + d_loss_real) / 2 g_loss = gan_loss(d(output_fake), True) # update optim_G.zero_grad() g_loss.backward() optim_G.step() optim_D.zero_grad() d_loss.backward() optim_D.step() if ((epoch * total_batch) + idx) % 1000 == 0: print( 'Epoch [%d/%d], Iter [%d/%d], D_loss: %.4f, G_loss: %.4f' % (epoch, num_epoch, idx + 1, total_batch, d_loss.item(), g_loss.item())) save_model('model', 'GAN', g, {'loss': g_loss.item()})
import numpy as np import torch import torch.nn as nn from torch.nn.parallel import DataParallel import os import time from glob import glob from collections import OrderedDict from os import makedirs, environ from os.path import join, exists, split, isfile from scipy.misc import imread, imresize, imsave, imrotate from loss import GANLoss, gram_matrix cri_gan = GANLoss('gan', 1.0, 0.0) from model import VGGMOD, SR, Discriminator, compute_gradient_penalty torch.set_default_dtype(torch.float32) os.environ["CUDA_VISIBLE_DEVICES"] = "0" # some global variables MODEL_FOLDER = 'model' SAMPLE_FOLDER = 'sample' input_dir = 'sr_data/CUFED_128/input' # original images ref_dir = 'sr_data/CUFED_128/ref' # reference images map_dir = 'sr_data/CUFED_128/map_321' # texture maps after texture swapping use_gpu = True use_train_ref = True pre_load_img = True
def __init__(self, device='cpu'): super(RaSGAN, self).__init__(device=device) self.criterion = GANLoss(relativistic=True, average=True)
def __init__(self, device='cpu'): super(RSGAN, self).__init__(device=device, last=None) self.criterion = GANLoss(relativistic=True, average=False)
data_dir = "/content/drive/My Drive/Grasping GAN/processed" model_dir = "/content/drive/My Drive/Grasping GAN/models" batch_size = 8 epochs = 1 lr = 0.01 dataset = GraspingDataset(data_dir) data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net_g = define_G(3, 3, 64, "batch", False, "normal", 0.02, gpu_id=device) net_d = define_D(3 + 3, 64, "basic", gpu_id=device) criterionGAN = GANLoss().to(device) criterionL1 = nn.L1Loss().to(device) criterionMSE = nn.MSELoss().to(device) optimizer_g = optim.Adam(net_g.parameters(), lr=lr) optimizer_d = optim.Adam(net_d.parameters(), lr=lr) l1_weight = 10 for epoch in range(epochs): # train for iteration, batch in enumerate(data_loader, 1): # forward real_a, real_b = batch[0].to(device), batch[1].to(device) fake_b = net_g(real_a)