def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + \ FLAGS.load_model.lstrip("checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): # 设置默认的 图 cycle_sr = CycleSR(input_ture_LR_x=FLAGS.X_LR_INPUT, input_ture_HR_y=FLAGS.Y_HR_INPUT, dis_true_HR_x=FLAGS.X_HR_DIS, dis_true_LR_y=FLAGS.Y_LR_DIS, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf) G_loss, D_Y_loss, F_loss, D_X_loss, dis_fake_HR_x, dis_fake_LR_y = cycle_sr.model( ) optimizers = cycle_sr.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: # 创建会话 if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: dis_fake_HR_x_pool = ImagePool(FLAGS.pool_size) dis_fake_LR_y_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images dis_fake_HR_x, dis_fake_LR_y = sess.run( [dis_fake_HR_x, dis_fake_LR_y]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_sr.dis_fake_HR_x: dis_fake_HR_x_pool.query(dis_fake_HR_x), cycle_sr.dis_fake_LR_y: dis_fake_LR_y_pool.query(dis_fake_LR_y) })) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN( X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size_w=FLAGS.image_size_w, image_size_h=FLAGS.image_size_h, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda1, beta1=FLAGS.beta1, ngf=FLAGS.ngf, ) G_loss, C_loss, fake_y, fake_x = cycle_gan.model() G_optimizer, C_optimizer = cycle_gan.optimize(G_loss, C_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) # train adjusted_lr = (FLAGS.learning_rate * 0.5 ** max(0, (step / FLAGS.decay_step) - 2)) feed_ = {cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val), cycle_gan.learning_rate: adjusted_lr} # update D 5 times before update G for i in range(5): _ = sess.run(C_optimizer, feed_dict=feed_) _ = sess.run(G_optimizer, feed_dict=feed_) G_loss_val, C_loss_val, summary = ( sess.run( [G_loss, C_loss, summary_op], feed_dict=feed_ ) ) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' C_loss : {}'.format(C_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
class SimpleGAN(Trainer): def __init__(self, parsed_args, parsed_groups): super().__init__(**parsed_groups['trainer arguments']) self.gen = UNetGenerator(**parsed_groups['generator arguments']) self.disc = NLayerDiscriminator( **parsed_groups['discriminator arguments']) init_weights(self.gen) init_weights(self.disc) self.real_label = torch.tensor(1.0) self.fake_label = torch.tensor(0.0) self.crit = torch.nn.MSELoss() # LSGAN self.image_pool = ImagePool(parsed_args.pool_size, parsed_args.replay_prob) self.sel_ind = 0 self.un_normalize = lambda x: 255. * (1 + x.clamp(min=-1, max=1)) / 2. self.parsed_args = parsed_args self.n_vis = parsed_args.n_vis self.vis = Visualizer() def configure_optimizers(self): opt1 = torch.optim.Adam(self.disc.parameters(), lr=self.parsed_args.lr) opt2 = torch.optim.Adam(self.gen.parameters(), lr=self.parsed_args.lr) N = self.parsed_args.n_epochs N_start = int(N * self.parsed_args.frac_decay_start) sched_lamb = lambda x: 1.0 - max(0, x - N_start) / (N - N_start) sched1 = torch.optim.lr_scheduler.LambdaLR(opt1, lr_lambda=sched_lamb) sched2 = torch.optim.lr_scheduler.LambdaLR(opt2, lr_lambda=sched_lamb) return opt1, opt2, sched1, sched2 def on_fit_start(self): self.vis.start() def on_fit_end(self): self.vis.stop() def _shared_step(self, batch, save_img, is_train): res = Result() X, Y_real = batch Y_fake = self.gen(X) if is_train: Y_pool = self.image_pool.query(Y_fake.detach()) else: Y_pool = Y_fake res.recon_error = self.crit(Y_real, Y_fake) real_predict = self.disc(Y_real) fake_predict = self.disc(Y_pool) real_label = self.real_label.expand_as(real_predict) fake_label = self.fake_label.expand_as(fake_predict) disc_loss = 0.5 * (self.crit(real_predict, real_label) + \ self.crit(fake_predict, fake_label)) if is_train: res.step(disc_loss) res.disc_loss = disc_loss gen_predict = self.disc(Y_fake) gen_loss = self.crit(gen_predict, real_label) if is_train: res.step(gen_loss) res.gen_loss = gen_loss if save_img: res.img = [ self.un_normalize(X[:self.n_vis]), self.un_normalize(Y_fake[:self.n_vis]), self.un_normalize(Y_real[:self.n_vis]) ] return res def training_step(self, batch, batch_idx): res = self._shared_step(batch, save_img=(batch_idx == 0), is_train=True) return res def validation_step(self, batch, batch_idx): res = self._shared_step(batch, save_img=(batch_idx == self.sel_ind), is_train=False) return res def _shared_end(self, result_outputs, is_train): phase = 'Train' if is_train else 'Valid' self.vis.plot('Gen. Loss', phase + ' Loss', self.current_epoch, torch.mean(torch.stack(result_outputs.gen_loss))) self.vis.plot('Disc. Loss', phase + ' Loss', self.current_epoch, torch.mean(torch.stack(result_outputs.disc_loss))) collated_imgs = torch.cat([*torch.cat(result_outputs.img[0], dim=3)], dim=1) self.vis.show_image(phase + ' Images', collated_imgs) def training_epoch_end(self, training_outputs): self._shared_end(training_outputs, is_train=True) def validation_epoch_end(self, validation_outputs): self._shared_end(validation_outputs, is_train=False) self.sel_ind = random.randint(0, len(self.validation_loader) - 1) return torch.mean(torch.stack(validation_outputs.recon_error))
def main(): args = args_initialize() save_freq = args.save_freq epochs = args.num_epoch cuda = args.cuda train_dataset = UnalignedDataset(is_train=True) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0 ) net_G_A = ResNetGenerator(input_nc=3, output_nc=3) net_G_B = ResNetGenerator(input_nc=3, output_nc=3) net_D_A = Discriminator() net_D_B = Discriminator() if args.cuda: net_G_A = net_G_A.cuda() net_G_B = net_G_B.cuda() net_D_A = net_D_A.cuda() net_D_B = net_D_B.cuda() fake_A_pool = ImagePool(50) fake_B_pool = ImagePool(50) criterionGAN = GANLoss(cuda=cuda) criterionCycle = torch.nn.L1Loss() criterionIdt = torch.nn.L1Loss() optimizer_G = torch.optim.Adam( itertools.chain(net_G_A.parameters(), net_G_B.parameters()), lr=args.lr, betas=(args.beta1, 0.999) ) optimizer_D_A = torch.optim.Adam(net_D_A.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D_B = torch.optim.Adam(net_D_B.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) log_dir = './logs' checkpoints_dir = './checkpoints' os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoints_dir, exist_ok=True) writer = SummaryWriter(log_dir) for epoch in range(epochs): running_loss = np.zeros((8)) for batch_idx, data in enumerate(train_loader): input_A = data['A'] input_B = data['B'] if cuda: input_A = input_A.cuda() input_B = input_B.cuda() real_A = Variable(input_A) real_B = Variable(input_B) """ Backward net_G """ optimizer_G.zero_grad() lambda_idt = 0.5 lambda_A = 10.0 lambda_B = 10.0 # 各 Generatorに変換後の画像を入力 # 何もしないのが理想の出力 idt_B = net_G_A(real_B) loss_idt_A = criterionIdt(idt_B, real_B) * lambda_B * lambda_idt idt_A = net_G_B(real_A) loss_idt_B = criterionIdt(idt_A, real_A) * lambda_A * lambda_idt # GAN loss = D_A(G_A(A)) # G_Aとしては生成した偽物画像が本物(True)と判断して欲しい fake_B = net_G_A(real_A) pred_fake = net_D_A(fake_B) loss_G_A = criterionGAN(pred_fake, True) fake_A = net_G_B(real_B) pred_fake = net_D_B(fake_A) loss_G_B = criterionGAN(pred_fake, True) rec_A = net_G_B(fake_B) loss_cycle_A = criterionCycle(rec_A, real_A) * lambda_A rec_B = net_G_A(fake_A) loss_cycle_B = criterionCycle(rec_B, real_B) * lambda_B loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() optimizer_G.step() """ update D_A """ optimizer_D_A.zero_grad() fake_B = fake_B_pool.query(fake_B.data) pred_real = net_D_A(real_B) loss_D_real = criterionGAN(pred_real, True) pred_fake = net_D_A(fake_B.detach()) loss_D_fake = criterionGAN(pred_fake, False) loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A.backward() optimizer_D_A.step() """ update D_B """ optimizer_D_B.zero_grad() fake_A = fake_A_pool.query(fake_A.data) pred_real = net_D_B(real_A) loss_D_real = criterionGAN(pred_real, True) pred_fake = net_D_B(fake_A.detach()) loss_D_fake = criterionGAN(pred_fake, False) loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B.backward() optimizer_D_B.step() ret_loss = np.array([ loss_G_A.data.detach().cpu().numpy(), loss_D_A.data.detach().cpu().numpy(), loss_G_B.data.detach().cpu().numpy(), loss_D_B.data.detach().cpu().numpy(), loss_cycle_A.data.detach().cpu().numpy(), loss_cycle_B.data.detach().cpu().numpy(), loss_idt_A.data.detach().cpu().numpy(), loss_idt_B.data.detach().cpu().numpy() ]) running_loss += ret_loss """ Save checkpoints """ if (epoch + 1) % save_freq == 0: save_network(net_G_A, 'G_A', str(epoch + 1)) save_network(net_D_A, 'D_A', str(epoch + 1)) save_network(net_G_B, 'G_B', str(epoch + 1)) save_network(net_D_B, 'D_B', str(epoch + 1)) running_loss /= len(train_loader) losses = running_loss print('epoch %d, losses: %s' % (epoch + 1, running_loss)) writer.add_scalar('loss_G_A', losses[0], epoch) writer.add_scalar('loss_D_A', losses[1], epoch) writer.add_scalar('loss_G_B', losses[2], epoch) writer.add_scalar('loss_D_B', losses[3], epoch) writer.add_scalar('loss_cycle_A', losses[4], epoch) writer.add_scalar('loss_cycle_B', losses[5], epoch) writer.add_scalar('loss_idt_A', losses[6], epoch) writer.add_scalar('loss_idt_B', losses[7], epoch)
def train(args): print(args) # net netG = Generator() netG = netG.cuda() netD = Discriminator() netD = netD.cuda() # loss l1_loss = nn.L1Loss().cuda() l2_loss = nn.MSELoss().cuda() bce_loss = nn.BCELoss().cuda() # opt optimizerG = optim.Adam(netG.parameters(), lr=args.glr) optimizerD = optim.Adam(netD.parameters(), lr=args.dlr) # lr schedulerG = lr_scheduler.StepLR(optimizerG, args.lr_step_size, args.lr_gamma) schedulerD = lr_scheduler.StepLR(optimizerD, args.lr_step_size, args.lr_gamma) # utility for saving models, parameters and logs save = SaveData(args.save_dir, args.exp, True) save.save_params(args) # netG, _ = save.load_model(netG) dataset = MyDataset(args.data_dir, is_train=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.n_threads)) real_label = Variable( torch.ones([1, 1, args.patch_gan, args.patch_gan], dtype=torch.float)).cuda() fake_label = Variable( torch.zeros([1, 1, args.patch_gan, args.patch_gan], dtype=torch.float)).cuda() image_pool = ImagePool(args.pool_size) vgg = Vgg16(requires_grad=False) vgg.cuda() for epoch in range(args.epochs): print("* Epoch {}/{}".format(epoch + 1, args.epochs)) schedulerG.step() schedulerD.step() d_total_real_loss = 0 d_total_fake_loss = 0 d_total_loss = 0 g_total_res_loss = 0 g_total_per_loss = 0 g_total_gan_loss = 0 g_total_loss = 0 netG.train() netD.train() for batch, images in tqdm(enumerate(dataloader)): input_image, target_image = images input_image = Variable(input_image.cuda()) target_image = Variable(target_image.cuda()) output_image = netG(input_image) # Update D netD.requires_grad(True) netD.zero_grad() ## real image real_output = netD(target_image) d_real_loss = bce_loss(real_output, real_label) d_real_loss.backward() d_real_loss = d_real_loss.data.cpu().numpy() d_total_real_loss += d_real_loss ## fake image fake_image = output_image.detach() fake_image = Variable(image_pool.query(fake_image.data)) fake_output = netD(fake_image) d_fake_loss = bce_loss(fake_output, fake_label) d_fake_loss.backward() d_fake_loss = d_fake_loss.data.cpu().numpy() d_total_fake_loss += d_fake_loss ## loss d_total_loss += d_real_loss + d_fake_loss optimizerD.step() # Update G netD.requires_grad(False) netG.zero_grad() ## reconstruction loss g_res_loss = l1_loss(output_image, target_image) g_res_loss.backward(retain_graph=True) g_res_loss = g_res_loss.data.cpu().numpy() g_total_res_loss += g_res_loss ## perceptual loss g_per_loss = args.p_factor * l2_loss(vgg(output_image), vgg(target_image)) g_per_loss.backward(retain_graph=True) g_per_loss = g_per_loss.data.cpu().numpy() g_total_per_loss += g_per_loss ## gan loss output = netD(output_image) g_gan_loss = args.g_factor * bce_loss(output, real_label) g_gan_loss.backward() g_gan_loss = g_gan_loss.data.cpu().numpy() g_total_gan_loss += g_gan_loss ## loss g_total_loss += g_res_loss + g_per_loss + g_gan_loss optimizerG.step() d_total_real_loss = d_total_real_loss / (batch + 1) d_total_fake_loss = d_total_fake_loss / (batch + 1) d_total_loss = d_total_loss / (batch + 1) save.add_scalar('D/real', d_total_real_loss, epoch) save.add_scalar('D/fake', d_total_fake_loss, epoch) save.add_scalar('D/total', d_total_loss, epoch) g_total_res_loss = g_total_res_loss / (batch + 1) g_total_per_loss = g_total_per_loss / (batch + 1) g_total_gan_loss = g_total_gan_loss / (batch + 1) g_total_loss = g_total_loss / (batch + 1) save.add_scalar('G/res', g_total_res_loss, epoch) save.add_scalar('G/per', g_total_per_loss, epoch) save.add_scalar('G/gan', g_total_gan_loss, epoch) save.add_scalar('G/total', g_total_loss, epoch) if epoch % args.period == 0: log = "Train d_loss: {:.5f} \t g_loss: {:.5f}".format( d_total_loss, g_total_loss) print(log) save.save_log(log) save.save_model(netG, epoch)
class CycleGANModel(): def __init__(self, opt): self.opt = opt self.dynamic = opt.dynamic self.isTrain = opt.istrain self.Tensor = torch.cuda.FloatTensor # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = GModel(opt).cuda() self.netG_B = GModel(opt).cuda() self.netF_A = Fmodel().cuda() self.dataF = Fdata.get_loader() if self.isTrain: self.netD_A = DModel(opt).cuda() self.netD_B = DModel(opt).cuda() if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = torch.nn.MSELoss() self.criterionIdt = torch.nn.MSELoss() # initialize optimizers # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters())) self.optimizer_G = torch.optim.Adam([{ 'params': self.netG_A.parameters(), 'lr': 1e-3 }, { 'params': self.netF_A.parameters(), 'lr': 0.0 }, { 'params': self.netG_B.parameters(), 'lr': 1e-3 }]) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: # self.schedulers.append(networks.get_scheduler(optimizer, opt)) self.schedulers.append(optimizer) print('---------- Networks initialized -------------') # networks.print_network(self.netG_A) # networks.print_network(self.netG_B) # if self.isTrain: # networks.print_network(self.netD_A) # networks.print_network(self.netD_B) print('-----------------------------------------------') def train_forward(self): optimizer = torch.optim.Adam(self.netF_A.parameters(), lr=1e-3) loss_fn = torch.nn.MSELoss() loss_fn = torch.nn.L1Loss() for epoch in range(100): epoch_loss = 0 for i, item in enumerate(self.dataF): state0 = item[0].float().cuda() action = item[1].float().cuda() state1 = item[2].float().cuda() pred = self.netF_A(state0, action) loss = loss_fn(pred, state1) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print('epoch:{} loss:{:.7f}'.format(epoch, epoch_loss / len(self.dataF))) print('forward model has been trained!') def set_input(self, input): # AtoB = self.opt.which_direction == 'AtoB' # input_A = input['A' if AtoB else 'B'] # input_B = input['B' if AtoB else 'A'] self.input_A = input[0] self.input_Bt0 = input[1][0] self.input_Bt1 = input[1][2] self.action = input[1][1] def forward(self): self.real_A = Variable(self.input_A).float().cuda() self.real_Bt0 = Variable(self.input_Bt0).float().cuda() self.real_Bt1 = Variable(self.input_Bt1).float().cuda() self.action = Variable(self.action).float().cuda() def test(self): self.forward() real_A = Variable(self.input_A, volatile=True).float().cuda() fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_Bt0, volatile=True).float().cuda() fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_Bt0, fake_B) self.loss_D_A = loss_D_A.item() def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_At0) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.item() def backward_G(self): lambda_idt = 0.5 lambda_A = 100.0 lambda_B = 100.0 # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_Bt0) loss_idt_A = self.criterionIdt( idt_A, self.real_Bt0) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.item() self.loss_idt_B = loss_idt_B.item() else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 lambda_G = 1.0 # --------first cycle-----------# # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # ---------second cycle---------# # GAN loss D_B(G_B(B)) fake_At0 = self.netG_B(self.real_Bt0) pred_fake = self.netD_B(fake_At0) loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G # Backward cycle loss rec_Bt0 = self.netG_A(fake_At0) loss_cycle_Bt0 = self.criterionCycle(rec_Bt0, self.real_Bt0) * lambda_B # ---------third cycle---------# # GAN loss D_B(G_B(B)) fake_At1 = self.netF_A(fake_At0, self.action) pred_fake = self.netD_B(fake_At1) loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G # Backward cycle loss rec_Bt1 = self.netG_A(fake_At1) loss_cycle_Bt1 = self.criterionCycle(rec_Bt1, self.real_Bt1) * lambda_B # combined loss loss_G = loss_idt_A + loss_idt_B loss_G = loss_G + loss_G_A + loss_cycle_A loss_G = loss_G + loss_G_Bt0 + loss_cycle_Bt0 if self.dynamic: loss_G = loss_G + loss_G_Bt1 + loss_cycle_Bt1 loss_G.backward() self.fake_B = fake_B.data self.fake_At0 = fake_At0.data self.fake_At1 = fake_At1.data self.rec_A = rec_A.data self.rec_Bt0 = rec_Bt0.data self.rec_Bt1 = rec_Bt1.data self.loss_G_A = loss_G_A.item() self.loss_G_Bt0 = loss_G_Bt0.item() self.loss_G_Bt1 = loss_G_Bt1.item() self.loss_cycle_A = loss_cycle_A.item() self.loss_cycle_Bt0 = loss_cycle_Bt0.item() self.loss_cycle_Bt1 = loss_cycle_Bt1.item() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_Bt0), ('Cyc_B', self.loss_cycle_Bt0)]) # if self.opt.identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors # helper saving function that can be used by subclasses def save_network(self, network, network_label, path): save_filename = 'model_{}.pth'.format(network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) def save(self, path): self.save_network(self.netG_A, 'G_A', path) self.save_network(self.netD_A, 'D_A', path) self.save_network(self.netG_B, 'G_B', path) self.save_network(self.netD_B, 'D_B', path) def load_network(self, network, network_label, path): weight_filename = 'model_{}.pth'.format(network_label) weight_path = os.path.join(path, weight_filename) network.load_state_dict(torch.load(weight_path)) def load(self, path): self.load_network(self.netG_A, 'G_A', path) self.load_network(self.netG_B, 'G_B', path) def plot_points(self, item, label): item = item.cpu().data.numpy() plt.scatter(item[:, 0], item[:, 1], label=label) def visual(self, path): plt.rcParams['figure.figsize'] = (8.0, 3.0) plt.xlim(-4, 4) plt.ylim(-1.5, 1.5) self.plot_points(self.real_A, 'realA') self.plot_points(self.fake_B, 'fake_B') self.plot_points(self.rec_A, 'rec_A') # self.plot_points(self.real_B,'real_B') for p1, p2 in zip(self.real_A, self.fake_B): p1, p2 = p1.cpu().data.numpy(), p2.cpu().data.numpy() plt.plot([p1[0], p2[0]], [p1[1], p2[1]]) plt.legend() plt.savefig(path) plt.cla() plt.clf()
class cycleGan(): # def __init__(self, g_conv_dim=64, d_conv_dim=64,res_blocks=4,lr=0.001, beta1=0.5, beta2=0.999): def __init__(self, opt): super(cycleGan, self).__init__() self.opt = opt self.G_XtoY = RensetGenerator(opt.input_nc, opt.output_nc, opt.ngf, opt.norm, not opt.no_dropout, opt.n_blocks, opt.padding_type).to(device) # self.G_YtoX = CycleGenerator(d_conv_dim, res_blocks) self.G_YtoX = RensetGenerator(opt.output_nc, opt.input_nc, opt.ngf, opt.norm, not opt.no_dropout, opt.n_blocks, opt.padding_type).to(device) # self.D_X = CycleDiscriminator(d_conv_dim) self.D_X = PatchDiscriminator(opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm).to(device) # self.D_Y = CycleDiscriminator(d_conv_dim) self.D_Y = PatchDiscriminator(opt.input_nc, opt.ndf, opt.n_layers_D, opt.norm).to(device) # self.G_XtoY.apply(init_weights) # self.G_YtoX.apply(init_weights) # self.D_X.apply(init_weights) # self.D_Y.apply(init_weights) print(self.G_XtoY) print("Parameters: ", len(list(self.G_XtoY.parameters()))) print(self.G_YtoX) print("Parameters: ", len(list(self.G_YtoX.parameters()))) print(self.D_X) print("Parameters: ", len(list(self.D_X.parameters()))) print(self.D_Y) print("Parameters: ", len(list(self.D_Y.parameters()))) self.fake_X_pool = ImagePool(opt.pool_size) self.fake_Y_pool = ImagePool(opt.pool_size) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() #For implementing Identity Loss self.optimizer_G = torch.optim.Adam(itertools.chain( self.G_XtoY.parameters(), self.G_YtoX.parameters()), lr=opt.lr, betas=[opt.beta1, 0.999]) self.optimizer_DX = torch.optim.Adam(self.D_X.parameters(), lr=opt.lr, betas=[opt.beta1, 0.999]) self.optimizer_DY = torch.optim.Adam(self.D_Y.parameters(), lr=opt.lr, betas=[opt.beta1, 0.999]) def get_input(self, inputX, inputY): self.inputX = inputX self.inputY = inputY def forward(self): self.fake_X = self.G_YtoX(self.inputY).to(device) self.rec_Y = self.G_XtoY(self.fake_X).to(device) self.fake_Y = self.G_XtoY(self.inputX).to(device) self.rec_X = self.G_YtoX(self.fake_Y).to(device) def backward_D_basic(self, netD, real, fake): pred_real = netD(real) loss_D_real = real_mse_loss(pred_real) # Fake pred_fake = netD(fake.detach()) loss_D_fake = fake_mse_loss(pred_fake) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_X(self): """Calculate GAN loss for discriminator D_A""" if self.opt.pool == True: fake_X = self.fake_X_pool.query(self.fake_X) self.loss_D_X = self.backward_D_basic(self.D_X, self.inputX, fake_X) else: self.loss_D_X = self.backward_D_basic(self.D_X, self.inputX, self.fake_X) def backward_D_Y(self): """Calculate GAN loss for discriminator D_A""" if self.opt.pool == True: fake_Y = self.fake_Y_pool.query(self.fake_Y) self.loss_D_Y = self.backward_D_basic(self.D_Y, self.inputY, fake_Y) else: self.loss_D_Y = self.backward_D_basic(self.D_Y, self.inputY, self.fake_Y) def backward_G(self): # Not implemented identity loss lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B self.loss_G_X = real_mse_loss(self.D_X(self.fake_X)) # GAN loss D_B(G_B(B)) self.loss_G_Y = real_mse_loss(self.D_Y(self.fake_Y)) # Forward cycle loss || G_B(G_A(A)) - A|| self.loss_cycle_X = self.criterionCycle(self.rec_X, self.inputX) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_cycle_Y = self.criterionCycle(self.rec_Y, self.inputY) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_X + self.loss_G_Y + self.loss_cycle_X + self.loss_cycle_Y self.loss_G.backward() def change_lr(self, new_lr): self.opt.lr = new_lr self.optimizer_G = torch.optim.Adam(itertools.chain( self.G_XtoY.parameters(), self.G_YtoX.parameters()), lr=self.opt.lr, betas=[self.opt.beta1, 0.999]) self.optimizer_DX = torch.optim.Adam(self.D_X.parameters(), lr=self.opt.lr, betas=[self.opt.beta1, 0.999]) self.optimizer_DY = torch.optim.Adam(self.D_Y.parameters(), lr=self.opt.lr, betas=[self.opt.beta1, 0.999]) def set_requires_grad(self, nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def optimize(self): self.set_requires_grad([self.D_X, self.D_Y], False) self.forward() self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights self.set_requires_grad([self.D_X, self.D_Y], True) self.forward() self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # D_A and D_B self.optimizer_DX.zero_grad() self.optimizer_DY.zero_grad() # set D_A and D_B's gradients to zero self.backward_D_X() # calculate gradients for D_A self.backward_D_Y() # calculate graidents for D_B self.optimizer_DX.step() # update D_A and D_B's weights self.optimizer_DY.step() # update D_A and D_B's weights
def main(): num_epoch = 100000 pool_size = 20 batch_size = 1 oldpath = FLAGS.buckets RealPicPath = 'picF' AnimaPicPaht = 'picG' useCopyfile = True if useCopyfile: trainfiles = ['picf1.zip', 'picf2.zip', 'picg1.zip'] # trainfiles.extend(['picf3.zip','picf4.zip','picg2.zip']) print(trainfiles) for f in trainfiles: fn = utils.pai_copy(f, oldpath) utils.Unzip(fn) RealPicPath = os.path.join('temp', RealPicPath) AnimaPicPaht = os.path.join('temp', AnimaPicPaht) print(RealPicPath) print(AnimaPicPaht) sess = tf.InteractiveSession( config=tf.ConfigProto(allow_soft_placement=True)) cycle_gan = CycleGAN( X_train_file=AnimaPicPaht, Y_train_file=RealPicPath, batch_size=batch_size, image_size=(270, 480), lossfunc = 'wgan', norm='instance', learning_rate=2e-4, start_decay_step = 10000, decay_steps = 100000 optimizer = 'RMSProp' ) Ga2b_loss, Da2b_loss, Gb2a_loss, Db2a_loss, fake_a, fake_b,real_a,real_b = cycle_gan.build() optimizers = cycle_gan.optimize(Ga2b_loss, Da2b_loss, Gb2a_loss, Db2a_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.checkpointDir) saver = tf.train.Saver(max_to_keep=0) sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # save_path = saver.save(sess,os.path.join(FLAGS.checkpointDir,"model_pre.ckpt")) # print("Model saved in file: %s" % save_path) fake_a_pool = ImagePool(pool_size) fake_b_pool = ImagePool(pool_size) print('start train') start_time = time.time() for step in range(1, num_epoch + 1): # get previously generated images fake_a_val, fake_b_val = sess.run([fake_a, fake_b]) # train _, Ga2b_loss_val, Da2b_loss_val, Gb2a_loss_val, Db2a_loss_val,real_a_val,real_b_val, summary = ( sess.run( [optimizers, Ga2b_loss, Da2b_loss, Gb2a_loss, Db2a_loss,real_a,real_b, summary_op], feed_dict={cycle_gan.fake_a: fake_a_pool.query(fake_a_val), cycle_gan.fake_b: fake_b_pool.query(fake_b_val)} ) ) elapsed_time = time.time() - start_time start_time = time.time() if step % 25 == 0: print('Ga2b_loss_val : %s--Da2b_loss_val : %s--Gb2a_loss_val : %s--Db2a_loss_val : %s--' % (Ga2b_loss_val, Da2b_loss_val, Gb2a_loss_val, Db2a_loss_val)) print('step : %s --elapsed_time : %s' % (step, elapsed_time)) print('adding summary...') train_writer.add_summary(summary, step) train_writer.flush() # if step % 100 == 0: # print('-----------Step %d:-------------' % step) # print(' G_loss : {}'.format(G_loss_val)) # print(' D_Y_loss : {}'.format(D_Y_loss_val)) # print(' F_loss : {}'.format(F_loss_val)) # print(' D_X_loss : {}'.format(D_X_loss_val)) if step % 1000 == 0: save_path = saver.save(sess, os.path.join( FLAGS.checkpointDir, "model.ckpt"), global_step=step,write_meta_graph=False) print("Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads)
class Model(): @staticmethod def modify_commandline_options(parser, is_train=True): parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') return parser def __init__(self, opt): # BaseModel.__init__(self, opt) self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device( 'cpu') # get device name: CPU or GPU self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. torch.backends.cudnn.benchmark = True self.loss_names = [] self.model_names = [] self.visual_names = [] self.optimizers = [] self.image_paths = [] self.metric = None # used for learning rate policy 'plateau' self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # , 'perception_G_A', # 'perception_G_B', 'image_G_A', 'image_G_B', 'tv_G_A', 'tv_G_B', 'rl_G_A', 'rl_G_B'] # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals> visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] self.netG_A = define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) self.netG_B = define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define discriminators self.netD_A = define_D(opt.output_nc, opt.ndf, opt.netD, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = define_D(opt.input_nc, opt.ndf, opt.netD, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert(opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.criterionGAN = GANLoss(opt.gan_mode).to(self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() vgg = vgg16(pretrained=True) loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() for param in loss_network.parameters(): param.requires_grad = False loss_network.cuda() self.criterionLossnetwork = loss_network self.criterionMse = torch.nn.MSELoss() self.criterionTv = TVLoss() self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.A_paths = input['A_paths'][0] self.B_paths = input['B_paths'][0] self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def realistic_loss_grad(self, image, laplacian_m): img = image.squeeze(0) channel, height, width = img.size() loss = 0 for i in range(channel): # print(laplacian_m.size()) # print(img[i, :, :].size()) # print(img[i, :, :].reshape(-1, 1).size()) grad = torch.mm(laplacian_m, img[i, :, :].reshape(-1, 1)) loss += torch.mm(img[i, :, :].reshape(1, -1), grad) return loss def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss || G_B(G_A(A)) - A|| self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # Perception Loss self.loss_perception_G_A = self.criterionMse(self.criterionLossnetwork(self.fake_A), self.criterionLossnetwork(self.real_A)) * 0.5 self.loss_perception_G_B = self.criterionMse(self.criterionLossnetwork(self.fake_B), self.criterionLossnetwork(self.real_B)) * 0.5 # Image Loss self.loss_image_G_A = self.criterionMse(self.fake_A, self.real_A) * 20.0 self.loss_image_G_B = self.criterionMse(self.fake_B, self.real_B) * 20.0 # TV Loss self.loss_tv_G_A = self.criterionTv(self.fake_A) * 2e-8 self.loss_tv_G_B = self.criterionTv(self.fake_B) * 2e-8 # real loss print('Computing Laplacian matrix of content image') # print(self.real_A.size()) # image2 = cv2.imread(self.A_paths) # print(image2.shape) self.loss_rl_G_A = 0 self.loss_rl_G_B = 0 for i in range(self.real_A.size()[0]): L_A = compute_lap(self.real_A[i]) L_B = compute_lap(self.real_B[i]) self.loss_rl_G_A += self.realistic_loss_grad(self.fake_A[i], L_A) * 0.00001 self.loss_rl_G_B += self.realistic_loss_grad(self.fake_B[i], L_B) * 0.00001 self.loss_rl_G_A = torch.div(self.loss_rl_G_A, float(self.real_A.size()[0])) self.loss_rl_G_B = torch.div(self.loss_rl_G_B, float(self.real_B.size()[0])) self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + \ self.loss_idt_B # + self.loss_perception_G_A + self.loss_perception_G_B + self.loss_image_G_A + \ # self.loss_image_G_B + self.loss_tv_G_A + self.loss_tv_G_B + self.loss_rl_G_A + self.loss_rl_G_B self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero self.backward_D_A() # calculate gradients for D_A self.backward_D_B() # calculate graidents for D_B self.optimizer_D.step() # update D_A and D_B's weights return self.real_A, self.fake_A, self.real_B, self.fake_B, self.loss_G_A, self.loss_G_B, self.loss_D_A, \ self.loss_D_B, self.loss_cycle_A, self.loss_cycle_B, self.loss_idt_A, self.loss_idt_B # self.loss_perception_G_A, self.loss_perception_G_B, self.loss_image_G_A, self.loss_image_G_B, \ # self.loss_tv_G_A, self.loss_tv_G_B, self.loss_rl_G_A, self.loss_rl_G_B def setup(self, opt): if self.isTrain: self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch self.load_networks(load_suffix) self.print_networks(opt.verbose) def eval(self): """Make models eval mode during test time""" for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) net.eval() def test(self): with torch.no_grad(): self.forward() self.compute_visuals() def compute_visuals(self): pass def get_image_paths(self): return self.image_paths def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step(self.metric) lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr) def get_current_visuals(self): visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name) return visual_ret def get_current_losses(self): errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): errors_ret[name] = float( getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number return errors_ret def save_networks(self, epoch): for name in self.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'num_batches_tracked'): state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) def load_networks(self, epoch): for name in self.model_names: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) load_path = os.path.join(self.save_dir, load_filename) net = getattr(self, 'net' + name) if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=str(self.device)) if hasattr(state_dict, '_metadata'): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) net.load_state_dict(state_dict) def print_networks(self, verbose): print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) print('-----------------------------------------------') def set_requires_grad(self, nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad
loss_vertex_A = criterionCycle(fake_B, real_B) loss_vertex_B = criterionCycle(fake_A, real_A) loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B - cc_A * lambda_cc - cc_B * lambda_cc + loss_vertex_A * lambda_vertex + loss_vertex_B * lambda_vertex """ calculate gradients for G_A and G_B """ loss_G.backward() optimizer_G_A.step() # update G_A and G_B's weights optimizer_G_B.step() # update G_A and G_B's weights # train D_A and D_B set_requires_grad([netD_A, netD_B], True) optimizer_D_A.zero_grad() # set D_A and D_B's gradients to zero optimizer_D_B.zero_grad() # set D_A and D_B's gradients to zero """Calculate GAN loss for discriminator D_A""" fake_B = fake_B_pool.query(fake_B) loss_D_A = backward_D_basic(netD_A, real_B, fake_B, 0.1) """Calculate GAN loss for discriminator D_B""" fake_A = fake_A_pool.query(fake_A) loss_D_B = backward_D_basic(netD_B, real_A, fake_A, 0.1) optimizer_D_A.step() # update D_A and D_B's weights optimizer_D_B.step() # update D_A and D_B's weights print( "[{}:{}/{}] IDT_A={:.4}, IDT_B={:.4}, G_A={:.4}, G_B={:.4}, CYCLE_A={:.4}, CYCLE_B={:.4}, D_A={:.4}, D_B={:.4}, CC_A={:.4}, CC_B={:.4}" .format(epoch, batch_idx, len(train_dataloader), loss_idt_A, loss_idt_B, loss_G_A, loss_G_B, loss_cycle_A, loss_cycle_B, loss_D_A, loss_D_B, cc_A, cc_B)) writer.add_scalars('Train/IDT_loss', {
def train(): if FLAGS.load_model is not None: checkpoint_dir = 'checkpoint/' + FLAGS.load_model else: current_time = datetime.now().strftime('%Y%m%d-%H%M') checkpoint_dir = 'checkpoint/{}'.format(current_time) try: os.makedirs(checkpoint_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN(X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoint_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) meta_graph_path = checkpoint.model_checkpoint_path + '.meta' restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) step = int(meta_graph_path.split('-')[2].split('.')[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) print('Begin to train...') while not coord.should_stop(): # get previously generated images print('tf.Session().Run [fake_y, fake_x] ') fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) #train print('Calculate loss...') _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_X_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val) })) if step % 100 == 0: train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-------------Step %d------------' % step) logging.info('G_loss: {}'.format(G_loss_val)) logging.info('D_Y_loss: {}'.format(D_Y_loss_val)) logging.info('F_loss: {}'.format(F_loss_val)) logging.info('D_X_loss:{}'.format(D_X_loss_val)) logging.info('********************************') if step % 10000 == 0: save_path = saver.save(sess, checkpoint_dir + '/model.ckpt', global_step=step) logging.info('* Model saved in file %s' % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoint_dir + '/model.ckpt', global_step=step) logging.info('Model saved in file %s' % save_path) coord.request_stop() coord.join(threads)
class CycleGanModel(BaseModel): def __init__(self, opt): super(CycleGanModel, self).__init__(opt) print('-------------- Networks initializing -------------') self.mode = None # specify the training losses you want to print out. The program will call base_model.get_current_losses self.lossNames = [ 'loss{}'.format(i) for i in [ 'GenA', 'DisA', 'CycleA', 'IdtA', 'DisB', 'GenB', 'CycleB', 'IdtB' ] ] self.lossGenA, self.lossDisA, self.lossCycleA, self.lossIdtA = 0, 0, 0, 0 self.lossGenB, self.lossDisB, self.lossCycleB, self.lossIdtB = 0, 0, 0, 0 # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=opt.lsgan).to( opt.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # specify the training miou you want to print out. The program will call base_model.get_current_mious self.miouNames = [] # specify the images you want to save/display. The program will call base_model.get_current_visuals # only image doesn't have prefix imageNamesA = ['realA', 'fakeA', 'recA', 'idtA'] imageNamesB = ['realB', 'fakeB', 'recB', 'idtB'] self.imageNames = imageNamesA + imageNamesB self.realA, self.fakeA, self.recA, self.idtA = None, None, None, None self.realB, self.fakeB, self.recB, self.idtB = None, None, None, None # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks # naming is by the input domain self.modelNames = [ 'net{}'.format(i) for i in ['GenA', 'DisA', 'GenB', 'DisB'] ] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_RGB (G), G_D (F), D_RGB (D_Y), D_D (D_X) self.netGenA = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf, opt.which_model_netG, opt.norm, opt.dropout, opt.init_type, opt.init_gain, opt.gpuIds) self.netDisA = networks.define_D(opt.inputCh, opt.inputCh, opt.which_model_netD, opt.n_layers_D, opt.norm, not opt.lsgan, opt.init_type, opt.init_gain, opt.gpuIds) self.netGenB = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf, opt.which_model_netG, opt.norm, opt.dropout, opt.init_type, opt.init_gain, opt.gpuIds) self.netDisB = networks.define_D(opt.inputCh, opt.inputCh, opt.which_model_netD, opt.n_layers_D, opt.norm, not opt.lsgan, opt.init_type, opt.init_gain, opt.gpuIds) self.set_requires_grad( [self.netGenA, self.netGenB, self.netDisA, self.netDisB], True) # define image pool self.fakeAPool = ImagePool(opt.pool_size) self.fakeBPool = ImagePool(opt.pool_size) # initialize optimizers self.optimizerG = getOptimizer(itertools.chain( self.netGenA.parameters(), self.netGenB.parameters()), opt=opt.opt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizerD = getOptimizer(itertools.chain( self.netDisA.parameters(), self.netDisB.parameters()), opt=opt.opt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizers = [] self.optimizers.append(self.optimizerG) self.optimizers.append(self.optimizerD) print('--------------------------------------------------') def name(self): return 'CycleGanModel' def set_input(self, input): self.realA = input[0]['image'].to(self.opt.device) self.realB = input[1]['image'].to(self.opt.device) def forward(self): self.fakeA = self.netGenB(self.realB) self.fakeB = self.netGenA(self.realA) self.recA = self.netGenB(self.fakeB) self.recB = self.netGenA(self.fakeA) def backward_dis_basic(self, netDis, real, fake): # Real predReal = netDis(real) lossDisReal = self.criterionGAN(predReal, True) # Fake predFake = netDis(fake.detach()) lossDisFake = self.criterionGAN(predFake, False) # Combined loss lossDis = (lossDisReal + lossDisFake) * 0.5 # backward lossDis.backward() return float(lossDis) def backward_dis_A(self): fakeA = self.fakeAPool.query(self.fakeA) self.lossDisA = self.backward_dis_basic(self.netDisA, self.realA, fakeA) def backward_dis_B(self): fakeB = self.fakeBPool.query(self.fakeB) self.lossDisB = self.backward_dis_basic(self.netDisB, self.realB, fakeB) def backward_gen(self, retain_graph=False): lambdaIdt = self.opt.lambdaIdentity lambdaA = self.opt.lambdaA lambdaB = self.opt.lambdaB # Identity loss self.forward() if lambdaIdt > 0: # GenB should be identity if realA is fed. self.idtA = self.netGenB(self.realA) lossIdtA = self.criterionIdt(self.idtA, self.realA) * lambdaA * lambdaIdt # GenA should be identity if realB is fed. self.idtB = self.netGenA(self.realB) lossIdtB = self.criterionIdt(self.idtB, self.realB) * lambdaB * lambdaIdt else: lossIdtA = 0 lossIdtB = 0 # GAN D loss lossGenA = self.criterionGAN(self.netDisB(self.fakeB), True) # GAN D loss lossGenB = self.criterionGAN(self.netDisA(self.fakeA), True) # Forward cycle loss lossCycleA = self.criterionCycle(self.recA, self.realA) * lambdaA # Backward cycle loss lossCycleB = self.criterionCycle(self.recB, self.realB) * lambdaB # combined loss lossG = lossGenA + lossGenB + lossCycleA + lossCycleB + lossIdtA + lossIdtB lossG.backward(retain_graph=retain_graph) # move image to cpu self.lossGenA = float(lossGenA) self.lossGenB = float(lossGenB) self.lossCycleA = float(lossCycleA) self.lossCycleB = float(lossCycleB) self.lossIdtA = float(lossIdtA) self.lossIdtB = float(lossIdtB) def optimize_parameters(self): # GenA and GenB self.set_requires_grad([self.netDisA, self.netDisB], False) self.optimizerG.zero_grad() self.backward_gen() self.optimizerG.step() # DisA and DisB self.set_requires_grad([self.netDisA, self.netDisB], True) self.optimizerD.zero_grad() self.backward_dis_A() self.backward_dis_B() self.optimizerD.step()
def train(): if cfg.load_model is not None: checkpoints_dir = cfg.load_model graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN() G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() G_optimizers, D_optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, gan=cfg.gan) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(cfg.tb_dir, graph) #for v in tf.global_variables(): # print(v.name) if cfg.new_pretrain is not None: var_to_restore = [] for v in tf.global_variables(): var_to_restore.append(v) saver = tf.train.Saver(var_to_restore) saver_dump = tf.train.Saver() else: saver = tf.train.Saver() saver_dump = tf.train.Saver() with tf.Session(graph=graph) as sess: if cfg.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[1].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 print( '--------------------------------------------------------------------------------' ) if cfg.new_pretrain is not None: saver.restore(sess, cfg.new_pretrain) ## TODO dataset trainA = Dataset(cfg.trainA_dir) trainB = Dataset(cfg.trainB_dir) # train coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) D_times = 0 G_train_times = 0 try: fake_Y_pool = ImagePool(cfg.pool_size) fake_X_pool = ImagePool(cfg.pool_size) while not coord.should_stop(): st_t = time.time() # generate data x_image = sess.run(trainA.data)[0] #x_image = x_image + tf.random_normal(shape=tf.shape(x_image), mean=0.0, stddev=0.1, dtype=tf.float32) y_image = sess.run(trainB.data)[0] # y_image = y_image + tf.random_normal(shape=tf.shape(y_image), mean=0.0, stddev=0.1, dtype=tf.float32) data_time = time.time() - st_t st_t = time.time() # generate fake_x, fake_y fake_y_val, fake_x_val = sess.run([fake_y, fake_x], feed_dict={ cycle_gan.x_image: x_image, cycle_gan.y_image: y_image }) gen_fake_time = time.time() - st_t st_t = time.time() # train # Discrminator _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = \ sess.run([D_optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val), cycle_gan.x_image: x_image, cycle_gan.y_image: y_image}) if D_times > 0 and D_times % cfg.D_times == 0: D_times = 0 G_train_times += 1 _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = \ sess.run([G_optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val), cycle_gan.x_image: x_image, cycle_gan.y_image: y_image}) bp_time = time.time() - st_t train_writer.add_summary(summary, step) train_writer.flush() if step % 1 == 0: logging.info( 'step {} | G_loss : {:.4f} | D_Y_loss : {:.4f} | F_loss : {:.4f} |' 'D_X_loss : {:.4f} | g_train_times: {} | data {:.3f}s | gen_fake {:.3f}s | bp {:.3f}s' .format(step, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, G_train_times, data_time, gen_fake_time, bp_time)) if step % 100 == 0: save_path = saver_dump.save(sess, cfg.model_dump_dir + '/model.ckpt', global_step=step) logging.info('model saved in files: %s' % save_path) D_times += 1 step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver_dump.save(sess, cfg.model_dump_dir + '/model.ckpt', global_step=step) logging.info('model saved in files: %s' % save_path) coord.request_stop() coord.join(threads)
def train(): # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练,如果为None则建立一个新的,以时间命名的文件夹存储训练结果 if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) os.makedirs(FLAGS.res_im_path) except os.error: pass graph = tf.Graph() with graph.as_default(): # 初始化 cyclegan 类 cycle_gan = CycleGAN(FLAGS) # 构建图 G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, real_y, real_x = cycle_gan.model( ) optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) # 初始化summary summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver(max_to_keep=10) with tf.Session(graph=graph) as sess: # 如果存储中间训练结果的路径设置不为None 就从路径中读取数据继续训练 if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 # 初始化样本队列 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: # 初始化在线样本池 fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val, real_y_in, real_x_in = sess.run( [fake_y, fake_x, real_y, real_x]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val) })) train_writer.add_summary(summary, step) train_writer.flush() # 输出当前状态 if step % 1 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 1000 == 0: ops.save_img_result(fake_y_val, fake_x_val, real_y_in, real_x_in, FLAGS.res_im_path, step) if step % 1000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 if step == FLAGS.epho: coord.request_stop() # 发出停止训练信号 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) ops.save_img_result(fake_y_val, fake_x_val, real_y_in, real_x_in, FLAGS.res_im_path, step) logging.info("Model saved in file: %s" % save_path) coord.request_stop() # 停止训练 coord.join(threads)
class CycleGANModel(): def __init__(self,opt): self.opt = opt self.dynamic = opt.dynamic self.isTrain = opt.istrain self.Tensor = torch.cuda.FloatTensor # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = state2img().cuda() self.netG_B = img2state().cuda() self.netF_A = Fmodel().cuda() self.dataF = CDFdata.get_loader(opt) self.train_forward() if self.isTrain: self.netD_A = imgDmodel().cuda() self.netD_B = stateDmodel().cuda() if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = torch.nn.MSELoss() self.criterionIdt = torch.nn.MSELoss() # initialize optimizers # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters())) self.optimizer_G = torch.optim.Adam([{'params':self.netG_A.parameters(),'lr':1e-3}, {'params':self.netF_A.parameters(),'lr':0.0}, {'params':self.netG_B.parameters(),'lr':1e-3}]) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: # self.schedulers.append(networks.get_scheduler(optimizer, opt)) self.schedulers.append(optimizer) print('---------- Networks initialized -------------') # networks.print_network(self.netG_A) # networks.print_network(self.netG_B) # if self.isTrain: # networks.print_network(self.netD_A) # networks.print_network(self.netD_B) print('-----------------------------------------------') def train_forward(self,pretrained=True): if pretrained: self.netF_A.load_state_dict(torch.load('./pred.pth')) return None optimizer = torch.optim.Adam(self.netF_A.parameters(),lr=1e-3) loss_fn = torch.nn.L1Loss() for epoch in range(100): epoch_loss = 0 for i,item in enumerate(tqdm(self.dataF)): state, action, result = item[1] state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_A(state, action) loss = loss_fn(out, result) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print('epoch:{} loss:{:.7f}'.format(epoch,epoch_loss/len(self.dataF))) print('forward model has been trained!') def set_input(self, input): # AtoB = self.opt.which_direction == 'AtoB' # input_A = input['A' if AtoB else 'B'] # input_B = input['B' if AtoB else 'A'] # A is state self.input_A = input[1][0] # B is img self.input_Bt0 = input[0][0] self.input_Bt1 = input[0][2] self.action = input[0][1] self.gt0 = input[2][0].float().cuda() self.gt1 = input[2][1].float().cuda() def forward(self): self.real_A = Variable(self.input_A).float().cuda() self.real_Bt0 = Variable(self.input_Bt0).float().cuda() self.real_Bt1 = Variable(self.input_Bt1).float().cuda() self.action = Variable(self.action).float().cuda() def test(self): self.forward() real_A = Variable(self.input_A, volatile=True).float().cuda() fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_Bt0, volatile=True).float().cuda() fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_Bt0, fake_B) self.loss_D_A = loss_D_A.item() def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_At0) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.item() def backward_G(self): lambda_idt = -0.5 lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B lambda_F = self.opt.lambda_F # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_Bt0) loss_idt_A = self.criterionIdt(idt_A, self.real_Bt0) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.item() self.loss_idt_B = loss_idt_B.item() else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 lambda_G_A = 1.0 lambda_G_B = 1.0 # --------first cycle-----------# # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G_A # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # ---------second cycle---------# # GAN loss D_B(G_B(B)) fake_At0 = self.netG_B(self.real_Bt0) pred_fake = self.netD_B(fake_At0) loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G_B # Backward cycle loss rec_Bt0 = self.netG_A(fake_At0) loss_cycle_Bt0 = self.criterionCycle(rec_Bt0, self.real_Bt0) * lambda_B # ---------third cycle---------# # GAN loss D_B(G_B(B)) fake_At1 = self.netF_A(fake_At0,self.action) pred_fake = self.netD_B(fake_At1) loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G_B # Backward cycle loss rec_Bt1 = self.netG_A(fake_At1) loss_cycle_Bt1 = self.criterionCycle(rec_Bt1, self.real_Bt1) * lambda_F # combined loss loss_G = loss_idt_A + loss_idt_B loss_G = loss_G + loss_G_A + loss_cycle_A loss_G = loss_G + loss_G_Bt0 + loss_cycle_Bt0 if self.dynamic: loss_G = loss_G + loss_G_Bt1 + loss_cycle_Bt1 loss_G.backward() self.fake_B = fake_B.data self.fake_At0 = fake_At0.data self.fake_At1 = fake_At1.data self.rec_A = rec_A.data self.rec_Bt0 = rec_Bt0.data self.rec_Bt1 = rec_Bt1.data self.loss_G_A = loss_G_A.item() self.loss_G_Bt0 = loss_G_Bt0.item() self.loss_G_Bt1 = loss_G_Bt1.item() self.loss_cycle_A = loss_cycle_A.item() self.loss_cycle_Bt0 = loss_cycle_Bt0.item() self.loss_cycle_Bt1 = loss_cycle_Bt1.item() self.loss_state_l1 = self.criterionCycle(self.fake_At0, self.gt).item() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('L1',self.loss_state_l1), ('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B0', self.loss_G_Bt0), ('G_B1', self.loss_G_Bt1), ('Cyc_B0', self.loss_cycle_Bt0), ('Cyc_B1', self.loss_cycle_Bt1)]) # if self.opt.identity > 0.0: # ret_errors['idt_A'] = self.loss_idt_A # ret_errors['idt_B'] = self.loss_idt_B return ret_errors # helper saving function that can be used by subclasses def save_network(self, network, network_label, path): save_filename = 'model_{}.pth'.format(network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) def save(self, path): self.save_network(self.netG_A, 'G_A', path) self.save_network(self.netD_A, 'D_A', path) self.save_network(self.netG_B, 'G_B', path) self.save_network(self.netD_B, 'D_B', path) def load_network(self, network, network_label, path): weight_filename = 'model_{}.pth'.format(network_label) weight_path = os.path.join(path, weight_filename) network.load_state_dict(torch.load(weight_path)) def load(self,path): self.load_network(self.netG_A, 'G_A', path) self.load_network(self.netG_B, 'G_B', path) def plot_points(self,item,label): item = item.cpu().data.numpy() plt.scatter(item[:,0],item[:,1],label=label) def visual(self,path): imgs = [] for i in range(self.real_A.shape[0]): imgs_i = [self.real_Bt0[i], self.rec_Bt0[i]] imgs_i += [self.real_Bt1[i], self.rec_Bt1[i]] imgs_i += [self.fake_B[i]] imgs_i = torch.cat(imgs_i, 2).cpu() imgs.append(imgs_i) imgs = torch.cat(imgs, 1) imgs = (imgs + 1) / 2 imgs = transforms.ToPILImage()(imgs) imgs.save(path)
def train(): if FLAGS.load_model is not None: #如果该命令行参数不为空,则据此给出checkpoint_dir checkpoints_dir = "checkpoints/" + FLAGS.load_model else: #否则,根据当前时间,创建一个checkpoint_dir current_time = datetime.now().strftime("%Y%m%d - %H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() #创建计算图 with graph.as_default(): cycle_gan = CycleGAN(X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda1, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf) #引入CycleGAN网络 G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model( ) #返回值分别是:反向生成网络损失,正向判别函数损失,生成网络损失,逆向判别函数损失,正向生成的y,反向生成的x optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) #四个损失的优化器 summary_op = tf.summary.merge_all() #将一些信息显示在stdoutput中 train_writer = tf.summary.FileWriter(checkpoints_dir, graph) #将图保存在checkpoints_dir中 saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: #如果已存在训练模型,则加载继续训练 checkpoint = tf.train.get_checkpoint_state( checkpoints_dir) #将最新的model加载进来 meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) #加载model结构 restore.restore( sess, tf.train.latest_checkpoint(checkpoints_dir)) #加载最新的model模型参数 step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) #初始化全局变量 step = 0 coord = tf.train.Coordinator() #进行线程管理 threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLASG.pool_size) #设定image缓冲大小 fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): fake_y_val, fake_x_val = sess.run( [fake_y, fake_x]) #先得出generated image x,y??? #train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query( fake_y_val ), #将上述得到的fake_x,fake_y输入到optimizers,G_loss,...,中,优化; 假设,初始化F,D_y,然后根据x得到fake_y,然后根据G,D_x,y,得到fake_x,根据这些value:x,y,fake_x,fake_y,求上述的几个loss,利用优化器对其进行优化 cycle_gan.fake_x: fake_X_pool.query(fake_x_val) } #还是没太弄明白 为什么一会儿fake_y,一会儿self.fake_y;是要缓冲若干个fake_y??? )) #进行训练 if step % 100 == 0: #到100步时,将信息输出到stdout train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('----------step %d:--------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save( sess, checkpoints_dir + "/model.ckpt", global_step=step) #训练完成后,将训练好的model保存起来.ckpt; logging.info("Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads)
class CycleGAN(ModelBackbone): def __init__(self, p): super(CycleGAN, self).__init__(p) nb = p.batchSize size = p.cropSize # load/define models # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(p.input_nc, p.output_nc, p.ngf, p.which_model_netG, p.norm, not p.no_dropout, p.init_type, self.gpu_ids) self.netG_B = networks.define_G(p.output_nc, p.input_nc, p.ngf, p.which_model_netG, p.norm, not p.no_dropout, p.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = p.no_lsgan self.netD_A = networks.define_D(p.output_nc, p.ndf, p.which_model_netD, p.n_layers_D, p.norm, use_sigmoid, p.init_type, self.gpu_ids) self.netD_B = networks.define_D(p.input_nc, p.ndf, p.which_model_netD, p.n_layers_D, p.norm, use_sigmoid, p.init_type, self.gpu_ids) if not self.isTrain or p.continue_train: which_epoch = p.which_epoch self.load_model(self.netG_A, 'G_A', which_epoch) self.load_model(self.netG_B, 'G_B', which_epoch) if self.isTrain: self.load_model(self.netD_A, 'D_A', which_epoch) self.load_model(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = p.lr self.fake_A_pool = ImagePool(p.pool_size) self.fake_B_pool = ImagePool(p.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not p.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A.parameters(), self.netG_B.parameters()), lr=p.lr, betas=(p.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=p.lr, betas=(p.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=p.lr, betas=(p.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, p)) # print('---------- Networks initialized -------------') # networks.print_network(self.netG_A) # networks.print_network(self.netG_B) # if self.isTrain: # networks.print_network(self.netD_A) # networks.print_network(self.netD_B) # print('-----------------------------------------------') def name(self): return 'CycleGAN' def set_input(self, inp): AtoB = self.p.which_direction == 'AtoB' input_A = inp['A' if AtoB else 'B'] input_B = inp['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0]) input_B = input_B.cuda(self.gpu_ids[0]) self.input_A = input_A self.input_B = input_B self.image_paths = inp['A_path' if AtoB else 'B_path'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): with torch.no_grad(): real_A = Variable(self.input_A) fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B) fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.item() def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.item() def backward_G(self): lambda_idt = self.p.identity lambda_A = self.p.lambda_A lambda_B = self.p.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.item() self.loss_idt_B = loss_idt_B.item() else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # print("loss_cycle_A ",loss_cycle_A.grad) # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.item() self.loss_G_B = loss_G_B.item() self.loss_cycle_A = loss_cycle_A.item() self.loss_cycle_B = loss_cycle_B.item() # def get_gradient(self, img): # # print("get_gradient ",img) # # return np.gradient(np.array(img)) # return np.gradient(img) def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.p.identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors def get_current_visuals(self): real_A = tensor2im(self.input_A) fake_B = tensor2im(self.fake_B) rec_A = tensor2im(self.rec_A) real_B = tensor2im(self.input_B) fake_A = tensor2im(self.fake_A) rec_B = tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.isTrain and self.p.identity > 0.0: ret_visuals['idt_A'] = tensor2im(self.idt_A) ret_visuals['idt_B'] = tensor2im(self.idt_B) return ret_visuals def save(self, label): self.save_model(self.netG_A, 'G_A', label, self.gpu_ids) self.save_model(self.netD_A, 'D_A', label, self.gpu_ids) self.save_model(self.netG_B, 'G_B', label, self.gpu_ids) self.save_model(self.netD_B, 'D_B', label, self.gpu_ids)
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() variable_to_restore = [] with graph.as_default(): segmentation = SegmentationNN('combined_model') variable_to_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) cycle_gan = CycleGAN(X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf, segmentation=segmentation) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) print('variables', variable_to_restore) restore1 = tf.train.Saver(variable_to_restore) restore1.restore(sess, 'Segmentation/lib/real.ckpt') step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) R = 95 * np.ones([cycle_gan.batch_size, 256, 256]) G = 40 * np.ones([cycle_gan.batch_size, 256, 256]) B = 20 * np.ones([cycle_gan.batch_size, 256, 256]) ones = np.ones(([cycle_gan.batch_size, 256, 256, 3])) RGB = np.stack([R, G, B], axis=3) _fake_x = fake_X_pool.query(fake_x_val) _fake_y = fake_Y_pool.query(fake_y_val) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: _fake_x, cycle_gan.fake_x: _fake_y, cycle_gan.RGB: RGB, cycle_gan.ones: ones, cycle_gan.covered1: False, cycle_gan.covered2: True })) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
x_last_test_predict_list = [] for last_i in range(len(testimage_x_list)): x_last_test_predict_list.append(0) y_last_test_predict_list = [] for last_i in range(len(testimage_x_list)): y_last_test_predict_list.append(0) while not coord.should_stop(): if step <= 25: for i in range(FLAGS.dis_pretrain): _, fake_y_val, fake_x_val = sess.run( [D_optimizer, fake_y, fake_x], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val) }) _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, fake_y_val, fake_x_val,\ real_y, real_x, reconstructed_y_val, reconstructed_x_val, summary, \ D_Y_output_real_1_val, D_Y_output_fake_forG_1_val, D_Y_output_real_2_val, D_Y_output_fake_forG_2_val, \ D_X_output_real_1_val, D_X_output_fake_forG_1_val, D_X_output_real_2_val, D_X_output_fake_forG_2_val = \ sess.run([optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, y, x, reconstructed_y, reconstructed_x, summary_op, D_Y_output_real_1, D_Y_output_fake_forG_1, D_Y_output_real_2, D_Y_output_fake_forG_2, D_X_output_real_1, D_X_output_fake_forG_1, D_X_output_real_2, D_X_output_fake_forG_2], feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val)})
class resgan(BaseModel): def init_architecture(self, opt): self.opt = opt self.netG = define_G(opt.in_nc, opt.out_nc, opt.nz, opt.ngf, which_model_netG=opt.G_model) if opt.use_gpu: self.netG.cuda() if opt.isTrain: self.netD = define_D(opt.in_nc, opt.ngf, 'basic_128') self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.use_gpu: self.netD.cuda() self.optimizers = [self.optimizer_G, self.optimizer_D] self.fake_A_pool = ImagePool(opt.pool_size) def forward(self): self.real_A = self.input_A self.real_B = self.input_B self.G_fake_B = self.netG(self.real_A) def update_D(self, netD, real, fake, optim): D_fake = netD(self.fake_A_pool.query(fake.data)) D_real = netD(real) D_fake_loss = self.critGAN(D_fake, False) D_real_loss = self.critGAN(D_real, True) D_loss = (D_fake_loss + D_real_loss) * 0.5 optim.zero_grad() D_loss.backward() optim.step() return D_loss def update_G(self): loss_G = 0 pred_fake = self.netD(self.G_fake_B) loss_GAN = self.critGAN(pred_fake, True) loss_G += loss_GAN noise_est = self.real_A - self.G_fake_B noisy_est = self.real_B + noise_est rec_B = self.netG(noisy_est) rec_loss = self.critL1(rec_B, self.real_B) loss_G += self.opt.l1_lambda * rec_loss self.optimizer_G.zero_grad() loss_G.backward() self.optimizer_G.step() return loss_G def optimize_parameters(self): self.forward() self.update_G() self.update_D(self.netD, self.real_B, self.G_fake_B.detach(), self.optimizer_D) def save(self): self.save_network(self.netG, 'G') self.save_network(self.netD, 'D')
class Model: def initialize(self, cfg): self.cfg = cfg ## set devices if cfg['GPU_IDS']: assert (torch.cuda.is_available()) self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0])) torch.backends.cudnn.benchmark = True print('Using %d GPUs' % len(cfg['GPU_IDS'])) else: self.device = torch.device('cpu') # define network if cfg['ARCHI'] == 'alexnet': self.netB = networks.netB_alexnet() self.netH = networks.netH_alexnet() if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = networks.netD_alexnet(self.cfg['DA_LAYER']) elif cfg['ARCHI'] == 'vgg16': raise NotImplementedError self.netB = networks.netB_vgg16() self.netH = networks.netH_vgg16() if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = netD_vgg16(self.cfg['DA_LAYER']) elif 'resnet' in cfg['ARCHI']: self.netB = networks.netB_resnet34() self.netH = networks.netH_resnet34() if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = networks.netD_resnet(self.cfg['DA_LAYER']) else: raise ValueError('Un-supported network') ## initialize network param. self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier') self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier') if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier') # loss, optimizer, and scherduler if cfg['TRAIN']: self.total_steps = 0 ## Output path self.save_dir = os.path.join( cfg['OUTPUT_PATH'], cfg['ARCHI'], datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) if not os.path.isdir(self.save_dir): os.makedirs(self.save_dir) # self.logger = Logger(self.save_dir) ## model names self.model_names = ['netB', 'netH'] ## loss self.criterionGAN = networks.GANLoss().to(self.device) self.criterionDepth1 = torch.nn.MSELoss().to(self.device) self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device) self.criterionEdge = torch.nn.BCELoss().to(self.device) ## optimizers self.lr = cfg['LR'] self.optimizers = [] self.optimizer_B = torch.optim.Adam(self.netB.parameters(), lr=cfg['LR'], betas=(cfg['BETA1'], cfg['BETA2'])) self.optimizer_H = torch.optim.Adam(self.netH.parameters(), lr=cfg['LR'], betas=(cfg['BETA1'], cfg['BETA2'])) self.optimizers.append(self.optimizer_B) self.optimizers.append(self.optimizer_H) if cfg['USE_DA']: self.real_pool = ImagePool(cfg['POOL_SIZE']) self.syn_pool = ImagePool(cfg['POOL_SIZE']) self.model_names.append('netD') ## use SGD for discriminator self.optimizer_D = torch.optim.SGD( self.netD.parameters(), lr=cfg['LR'], momentum=cfg['MOMENTUM'], weight_decay=cfg['WEIGHT_DECAY']) self.optimizers.append(self.optimizer_D) ## LR scheduler self.schedulers = [ networks.get_scheduler(optimizer, cfg) for optimizer in self.optimizers ] else: ## testing self.model_names = ['netB', 'netH'] self.criterionDepth1 = torch.nn.MSELoss().to(self.device) self.criterionNorm = torch.nn.CosineEmbeddingLoss( reduction='none').to(self.device) self.load_dir = os.path.join(cfg['CKPT_PATH']) self.criterionNorm_eval = torch.nn.CosineEmbeddingLoss( reduction='none').to(self.device) if cfg['TEST'] or cfg['RESUME']: self.load_networks(cfg['EPOCH_LOAD']) def set_input(self, inputs): if self.cfg['GRAY']: _ch = np.random.randint(3) _syn = inputs['syn']['color'][:, _ch, :, :] self.input_syn_color = torch.stack((_syn, _syn, _syn), dim=1).to(self.device) else: self.input_syn_color = inputs['syn']['color'].to(self.device) self.input_syn_dep = inputs['syn']['depth'].to(self.device) self.input_syn_edge = inputs['syn']['edge'].to(self.device) self.input_syn_edge_count = inputs['syn']['edge_pix'].to(self.device) self.input_syn_norm = inputs['syn']['normal'].to(self.device) if self.cfg['USE_DA']: if self.cfg['GRAY']: _ch = np.random.randint(3) _real = inputs['real'][0][:, _ch, :, :] self.input_real_color = torch.stack((_real, _real, _real), dim=1).to(self.device) else: self.input_real_color = inputs['real'][0].to(self.device) def forward(self): self.feat_syn = self.netB(self.input_syn_color) # TODO: make it compatible with other networks in addition to ResNet self.head_pred = self.netH(self.feat_syn['layer1'], self.feat_syn['layer2'], self.feat_syn['layer3'], self.feat_syn['layer4']) if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.feat_real = self.netB(self.input_real_color) self.pred_D_real = self.netD(self.feat_real[self.cfg['DA_LAYER']]) self.pred_D_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']]) return self.head_pred def backward_BH(self): ## forward to compute prediction # TODO: replace this with self.head_pred to avoid computation twice self.task_pred = self.netH(self.feat_syn['layer1'], self.feat_syn['layer2'], self.feat_syn['layer3'], self.feat_syn['layer4']) # depth depth_diff = self.task_pred['depth'] - self.input_syn_dep _n = self.task_pred['depth'].size(0) * self.task_pred['depth'].size( 2) * self.task_pred['depth'].size(3) loss_depth2 = depth_diff.sum().pow_(2).div_(_n).div_(_n) loss_depth1 = self.criterionDepth1(self.task_pred['depth'], self.input_syn_dep) self.loss_dep = self.cfg['DEP_WEIGHT'] * (loss_depth1 - loss_depth2) * 0.5 # surface normal ch = self.task_pred['norm'].size(1) _pred = self.task_pred['norm'].permute(0, 2, 3, 1).contiguous().view(-1, ch) _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch) _gt = (_gt / 127.5) - 1 _pred = torch.nn.functional.normalize(_pred, dim=1) self.task_pred['norm'] = _pred.view(self.task_pred['norm'].size(0), self.task_pred['norm'].size(2), self.task_pred['norm'].size(3), 3).permute(0, 3, 1, 2) self.task_pred['norm'] = (self.task_pred['norm'] + 1) * 127.5 cos_label = torch.ones(_gt.size(0)).to(self.device) self.loss_norm = self.cfg['NORM_WEIGHT'] * self.criterionNorm( _pred, _gt, cos_label) # # edge self.loss_edge = self.cfg['EDGE_WEIGHT'] * self.criterionEdge( self.task_pred['edge'], self.input_syn_edge) ## combined loss loss = self.loss_dep + self.loss_norm + self.loss_edge if self.cfg['USE_DA']: pred_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']]) self.loss_DA = self.criterionGAN(pred_syn, True) loss += self.loss_DA * self.cfg['DA_WEIGHT'] loss.backward() def backward_D(self): ## Synthetic # stop backprop to netB by detaching _feat_s = self.syn_pool.query( self.feat_syn[self.cfg['DA_LAYER']].detach().cpu()) pred_syn = self.netD(_feat_s.to(self.device)) self.loss_D_syn = self.criterionGAN(pred_syn, False) ## Real _feat_r = self.real_pool.query( self.feat_real[self.cfg['DA_LAYER']].detach().cpu()) pred_real = self.netD(_feat_r.to(self.device)) self.loss_D_real = self.criterionGAN(pred_real, True) ## Combined self.loss_D = (self.loss_D_syn + self.loss_D_real) * 0.5 self.loss_D.backward() def optimize(self): self.total_steps += 1 self.train_mode() self.forward() # if DA, update on real data if self.cfg['USE_DA']: self.set_requires_grad(self.netD, True) self.set_requires_grad([self.netB, self.netH], False) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update on synthetic data if self.cfg['USE_DA']: self.set_requires_grad([self.netB, self.netH], True) self.set_requires_grad(self.netD, False) self.optimizer_B.zero_grad() self.optimizer_H.zero_grad() self.backward_BH() self.optimizer_B.step() self.optimizer_H.step() def train_mode(self): self.netB.train() self.netH.train() if self.cfg['USE_DA']: self.netD.train() # make models eval mode during test time def eval_mode(self): self.netB.eval() self.netH.eval() if self.cfg['USE_DA']: self.netD.eval() def angle_error_ratio(self, angle_degree, base_angle_degree): logic_map = torch.gt( base_angle_degree * torch.ones_like(angle_degree, device=self.device), angle_degree) if len(logic_map.size()) == 1: num_pixels = torch.sum(logic_map).float() else: num_pixels = torch.sum(logic_map, dim=1).float() ratio = torch.div( num_pixels, torch.tensor(angle_degree.nelement(), device=self.device, dtype=torch.float64)) return ratio, logic_map def normal_angle(self): # surface normal ch = self.head_pred['norm'].size(1) _pred = self.head_pred['norm'].permute(0, 2, 3, 1).contiguous().view(-1, ch) _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch) _gt = (_gt / 127.5) - 1 _pred = torch.nn.functional.normalize(_pred, dim=1) _gt = torch.nn.functional.normalize(_gt, dim=1) cos_label = torch.ones(_gt.size(0)).to(self.device) norm_diff = self.criterionNorm_eval(_pred, _gt, cos_label) cos_val = 1 - norm_diff cos_val = torch.max(cos_val, -torch.ones_like(cos_val, device=self.device)) cos_val = torch.min(cos_val, torch.ones_like(cos_val, device=self.device)) angle_rad = torch.acos(cos_val) angle_degree = angle_rad / 3.14 * 180 return angle_degree def test(self): self.eval_mode() with torch.no_grad(): self.forward() angle_degree = self.normal_angle() # ratio metrics ratio_11, _ = self.angle_error_ratio(angle_degree, 11.25) ratio_22, _ = self.angle_error_ratio(angle_degree, 22.5) ratio_30, _ = self.angle_error_ratio(angle_degree, 30.0) ratio_45, _ = self.angle_error_ratio(angle_degree, 45.0) # image-wise metrics batch_size = self.head_pred['norm'].size(0) # TODO double check if it's image-wise batch_angles = angle_degree.view(batch_size, -1) image_mean = torch.mean(batch_angles, dim=1) image_score, _ = self.angle_error_ratio(batch_angles, 45.0) return { 'batch_size': batch_size, 'pixel_error': angle_degree.cpu().detach().numpy(), 'image_mean': image_mean.cpu().detach().numpy(), 'image_score': image_score.cpu().detach().numpy(), 'ratio_11': ratio_11.cpu().detach().numpy(), 'ratio_22': ratio_22.cpu().detach().numpy(), 'ratio_30': ratio_30.cpu().detach().numpy(), 'ratio_45': ratio_45.cpu().detach().numpy() } def out_logic_map(self, epoch_num, img_num): self.eval_mode() with torch.no_grad(): self.forward() # surface normal batch_size = self.head_pred['norm'].size(0) ch = self.head_pred['norm'].size(1) _pred = self.head_pred['norm'].permute(0, 2, 3, 1).contiguous().view(-1, ch) _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1, ch) _gt = (_gt / 127.5) - 1 _pred = torch.nn.functional.normalize(_pred, dim=1) _gt = torch.nn.functional.normalize(_gt, dim=1) cos_label = torch.ones(_gt.size(0)).to(self.device) norm_diff = self.criterionNorm_eval(_pred, _gt, cos_label) cos_val = 1 - norm_diff cos_val = torch.max(cos_val, -torch.ones_like(cos_val, device=self.device)) cos_val = torch.min(cos_val, torch.ones_like(cos_val, device=self.device)) angle_rad = torch.acos(cos_val) angle_degree = angle_rad / 3.14 * 180 # ratio metrics ratio_11, c = self.angle_error_ratio(angle_degree, 11.25) good_pixel_img = torch.cat( (c.view(-1, 1), c.view(-1, 1), c.view(-1, 1)), 1).view(self.head_pred['norm'].size(0), self.head_pred['norm'].size(2), self.head_pred['norm'].size(3), 3).permute(0, 3, 1, 2) self.head_pred['norm'] = _pred.view(self.head_pred['norm'].size(0), self.head_pred['norm'].size(2), self.head_pred['norm'].size(3), 3).permute(0, 3, 1, 2) self.head_pred['norm'] = (self.head_pred['norm'] + 1) * 127.5 vis_norm = torch.cat((self.input_syn_norm, self.head_pred['norm'], good_pixel_img.float() * 255), dim=0) map_path = '%s/ep%d' % (self.cfg['VIS_PATH'], epoch_num) if not os.path.isdir(map_path): os.makedirs(map_path) torchvision.utils.save_image(vis_norm.detach(), '%s/%d_norm.jpg' % (map_path, img_num), nrow=1, normalize=True) # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() self.lr = self.cfgimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % self.lr) # return visualization images. train.py will save the images. def visualize_pred(self, ep=0): vis_dir = os.path.join(self.save_dir, 'vis') if not os.path.isdir(vis_dir): os.makedirs(vis_dir) if self.total_steps % self.cfg['VIS_FREQ'] == 0: num_pic = min(10, self.task_pred['norm'].size(0)) torchvision.utils.save_image(self.input_syn_color[0:num_pic].cpu(), '%s/ep_%d_iter_%d_color.jpg' % (vis_dir, ep, self.total_steps), nrow=num_pic, normalize=True) vis_norm = torch.cat((self.input_syn_norm[0:num_pic], self.task_pred['norm'][0:num_pic]), dim=0) torchvision.utils.save_image(vis_norm.detach(), '%s/ep_%d_iter_%d_norm.jpg' % (vis_dir, ep, self.total_steps), nrow=num_pic, normalize=True) vis_depth = torch.cat((self.input_syn_dep[0:num_pic], self.task_pred['depth'][0:num_pic]), dim=0) torchvision.utils.save_image(vis_depth.detach(), '%s/ep_%d_iter_%d_depth.jpg' % (vis_dir, ep, self.total_steps), nrow=num_pic, normalize=True) # TODO Jason: visualization # edge_vis = torch.nn.functional.sigmoid(self.task_pred['edge']) # vis_edge = torch.cat((self.input_syn_edge[0:num_pic], edge_vis[0:num_pic]), dim=0) # torchvision.utils.save_image(vis_edge.detach(), # '%s/ep_%d_iter_%d_edge.jpg' % (vis_dir,ep,self.total_steps), # nrow=num_pic, normalize=False) if self.cfg['USE_DA']: torchvision.utils.save_image( self.input_real_color[0:num_pic].cpu(), '%s/ep_%d_iter_%d_real.jpg' % (vis_dir, ep, self.total_steps), nrow=num_pic, normalize=True) print('==> Saved epoch %d total step %d visualization to %s' % (ep, self.total_steps, vis_dir)) # print on screen, log into tensorboard def print_n_log_losses(self, ep=0): if self.total_steps % self.cfg['PRINT_FREQ'] == 0: print('\nEpoch: %d Total_step: %d LR: %f' % (ep, self.total_steps, self.lr)) # print('Train on tasks: Loss_dep: %.4f | Loss_edge: %.4f | Loss_norm: %.4f' # % (self.loss_dep, self.loss_edge, self.loss_norm)) print('Train on tasks: Loss_dep: %.4f | Loss_norm: %.4f' % (self.loss_dep, self.loss_norm)) info = { 'loss_dep': self.loss_dep, 'loss_norm': self.loss_norm #, # 'loss_edge': self.loss_edge } if self.cfg['USE_DA']: print( 'Train for DA: Loss_D_syn: %.4f | Loss_D_real: %.4f | Loss_DA: %.4f' % (self.loss_D_syn, self.loss_D_real, self.loss_DA)) info['loss_D_syn'] = self.loss_D_syn info['loss_D_real'] = self.loss_D_real info['loss_DA'] = self.loss_DA # for tag, value in info.items(): # self.logger.scalar_summary(tag, value, self.total_steps) # save models to the disk def save_networks(self, which_epoch): for name in self.model_names: save_filename = '%s_ep%s.pth' % (name, which_epoch) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, name) if isinstance(net, torch.nn.DataParallel): torch.save(net.module.cpu().state_dict(), save_path) else: torch.save(net.cpu().state_dict(), save_path) print('==> Saved networks to %s' % save_path) if torch.cuda.is_available: net.cuda(self.device) # load models from the disk def load_networks(self, which_epoch): self.save_dir = self.load_dir print('loading networks...') if which_epoch == 'None': print('epoch is None') exit() for name in self.model_names: if isinstance(name, str): load_filename = '%s_ep%s.pth' % (name, which_epoch) load_path = os.path.join(self.load_dir, load_filename) net = getattr(self, name) if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=str(self.device)) net.load_state_dict(state_dict) # set requies_grad=Fasle to avoid computation def set_requires_grad(self, nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad
class CycleGANModel(): def __init__(self, opt): self.opt = opt self.dynamic = opt.dynamic self.isTrain = opt.istrain self.Tensor = torch.cuda.FloatTensor # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_B = img2state().cuda() self.netF_A = Fmodel().cuda() self.dataF = CDFdata.get_loader(opt) self.train_forward(pretrained=True) self.gt_buffer = [] self.pred_buffer = [] # if self.isTrain: self.netD_B = stateDmodel().cuda() # if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = torch.nn.MSELoss() self.criterionIdt = torch.nn.MSELoss() # initialize optimizers # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters())) self.optimizer_G = torch.optim.Adam([{ 'params': self.netF_A.parameters(), 'lr': 0.0 }, { 'params': self.netG_B.parameters(), 'lr': 1e-3 }]) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) print('---------- Networks initialized ---------------') print('-----------------------------------------------') def train_forward(self, pretrained=False): if pretrained: self.netF_A.load_state_dict(torch.load('./pred_large.pth')) return None optimizer = torch.optim.Adam(self.netF_A.parameters(), lr=1e-3) loss_fn = torch.nn.L1Loss() for epoch in range(10): epoch_loss = 0 for i, item in enumerate(tqdm(self.dataF)): state, action, result = item[1] state = state.float().cuda() action = action.float().cuda() result = result.float().cuda() out = self.netF_A(state, action) loss = loss_fn(out, result) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print('epoch:{} loss:{:.7f}'.format(epoch, epoch_loss / len(self.dataF))) torch.save(self.netF_A.state_dict(), './pred_large.pth') print('forward model has been trained!') def set_input(self, input): # AtoB = self.opt.which_direction == 'AtoB' # input_A = input['A' if AtoB else 'B'] # input_B = input['B' if AtoB else 'A'] # A is state self.input_A = input[1][0] # B is img self.input_Bt0 = input[0][0] self.input_Bt1 = input[0][2] self.action = input[0][1] self.gt0 = input[2][0].float().cuda() self.gt1 = input[2][1].float().cuda() def forward(self): self.real_A = Variable(self.input_A).float().cuda() self.real_Bt0 = Variable(self.input_Bt0).float().cuda() self.real_Bt1 = Variable(self.input_Bt1).float().cuda() self.action = Variable(self.action).float().cuda() def test(self): # forward self.forward() # G_A and G_B self.backward_G() self.backward_D_B() def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward if self.isTrain: loss_D.backward() return loss_D def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_At0) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.item() def backward_G(self): lambda_G_B0 = 1.0 lambda_G_B1 = 1.0 lambda_F = 500.0 # GAN loss D_B(G_B(B)) fake_At0 = self.netG_B(self.real_Bt0) pred_fake = self.netD_B(fake_At0) loss_G_Bt0 = self.criterionGAN(pred_fake, True) * lambda_G_B0 # GAN loss D_B(G_B(B)) fake_At1 = self.netF_A(fake_At0, self.action) pred_fake = self.netD_B(fake_At1) loss_G_Bt1 = self.criterionGAN(pred_fake, True) * lambda_G_B1 # cycle loss pred_At1 = self.netG_B(self.real_Bt1) cycle_label = torch.zeros_like(fake_At1).float().cuda() loss_cycle = self.criterionCycle(fake_At1 - pred_At1, cycle_label) * lambda_F # combined loss loss_G = loss_G_Bt0 + loss_G_Bt1 + loss_cycle if self.isTrain: loss_G.backward() self.fake_At0 = fake_At0.data self.fake_At1 = fake_At1.data self.loss_G_Bt0 = loss_G_Bt0.item() self.loss_G_Bt1 = loss_G_Bt1.item() self.loss_cycle = loss_cycle.item() self.loss_state_lt0 = self.criterionCycle(self.fake_At0, self.gt0).item() self.loss_state_lt1 = self.criterionCycle(self.fake_At1, self.gt1).item() self.gt_buffer.append(self.gt0.cpu().data.numpy()) self.gt_buffer.append(self.gt1.cpu().data.numpy()) self.pred_buffer.append(self.fake_At0.cpu().data.numpy()) self.pred_buffer.append(self.fake_At1.cpu().data.numpy()) def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('L_t0', self.loss_state_lt0), ('L_t1', self.loss_state_lt1), ('D_B', self.loss_D_B), ('G_B0', self.loss_G_Bt0), ('G_B1', self.loss_G_Bt1), ('Cyc', self.loss_cycle)]) # if self.opt.identity > 0.0: # ret_errors['idt_A'] = self.loss_idt_A # ret_errors['idt_B'] = self.loss_idt_B return ret_errors # helper saving function that can be used by subclasses def save_network(self, network, network_label, path): save_filename = 'model_{}.pth'.format(network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) def save(self, path): self.save_network(self.netG_B, 'G_B2', path) self.save_network(self.netD_B, 'D_B2', path) def load_network(self, network, network_label, path): weight_filename = 'model_{}.pth'.format(network_label) weight_path = os.path.join(path, weight_filename) network.load_state_dict(torch.load(weight_path)) def load(self, path): self.load_network(self.netG_B, 'G_B', path) def show_points(self): # num_images = min(imgs.shape[0],num_images) ncols = 1 nrows = 3 _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3)) axes = axes.flatten() gt_data = np.vstack(self.gt_buffer) pred_data = np.vstack(self.pred_buffer) print(abs(gt_data - pred_data).mean(0)) for ax_i, ax in enumerate(axes): if ax_i < nrows: ax.scatter(gt_data[:, ax_i], pred_data[:, ax_i], s=3, label='xyz_{}'.format(ax_i)) else: ax.scatter(self.npdata(self.fake_At1[:, ax_i - nrows]), self.npdata(self.gt1[:, ax_i - nrows]), label='t1_{}'.format(ax_i - nrows)) def npdata(self, item): return item.cpu().data.numpy() def visual(self, path): # plt.xlim(-4,4) # plt.ylim(-1.5,1.5) self.show_points() plt.legend() plt.savefig(path) plt.cla() plt.clf() self.gt_buffer = [] self.pred_buffer = []
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN( X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda1, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf ) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() GPU_OPTIONS=tf.GPUOptions(per_process_gpu_memory_fraction=0.5) DEVICE_CONFIG_GPU = tf.ConfigProto(device_count={"CPU": 6}, gpu_options=GPU_OPTIONS, intra_op_parallelism_threads=0, inter_op_parallelism_threads=0) with tf.Session(graph=graph, config=DEVICE_CONFIG_GPU) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val)} ) ) if step % 100 == 0: train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
class CycleGANModel(): def __init__(self,opt): self.opt = opt self.isTrain = opt.istrain self.Tensor = torch.cuda.FloatTensor # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = GModel(opt).cuda() self.netG_B = GModel(opt).cuda() if self.isTrain: self.netD_A = DModel(opt).cuda() self.netD_B = DModel(opt).cuda() if self.isTrain: self.fake_A_pool = ImagePool(pool_size=128) self.fake_B_pool = ImagePool(pool_size=128) # define loss functions self.criterionGAN = GANLoss(tensor=self.Tensor).cuda() if opt.loss == 'l1': self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() elif opt.loss == 'l2': self.criterionCycle = torch.nn.MSELoss() self.criterionIdt = torch.nn.MSELoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters())) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters()) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters()) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D_A) self.optimizers.append(self.optimizer_D_B) for optimizer in self.optimizers: # self.schedulers.append(networks.get_scheduler(optimizer, opt)) self.schedulers.append(optimizer) print('---------- Networks initialized -------------') # networks.print_network(self.netG_A) # networks.print_network(self.netG_B) # if self.isTrain: # networks.print_network(self.netD_A) # networks.print_network(self.netD_B) print('-----------------------------------------------') def set_input(self, input): # AtoB = self.opt.which_direction == 'AtoB' # input_A = input['A' if AtoB else 'B'] # input_B = input['B' if AtoB else 'A'] self.input_A = input[0] self.input_B = input[1] def forward(self): self.real_A = Variable(self.input_A).float().cuda() self.real_B = Variable(self.input_B).float().cuda() def test(self): self.forward() real_A = Variable(self.input_A, volatile=True).float().cuda() fake_B = self.netG_A(real_A) self.rec_A = self.netG_B(fake_B).data self.fake_B = fake_B.data real_B = Variable(self.input_B, volatile=True).float().cuda() fake_A = self.netG_B(real_B) self.rec_B = self.netG_A(fake_A).data self.fake_A = fake_A.data def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = loss_D_A.item() def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = loss_D_B.item() def backward_G(self): lambda_idt = 0.5 lambda_A = self.opt.lambda_AB lambda_B = self.opt.lambda_AB # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A(self.real_B) loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. idt_B = self.netG_B(self.real_A) loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt self.idt_A = idt_A.data self.idt_B = idt_B.data self.loss_idt_A = loss_idt_A.item() self.loss_idt_B = loss_idt_B.item() else: loss_idt_A = 0 loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 lambda_G = 1.0 # GAN loss D_A(G_A(A)) fake_B = self.netG_A(self.real_A) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) * lambda_G # GAN loss D_B(G_B(B)) fake_A = self.netG_B(self.real_B) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) * lambda_G # Forward cycle loss rec_A = self.netG_B(fake_B) loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A # Backward cycle loss rec_B = self.netG_A(fake_A) loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.item() self.loss_G_B = loss_G_B.item() self.loss_cycle_A = loss_cycle_A.item() self.loss_cycle_B = loss_cycle_B.item() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() self.backward_D_B() self.optimizer_D_B.step() def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) # if self.opt.identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors # helper saving function that can be used by subclasses def save_network(self, network, network_label, path): save_filename = 'model_{}.pth'.format(network_label) save_path = os.path.join(path, save_filename) torch.save(network.state_dict(), save_path) def save(self, path): self.save_network(self.netG_A, 'G_A', path) self.save_network(self.netD_A, 'D_A', path) self.save_network(self.netG_B, 'G_B', path) self.save_network(self.netD_B, 'D_B', path) def load_network(self, network, network_label, path): weight_filename = 'model_{}.pth'.format(network_label) weight_path = os.path.join(path, weight_filename) network.load_state_dict(torch.load(weight_path)) def load(self,path): self.load_network(self.netG_A, 'G_A', path) self.load_network(self.netG_B, 'G_B', path) def visual(self,path): imgs = [] for i in range(self.real_A.shape[0]): imgs_i = [self.real_A[i], self.fake_B[i]] imgs_i += [self.rec_A[i], self.real_B[i]] imgs_i = torch.cat(imgs_i, 2).cpu() imgs.append(imgs_i) imgs = torch.cat(imgs, 1) imgs = (imgs + 1) / 2 imgs = transforms.ToPILImage()(imgs) imgs.save(path)
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/") else: #current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(nameNet) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): paired_gan = PairedGANDisenRevFullRep( XY_train_file=FLAGS.XY, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambdaRecon=FLAGS.lambdaRecon, lambdaAlign=FLAGS.lambdaAlign, lambdaRev=FLAGS.lambdaRev, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, nfS=FLAGS.nfS, nfE=FLAGS.nfE ) G_loss, D_Y_loss, Dex_Y_loss, F_loss, D_X_loss, Dex_X_loss, A_loss, DC_loss, fake_y, fake_x, fake_ex_y, fake_ex_x = paired_gan.model() #G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, fake_y, fake_x = paired_gan.model() optimizers = paired_gan.optimize(G_loss, D_Y_loss, Dex_Y_loss, F_loss, D_X_loss, Dex_X_loss, A_loss, DC_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) fake_ex_Y_pool = ImagePool(FLAGS.pool_size) fake_ex_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val, fake_ex_y_val, fake_ex_x_val = sess.run([fake_y, fake_x, fake_ex_y, fake_ex_x]) train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, A_loss_val, DC_loss_val, summary = ( sess.run( [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, DC_loss, summary_op], feed_dict={paired_gan.fake_y: fake_Y_pool.query(fake_y_val), paired_gan.fake_x: fake_X_pool.query(fake_x_val), paired_gan.fake_ex_y: fake_ex_Y_pool.query(fake_y_val), paired_gan.fake_ex_x: fake_ex_X_pool.query(fake_x_val)} ) ) #_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, A_loss_val, summary = ( #sess.run( #[optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, A_loss, summary_op], #feed_dict={paired_gan.fake_y: fake_Y_pool.query(fake_y_val), #paired_gan.fake_x: fake_X_pool.query(fake_x_val)} #) #) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) logging.info(' A_loss : {}'.format(A_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
class Model(): def initialize(self, cfg): self.cfg = cfg ## set devices if cfg['GPU_IDS']: assert(torch.cuda.is_available()) self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0])) torch.backends.cudnn.benchmark = True print('Using %d GPUs'% len(cfg['GPU_IDS'])) else: self.device = torch.device('cpu') # define network if cfg['ARCHI'] == 'alexnet': self.netB = networks.netB_alexnet() self.netH = networks.netH_alexnet() if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = networks.netD_alexnet(self.cfg['DA_LAYER']) elif cfg['ARCHI'] == 'vgg16': raise NotImplementedError self.netB = networks.netB_vgg16() self.netH = networks.netH_vgg16() if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = netD_vgg16(self.cfg['DA_LAYER']) elif 'resnet' in cfg['ARCHI']: raise NotImplementedError self.netB = networks.netB_resnet() self.netH = networks.netH_resnet() if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = networks.netD_resnet(self.cfg['DA_LAYER']) else: raise ValueError('Un-supported network') ## initialize network param. self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier') self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier') if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier') print(self.netB, self.netH, self.netD) # loss, optimizer, and scherduler if cfg['TRAIN']: self.total_steps = 0 ## Output path self.save_dir = os.path.join(cfg['OUTPUT_PATH'], cfg['ARCHI'], datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) if not os.path.isdir(self.save_dir): os.makedirs(self.save_dir) self.logger = Logger(self.save_dir) ## model names self.model_names = ['netB', 'netH'] ## loss self.criterionGAN = networks.GANLoss().to(self.device) self.criterionDepth1 = torch.nn.MSELoss().to(self.device) self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device) # define during running, rely on data weight self.criterionEdge = None ## optimizers self.lr = cfg['LR'] self.optimizers = [] self.optimizer_B = torch.optim.Adam(self.netB.parameters(), lr=cfg['LR'], betas=(cfg['BETA1'], cfg['BETA2'])) self.optimizer_H = torch.optim.Adam(self.netH.parameters(), lr=cfg['LR'], betas=(cfg['BETA1'], cfg['BETA2'])) self.optimizers.append(self.optimizer_B) self.optimizers.append(self.optimizer_H) if cfg['USE_DA']: self.real_pool = ImagePool(cfg['POOL_SIZE']) self.syn_pool = ImagePool(cfg['POOL_SIZE']) self.model_names.append('netD') ## use SGD for discriminator self.optimizer_D = torch.optim.SGD(self.netD.parameters(), lr=cfg['LR'], momentum=cfg['MOMENTUM'], weight_decay=cfg['WEIGHT_DECAY']) self.optimizers.append(self.optimizer_D) ## LR scheduler self.schedulers = [networks.get_scheduler(optimizer, cfg) for optimizer in self.optimizers] if cfg['TEST'] or cfg['RESUME']: self.load_networks(cfg['CKPT_PATH']) def set_input(self, inputs): if self.cfg['GRAY']: _ch = np.random.randint(3) _syn = inputs['syn']['color'][:, _ch, :, :] self.input_syn_color = torch.stack((_syn, _syn, _syn), dim=1).to(self.device) else: self.input_syn_color = inputs['syn']['color'].to(self.device) self.input_syn_dep = inputs['syn']['depth'].to(self.device) self.input_syn_edge = inputs['syn']['edge'].to(self.device) self.input_syn_edge_count = inputs['syn']['edge_pix'].to(self.device) self.input_syn_norm = inputs['syn']['normal'].to(self.device) if self.cfg['USE_DA']: if self.cfg['GRAY']: _ch = np.random.randint(3) _real = inputs['real'][0][:, _ch, :, :] self.input_real_color = torch.stack((_real, _real, _real), dim=1).to(self.device) else: self.input_real_color = inputs['real'][0].to(self.device) def forward(self): self.feat_syn = self.netB(self.input_syn_color) self.head_pred = self.netH(self.feat_syn['out']) if self.cfg['USE_DA'] and self.cfg['TRAIN']: self.feat_real = self.netB(self.input_real_color) self.pred_D_real = self.netD(self.feat_real[self.cfg['DA_LAYER']]) self.pred_D_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']]) def backward_BH(self): ## forward to compute prediction self.task_pred = self.netH(self.feat_syn['out']) # depth depth_diff = self.task_pred['depth'] - self.input_syn_dep _n = self.task_pred['depth'].size(0) * self.task_pred['depth'].size(2) * self.task_pred['depth'].size(3) loss_depth2 = depth_diff.sum().div_(_n).pow(2).mul_(0.5) loss_depth1 = self.criterionDepth1(self.task_pred['depth'], self.input_syn_dep) self.loss_dep = self.cfg['DEP_WEIGHT'] * (loss_depth1 + loss_depth2) * 0.5 # surface normal ch = self.task_pred['norm'].size(1) _pred = self.task_pred['norm'].permute(0, 2, 3, 1).contiguous().view(-1,ch) _gt = self.input_syn_norm.permute(0, 2, 3, 1).contiguous().view(-1,ch) _gt = (_gt / 127.5) - 1 _pred = torch.nn.functional.normalize(_pred, dim=1) self.task_pred['norm'] = _pred.view(self.task_pred['norm'].size(0), self.task_pred['norm'].size(2), self.task_pred['norm'].size(3),3).permute(0, 3, 1, 2) self.task_pred['norm'] = (self.task_pred['norm'] + 1) * 127.5 cos_label = torch.ones(_gt.size(0)).to(self.device) self.loss_norm = self.cfg['NORM_WEIGHT'] * self.criterionNorm(_pred, _gt, cos_label) # edge weight_e = (self.task_pred['edge'].size(2) * self.task_pred['edge'].size(3) - self.input_syn_edge_count ) / self.input_syn_edge_count self.criterionEdge = torch.nn.BCEWithLogitsLoss(weight=weight_e.float().view(-1,1,1,1)).to(self.device) self.loss_edge = self.cfg['EDGE_WEIGHT'] * self.criterionEdge(self.task_pred['edge'], self.input_syn_edge) ## combined loss loss = self.loss_edge + self.loss_norm + self.loss_dep if self.cfg['USE_DA']: pred_syn = self.netD(self.feat_syn[self.cfg['DA_LAYER']].detach()) self.loss_DA = self.criterionGAN(pred_syn, True) loss += self.loss_DA * self.cfg['DA_WEIGHT'] loss.backward() def backward_D(self): ## Synthetic # stop backprop to netB by detaching _feat_s = self.syn_pool.query(self.feat_syn[self.cfg['DA_LAYER']].detach().cpu()) pred_syn = self.netD(_feat_s.to(self.device)) self.loss_D_syn = self.criterionGAN(pred_syn, False) ## Real _feat_r = self.real_pool.query(self.feat_real[self.cfg['DA_LAYER']].detach().cpu()) pred_real = self.netD(_feat_r.to(self.device)) self.loss_D_real = self.criterionGAN(pred_real, True) ## Combined self.loss_D = (self.loss_D_syn + self.loss_D_real) * 0.5 self.loss_D.backward() def optimize(self): self.total_steps += 1 self.forward() # if DA, update on real data if self.cfg['USE_DA']: self.set_requires_grad(self.netD, True) self.set_requires_grad([self.netB, self.netH], False) self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() # update on synthetic data self.set_requires_grad([self.netB, self.netH], True) self.set_requires_grad(self.netD, False) self.optimizer_B.zero_grad() self.optimizer_H.zero_grad() self.backward_BH() self.optimizer_B.step() self.optimizer_H.step() # make models eval mode during test time def eval(self): self.netB.eval() self.netH.eval() self.netD.eval() # used in test time, wrapping `forward` in no_grad() so we don't save # intermediate steps for backprop def test(self): with torch.no_grad(): self.forward() # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() self.lr = self.cfgimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % self.lr) # return visualization images. train.py will save the images. def visualize_pred(self, ep=0): vis_dir = os.path.join(self.save_dir, 'vis') if not os.path.isdir(vis_dir): os.makedirs(vis_dir) if self.total_steps % self.cfg['VIS_FREQ'] == 0: num_pic = min(8, self.task_pred['norm'].size(0)) torchvision.utils.save_image(self.input_syn_color[0:num_pic].cpu(), '%s/ep_%d_iter_%d_color.jpg' % (vis_dir,ep,self.total_steps), nrow=num_pic, normalize=True) vis_norm = torch.cat((self.input_syn_norm[0:num_pic], self.task_pred['norm'][0:num_pic]), dim=0) torchvision.utils.save_image(vis_norm.detach(), '%s/ep_%d_iter_%d_norm.jpg' % (vis_dir,ep,self.total_steps), nrow=num_pic, normalize=True) vis_depth = torch.cat((self.input_syn_dep[0:num_pic], self.task_pred['depth'][0:num_pic]), dim=0) torchvision.utils.save_image(vis_depth.detach(), '%s/ep_%d_iter_%d_depth.jpg' % (vis_dir,ep,self.total_steps), nrow=num_pic, normalize=True) # TODO: visualization edge_vis = torch.nn.functional.sigmoid(self.task_pred['edge']) vis_edge = torch.cat((self.input_syn_edge[0:num_pic], edge_vis[0:num_pic]), dim=0) torchvision.utils.save_image(vis_edge.detach(), '%s/ep_%d_iter_%d_edge.jpg' % (vis_dir,ep,self.total_steps), nrow=num_pic, normalize=False) if self.cfg['USE_DA']: torchvision.utils.save_image(self.input_real_color[0:num_pic].cpu(), '%s/ep_%d_iter_%d_real.jpg' % (vis_dir,ep,self.total_steps), nrow=num_pic, normalize=True) print('==> Saved epoch %d total step %d visualization to %s' % (ep, self.total_steps, vis_dir)) # print on screen, log into tensorboard def print_n_log_losses(self, ep=0): if self.total_steps % self.cfg['PRINT_FREQ'] == 0: print('\nEpoch: %d Total_step: %d LR: %f' % (ep, self.total_steps, self.lr)) print('Train on tasks: Loss_dep: %.4f | Loss_edge: %.4f | Loss_norm: %.4f' % (self.loss_dep, self.loss_edge, self.loss_norm)) info = { 'loss_dep': self.loss_dep, 'loss_norm': self.loss_norm, 'loss_edge': self.loss_edge } if self.cfg['USE_DA']: print('Train for DA: Loss_D_syn: %.4f | Loss_D_real: %.4f | Loss_DA: %.4f' % (self.loss_D_syn, self.loss_D_real, self.loss_DA)) info['loss_D_syn'] = self.loss_D_syn info['loss_D_real'] = self.loss_D_real info['loss_DA'] = self.loss_DA for tag, value in info.items(): self.logger.scalar_summary(tag, value, self.total_steps) # save models to the disk def save_networks(self, which_epoch): for name in self.model_names: save_filename = '%s_ep%s.pth' % (name, which_epoch) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, name) if isinstance(net, torch.nn.DataParallel): torch.save(net.module.cpu().state_dict(), save_path) else: torch.save(net.cpu().state_dict(), save_path) print('==> Saved to %s' % save_path) if torch.cuda.is_available: net.cuda(self.device) # load models from the disk def load_networks(self, which_epoch): for name in self.model_names: if isinstance(name, str): load_filename = '%s_%s.pth' % (which_epoch, name) load_path = os.path.join(self.save_dir, load_filename) net = getattr(self, 'net' + name) if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=str(self.device)) net.load_state_dict(state_dict) # set requies_grad=Fasle to avoid computation def set_requires_grad(self, nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad
def train(): if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN( X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda1, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf ) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val)} ) ) if step % 100 == 0: train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) os.makedirs(checkpoints_dir, exist_ok=True) graph = tf.Graph() with graph.as_default(): cycle_gan = CycleGAN(X_train_file=FLAGS.X_train_file, Y_train_file=FLAGS.Y_train_file, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda1, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: step = 0 while not coord.should_stop(): # update previously generated images fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) # train _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val) })) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 10000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
def train(): # train the model if FLAGS.load_model is not None: checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip( "checkpoints/") else: current_time = datetime.now().strftime("%Y%m%d-%H%M") checkpoints_dir = "checkpoints/{}".format(current_time) try: os.makedirs(checkpoints_dir) except os.error: pass graph = tf.Graph() # add segmentation work-flow yw3025 segementation = Segementation(None) with graph.as_default(): cycle_gan = CycleGAN( X_train_file=FLAGS.X, Y_train_file=FLAGS.Y, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, use_lsgan=FLAGS.use_lsgan, norm=FLAGS.norm, lambda1=FLAGS.lambda1, lambda2=FLAGS.lambda2, learning_rate=FLAGS.learning_rate, beta1=FLAGS.beta1, ngf=FLAGS.ngf, ) G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x, y, x = cycle_gan.model( ) optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) summary_op = tf.summary.merge_all() train_writer = tf.summary.FileWriter(checkpoints_dir, graph) saver = tf.train.Saver() with tf.Session(graph=graph) as sess: if FLAGS.load_model is not None: checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) meta_graph_path = checkpoint.model_checkpoint_path + ".meta" restore = tf.train.import_meta_graph(meta_graph_path) restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) step = int(meta_graph_path.split("-")[2].split(".")[0]) else: sess.run(tf.global_variables_initializer()) step = 0 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: fake_Y_pool = ImagePool(FLAGS.pool_size) fake_X_pool = ImagePool(FLAGS.pool_size) while not coord.should_stop(): # get previously generated images fake_y_val, fake_x_val, y_val, x_val = sess.run( [fake_y, fake_x, y, x]) # add segmentation work-flow yw3025 mask_y = segementation.get_result_railway(y_val) mask_x = segementation.get_result_railway(x_val) mask_fake_x = segementation.get_result_railway(fake_x_val) mask_fake_y = segementation.get_result_railway(fake_y_val) # train yw3025 _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( sess.run( [ optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op ], feed_dict={ cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), cycle_gan.fake_x: fake_X_pool.query(fake_x_val), cycle_gan.g_mask_x: mask_x, cycle_gan.g_mask_y: mask_y, cycle_gan.g_mask_fake_y: mask_fake_y, cycle_gan.g_mask_fake_x: mask_fake_x, cycle_gan.mask_x: mask_x, cycle_gan.mask_y: mask_y })) train_writer.add_summary(summary, step) train_writer.flush() if step % 100 == 0: logging.info('-----------Step %d:-------------' % step) logging.info(' G_loss : {}'.format(G_loss_val)) logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) logging.info(' F_loss : {}'.format(F_loss_val)) logging.info(' D_X_loss : {}'.format(D_X_loss_val)) if step % 3000 == 0: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) step += 1 except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: coord.request_stop(e) finally: save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) logging.info("Model saved in file: %s" % save_path) # When done, ask the threads to stop. coord.request_stop() coord.join(threads)
class CycleMcdModel(BaseModel): def __init__(self, opt): super(CycleMcdModel, self).__init__(opt) print('-------------- Networks initializing -------------') self.mode = None # specify the training losses you want to print out. The program will call base_model.get_current_losses self.lossNames = [ 'loss{}'.format(i) for i in [ 'GenA', 'DisA', 'CycleA', 'IdtA', 'DisB', 'GenB', 'CycleB', 'IdtB', 'Supervised', 'UnsupervisedClassifier', 'UnsupervisedFeature' ] ] self.lossGenA, self.lossDisA, self.lossCycleA, self.lossIdtA = 0, 0, 0, 0 self.lossGenB, self.lossDisB, self.lossCycleB, self.lossIdtB = 0, 0, 0, 0 self.lossSupervised, self.lossUnsupervisedClassifier, self.lossUnsupervisedFeature = 0, 0, 0 # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=opt.lsgan).to( opt.device) # lsgan = True use MSE loss, False use BCE loss self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionSeg = CrossEntropyLoss2d(opt) # 2d for each pixels self.criterionDis = Distance(opt) # specify the training miou you want to print out. The program will call base_model.get_current_mious self.miouNames = [ 'miou{}'.format(i) for i in ['SupervisedA', 'UnsupervisedA', 'SupervisedB', 'UnsupervisedB'] ] self.miouSupervisedA = IouEval(opt.nClass) self.miouUnsupervisedA = IouEval(opt.nClass) self.miouSupervisedB = IouEval(opt.nClass) self.miouUnsupervisedB = IouEval(opt.nClass) # specify the images you want to save/display. The program will call base_model.get_current_visuals # only image doesn't have prefix imageNamesA = [ 'realA', 'fakeA', 'recA', 'idtA', 'supervisedA', 'predSupervisedA', 'gndSupervisedA', 'unsupervisedA', 'predUnsupervisedA', 'gndUnsupervisedA' ] imageNamesB = [ 'realB', 'fakeB', 'recB', 'idtB', 'supervisedB', 'predSupervisedB', 'gndSupervisedB', 'unsupervisedB', 'predUnsupervisedB', 'gndUnsupervisedB' ] self.imageNames = imageNamesA + imageNamesB self.realA, self.fakeA, self.recA, self.idtA = None, None, None, None self.supervisedA, self.predSupervisedA, self.gndSupervisedA = None, None, None self.unsupervisedA, self.predUnsupervisedA, self.gndUnsupervisedA = None, None, None self.realB, self.fakeB, self.recB, self.idtB = None, None, None, None self.supervisedB, self.predSupervisedB, self.gndSupervisedB = None, None, None self.unsupervisedB, self.predUnsupervisedB, self.gndUnsupervisedB = None, None, None # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks # naming is by the input domain # Cycle gan model: 'GenA', 'DisA', 'GenB', 'DisB' # Mcd model : 'Features', 'Classifier1', 'Classifier2' self.modelNames = [ 'net{}'.format(i) for i in [ 'GenA', 'DisA', 'GenB', 'DisB', 'Features', 'Classifier1', 'Classifier2' ] ] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_RGB (G), G_D (F), D_RGB (D_Y), D_D (D_X) self.netGenA = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf, opt.which_model_netG, opt.norm, opt.dropout, opt.init_type, opt.init_gain, opt.gpuIds) self.netDisA = networks.define_D(opt.inputCh, opt.inputCh, opt.which_model_netD, opt.n_layers_D, opt.norm, not opt.lsgan, opt.init_type, opt.init_gain, opt.gpuIds) self.netGenB = networks.define_G(opt.inputCh, opt.inputCh, opt.ngf, opt.which_model_netG, opt.norm, opt.dropout, opt.init_type, opt.init_gain, opt.gpuIds) self.netDisB = networks.define_D(opt.inputCh, opt.inputCh, opt.which_model_netD, opt.n_layers_D, opt.norm, not opt.lsgan, opt.init_type, opt.init_gain, opt.gpuIds) self.netFeatures = self.initNet( DRNSegBase(model_name=opt.segNet, n_class=opt.nClass, input_ch=opt.inputCh)) self.netClassifier1 = self.initNet( DRNSegPixelClassifier(n_class=opt.nClass)) self.netClassifier2 = self.initNet( DRNSegPixelClassifier(n_class=opt.nClass)) self.set_requires_grad([ self.netGenA, self.netGenB, self.netDisA, self.netDisB, self.netFeatures, self.netClassifier1, self.netClassifier2 ], True) # define image pool self.fakeAPool = ImagePool(opt.pool_size) self.fakeBPool = ImagePool(opt.pool_size) # initialize optimizers self.optimizerG = getOptimizer(itertools.chain( self.netGenA.parameters(), self.netGenB.parameters()), opt=opt.cycleOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizerD = getOptimizer(itertools.chain( self.netDisA.parameters(), self.netDisB.parameters()), opt=opt.cycleOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizerF = getOptimizer(itertools.chain( self.netFeatures.parameters()), opt=opt.mcdOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizerC = getOptimizer(itertools.chain( self.netClassifier1.parameters(), self.netClassifier2.parameters()), opt=opt.mcdOpt, lr=opt.lr, beta1=opt.beta1, momentum=opt.momentum, weight_decay=opt.weight_decay) self.optimizers = [] self.optimizers.append(self.optimizerG) self.optimizers.append(self.optimizerD) self.optimizers.append(self.optimizerF) self.optimizers.append(self.optimizerC) self.colorize = Colorize() print('--------------------------------------------------') def name(self): return 'CycleMcdModel' def current_images(self): imageNames = [ 'realA', 'fakeA', 'recA', 'idtA', 'realB', 'fakeB', 'recB', 'idtB', 'supervisedA', 'supervisedB', 'unsupervisedA', 'unsupervisedB' ] segmentationMapNames = [ 'predSupervisedA', 'gndSupervisedA', 'predUnsupervisedA', 'gndUnsupervisedA', 'predSupervisedB', 'gndSupervisedB', 'predUnsupervisedB', 'gndUnsupervisedB' ] visual_ret = OrderedDict() for name in self.imageNames: if name in imageNames: visual_ret[name] = self.invTransform(getattr(self, name)[0]) elif name in segmentationMapNames: visual_ret[name] = \ self.colorize(getattr(self,name)[0]).permute(2,0,1).float()/255 else: raise NotImplementedError return visual_ret def set_input(self, input): self.supervisedA = input['supervisedA']['image'].to(self.opt.device) self.gndSupervisedA = input['supervisedA']['label'].to(self.opt.device) self.unsupervisedA = input['unsupervisedA']['image'].to( self.opt.device) self.gndUnsupervisedA = input['unsupervisedA']['label'].to( self.opt.device) self.supervisedB = input['supervisedB']['image'].to(self.opt.device) self.gndSupervisedB = input['supervisedB']['label'].to(self.opt.device) self.unsupervisedB = input['unsupervisedB']['image'].to( self.opt.device) self.gndUnsupervisedB = input['unsupervisedB']['label'].to( self.opt.device) def forward(self): ''' self.predSupervisedA = self.forwardSegmentation(self.supervisedA) self.predUnsupervisedA = self.forwardSegmentation(self.unsupervisedA) self.predSupervisedB = self.forwardSegmentation(self.supervisedB) self.predUnsupervisedB = self.forwardSegmentation(self.unsupervisedB) ''' def backward_dis_basic(self, netDis, real, fake): # Real predReal = netDis(real) lossDisReal = self.criterionGAN(predReal, True) # Fake predFake = netDis(fake.detach()) lossDisFake = self.criterionGAN(predFake, False) # Combined loss lossDis = (lossDisReal + lossDisFake) * 0.5 # backward lossDis.backward() return float(lossDis) def backward_dis_A(self): fakeA = self.fakeAPool.query(self.fakeA) self.lossDisA = self.backward_dis_basic(self.netDisA, self.realA, fakeA) self.fakeA = self.fakeA.to('cpu') def backward_dis_B(self): fakeB = self.fakeBPool.query(self.fakeB) self.lossDisB = self.backward_dis_basic(self.netDisB, self.realB, fakeB) self.fakeB = self.fakeB.to('cpu') def backward_gen(self, retain_graph=False): lambdaIdt = self.opt.lambdaIdentity lambdaA = self.opt.lambdaA lambdaB = self.opt.lambdaB # Identity loss self.realA = torch.cat([self.supervisedA, self.unsupervisedA], 0) self.realB = torch.cat([self.supervisedB, self.unsupervisedB], 0) self.fakeA = self.netGenB(self.realB) self.fakeB = self.netGenA(self.realA) self.recA = self.netGenB(self.fakeB) self.recB = self.netGenA(self.fakeA) if lambdaIdt > 0: # GenB should be identity if realA is fed. self.idtA = self.netGenB(self.realA) lossIdtA = self.criterionIdt(self.idtA, self.realA) * lambdaA * lambdaIdt # GenA should be identity if realB is fed. self.idtB = self.netGenA(self.realB) lossIdtB = self.criterionIdt(self.idtB, self.realB) * lambdaB * lambdaIdt else: lossIdtA = 0 lossIdtB = 0 # GAN D loss lossGenA = self.criterionGAN(self.netDisB(self.fakeB), True) # GAN D loss lossGenB = self.criterionGAN(self.netDisA(self.fakeA), True) # Forward cycle loss lossCycleA = self.criterionCycle(self.recA, self.realA) * lambdaA # Backward cycle loss lossCycleB = self.criterionCycle(self.recB, self.realB) * lambdaB # combined loss lossG = lossGenA + lossGenB + lossCycleA + lossCycleB + lossIdtA + lossIdtB lossG.backward(retain_graph=retain_graph) # move image to cpu self.realA = self.realA.to('cpu') self.realB = self.realB.to('cpu') self.recA = self.recA.to('cpu') self.recB = self.recB.to('cpu') self.recA = self.recA.to('cpu') self.recB = self.recB.to('cpu') self.lossGenA = float(lossGenA) self.lossGenB = float(lossGenB) self.lossCycleA = float(lossCycleA) self.lossCycleB = float(lossCycleB) self.lossIdtA = float(lossIdtA) self.lossIdtB = float(lossIdtB) def optimize_parameters_cyclegan(self): # GenA and GenB self.set_requires_grad([self.netDisA, self.netDisB], False) self.optimizerG.zero_grad() self.backward_gen() self.optimizerG.step() # DisA and DisB self.set_requires_grad([self.netDisA, self.netDisB], True) self.optimizerD.zero_grad() self.backward_dis_A() self.backward_dis_B() self.optimizerD.step() def forward_mcd(self, data): feature = self.netFeatures(data) pred1 = self.netClassifier1(feature) pred2 = self.netClassifier2(feature) return pred1, pred2 def backward_supervised(self, retain_graph=False): supervised = self.concate_from_A(self.supervisedA) gnd = self.gndSupervisedA.repeat(2, 1, 1) feature = self.netFeatures(supervised) supervisedPred1 = self.netClassifier1(feature) supervisedPred2 = self.netClassifier2(feature) lossSupervisedA = self.criterionSeg(supervisedPred1, gnd) \ + self.criterionSeg(supervisedPred2, gnd) lossSupervisedA.backward(retain_graph=retain_graph) self.predSupervisedA = (supervisedPred1 + supervisedPred2).argmax(1).to('cpu') self.miouSupervisedA.update(self.predSupervisedA, gnd) supervised = self.concate_from_B(self.supervisedB) gnd = self.gndSupervisedB.repeat(2, 1, 1) feature = self.netFeatures(supervised) supervisedPred1 = self.netClassifier1(feature) supervisedPred2 = self.netClassifier2(feature) lossSupervisedB = self.criterionSeg(supervisedPred1, gnd) \ + self.criterionSeg(supervisedPred2, gnd) lossSupervisedB.backward(retain_graph=retain_graph) self.predSupervisedB = (supervisedPred1 + supervisedPred2).argmax(1).to('cpu') self.miouSupervisedB.update(self.predSupervisedB, gnd) self.lossSupervised = float(lossSupervisedA) + float(lossSupervisedB) def backward_unsupervised_classifier(self, retain_graph=False): # A domain supervised = self.concate_from_A(self.supervisedA) supervisedGnd = self.gndSupervisedA.repeat(2, 1, 1) unsupervised = self.concate_from_A(self.unsupervisedA) unsupervisedGnd = self.gndUnsupervisedA.repeat(2, 1, 1) # forward supervised supervisedPred1, supervisedPred2 = self.forward_mcd(supervised) # forward unsupervised unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised) lossUnsupervisedClassifierA = self.criterionSeg(supervisedPred1, supervisedGnd) \ + self.criterionSeg(supervisedPred2, supervisedGnd) \ - self.criterionDis(unsupervisedPred1, unsupervisedPred2) lossUnsupervisedClassifierA.backward(retain_graph=retain_graph) self.predUnsupervisedA = (unsupervisedPred1 + unsupervisedPred2).argmax(1).to('cpu') self.miouUnsupervisedA.update(self.predUnsupervisedA, unsupervisedGnd) # B domain supervised = self.concate_from_B(self.supervisedB) supervisedGnd = self.gndSupervisedB.repeat(2, 1, 1) unsupervised = self.concate_from_B(self.unsupervisedB) unsupervisedGnd = self.gndUnsupervisedB.repeat(2, 1, 1) # forward supervised supervisedPred1, supervisedPred2 = self.forward_mcd(supervised) # forward unsupervised unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised) lossUnsupervisedClassifierB = self.criterionSeg(supervisedPred1, supervisedGnd) \ + self.criterionSeg(supervisedPred2, supervisedGnd) \ - self.criterionDis(unsupervisedPred1, unsupervisedPred2) lossUnsupervisedClassifierB.backward(retain_graph=retain_graph) self.predUnsupervisedB = (unsupervisedPred1 + unsupervisedPred2).argmax(1).to('cpu') self.miouUnsupervisedB.update(self.predUnsupervisedB, unsupervisedGnd) self.lossUnsupervisedClassifier = float(lossUnsupervisedClassifierA) + \ float(lossUnsupervisedClassifierB) def backward_unsupervised_feature(self, retain_graph=False): # A domain unsupervised = self.concate_from_A(self.unsupervisedA) # forward unsupervised unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised) lossUnsupervisedFeatureA = self.criterionDis(unsupervisedPred1, unsupervisedPred2) \ * self.opt.nTimesDLoss lossUnsupervisedFeatureA.backward(retain_graph=retain_graph) # B domain unsupervised = self.concate_from_B(self.unsupervisedB) # forward unsupervised unsupervisedPred1, unsupervisedPred2 = self.forward_mcd(unsupervised) lossUnsupervisedFeatureB = self.criterionDis(unsupervisedPred1, unsupervisedPred2) \ * self.opt.nTimesDLoss lossUnsupervisedFeatureB.backward(retain_graph=retain_graph) self.lossUnsupervisedFeature = float(lossUnsupervisedFeatureA) + \ float(lossUnsupervisedFeatureB) def concate_from_A(self, A): B = self.netGenA(A) return torch.cat([A, B], 0) def concate_from_B(self, B): A = self.netGenB(B) return torch.cat([A, B], 0) def optimize_parameters_mcd(self): # update F and C for Source self.set_requires_grad([self.netClassifier1, self.netClassifier2], True) self.optimizerF.zero_grad() self.optimizerC.zero_grad() self.backward_supervised(retain_graph=False) self.optimizerF.step() self.optimizerC.step() # update C for Target self.set_requires_grad([self.netFeatures], False) self.optimizerC.zero_grad() self.backward_unsupervised_classifier() self.optimizerC.step() # update F for Target self.set_requires_grad([self.netFeatures], True) self.set_requires_grad([self.netClassifier1, self.netClassifier2], False) for i in range(self.opt.k): self.optimizerG.zero_grad() self.optimizerF.zero_grad() self.backward_unsupervised_feature() self.optimizerG.step() self.optimizerF.step() def optimize_parameters(self): self.optimize_parameters_cyclegan() self.optimize_parameters_mcd()