def __init__(self, train_set, test_set, loss_name="Vanilla", mixed_precision=False, learning_rate=2e-4, tmp_path=None, out_path=None): super(DCGAN, self).__init__() #接收数据集和相关参数 self.train_set = train_set self.test_set = test_set self.tmp_path = tmp_path self.out_path = out_path #定义模型 self.G = networks.Generator(name="G") if loss_name in ["WGAN-SN", "WGAN-GP-SN"]: self.D = networks.Discriminator(name="If_is_real", use_sigmoid=False, sn=True) self.loss_name = loss_name[:-3] elif loss_name in ["WGAN", "WGAN-GP"]: self.D = networks.Discriminator(name="If_is_real", use_sigmoid=False, sn=False) self.loss_name = loss_name elif loss_name in ["Vanilla-SN", "LSGAN-SN"]: self.D = networks.Discriminator(name="If_is_real", use_sigmoid=True, sn=True) self.loss_name = loss_name[:-3] elif loss_name in ["Vanilla", "LSGAN"]: self.D = networks.Discriminator(name="If_is_real", use_sigmoid=True, sn=False) self.loss_name = loss_name else: raise ValueError("Do not support the loss " + loss_name) self.model_list = [self.G, self.D] #定义损失函数 优化器 记录等 self.gan_loss = GanLoss(self.loss_name) self.optimizers_list = self.optimizers_config( mixed_precision=mixed_precision, learning_rate=learning_rate) self.mixed_precision = mixed_precision self.matrics_list = self.matrics_config() self.checkpoint_config() self.get_seed()
def __init__(self, conf): os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7' # Acquire configuration self.conf = conf self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Define the GAN self.G = networks.Generator(conf) self.D = networks.Discriminator(conf) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") print("gpu num : ", torch.cuda.device_count()) # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs self.G = nn.DataParallel(self.G) self.D = nn.DataParallel(self.D) print("haha, gpu num : ", torch.cuda.device_count()) self.G.to(self._device) self.D.to(self._device) # Calculate D's input & output shape according to the shaving done by the networks if torch.cuda.device_count() > 1: self.d_input_shape = self.G.module.output_size self.d_output_shape = self.d_input_shape - self.D.module.forward_shave else: self.d_input_shape = self.G.output_size self.d_output_shape = self.d_input_shape - self.D.forward_shave # Input tensors self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda() self.d_input = torch.FloatTensor(1, 3, self.d_input_shape, self.d_input_shape).cuda() # The kernel G is imitating self.curr_k = torch.FloatTensor(conf.G_kernel_size, conf.G_kernel_size).cuda() # Losses self.GAN_loss_layer = loss.GANLoss(d_last_layer_size=self.d_output_shape).cuda() self.bicubic_loss = loss.DownScaleLoss(scale_factor=conf.scale_factor).cuda() self.sum2one_loss = loss.SumOfWeightsLoss().cuda() self.boundaries_loss = loss.BoundariesLoss(k_size=conf.G_kernel_size).cuda() self.centralized_loss = loss.CentralizedLoss(k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda() self.sparse_loss = loss.SparsityLoss().cuda() self.loss_bicubic = 0 # Define loss function self.criterionGAN = self.GAN_loss_layer.forward # Initialize networks weights self.G.apply(networks.weights_init_G) self.D.apply(networks.weights_init_D) # Optimizers self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999)) print('*' * 60 + '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path)
def __init__(self, conf): # Acquire configuration self.conf = conf # Define the GAN self.G = networks.Generator(conf).cuda() self.D = networks.Discriminator(conf).cuda() # Calculate D's input & output shape according to the shaving done by the networks self.d_input_shape = self.G.output_size self.d_output_shape = self.d_input_shape - self.D.forward_shave # Input tensors self.g_input = torch.FloatTensor(1, 3, conf.input_crop_size, conf.input_crop_size).cuda() self.d_input = torch.FloatTensor(1, 3, self.d_input_shape, self.d_input_shape).cuda() # The kernel G is imitating self.curr_k = torch.FloatTensor(conf.G_kernel_size, conf.G_kernel_size).cuda() # Losses self.GAN_loss_layer = loss.GANLoss( d_last_layer_size=self.d_output_shape).cuda() self.bicubic_loss = loss.DownScaleLoss( scale_factor=conf.scale_factor).cuda() self.sum2one_loss = loss.SumOfWeightsLoss().cuda() self.boundaries_loss = loss.BoundariesLoss( k_size=conf.G_kernel_size).cuda() self.centralized_loss = loss.CentralizedLoss( k_size=conf.G_kernel_size, scale_factor=conf.scale_factor).cuda() self.sparse_loss = loss.SparsityLoss().cuda() self.loss_bicubic = 0 # Define loss function self.criterionGAN = self.GAN_loss_layer.forward # Initialize networks weights self.G.apply(networks.weights_init_G) self.D.apply(networks.weights_init_D) # Optimizers self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=conf.g_lr, betas=(conf.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=conf.d_lr, betas=(conf.beta1, 0.999)) self.iteration = 0 # for tensorboard # self.ground_truth_kernel = np.loadtxt(conf.ground_truth_kernel_path) # writer.add_image("ground_truth_kernel", (self.ground_truth_kernel - np.min(self.ground_truth_kernel)) / (np.max(self.ground_truth_kernel - np.min(self.ground_truth_kernel))), 0, dataformats="HW") print('*' * 60 + '\nSTARTED KernelGAN on: \"%s\"...' % conf.input_image_path)
def tensorboard_plot(domain_src, domain_tgt, domain_name): print('----------------{}---------------'.format(domain_name)) root_path = r'E:\cht_project\domain_adaptation_images\imageCLEF_resnet50' dataloader_src = get_data.get_src_dataloader(root_path, domain_src) dataloader_tgt = get_data.get_src_dataloader(root_path, domain_tgt) train_epochs = 81 classifier = networks.Classifier(in_dim=feature_dim, out_dim=num_classes).cuda() discriminator = networks.Discriminator(in_dim=num_classes*encoder_out_dim).cuda() with SummaryWriter('./runs/{}_1028'.format(domain_name)) as writer: CDAN.train(dataloader_src, dataloader_tgt, discriminator=discriminator, classifier=classifier, train_epochs=train_epochs, writer=writer)
def __init__(self, opts): super(SAVI2I, self).__init__() self.opts = opts if opts.gpu >= 0: self.device = torch.device('cuda:%d' % opts.gpu) else: self.device = torch.device('cpu') torch.cuda.set_device(opts.gpu) cudnn.benchmark = True self.phase = opts.phase self.type = opts.type self.nz = opts.input_nz self.style_dim = opts.style_dim self.num_domains = opts.num_domains self.enc_a = nn.DataParallel( networks.E_attr(img_size=opts.img_size, input_dim=opts.input_dim, nz=self.nz, n_domains=self.num_domains)) self.f = nn.DataParallel( networks.MappingNetwork(nz=self.nz, n_domains=self.num_domains, n_style=self.style_dim, hidden_dim=512, hidden_layer=1)) self.vgg = networks.VGG(self.device) if self.type==1: self.enc_c = nn.DataParallel(networks.E_content_style(img_size=opts.img_size, input_dim=opts.input_dim)) self.gen = nn.DataParallel(networks.Generator_style(img_size=opts.img_size, style_dim=self.style_dim)) elif self.type==0: self.enc_c = nn.DataParallel(networks.E_content_shape(img_size=opts.img_size, input_dim=opts.input_dim)) self.gen = nn.DataParallel(networks.Generator_shape(img_size=opts.img_size, style_dim=self.style_dim)) if self.phase == 'train': self.lr = opts.lr self.f_lr = opts.f_lr self.lr_dcontent = self.lr/2.5 self.dis = nn.DataParallel(networks.Discriminator(img_size=opts.img_size, num_domains=self.num_domains)) if self.type==1: self.disContent = nn.DataParallel(networks.Dis_content_style(c_dim=self.num_domains)) elif self.type==0: self.disContent = nn.DataParallel(networks.Dis_content_shape(c_dim=self.num_domains)) self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=self.lr, betas=(0, 0.99), weight_decay=0.0001) self.f_opt = torch.optim.Adam(self.f.parameters(), lr=self.f_lr, betas=(0, 0.99), weight_decay=0.0001) self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=self.lr_dcontent, betas=(0, 0.99), weight_decay=0.0001) self.criterion_GAN = nn.BCEWithLogitsLoss() self.criterion_mmd = utils.get_mmd_loss()
def __init__(self, opt): super(UDAAN, self).__init__() self.share_encoder = opt.share_encoder self.use_center_loss = opt.use_center_loss self.use_triplet_loss = opt.use_triplet_loss self.threshold_T = opt.threshold_T # init encoder if opt.share_encoder: if opt.image_size == 32: self.encoder = networks.LeNetEncoder() networks.init_weights(self.encoder) else: self.encoder = networks.ResNet50_encoder(opt.unfreeze_layers) else: if opt.image_size == 32: self.encoder_s = networks.LeNetEncoder() self.encoder_t = networks.LeNetEncoder() networks.init_weights(self.encoder_s) networks.init_weights(self.encoder_t) else: self.encoder_s = networks.ResNet50_encoder(opt.unfreeze_layers) self.encoder_t = networks.ResNet50_encoder(opt.unfreeze_layers) # init discriminator and classifier if opt.image_size == 32: self.discriminator = networks.DigitDiscriminator() self.classifier = networks.LeNetClassifier() else: self.discriminator = networks.Discriminator() self.classifier = networks.Classifier(opt.num_classes) networks.init_weights(self.discriminator) networks.init_weights(self.classifier) if opt.use_center_loss: if opt.image_size == 32: self.center_loss = networks.CenterLoss( opt.num_classes, self.classifier.fc1.in_features) else: self.center_loss = networks.CenterLoss( opt.num_classes, self.classifier.fc.in_features) elif opt.use_triplet_loss: self.triplet_loss = networks.TripletLoss()
def init_model(model_dir): device = torch.device("cuda:0") generator_net = networks.Generator().to(device) discriminator_net = networks.Discriminator().to(device) generator_net.apply(networks.weights_init) discriminator_net.apply(networks.weights_init) utils.print_and_log(model_dir, generator_net) utils.print_and_log(model_dir, discriminator_net) learning_rate = 0.0002 beta1 = 0.5 discriminator_optimizer = optim.Adam(discriminator_net.parameters(), lr=learning_rate, betas=(beta1, 0.999)) generator_optimizer = optim.Adam(generator_net.parameters(), lr=learning_rate, betas=(beta1, 0.999)) return discriminator_net, generator_net, discriminator_optimizer, generator_optimizer
def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.backends.cudnn.enabled: torch.backends.cudnn.benchmark = True prepare_result() make_edge_promoting_img() # data_loader landscape_dataloader = CreateTrainDataLoader(args, "landscape") anime_dataloader = CreateTrainDataLoader(args, "anime") landscape_test_dataloader = CreateTestDataLoader(args, "landscape") anime_test_dataloader = CreateTestDataLoader(args, "anime") generator = networks.Generator(args.ngf) if args.latest_generator_model != '': if torch.cuda.is_available(): generator.load_state_dict(torch.load(args.latest_generator_model)) else: # cpu mode generator.load_state_dict( torch.load(args.latest_generator_model, map_location=lambda storage, loc: storage)) discriminator = networks.Discriminator(args.in_ndc, args.out_ndc, args.ndf) if args.latest_discriminator_model != '': if torch.cuda.is_available(): discriminator.load_state_dict( torch.load(args.latest_discriminator_model)) else: discriminator.load_state_dict( torch.load(args.latest_discriminator_model, map_location=lambda storage, loc: storage)) VGG = networks.VGG19(init_weights=args.vgg_model, feature_mode=True) generator.to(device) discriminator.to(device) VGG.to(device) generator.train() discriminator.train() VGG.eval() G_optimizer = optim.Adam(generator.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) D_optimizer = optim.Adam(discriminator.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) # G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1) # D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1) print('---------- Networks initialized -------------') utils.print_network(generator) utils.print_network(discriminator) utils.print_network(VGG) print('-----------------------------------------------') BCE_loss = nn.BCELoss().to(device) Hinge_loss = nn.HingeEmbeddingLoss().to(device) L1_loss = nn.L1Loss().to(device) MSELoss = nn.MSELoss().to(device) Adv_loss = BCE_loss pre_train_hist = {} pre_train_hist['Recon_loss'] = [] pre_train_hist['per_epoch_time'] = [] pre_train_hist['total_time'] = [] """ Pre-train reconstruction """ if args.latest_generator_model == '': print('Pre-training start!') start_time = time.time() for epoch in range(args.pre_train_epoch): epoch_start_time = time.time() Recon_losses = [] for lcimg, lhimg, lsimg in landscape_dataloader: lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) # train generator G G_optimizer.zero_grad() x_feature = VGG((lcimg + 1) / 2) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) gen_img = generator(lsimg, hint) G_feature = VGG((gen_img + 1) / 2) Recon_loss = 10 * L1_loss(G_feature, x_feature.detach()) Recon_losses.append(Recon_loss.item()) pre_train_hist['Recon_loss'].append(Recon_loss.item()) Recon_loss.backward() G_optimizer.step() per_epoch_time = time.time() - epoch_start_time pre_train_hist['per_epoch_time'].append(per_epoch_time) print('[%d/%d] - time: %.2f, Recon loss: %.3f' % ((epoch + 1), args.pre_train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Recon_losses)))) # Save if (epoch + 1) % 5 == 0: with torch.no_grad(): generator.eval() for n, (lcimg, lhimg, lsimg) in enumerate(landscape_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Reconstruction', args.name + '_train_recon_' + f'epoch_{epoch}_' + str(n + 1) + '.png') plt.imsave( path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break for n, (lcimg, lhimg, lsimg) in enumerate(landscape_test_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Reconstruction', args.name + '_test_recon_' + f'epoch_{epoch}_' + str(n + 1) + '.png') plt.imsave( path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break total_time = time.time() - start_time pre_train_hist['total_time'].append(total_time) with open(os.path.join(args.name + '_results', 'pre_train_hist.pkl'), 'wb') as f: pickle.dump(pre_train_hist, f) torch.save( generator.state_dict(), os.path.join(args.name + '_results', 'generator_pretrain.pkl')) else: print('Load the latest generator model, no need to pre-train') train_hist = {} train_hist['Disc_loss'] = [] train_hist['Gen_loss'] = [] train_hist['Con_loss'] = [] train_hist['per_epoch_time'] = [] train_hist['total_time'] = [] print('training start!') start_time = time.time() real = torch.ones(args.batch_size, 1, args.input_size // 4, args.input_size // 4).to(device) fake = torch.zeros(args.batch_size, 1, args.input_size // 4, args.input_size // 4).to(device) for epoch in range(args.train_epoch): epoch_start_time = time.time() generator.train() Disc_losses = [] Gen_losses = [] Con_losses = [] for i, ((acimg, ac_smooth_img, _), (lcimg, lhimg, lsimg)) in enumerate( zip(anime_dataloader, landscape_dataloader)): acimg, ac_smooth_img, lcimg, lhimg, lsimg = acimg.to( device), ac_smooth_img.to(device), lcimg.to(device), lhimg.to( device), lsimg.to(device) if i % args.n_dis == 0: # train G G_optimizer.zero_grad() mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) gen_img = generator(lsimg, hint) D_fake = discriminator(gen_img) D_fake_loss = Adv_loss(D_fake, real) x_feature = VGG((lcimg + 1) / 2) G_feature = VGG((gen_img + 1) / 2) Con_loss = args.con_lambda * L1_loss(G_feature, x_feature.detach()) Gen_loss = D_fake_loss + Con_loss Gen_losses.append(D_fake_loss.item()) train_hist['Gen_loss'].append(D_fake_loss.item()) Con_losses.append(Con_loss.item()) train_hist['Con_loss'].append(Con_loss.item()) Gen_loss.backward() G_optimizer.step() # G_scheduler.step() # train D D_optimizer.zero_grad() D_real = discriminator(acimg) D_real_loss = Adv_loss(D_real, real) # Hinge Loss (?) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) gen_img = generator(lsimg, hint) D_fake = discriminator(gen_img) D_fake_loss = Adv_loss(D_fake, fake) D_edge = discriminator(ac_smooth_img) D_edge_loss = Adv_loss(D_edge, fake) Disc_loss = D_real_loss + D_fake_loss + D_edge_loss # Disc_loss = D_real_loss + D_fake_loss Disc_losses.append(Disc_loss.item()) train_hist['Disc_loss'].append(Disc_loss.item()) Disc_loss.backward() D_optimizer.step() # G_scheduler.step() # D_scheduler.step() per_epoch_time = time.time() - epoch_start_time train_hist['per_epoch_time'].append(per_epoch_time) print( '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((epoch + 1), args.train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Disc_losses)), torch.mean(torch.FloatTensor(Gen_losses)), torch.mean(torch.FloatTensor(Con_losses)))) if epoch % 2 == 1 or epoch == args.train_epoch - 1: with torch.no_grad(): generator.eval() for n, (lcimg, lhimg, lsimg) in enumerate(landscape_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Transfer', str(epoch + 1) + '_epoch_' + args.name + '_train_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break for n, (lcimg, lhimg, lsimg) in enumerate(landscape_test_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Transfer', str(epoch + 1) + '_epoch_' + args.name + '_test_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break torch.save( generator.state_dict(), os.path.join(args.name + '_results', 'generator_latest.pkl')) torch.save( generator.state_dict(), os.path.join(args.name + '_results', 'discriminator_latest.pkl')) total_time = time.time() - start_time train_hist['total_time'].append(total_time) print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor( train_hist['per_epoch_time'])), args.train_epoch, total_time)) print("Training finish!... save training results") torch.save(generator.state_dict(), os.path.join(args.name + '_results', 'generator_param.pkl')) torch.save(discriminator.state_dict(), os.path.join(args.name + '_results', 'discriminator_param.pkl')) with open(os.path.join(args.name + '_results', 'train_hist.pkl'), 'wb') as f: pickle.dump(train_hist, f)
import torch.nn as nn import torch.utils.data as data import matplotlib.pyplot as plt import networks as net import dataloader as dl ##### Generatorの動作チェック ##### G = net.Generator() input_image = torch.randn(128, 128, 3) input_image = input_image.view(1, 3, 128, 128) fake_image = G(input_image) print('fake image:', fake_image.shape) ##### Discriminatorの動作チェック ##### D = net.Discriminator() d_out = D(fake_image) print('Discriminator output', nn.Sigmoid()(d_out).shape) # 出力にSigmoidをかけて[0, 1]に変換 ##### Dataset, DataLoaderの動作チェック ##### train_img_A, train_img_B = dl.make_datapath_list(is_train=True) # Datasetを作成 mean = (0.5, ) std = (0.5, ) train_dataset = dl.UnpairedDataset(train_img_A, train_img_B, transform=dl.ImageTransform(mean, std)) # DataLoaderを作成
# "paragraph-vectors/data/sentences_train_model.dbow_numnoisewords\ # .2_vecdim.100_batchsize.32_lr.0.001000_epoch.100_loss.0.781092.csv" # ) # change this to val # val_dataset = dataservices.RecipeQADataset( # "recipeqa/new_val_cleaned.json", # "recipeqa/features/val", # val_embeddings # ) # val_dataloader = torch.utils.data.DataLoader( # val_dataset, batch_size=opt.batchSize, # shuffle=True, num_workers=int(opt.workers), # collate_fn=dataservices.batch_collator(device=device) # ) netG = networks.Generator(opt).to(device) netD = networks.Discriminator(opt).to(device) def weights_init(m): pass criterionD = nn.BCEWithLogitsLoss() # logsigmoid + binary cross entropy criterionG = nn.CrossEntropyLoss() # logsoftmax + multi-class cross entropy optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) def fscore(probabilities, labels): true_positives = float(
def main(): # Get training options opt = get_opt() device = torch.device("cuda") if opt.cuda else torch.device("cpu") # Define the networks # netG_A: used to transfer image from domain A to domain B netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout) if opt.u_net: netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf) # netD_B: used to test whether an image is from domain A netD_B = networks.Discriminator(opt.input_nc + opt.output_nc, opt.ndf) # Initialize the networks if opt.cuda: netG_A.cuda() netD_B.cuda() utils.init_weight(netG_A) utils.init_weight(netD_B) if opt.pretrained: netG_A.load_state_dict(torch.load('pretrained/netG_A.pth')) netD_B.load_state_dict(torch.load('pretrained/netD_B.pth')) # Define the loss functions criterion_GAN = utils.GANLoss() if opt.cuda: criterion_GAN.cuda() criterion_l1 = torch.nn.L1Loss() # Define the optimizers optimizer_G = torch.optim.Adam(netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # Create learning rate schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda = utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda = utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) # Define the transform, and load the data transform = transforms.Compose([transforms.Resize((opt.sizeh, opt.sizew)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) dataloader = DataLoader(PairedImage(opt.rootdir, transform = transform, mode = 'train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) # numpy arrays to store the loss of epoch loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) # Training for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay): start = time.strftime("%H:%M:%S") print("current epoch :", epoch, " start time :", start) # Empty list to store the loss of each mini-batch loss_G_list = [] loss_D_B_list = [] for i, batch in enumerate(dataloader): if i % 20 == 1: print("current step: ", i) current = time.strftime("%H:%M:%S") print("current time :", current) print("last loss G_A:", loss_G_list[-1], "last loss D_B:", loss_D_B_list[-1]) real_A = batch['A'].to(device) real_B = batch['B'].to(device) # Train the generator utils.set_requires_grad([netG_A], True) optimizer_G.zero_grad() # Compute fake images and reconstructed images fake_B = netG_A(real_A) # discriminators require no gradients when optimizing generators utils.set_requires_grad([netD_B], False) # GAN loss prediction_fake_B = netD_B(torch.cat((fake_B, real_A), dim=1)) loss_gan = criterion_GAN(prediction_fake_B, True) #L1 loss loss_l1 = criterion_l1(real_B, fake_B) * opt.l1_loss # total loss without the identity loss loss_G = loss_gan + loss_l1 loss_G_list.append(loss_G.item()) loss_G.backward() optimizer_G.step() # Train the discriminator utils.set_requires_grad([netG_A], False) utils.set_requires_grad([netD_B], True) # Train the discriminator D_B optimizer_D_B.zero_grad() # real images pred_real = netD_B(torch.cat((real_B, real_A), dim=1)) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_B = netG_A(real_A) pred_fake = netD_B(torch.cat((fake_B, real_A), dim=1)) loss_D_fake = criterion_GAN(pred_fake, False) # total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B_list.append(loss_D_B.item()) loss_D_B.backward() optimizer_D_B.step() # Update the learning rate lr_scheduler_G.step() lr_scheduler_D_B.step() # Save models checkpoints torch.save(netG_A.state_dict(), 'model/netG_A_pix.pth') torch.save(netD_B.state_dict(), 'model/netD_B_pix.pth') # Save other checkpoint information checkpoint = {'epoch': epoch, 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_B': optimizer_D_B.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D_B': lr_scheduler_D_B.state_dict()} torch.save(checkpoint, 'model/checkpoint.pth') # Update the numpy arrays that record the loss loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list) loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list) np.savetxt('model/loss_G.txt', loss_G_array) np.savetxt('model/loss_D_B.txt', loss_D_B_array) end = time.strftime("%H:%M:%S") print("current epoch :", epoch, " end time :", end) print("G loss :", loss_G_array[epoch], "D_B loss :", loss_D_B_array[epoch])
def __init__(self, opt=None): ''' :param opt: :param nic: ''' super(cycleGAN, self).__init__() if opt == None: # parser = argparse.ArgumentParser() # parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") # parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") # parser.add_argument("--b2", type=float, default=0.999, # help="adam: decay of first order momentum of gradient") # parser.add_argument("--channels", type=int, default=3, help="number of image channels") # parser.add_argument("--out_channels", type=int, default=3, help="number of generator output channels") # parser.add_argument("--n_residual_blocks", type=int, default=9, # help="number of residual blocks in generator") # parser.add_argument("--save_dir", type=str, default='/saved_models/', help="save directory") # option = parser.parse_args() opt = self.default_option() self.opt = opt nic = opt.channels noc = opt.out_channels # model & loss names self.model_names = ['GenA', 'GenB', 'DisA', 'DisB'] self.loss_names = [ 'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B' ] # define Generator """ ResnetGenerator(nic, out_channels, num_residual_blocks, ngf) """ self.GenA = networks.ResnetGenerator(nic, noc, opt.n_residual_blocks) self.GenB = networks.ResnetGenerator(nic, noc, opt.n_residual_blocks) # define Discriminator """ Discriminator(nic) """ self.DisA = networks.Discriminator(nic, opt.image_size) self.DisB = networks.Discriminator(nic, opt.image_size) # criterion define loss function """ GAN Loss Cycle-Consistency Loss Identity Loss """ self.criterion_GAN = nn.MSELoss() self.criterion_Cycle = nn.L1Loss() self.criterion_idt = nn.L1Loss() self.optimizers = [] # define optimizer self.optimizer_G = torch.optim.Adam(itertools.chain( self.GenA.parameters(), self.GenB.parameters()), lr=opt.lr, betas=(opt.b1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain( self.DisA.parameters(), self.DisB.parameters()), lr=opt.lr, betas=(opt.b1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # self.save_dir = opt.save_dir self.device = opt.device step_size = 100 self.schedulers = [ networks.get_scheduler(optimizer, step_size) for optimizer in self.optimizers ]
def main(): # Get training options opt = get_opt() # Define the networks # netG_A: used to transfer image from domain A to domain B # netG_B: used to transfer image from domain B to domain A netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf, opt.n_res, opt.dropout) netG_B = networks.Generator(opt.output_nc, opt.input_nc, opt.ngf, opt.n_res, opt.dropout) if opt.u_net: netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf) netG_B = networks.U_net(opt.output_nc, opt.input_nc, opt.ngf) # netD_A: used to test whether an image is from domain B # netD_B: used to test whether an image is from domain A netD_A = networks.Discriminator(opt.input_nc, opt.ndf) netD_B = networks.Discriminator(opt.output_nc, opt.ndf) # Initialize the networks if opt.cuda: netG_A.cuda() netG_B.cuda() netD_A.cuda() netD_B.cuda() utils.init_weight(netG_A) utils.init_weight(netG_B) utils.init_weight(netD_A) utils.init_weight(netD_B) if opt.pretrained: netG_A.load_state_dict(torch.load('pretrained/netG_A.pth')) netG_B.load_state_dict(torch.load('pretrained/netG_B.pth')) netD_A.load_state_dict(torch.load('pretrained/netD_A.pth')) netD_B.load_state_dict(torch.load('pretrained/netD_B.pth')) # Define the loss functions criterion_GAN = utils.GANLoss() if opt.cuda: criterion_GAN.cuda() criterion_cycle = torch.nn.L1Loss() # Alternatively, can try MSE cycle consistency loss #criterion_cycle = torch.nn.MSELoss() criterion_identity = torch.nn.L1Loss() # Define the optimizers optimizer_G = torch.optim.Adam(itertools.chain(netG_A.parameters(), netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) # Create learning rate schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs, opt.n_epochs_decay).step) Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batch_size, opt.input_nc, opt.sizeh, opt.sizew) input_B = Tensor(opt.batch_size, opt.output_nc, opt.sizeh, opt.sizew) # Define two image pools to store generated images fake_A_pool = utils.ImagePool() fake_B_pool = utils.ImagePool() # Define the transform, and load the data transform = transforms.Compose([ transforms.Resize((opt.sizeh, opt.sizew)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ]) dataloader = DataLoader(ImageDataset(opt.rootdir, transform=transform, mode='train'), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) # numpy arrays to store the loss of epoch loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_A_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay) # Training for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay): start = time.strftime("%H:%M:%S") print("current epoch :", epoch, " start time :", start) # Empty list to store the loss of each mini-batch loss_G_list = [] loss_D_A_list = [] loss_D_B_list = [] for i, batch in enumerate(dataloader): if i % 50 == 1: print("current step: ", i) current = time.strftime("%H:%M:%S") print("current time :", current) print("last loss G:", loss_G_list[-1], "last loss D_A", loss_D_A_list[-1], "last loss D_B", loss_D_B_list[-1]) real_A = input_A.copy_(batch['A']) real_B = input_B.copy_(batch['B']) # Train the generator optimizer_G.zero_grad() # Compute fake images and reconstructed images fake_B = netG_A(real_A) fake_A = netG_B(real_B) if opt.identity_loss != 0: same_B = netG_A(real_B) same_A = netG_B(real_A) # discriminators require no gradients when optimizing generators utils.set_requires_grad([netD_A, netD_B], False) # Identity loss if opt.identity_loss != 0: loss_identity_A = criterion_identity( same_A, real_A) * opt.identity_loss loss_identity_B = criterion_identity( same_B, real_B) * opt.identity_loss # GAN loss prediction_fake_B = netD_B(fake_B) loss_gan_B = criterion_GAN(prediction_fake_B, True) prediction_fake_A = netD_A(fake_A) loss_gan_A = criterion_GAN(prediction_fake_A, True) # Cycle consistent loss recA = netG_B(fake_B) recB = netG_A(fake_A) loss_cycle_A = criterion_cycle(recA, real_A) * opt.cycle_loss loss_cycle_B = criterion_cycle(recB, real_B) * opt.cycle_loss # total loss without the identity loss loss_G = loss_gan_B + loss_gan_A + loss_cycle_A + loss_cycle_B if opt.identity_loss != 0: loss_G += loss_identity_A + loss_identity_B loss_G_list.append(loss_G.item()) loss_G.backward() optimizer_G.step() # Train the discriminator utils.set_requires_grad([netD_A, netD_B], True) # Train the discriminator D_A optimizer_D_A.zero_grad() # real images pred_real = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_A = fake_A_pool.query(fake_A) pred_fake = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, False) #total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A_list.append(loss_D_A.item()) loss_D_A.backward() optimizer_D_A.step() # Train the discriminator D_B optimizer_D_B.zero_grad() # real images pred_real = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, True) # fake images fake_B = fake_B_pool.query(fake_B) pred_fake = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, False) # total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B_list.append(loss_D_B.item()) loss_D_B.backward() optimizer_D_B.step() # Update the learning rate lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() # Save models checkpoints torch.save(netG_A.state_dict(), 'model/netG_A.pth') torch.save(netG_B.state_dict(), 'model/netG_B.pth') torch.save(netD_A.state_dict(), 'model/netD_A.pth') torch.save(netD_B.state_dict(), 'model/netD_B.pth') # Save other checkpoint information checkpoint = { 'epoch': epoch, 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_A': optimizer_D_A.state_dict(), 'optimizer_D_B': optimizer_D_B.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D_A': lr_scheduler_D_A.state_dict(), 'lr_scheduler_D_B': lr_scheduler_D_B.state_dict() } torch.save(checkpoint, 'model/checkpoint.pth') # Update the numpy arrays that record the loss loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list) loss_D_A_array[epoch] = sum(loss_D_A_list) / len(loss_D_A_list) loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list) np.savetxt('model/loss_G.txt', loss_G_array) np.savetxt('model/loss_D_A.txt', loss_D_A_array) np.savetxt('model/loss_D_b.txt', loss_D_B_array) if epoch % 10 == 9: torch.save(netG_A.state_dict(), 'model/netG_A' + str(epoch) + '.pth') torch.save(netG_B.state_dict(), 'model/netG_B' + str(epoch) + '.pth') torch.save(netD_A.state_dict(), 'model/netD_A' + str(epoch) + '.pth') torch.save(netD_B.state_dict(), 'model/netD_B' + str(epoch) + '.pth') end = time.strftime("%H:%M:%S") print("current epoch :", epoch, " end time :", end) print("G loss :", loss_G_array[epoch], "D_A loss :", loss_D_A_array[epoch], "D_B loss :", loss_D_B_array[epoch])
lr = 0.0001 betas = (0.5, 0.999) batch_size = 1 input_size = 128 # patch_size = 8(128/2^4) num_epochs = 200 lambda_cycle = 10.0 lambda_identity = 5.0 # GPUが使用できるか確認 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('使用デバイス:', device) # 各生成器と判別器をインスタンス化 G_A2B = net.Generator() G_B2A = net.Generator() D_A = net.Discriminator() D_B = net.Discriminator() # GPUが使用できるならモデルをGPUに載せて初期化 if device == 'cuda:0': G_A2B.cuda() G_B2A.cuda() D_A.cuda() D_B.cuda() print('ネットワークの初期化中...', end='') G_A2B.apply(utils.weights_init) G_B2A.apply(utils.weights_init) D_A.apply(utils.weights_init) D_B.apply(utils.weights_init) print('完了!')
if not os.path.exists(dn): os.makedirs(fp) manual_seed = 999 torch.manual_seed(manual_seed) batch_size = 1000 fixed_size = 10000 train_epoch = 20000 lr_d = 0.00002 lr_g = 0.0002 beta1 = 0.5 generator = networks.Generator(d_noise, d_data) discriminator = networks.Discriminator(d_data) s = sampler.SAMPLER() # plt.ion() fig = plt.figure(figsize=(6, 6)) fig.canvas.set_window_title("2D Generation") fig.suptitle(f"dimension data:{d_data}, dimension noise:{d_noise}") if d_data == 2: # 感觉这个if 可以写在sampler类里面,没必要在这里拿出来 sam = s.sampler(bs=batch_size) sam_show = s.sampler(bs=fixed_size) ax1 = fig.add_subplot(221) ax2 = fig.add_subplot(222) ax3 = fig.add_subplot(223) ax4 = fig.add_subplot(224)
def main(args): generator = networks.Generator() print(generator) discriminator = networks.Discriminator() print(discriminator)
num_train, num_test = len(train_data) , len(test_data) train_loader = DataLoader(train_data,batch_size = opt.batch_size, shuffle = True, num_workers = 4) test_loader = DataLoader(test_data,batch_size = opt.batch_size, shuffle = False, num_workers = 4) # for i in train_loader: # print(i[0].size()) train_labels = LoadLabel(TRAIN_DIR) train_labels_onehot = EncodingOnehot(train_labels, nclasses) test_labels = LoadLabel(TEST_DIR) test_labels_onehot = EncodingOnehot(test_labels, nclasses) Y = train_labels_onehot G = networks.Generator(opt.g_input_size,opt.g_hidden_size,opt.g_output_size) D = networks.Discriminator(opt.d_input_size,opt.d_hidden_size,opt.d_output_size) H = networks.Hashnet(opt.h_input_size,opt.h_hidden_size,opt.bit) # print(G) # print(D) # print(H) c3d = networks.C3D() c3d_dict = c3d.state_dict() pretrained_dict = torch.load('./c3d.pickle') pretrained_dict = {k : v for k, v in pretrained_dict.items() if k in c3d_dict} c3d_dict.update(pretrained_dict) c3d.load_state_dict(c3d_dict) c3d.fc8 = nn.Linear(4096,51) G.cuda() D.cuda() H.cuda()
def __init__(self, config): super(PGGAN, self).__init__() self.tf_record_dir = config.tf_record_dir self.latent_size = config.latent_size self.label_size = config.label_size self.labels_exist = self.label_size > 0 self.dimensionality = config.dimensionality self.num_channels = config.num_channels self.learning_rate = config.lr self.gpus = config.gpus self.d_repeats = config.d_repeats self.iters_per_transition = config.kiters_per_transition self.iters_per_resolution = config.kiters_per_resolution self.start_resolution = config.start_resolution self.target_resolution = config.target_resolution self.resolution_batch_size = config.resolution_batch_size self.img_ext = config.img_ext self.generator = networks.Generator(self.latent_size, dimensionality=self.dimensionality, num_channels=self.num_channels, fmap_base=config.g_fmap_base) self.discriminator = networks.Discriminator( self.label_size, dimensionality=self.dimensionality, num_channels=self.num_channels, fmap_base=config.d_fmap_base) self.run_id = Path(config.run_id) self.generated_dir = self.run_id.joinpath(config.generated_dir) self.model_dir = self.run_id.joinpath(config.model_dir) self.log_dir = self.run_id.joinpath(config.log_dir) self.run_id.mkdir(exist_ok=True) self.generated_dir.mkdir(exist_ok=True) self.model_dir.mkdir(exist_ok=True) self.train_summary_writer = tf.summary.create_file_writer( str(self.log_dir)) current_resolution = 2 self.add_resolution() while 2**current_resolution < self.start_resolution: self.add_resolution() current_resolution += 1 self.strategy = config.strategy if self.strategy is not None: with self.strategy.scope(): self.generator = networks.Generator( self.latent_size, dimensionality=self.dimensionality, num_channels=self.num_channels, fmap_base=config.g_fmap_base) self.discriminator = networks.Discriminator( self.label_size, dimensionality=self.dimensionality, num_channels=self.num_channels, fmap_base=config.d_fmap_base) current_resolution = 2 self.add_resolution() while 2**current_resolution < self.start_resolution: self.add_resolution() current_resolution += 1
def __init__(self, options): self.opt = options self.dropout = options.dropout self.refine_stage = list(range(options.refine_stage)) self.refine = options.refine self.crop_mode = options.crop_mode self.gan = options.gan self.gan2 = options.gan2 self.edge_refine = options.edge_refine self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) # checking height and width are multiples of 32 assert self.opt.height % 32 == 0, "'height' must be a multiple of 32" assert self.opt.width % 32 == 0, "'width' must be a multiple of 32" self.models = {} self.parameters_to_train = [] self.parameters_to_train_refine = [] self.device = torch.device("cpu" if self.opt.no_cuda else "cuda") self.num_scales = len(self.opt.scales) self.num_input_frames = len(self.opt.frame_ids) self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0" self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0]) if self.refine: if len(self.refine_stage) > 4: if self.crop_mode == 'b' or self.crop_mode == 'cl': self.crop_h = [128, 168, 192, 192, 192] self.crop_w = [192, 256, 384, 448, 640] # self.crop_h = [192,192,192,192,192] # self.crop_w = [192,256,384,448,640] else: self.crop_h = [96, 128, 160, 192, 192] self.crop_w = [192, 256, 384, 448, 640] else: self.crop_h = [96, 128, 160, 192] self.crop_w = [192, 256, 384, 640] if self.opt.refine_model == 's': self.models["mid_refine"] = networks.Simple_Propagate( self.crop_h, self.crop_w, self.crop_mode, self.dropout) elif self.opt.refine_model == 'i': self.models["mid_refine"] = networks.Iterative_Propagate( self.crop_h, self.crop_w, self.crop_mode, False) elif self.opt.refine_model == 'is': self.models["mid_refine"] = networks.Iterative_Propagate_seq( self.crop_h, self.crop_w, self.crop_mode, False) elif self.opt.refine_model == 'io': self.models["mid_refine"] = networks.Iterative_Propagate_old( self.crop_h, self.crop_w, self.crop_mode, False) self.models["mid_refine"].to(self.device) self.parameters_to_train_refine += list( self.models["mid_refine"].parameters()) if self.gan: self.models["netD"] = networks.Discriminator() self.models["netD"].to(self.device) self.parameters_D = list(self.models["netD"].parameters()) if self.gan2: self.models["netD"] = networks.Discriminator_group() self.models["netD"].to(self.device) self.parameters_D = list(self.models["netD"].parameters()) self.models["encoder"] = networks.ResnetEncoder( self.opt.num_layers, self.opt.weights_init == "pretrained", num_input_images=1) self.models["encoder"].to(self.device) self.parameters_to_train += list(self.models["encoder"].parameters()) if self.refine: for param in self.models["encoder"].parameters(): param.requeires_grad = False self.models["depth"] = networks.DepthDecoder( self.models["encoder"].num_ch_enc, self.opt.scales, refine=self.refine) self.models["depth"].to(self.device) self.parameters_to_train += list(self.models["depth"].parameters()) self.models["pose_encoder"] = networks.ResnetEncoder( self.opt.num_layers, self.opt.weights_init == "pretrained", num_input_images=self.num_pose_frames) self.models["pose_encoder"].to(self.device) self.parameters_to_train += list( self.models["pose_encoder"].parameters()) if self.refine: for param in self.models["pose_encoder"].parameters(): param.requeires_grad = False self.models["pose"] = networks.PoseDecoder( self.models["pose_encoder"].num_ch_enc, num_input_features=1, num_frames_to_predict_for=2) self.models["pose"].to(self.device) self.parameters_to_train += list(self.models["pose"].parameters()) if self.refine: for param in self.models["pose"].parameters(): param.requeires_grad = False if self.opt.load_weights_folder is not None: self.load_model() if self.refine: self.models["depth_ref"] = copy.deepcopy(self.models["depth"]) self.models["depth_ref"].to(self.device) #self.parameters_to_train_refine += list(self.models["depth_ref"].parameters()) for param in self.models["depth"].parameters(): param.requeires_grad = False for param in self.models["depth_ref"].parameters(): param.requeires_grad = False if self.refine: parameters_to_train = self.parameters_to_train_refine else: parameters_to_train = self.parameters_to_train self.model_optimizer = optim.Adam(parameters_to_train, self.opt.learning_rate) self.model_lr_scheduler = optim.lr_scheduler.StepLR( self.model_optimizer, self.opt.scheduler_step_size, 0.1) if self.gan or self.gan2: self.D_optimizer = optim.Adam(self.parameters_D, 1e-4) self.model_lr_scheduler_D = optim.lr_scheduler.StepLR( self.D_optimizer, self.opt.scheduler_step_size, 0.1) if self.gan: self.pix2pix = networks.pix2pix_loss_iter( self.model_optimizer, self.D_optimizer, self.models["netD"], self.opt, self.crop_h, self.crop_w, mode=self.crop_mode, ) else: self.pix2pix = networks.pix2pix_loss_iter2( self.model_optimizer, self.D_optimizer, self.models["netD"], self.opt, self.crop_h, self.crop_w, mode=self.crop_mode, ) print("Training model named:\n ", self.opt.model_name) print("Models and tensorboard events files are saved to:\n ", self.opt.log_dir) print("Training is using:\n ", self.device) # data datasets_dict = { "kitti": datasets.KITTIRAWDataset, "kitti_odom": datasets.KITTIOdomDataset, "kitti_depth": datasets.KITTIDepthDataset } self.dataset = datasets_dict[self.opt.dataset] fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files_p.txt") train_filenames = readlines(fpath.format("train")) val_filenames = readlines(fpath.format("val")) img_ext = '.png' if self.opt.png else '.jpg' num_train_samples = len(train_filenames) self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs train_dataset = self.dataset(self.opt.data_path, train_filenames, self.opt.height, self.opt.width, self.opt.frame_ids, 4, is_train=True, img_ext=img_ext, refine=self.refine, crop_mode=self.crop_mode, crop_h=self.crop_h, crop_w=self.crop_w) self.train_loader = DataLoader(train_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) val_dataset = self.dataset(self.opt.data_path, val_filenames, self.opt.height, self.opt.width, self.opt.frame_ids, 4, is_train=False, img_ext=img_ext, refine=self.refine, crop_mode=self.crop_mode, crop_h=self.crop_h, crop_w=self.crop_w) self.val_loader = DataLoader(val_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) self.val_iter = iter(self.val_loader) self.writers = {} for mode in ["train", "val"]: self.writers[mode] = SummaryWriter( os.path.join(self.log_path, mode)) if not self.opt.no_ssim: self.ssim = SSIM() self.ssim.to(self.device) self.backproject_depth = {} self.project_3d = {} for scale in self.refine_stage: h = self.crop_h[scale] w = self.crop_w[scale] self.backproject_depth[scale] = BackprojectDepth( self.opt.batch_size, h, w) self.backproject_depth[scale].to(self.device) self.project_3d[scale] = Project3D(self.opt.batch_size, h, w) self.project_3d[scale].to(self.device) self.depth_metric_names = [ "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3" ] print("Using split:\n ", self.opt.split) print( "There are {:d} training items and {:d} validation items\n".format( len(train_dataset), len(val_dataset))) self.save_opts()
def train(epochs, batch_size, input_dir, model_save_dir): # Make an instance of the VGG class vgg_model = VGG_MODEL(image_shape) # Get High-Resolution(HR) [148,148,3] in this case and corresponding Low-Resolution(LR) images x_train_lr, x_train_hr = utils.load_training_data(input_dir, [148, 148, 3]) #Based on the the batch size, get the total number of batches batch_count = int(x_train_hr.shape[0] / batch_size) #Get the downscaled image shape based on the downscale factor image_shape_downscaled = utils.get_downscaled_shape( image_shape, downscale_factor) # Initialize the generator network with the input image shape as the downscaled image shape (shape of LR images) generator = networks.Generator(input_shape=image_shape_downscaled) # Initialize the discriminator with the input image shape as the original image shape (HR image shape) discriminator = networks.Discriminator(image_shape) # Get the optimizer to tweak parameters based on loss optimizer = vgg_model.get_optimizer() # Compile the three models - generator, discriminator and gan(comb of both gen and disc - this network will train generator and will not tweak discriminator) generator.compile(loss=vgg_model.vgg_loss, optimizer=optimizer) discriminator.compile(loss="binary_crossentropy", optimizer=optimizer) gan = networks.GAN_Network(generator, discriminator, image_shape_downscaled, optimizer, vgg_model.vgg_loss) # Run training for the number of epochs defined for e in range(1, epochs + 1): print('-' * 15, 'Epoch %d' % e, '-' * 15) for _ in tqdm(range(batch_count)): # Get the next batch of LR and HR images image_batch_lr, image_batch_hr = utils.get_random_batch( x_train_lr, x_train_hr, x_train_hr.shape[0], batch_size) generated_images_sr = generator.predict(image_batch_lr) print(generated_images_sr.shape) real_data_Y = np.ones( batch_size) - np.random.random_sample(batch_size) * 0.2 fake_data_Y = np.random.random_sample(batch_size) * 0.2 discriminator.trainable = True print(real_data_Y.shape) d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y) d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y) discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real) rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size) image_batch_hr = x_train_hr[rand_nums] image_batch_lr = x_train_lr[rand_nums] gan_Y = np.ones( batch_size) - np.random.random_sample(batch_size) * 0.2 discriminator.trainable = False gan_loss = gan.train_on_batch(image_batch_lr, [image_batch_hr, gan_Y]) print("discriminator_loss : %f" % discriminator_loss) print("gan_loss :", gan_loss) gan_loss = str(gan_loss) if e % 50 == 0: generator.save_weights(model_save_dir + 'gen_model%d.h5' % e) discriminator.save_weights(model_save_dir + 'dis_model%d.h5' % e) networks.save_model(gan)