def load_model(self, item_embedding): self.model = Model(n_head=self.config.N_HEAD, n_hid=item_embedding.shape[1], n_seq=self.config.MAX_N_SEQ, n_layer=self.config.N_LAYER, item2vec=item_embedding).cuda() torch_utils.clip_grad_norm_(self.model.parameters(), 5) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.LR, eps=self.config.EPS) self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=LambdaLR( self.config.MAX_DECAY_STEP, self.config.DECAY_STEP).step) # MAX_DECAY_STEP > DECAY_STEP self.warmup_scheduler = warmup.LinearWarmup( self.optimizer, warmup_period=self.config.WARMUP_PERIOD)
def _train_initialize_variables(model_str, model_params, opt_params, cuda): """Helper function that just initializes everything at the beginning of the train function""" # Params passed in as dict to model. model = eval(model_str)(model_params) model.train() # important! optimizer = init_optimizer(opt_params, model) criterion = get_criterion(model_str) if opt_params['lr_scheduler'] is not None: if opt_params['lr_scheduler'] == 'plateau': scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=.5, patience=1, threshold=1e-3) elif opt_params['lr_scheduler'] == 'delayedexpo': scheduler = LambdaLR(optimizer, lr_lambda=[lambda epoch: float(epoch<=4) + float(epoch>4)*1.2**(-epoch)]) else: raise NotImplementedError('only plateau scheduler has been implemented so far') else: scheduler = None if cuda: model = model.cuda() if 'VAE' in model_str: model.is_cuda = True return model, criterion, optimizer, scheduler
# Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=args.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=args.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=args.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step) device = 'cuda' if torch.cuda.is_available() else 'cpu' target_real = torch.ones(args.batch_size, dtype=torch.float).unsqueeze(1).to(device) target_fake = torch.ones(args.batch_size, dtype=torch.float).unsqueeze(1).to(device) wandb_step = 0 log_image_step = 50
disc_a_optimizer = torch.optim.Adam(disc_a.parameters(), lr=args.lr, betas=(0.5, 0.999)) disc_b_optimizer = torch.optim.Adam(disc_b.parameters(), lr=args.lr, betas=(0.5, 0.999)) gen_a_optimizer = torch.optim.Adam(gen_a.parameters(), lr=args.lr, betas=(0.5, 0.999)) gen_b_optimizer = torch.optim.Adam(gen_b.parameters(), lr=args.lr, betas=(0.5, 0.999)) disc_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( disc_a_optimizer, lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step) disc_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( disc_b_optimizer, lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step) gen_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( gen_a_optimizer, lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step) gen_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( gen_b_optimizer, lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step) a_fake_pool = ItemPool() b_fake_pool = ItemPool() ckpt_dir = '{}/checkpoints/{}'.format(args.root_dir, args.dataset) mkdir(ckpt_dir)
def main(args): torch.manual_seed(0) if args.mb_D: raise NotImplementedError('mb_D not implemented') assert args.batch_size > 1, 'batch size needs to be larger than 1 if mb_D' if args.img_norm != 'znorm': raise NotImplementedError('{} not implemented'.format(args.img_norm)) assert args.act in ['relu', 'mish'], 'args.act = {}'.format(args.act) modelarch = 'C_{0}_{1}_{2}_{3}_{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}{14}{15}{16}{17}{18}{19}{20}{21}{22}'.format( args.size, args.batch_size, args.lr, args.n_epochs, args.decay_epoch, # 0, 1, 2, 3, 4 '_G' if args.G_extra else '', # 5 '_D' if args.D_extra else '', # 6 '_U' if args.upsample else '', # 7 '_S' if args.slow_D else '', # 8 '_RL{}-{}'.format(args.start_recon_loss_val, args.start_recon_loss_val), # 9 '_GL{}-{}'.format(args.start_gan_loss_val, args.start_gan_loss_val), # 10 '_prop' if args.keep_prop else '', # 11 '_' + args.img_norm, # 12 '_WL' if args.wasserstein else '', # 13 '_MBD' if args.mb_D else '', # 14 '_FM' if args.fm_loss else '', # 15 '_BF{}'.format(args.buffer_size) if args.buffer_size != 50 else '', # 16 '_N' if args.add_noise else '', # 17 '_L{}'.format(args.load_iter) if args.load_iter > 0 else '', # 18 '_res{}'.format(args.n_resnet_blocks), # 19 '_n{}'.format(args.data_subset) if args.data_subset is not None else '', # 20 '_{}'.format(args.optim), # 21 '_{}'.format(args.act)) # 22 samples_path = os.path.join(args.output_dir, modelarch, 'samples') safe_mkdirs(samples_path) model_path = os.path.join(args.output_dir, modelarch, 'models') safe_mkdirs(model_path) test_path = os.path.join(args.output_dir, modelarch, 'test') safe_mkdirs(test_path) # Definition of variables ###### # Networks netG_A2B = Generator(args.input_nc, args.output_nc, img_size=args.size, extra_layer=args.G_extra, upsample=args.upsample, keep_weights_proportional=args.keep_prop, n_residual_blocks=args.n_resnet_blocks, act=args.act) netG_B2A = Generator(args.output_nc, args.input_nc, img_size=args.size, extra_layer=args.G_extra, upsample=args.upsample, keep_weights_proportional=args.keep_prop, n_residual_blocks=args.n_resnet_blocks, act=args.act) netD_A = Discriminator(args.input_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size) netD_B = Discriminator(args.output_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size) if args.cuda: netG_A2B.cuda() netG_B2A.cuda() netD_A.cuda() netD_B.cuda() if args.load_iter == 0: netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) else: netG_A2B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_A2B_{}.pth'.format(args.load_iter)))) netG_B2A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_B2A_{}.pth'.format(args.load_iter)))) netD_A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_A_{}.pth'.format(args.load_iter)))) netD_B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_B_{}.pth'.format(args.load_iter)))) netG_A2B.train() netG_B2A.train() netD_A.train() netD_B.train() # Lossess criterion_GAN = wasserstein_loss if args.wasserstein else torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() feat_criterion = torch.nn.HingeEmbeddingLoss() # I could also update D only if iters % 2 == 0 lr_G = args.lr lr_D = args.lr / 2 if args.slow_D else args.lr # Optimizers & LR schedulers if args.optim == 'adam': optim = torch.optim.Adam elif args.optim == 'radam': optim = RAdam elif args.optim == 'ranger': optim = Ranger elif args.optim == 'rangerlars': optim = RangerLars else: raise NotImplementedError('args.optim = {} not implemented'.format(args.optim)) optimizer_G = optim(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=args.lr, betas=(0.5, 0.999)) optimizer_D_A = optim(netD_A.parameters(), lr=lr_G, betas=(0.5, 0.999)) optimizer_D_B = optim(netD_B.parameters(), lr=lr_D, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if args.cuda else torch.Tensor input_A = Tensor(args.batch_size, args.input_nc, args.size, args.size) input_B = Tensor(args.batch_size, args.output_nc, args.size, args.size) target_real = Variable(Tensor(args.batch_size).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(args.batch_size).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer(args.buffer_size) fake_B_buffer = ReplayBuffer(args.buffer_size) # Transforms and dataloader for training set transforms_ = [] if args.resize_crop: transforms_ += [transforms.Resize(int(args.size*1.12), Image.BICUBIC), transforms.RandomCrop(args.size)] else: transforms_ += [transforms.Resize(args.size, Image.BICUBIC)] if args.horizontal_flip: transforms_ += [transforms.RandomHorizontalFlip()] transforms_ += [transforms.ToTensor()] if args.add_noise: transforms_ += [transforms.Lambda(lambda x: x + torch.randn_like(x))] transforms_norm = [] if args.img_norm == 'znorm': transforms_norm += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] elif 'scale01' in args.img_norm: transforms_norm += [transforms.Lambda(lambda x: x.mul(1/255))] # TODO this might not preserve the dimensions. is .mul per element? if 'flip' in args.img_norm: transforms_norm += [transforms.Lambda(lambda x: (x - 1).abs())] # TODO this might not preserve the dimensions. is .mul per element? else: raise ValueError('wrong --img_norm. only znorm|scale01|scale01flip') transforms_ += transforms_norm dataloader = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_, unaligned=True, n=args.data_subset), batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu) # Transforms and dataloader for test set transforms_test_ = [transforms.Resize(args.size, Image.BICUBIC), transforms.ToTensor()] transforms_test_ += transforms_norm dataloader_test = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_test_, mode='test'), batch_size=args.batch_size, shuffle=False, num_workers=args.n_cpu) # Training ###### if args.load_iter == 0 and args.load_epoch != 0: print('****** NOTE: args.load_iter == 0 and args.load_epoch != 0 ******') iter = args.load_iter prev_time = time.time() n_test = 10e10 if args.n_test is None else args.n_test n_sample = 10e10 if args.n_sample is None else args.n_sample rl_delta_x = args.n_epochs - args.recon_loss_epoch rl_delta_y = args.end_recon_loss_val - args.start_recon_loss_val gan_delta_x = args.n_epochs - args.gan_loss_epoch gan_delta_y = args.end_gan_loss_val - args.start_gan_loss_val for epoch in range(args.load_epoch, args.n_epochs): rl_effective_epoch = max(epoch - args.recon_loss_epoch, 0) recon_loss_rate = args.start_recon_loss_val + rl_effective_epoch * (rl_delta_y / rl_delta_x) gan_effective_epoch = max(epoch - args.gan_loss_epoch, 0) gan_loss_rate = args.start_gan_loss_val + gan_effective_epoch * (gan_delta_y / gan_delta_x) id_loss_rate = 5.0 for i, batch in enumerate(dataloader): # Set model input real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) # Generators A2B and B2A ###### optimizer_G.zero_grad() # Identity loss # G_A2B(B) should equal B if real B is fed same_B = netG_A2B(real_B) loss_identity_B = criterion_identity(same_B, real_B) # G_B2A(A) should equal A if real A is fed same_A = netG_B2A(real_A) loss_identity_A = criterion_identity(same_A, real_A) # GAN loss fake_B = netG_A2B(real_A) pred_fake, _ = netD_B(fake_B) loss_GAN_A2B = criterion_GAN(pred_fake, target_real) fake_A = netG_B2A(real_B) pred_fake, _ = netD_A(fake_A) loss_GAN_B2A = criterion_GAN(pred_fake, target_real) # Cycle loss recovered_A = netG_B2A(fake_B) loss_cycle_ABA = criterion_cycle(recovered_A, real_A) recovered_B = netG_A2B(fake_A) loss_cycle_BAB = criterion_cycle(recovered_B, real_B) # Total loss loss_G = (loss_identity_A + loss_identity_B) * id_loss_rate loss_G += (loss_GAN_A2B + loss_GAN_B2A) * gan_loss_rate loss_G += (loss_cycle_ABA + loss_cycle_BAB) * recon_loss_rate loss_G.backward() optimizer_G.step() # Discriminator A ###### optimizer_D_A.zero_grad() # Real loss pred_real, _ = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_A = fake_A_buffer.push_and_pop(fake_A) pred_fake, _ = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) loss_D_A = (loss_D_real + loss_D_fake) * 0.5 if args.fm_loss: pred_real, feats_real = netD_A(real_A) pred_fake, feats_fake = netD_A(fake_A.detach()) fm_loss_A = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda) loss_D_A = loss_D_A * 0.1 + fm_loss_A * 0.9 loss_D_A.backward() optimizer_D_A.step() # Discriminator B ###### optimizer_D_B.zero_grad() # Real loss pred_real, _ = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_B = fake_B_buffer.push_and_pop(fake_B) pred_fake, _ = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) loss_D_B = (loss_D_real + loss_D_fake)*0.5 if args.fm_loss: pred_real, feats_real = netD_B(real_B) pred_fake, feats_fake = netD_B(fake_B.detach()) fm_loss_B = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda) loss_D_B = loss_D_B * 0.1 + fm_loss_B * 0.9 loss_D_B.backward() optimizer_D_B.step() if iter % args.log_interval == 0: print('---------------------') print('GAN loss:', as_np(loss_GAN_A2B), as_np(loss_GAN_B2A)) print('Identity loss:', as_np(loss_identity_A), as_np(loss_identity_B)) print('Cycle loss:', as_np(loss_cycle_ABA), as_np(loss_cycle_BAB)) print('D loss:', as_np(loss_D_A), as_np(loss_D_B)) if args.fm_loss: print('fm loss:', as_np(fm_loss_A), as_np(fm_loss_B)) print('recon loss rate:', recon_loss_rate) print('time:', time.time() - prev_time) prev_time = time.time() if iter % args.plot_interval == 0: pass if iter % args.image_save_interval == 0: samples_path_ = os.path.join(samples_path, str(iter / args.image_save_interval)) safe_mkdirs(samples_path_) # New savedir test_pth_AB = os.path.join(test_path, str(iter / args.image_save_interval), 'AB') test_pth_BA = os.path.join(test_path, str(iter / args.image_save_interval), 'BA') safe_mkdirs(test_pth_AB) safe_mkdirs(test_pth_BA) for j, batch_ in enumerate(dataloader_test): real_A_test = Variable(input_A.copy_(batch_['A'])) real_B_test = Variable(input_B.copy_(batch_['B'])) fake_AB_test = netG_A2B(real_A_test) fake_BA_test = netG_B2A(real_B_test) if j < n_sample: recovered_ABA_test = netG_B2A(fake_AB_test) recovered_BAB_test = netG_A2B(fake_BA_test) fn = os.path.join(samples_path_, str(j)) imageio.imwrite(fn + '.A.jpg', tensor2image(real_A_test[0], args.img_norm)) imageio.imwrite(fn + '.B.jpg', tensor2image(real_B_test[0], args.img_norm)) imageio.imwrite(fn + '.BA.jpg', tensor2image(fake_BA_test[0], args.img_norm)) imageio.imwrite(fn + '.AB.jpg', tensor2image(fake_AB_test[0], args.img_norm)) imageio.imwrite(fn + '.ABA.jpg', tensor2image(recovered_ABA_test[0], args.img_norm)) imageio.imwrite(fn + '.BAB.jpg', tensor2image(recovered_BAB_test[0], args.img_norm)) if j < n_test: fn_A = os.path.basename(batch_['img_A'][0]) imageio.imwrite(os.path.join(test_pth_AB, fn_A), tensor2image(fake_AB_test[0], args.img_norm)) fn_B = os.path.basename(batch_['img_B'][0]) imageio.imwrite(os.path.join(test_pth_BA, fn_B), tensor2image(fake_BA_test[0], args.img_norm)) if iter % args.model_save_interval == 0: # Save models checkpoints torch.save(netG_A2B.state_dict(), os.path.join(model_path, 'G_A2B_{}.pth'.format(iter))) torch.save(netG_B2A.state_dict(), os.path.join(model_path, 'G_B2A_{}.pth'.format(iter))) torch.save(netD_A.state_dict(), os.path.join(model_path, 'D_A_{}.pth'.format(iter))) torch.save(netD_B.state_dict(), os.path.join(model_path, 'D_B_{}.pth'.format(iter))) iter += 1 # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step()
netD_A.load_state_dict(pretrained_dict) pretrained_dict = torch.load('/net/cremi/smjoshi/espaces/travail/barcelona/PyTorch-CycleGAN/checkpoints/netD_B.pth') netD_B.load_state_dict(pretrained_dict) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() # Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt['lr'], betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt['lr'], betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt['lr'], betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt['n_epochs'], opt['epoch'], opt['decay_epoch']).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt['n_epochs'], opt['epoch'], opt['decay_epoch']).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt['n_epochs'],opt['epoch'], opt['decay_epoch']).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt['cuda'] else torch.Tensor input_A = Tensor(opt['batch_size'], opt['input_nc'], opt['size'], opt['size']) input_B = Tensor(opt['batch_size'], opt['output_nc'], opt['size'], opt['size']) target_real = Variable(Tensor(opt['batch_size']).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(opt['batch_size']).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Dataset loader
netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() # Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) target_real = Variable(Tensor(opt.batchSize,1).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(opt.batchSize,1).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Dataset loader transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC),
# Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR( opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size) input_B = Tensor(opt.batch_size, opt.output_nc, opt.size, opt.size) target_real = Variable(Tensor(opt.batch_size, 1).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(opt.batch_size, 1).fill_(0.0),
# Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder_A2B.parameters(), decoder_B2A.parameters()), lr=lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR( n_epochs, start_epoch, decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR( n_epochs, start_epoch, decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR( n_epochs, start_epoch, decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if activate_cuda else torch.Tensor input_A = Tensor(batch_size, input_nc, image_size, image_size) input_B = Tensor(batch_size, output_nc, image_size, image_size) target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
netG_en2zh = Generator(3,3).to(device) netG_zh2en = Generator(3,3).to(device) netD_en = Discriminator(3).to(device) netD_zh = Discriminator(3).to(device) netG_en2zh.apply(weights_init_normal) netG_zh2en.apply(weights_init_normal) netD_en.apply(weights_init_normal) netD_zh.apply(weights_init_normal) # optimizers and learning rate schedulers optimizer_G = Adam(itertools.chain(netG_en2zh.parameters(), netG_zh2en.parameters()), lr=opt.lr, betas=BETAS) optimizer_D_en = Adam(netD_en.parameters(), lr=opt.lr, betas=BETAS) optimizer_D_zh = Adam(netD_zh.parameters(), lr=opt.lr, betas=BETAS) lr_scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda = LambdaLR(opt.n_epochs,0,DECAY_EPOCH).step) lr_scheduler_D_en = lr_scheduler.LambdaLR(optimizer_D_en, lr_lambda = LambdaLR(opt.n_epochs,0,DECAY_EPOCH).step) lr_scheduler_D_zh = lr_scheduler.LambdaLR(optimizer_D_zh, lr_lambda = LambdaLR(opt.n_epochs,0,DECAY_EPOCH).step) def train(): for epoch in range(opt.n_epochs): print('=== Starting epoch:', epoch, '===') lr_scheduler_G.step() lr_scheduler_D_en.step() lr_scheduler_D_zh.step() for index, data in enumerate(dataloader): real_data_en = data['en'].to(device) real_data_zh = data['zh'].to(device) ###################
# Lossess criterion_GAN = torch.nn.MSELoss() criterion_l1 = torch.nn.L1Loss() criterion_feat = torch.nn.MSELoss() criterion_VGG= VGGLoss() # Optimizers & LR schedulers optimizer_encoder = torch.optim.Adam(encoder.parameters(),lr=opt.lr, betas=(0.5, 0.999)) optimizer_decoder = torch.optim.Adam(decoder.parameters(),lr=opt.lr, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(0.5, 0.999)) # optimizer_t = torch.optim.Adam(transformer.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lr_scheduler_encoder = torch.optim.lr_scheduler.LambdaLR(optimizer_encoder, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_decoder = torch.optim.lr_scheduler.LambdaLR(optimizer_decoder, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # lr_scheduler_t = torch.optim.lr_scheduler.LambdaLR(optimizer_t, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor # input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) # input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) # target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) # target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) fake_B_buffer = ReplayBuffer() # Dataset loader transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC),
def main(): cuda = torch.cuda.is_available() input_shape = (opt.channels, opt.img_height, opt.img_width) # Initialize generator and discriminator G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks) G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks) D_A = Discriminator(input_shape) D_B = Discriminator(input_shape) if cuda: G_AB = G_AB.cuda() G_BA = G_BA.cuda() D_A = D_A.cuda() D_B = D_B.cuda() criterion_GAN.cuda() criterion_cycle.cuda() criterion_identity.cuda() if opt.epoch != 0: # Load pretrained models G_AB.load_state_dict( torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch))) G_BA.load_state_dict( torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch))) D_A.load_state_dict( torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch))) D_B.load_state_dict( torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch))) else: # Initialize weights G_AB.apply(weights_init_normal) G_BA.apply(weights_init_normal) D_A.apply(weights_init_normal) D_B.apply(weights_init_normal) # Optimizers optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor # Buffers of previously generated samples fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Image transformations transforms_ = [ transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC), transforms.RandomCrop((opt.img_height, opt.img_width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] # Training data loader dataloader = DataLoader( ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu, ) # Test data loader val_dataloader = DataLoader( ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"), batch_size=5, shuffle=True, num_workers=1, ) def sample_images(batches_done): """Saves a generated sample from the test set""" imgs = next(iter(val_dataloader)) G_AB.eval() G_BA.eval() real_A = Variable(imgs["A"].type(Tensor)) fake_B = G_AB(real_A) real_B = Variable(imgs["B"].type(Tensor)) fake_A = G_BA(real_B) # Arange images along x-axis real_A = make_grid(real_A, nrow=5, normalize=True) real_B = make_grid(real_B, nrow=5, normalize=True) fake_A = make_grid(fake_A, nrow=5, normalize=True) fake_B = make_grid(fake_B, nrow=5, normalize=True) # Arange images along y-axis image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False) # ---------- # Training # ---------- prev_time = time.time() for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # Set model input real_A = Variable(batch["A"].type(Tensor)) real_B = Variable(batch["B"].type(Tensor)) # Adversarial ground truths valid = Variable(Tensor( np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False) fake = Variable(Tensor( np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False) # ------------------ # Train Generators # ------------------ G_AB.train() G_BA.train() optimizer_G.zero_grad() # Identity loss loss_id_A = criterion_identity(G_BA(real_A), real_A) loss_id_B = criterion_identity(G_AB(real_B), real_B) loss_identity = (loss_id_A + loss_id_B) / 2 # GAN loss fake_B = G_AB(real_A) loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) fake_A = G_BA(real_B) loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle loss recov_A = G_BA(fake_B) loss_cycle_A = criterion_cycle(recov_A, real_A) recov_B = G_AB(fake_A) loss_cycle_B = criterion_cycle(recov_B, real_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity loss_G.backward() optimizer_G.step() # ----------------------- # Train Discriminator A # ----------------------- optimizer_D_A.zero_grad() # Real loss loss_real = criterion_GAN(D_A(real_A), valid) # Fake loss (on batch of previously generated samples) fake_A_ = fake_A_buffer.push_and_pop(fake_A) loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake) # Total loss loss_D_A = (loss_real + loss_fake) / 2 loss_D_A.backward() optimizer_D_A.step() # ----------------------- # Train Discriminator B # ----------------------- optimizer_D_B.zero_grad() # Real loss loss_real = criterion_GAN(D_B(real_B), valid) # Fake loss (on batch of previously generated samples) fake_B_ = fake_B_buffer.push_and_pop(fake_B) loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) # Total loss loss_D_B = (loss_real + loss_fake) / 2 loss_D_B.backward() optimizer_D_B.step() loss_D = (loss_D_A + loss_D_B) / 2 # -------------- # Log Progress # -------------- # Determine approximate time left batches_done = epoch * len(dataloader) + i batches_left = opt.n_epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" % ( epoch, opt.n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), loss_GAN.item(), loss_cycle.item(), loss_identity.item(), time_left, )) # If at sample interval save image if batches_done % opt.sample_interval == 0: sample_images(batches_done) # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0: # Save model checkpoints torch.save( G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch)) torch.save( G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch)) torch.save( D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch)) torch.save( D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch))) else: # Initialize weights init_weights_of_model(G_AB, init_type=init_type, init_gain=init_gain) init_weights_of_model(G_BA, init_type=init_type, init_gain=init_gain) init_weights_of_model(D_A, init_type=init_type, init_gain=init_gain) init_weights_of_model(D_B, init_type=init_type, init_gain=init_gain) # Optimizers optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)) optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2)) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step ) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step ) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step ) Tensor = torch.cuda.FloatTensor if CUDA else torch.Tensor # Buffers of previously generated samples fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Image transformations
def __init__(self, epoch=0, n_epochs=1000, batchSize=1, lr=0.0002, decay_epoch=100, size=256, input_nc=3, output_nc=3, cuda=True, n_cpu=8, load_from_ckpt=False): self.epoch = epoch self.n_epochs = n_epochs self.batchSize = batchSize self.lr = lr self.decay_epoch = decay_epoch self.size = size self.input_nc = input_nc self.output_nc = output_nc self.cuda = cuda self.n_cpu = n_cpu rootA = "../dataset/monet_field_data" rootB = "../dataset/field_data" if torch.cuda.is_available() and not self.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) ###### Definition of variables ###### # Networks self.netG_A2B = Generator(self.input_nc, self.output_nc) self.netG_B2A = Generator(self.output_nc, self.input_nc) self.netD_A = Discriminator(self.input_nc) self.netD_B = Discriminator(self.output_nc) if load_from_ckpt: print("loading from ckpt") self.netG_A2B.load_state_dict(torch.load('output/netG_A2B.pth')) self.netG_B2A.load_state_dict(torch.load('output/netG_B2A.pth')) self.netD_A.load_state_dict(torch.load('output/netD_A.pth')) self.netD_B.load_state_dict(torch.load('output/netD_B.pth')) else: self.netG_A2B.apply(weights_init_normal) self.netG_B2A.apply(weights_init_normal) self.netD_A.apply(weights_init_normal) self.netD_B.apply(weights_init_normal) if self.cuda: self.netG_A2B.cuda() self.netG_B2A.cuda() self.netD_A.cuda() self.netD_B.cuda() # Lossess self.criterion_GAN = torch.nn.MSELoss() self.criterion_cycle = torch.nn.L1Loss() self.criterion_identity = torch.nn.L1Loss() # Optimizers & LR schedulers self.optimizer_G = torch.optim.Adam(itertools.chain( self.netG_A2B.parameters(), self.netG_B2A.parameters()), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=self.lr, betas=(0.5, 0.999)) self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( self.optimizer_G, lr_lambda=LambdaLR(self.n_epochs, self.epoch, self.decay_epoch).step) self.lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D_A, lr_lambda=LambdaLR(self.n_epochs, self.epoch, self.decay_epoch).step) self.lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( self.optimizer_D_B, lr_lambda=LambdaLR(self.n_epochs, self.epoch, self.decay_epoch).step) if load_from_ckpt: print('load states') checkpoint = torch.load('output/states.pth') ''' self.optimizer_G.load_state_dict(checkpoint['optimizer_G']) self.optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A']) self.optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B']) self.lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) self.lr_scheduler_D_A.load_state_dict(checkpoint['lr_scheduler_D_A']) self.lr_scheduler_D_B.load_state_dict(checkpoint['lr_scheduler_D_B']) ''' self.lr = checkpoint['lr'] self.epoch = checkpoint['epoch'] + 1 # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor self.input_A = Tensor(self.batchSize, self.input_nc, self.size, self.size) self.input_B = Tensor(self.batchSize, self.output_nc, self.size, self.size) self.target_real = Variable(Tensor(self.batchSize).fill_(1.0), requires_grad=False) self.target_fake = Variable(Tensor(self.batchSize).fill_(0.0), requires_grad=False) self.fake_A_buffer = ReplayBuffer() self.fake_B_buffer = ReplayBuffer() # Dataset loader transforms_ = [ transforms.Resize((int(self.size * 1.12), int(self.size * 1.12)), Image.BICUBIC), transforms.RandomCrop(self.size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] self.dataloader = DataLoader(ImageDataset(rootA, rootB, transforms_=transforms_, unaligned=True), batch_size=self.batchSize, shuffle=True, num_workers=self.n_cpu)
def train(): G_AB = Generator(input_nc, output_nc) G_BA = Generator(output_nc, input_nc) D_A = Discriminator(input_nc) D_B = Discriminator(output_nc) G_AB.cuda() G_BA.cuda() D_A.cuda() D_B.cuda() G_AB.apply(weights_init) G_BA.apply(weights_init) D_A.apply(weights_init) D_B.apply(weights_init) #Loss GD_loss = nn.MSELoss() L1_loss = nn.L1Loss() L1_loss_identity = nn.L1Loss() optim_G = optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr_G, betas=(0.5, 0.999)) optim_D_A = optim.Adam(D_A.parameters(), lr=lr_D, betas=(0.5, 0.999)) optim_D_B = optim.Adam(D_B.parameters(), lr=lr_D, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optim_G, lr_lambda=LambdaLR(n_epochs, start_epoch, decay).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optim_D_A, lr_lambda=LambdaLR(n_epochs, start_epoch, decay).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optim_D_B, lr_lambda=LambdaLR(n_epochs, start_epoch, decay).step) Tensor = torch.cuda.FloatTensor input_A = Tensor(1, 3, 256, 256) input_B = Tensor(1, 3, 256, 256) fake_A_buffer = keep() fake_B_buffer = keep() if opt.opencv: print('OPENCV MODE') transforms_ = [ T.Scale(286), T.RandomCrop(256), T.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] else: print('PIL MODE') transforms_ = [ transforms.Resize(286, Image.BICUBIC), transforms.RandomCrop(256), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] dataloader = DataLoader(Loadimage(opt.dataroot, transforms_=transforms_, unaligned=True, mode_opencv=opt.opencv), batch_size=1, shuffle=True, num_workers=8) for epoch in range(start_epoch, n_epochs): for i, batch in enumerate(dataloader): real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) ######################################## #Train Generator #A to B optim_G.zero_grad() #identity loss same_B = G_AB(real_B) loss_identity_B = L1_loss_identity(same_B, real_B) * 5.0 same_A = G_BA(real_A) loss_identity_A = L1_loss_identity(same_A, real_A) * 5.0 fake_B = G_AB(real_A) pred_fake_B = D_B(fake_B) G_AB_Loss = GD_loss( pred_fake_B, Variable(torch.ones(pred_fake_B.size()).cuda())) #B to A fake_A = G_BA(real_B) pred_fake_A = D_A(fake_A) G_BA_Loss = GD_loss( pred_fake_A, Variable(torch.ones(pred_fake_A.size()).cuda())) #fake B to A similar_A = G_BA(fake_B) BA_cycle_loss = L1_loss(similar_A, real_A) * 10.0 #fake A to B similar_B = G_AB(fake_A) AB_cycle_loss = L1_loss(similar_B, real_B) * 10.0 #total loss G G_loss = G_AB_Loss + G_BA_Loss + BA_cycle_loss + AB_cycle_loss + loss_identity_A + loss_identity_B G_loss_identity = loss_identity_B + loss_identity_A G_loss_GAN = G_AB_Loss + G_BA_Loss G_loss_cycle = BA_cycle_loss + AB_cycle_loss #OptimizeG G_loss.backward() optim_G.step() loss_G_plot.append(G_loss.data[0]) loss_G_identity_plot.append(G_loss_identity.data[0]) loss_G_GAN_plot.append(G_loss_GAN.data[0]) loss_G_cycle_plot.append(G_loss_cycle.data[0]) ####################################### #Train Discriminator #Discriminator D_AB optim_D_A.zero_grad() pred_real_A = D_A(real_A) D_real_loss = GD_loss( pred_real_A, Variable(torch.ones(pred_real_A.size()).cuda())) fake_A = fake_A_buffer.empty_fill_data(fake_A) pred_d_fake_A = D_A(fake_A) D_fake_loss = GD_loss( pred_d_fake_A, Variable(torch.zeros(pred_d_fake_A.size()).cuda())) D_A_loss_total = (D_real_loss + D_fake_loss) * 0.5 D_A_loss_total.backward() optim_D_A.step() #Discriminator D_BA optim_D_B.zero_grad() pred_real_B = D_B(real_B) D_real_loss = GD_loss( pred_real_B, Variable(torch.ones(pred_real_B.size()).cuda())) fake_B = fake_B_buffer.empty_fill_data(fake_B) pred_d_fake_B = D_B(fake_B) D_fake_loss = GD_loss( pred_d_fake_B, Variable(torch.zeros(pred_d_fake_B.size()).cuda())) D_B_loss_total = (D_real_loss + D_fake_loss) * 0.5 D_B_loss_total.backward() optim_D_B.step() D_Loss = D_A_loss_total + D_B_loss_total loss_D_plot.append(D_Loss.data[0]) ##################################### #Print All losses print('Epoch [%d/%d], Step [%d/%d], G_loss: %.4f, D_B_Loss: %.4f' % (epoch + 1, n_epochs, i + 1, len(dataloader), G_loss.data[0], D_Loss.data[0])) lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() torch.save(G_AB.state_dict(), 'output/G_AB.pth') torch.save(G_BA.state_dict(), 'output/G_BA.pth') torch.save(D_A.state_dict(), 'output/D_AB.pth') torch.save(D_B.state_dict(), 'output/D_BA.pth') x = np.linspace(start_epoch, n_epochs, num=len(loss_G_plot)) plt.figure(1) plt.plot(x, loss_G_plot) plt.xticks(np.arange(start_epoch, n_epochs + 1, 25)) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Loss_G') plt.savefig('Loss_G.png') plt.figure(2) plt.plot(x, loss_G_identity_plot) plt.xticks(np.arange(start_epoch, n_epochs + 1, 25)) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Loss_G_Identity') plt.savefig('Loss_G_Identity.png') plt.figure(3) plt.plot(x, loss_G_GAN_plot) plt.xticks(np.arange(start_epoch, n_epochs + 1, 25)) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Loss_G_GAN') plt.savefig('Loss_G_GAN.png') plt.figure(4) plt.plot(x, loss_G_cycle_plot) plt.xticks(np.arange(start_epoch, n_epochs + 1, 25)) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Loss_G_Cycle') plt.savefig('Loss_G_Cycle.png') plt.figure(5) plt.plot(x, loss_D_plot) plt.xticks(np.arange(start_epoch, n_epochs + 1, 25)) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Loss_D') plt.savefig('Loss_D.png')
def main(args): writer = SummaryWriter(os.path.join(args.out_dir, 'logs')) current_time = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") os.makedirs( os.path.join(args.out_dir, 'models', args.model_name + '_' + current_time)) os.makedirs( os.path.join(args.out_dir, 'logs', args.model_name + '_' + current_time)) G_AB = Generator(args.in_channel, args.out_channel).to(args.device) G_BA = Generator(args.in_channel, args.out_channel).to(args.device) D_A = Discriminator(args.in_channel).to(args.device) D_B = Discriminator(args.out_channel).to(args.device) segmen_B = Unet(3, 34).to(args.device) if args.model_path is not None: AB_path = os.join.path(args.model_path, 'ab.pt') BA_path = os.join.path(args.model_path, 'ba.pt') DA_path = os.join.path(args.model_path, 'da.pt') DB_path = os.join.path(args.model_path, 'db.pt') segmen_path = os.join.path(args.model_path, 'semsg.pt') with open(AB_path, 'rb') as f: state_dict = torch.load(f) G_AB.load_state_dict(state_dict) with open(BA_path, 'rb') as f: state_dict = torch.load(f) G_BA.load_state_dict(state_dict) with open(DA_path, 'rb') as f: state_dict = torch.load(f) D_A.load_state_dict(state_dict) with open(DB_path, 'rb') as f: state_dict = torch.load(f) D_B.load_state_dict(state_dict) with open(segmen_path, 'rb') as f: state_dict = torch.load(f) segmen_B.load_state_dict(state_dict) else: G_AB.apply(weights_init_normal) G_BA.apply(weights_init_normal) D_A.apply(weights_init_normal) D_B.apply(weights_init_normal) G_AB = nn.DataParallel(G_AB) G_BA = nn.DataParallel(G_BA) D_A = nn.DataParallel(D_A) D_B = nn.DataParallel(D_B) segmen_B = nn.DataParallel(segmen_B) criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() criterion_segmen = torch.nn.BCELoss() optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=args.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=args.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=args.lr, betas=(0.5, 0.999)) optimizer_segmen_B = torch.optim.Adam(segmen_B.parameters(), lr=args.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() transforms_ = [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] dataloader = DataLoader(ImgDataset(args.dataset_path, transforms_=transforms_, unaligned=True, device=args.device), batch_size=args.batchSize, shuffle=True, num_workers=0) logger = Logger(args.n_epochs, len(dataloader)) target_real = Variable(torch.Tensor(args.batchSize, 1).fill_(1.)).to(args.device).detach() target_fake = Variable(torch.Tensor(args.batchSize, 1).fill_(0.)).to(args.device).detach() G_AB.train() G_BA.train() D_A.train() D_B.train() segmen_B.train() for epoch in range(args.epoch, args.n_epochs): for i, batch in enumerate(dataloader): real_A = batch['A'].clone() real_B = batch['B'].clone() B_label = batch['B_label'].clone() fake_b = G_AB(real_A) fake_a = G_BA(real_B) same_b = G_AB(real_B) same_a = G_BA(real_A) recovered_A = G_BA(fake_b) recovered_B = G_AB(fake_a) pred_Blabel = segmen_B(real_B) pred_fakeAlabel = segmen_B(fake_a) optimizer_segmen_B.zero_grad() #segmen loss, do we assume that it also learns how to segment images after doing domain transfer? loss_segmen_B = criterion_segmen( pred_Blabel, B_label) + criterion_segmen( segmen_B(fake_a.detach()), B_label) loss_segmen_B.backward() optimizer_segmen_B.step() optimizer_G.zero_grad() #gan loss pred_fakeb = D_B(fake_b) loss_gan_AB = criterion_GAN(pred_fakeb, target_real) pred_fakea = D_A(fake_a) loss_gan_BA = criterion_GAN(pred_fakea, target_real) #identity loss loss_identity_B = criterion_identity(same_b, real_B) * 5 loss_identity_A = criterion_identity(same_a, real_A) * 5 #cycle consistency loss loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10 loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10 #cycle segmen diff loss loss_segmen_diff = criterion_segmen(segmen_B(recovered_B), pred_Blabel.detach()) loss_G = loss_gan_AB + loss_gan_BA + loss_identity_B + loss_identity_A + loss_cycle_ABA + loss_cycle_BAB + loss_segmen_diff loss_G.backward() optimizer_G.step() ##discriminator a optimizer_D_A.zero_grad() pred_realA = D_A(real_A) loss_D_A_real = criterion_GAN(pred_realA, target_real) fake_A = fake_A_buffer.push_and_pop(fake_a) pred_fakeA = D_A(fake_A.detach()) loss_D_A_fake = criterion_GAN(pred_fakeA, target_fake) loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5 loss_D_A.backward() optimizer_D_A.step() #discriminator b optimizer_D_B.zero_grad() pred_realB = D_B(real_B) loss_D_B_real = criterion_GAN(pred_realB, target_real) fake_B = fake_B_buffer.push_and_pop(fake_b) pred_fakeB = D_B(fake_B.detach()) loss_D_B_fake = criterion_GAN(pred_fakeB, target_fake) loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5 loss_D_B.backward() optimizer_D_B.step() logger.log( { 'loss_segmen_B': loss_segmen_B, 'loss_G': loss_G, 'loss_G_identity': (loss_identity_A + loss_identity_B), 'loss_G_GAN': (loss_gan_AB + loss_gan_BA), 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B) }, images={ 'real_A': real_A, 'real_B': real_B, 'fake_A': fake_a, 'fake_B': fake_b, 'reconstructed_A': recovered_A, 'reconstructed_B': recovered_B }, out_dir=os.path.join( args.out_dir, 'logs', args.model_name + '_' + current_time + '/' + str(epoch)), writer=writer) if (epoch + 1) % args.save_per_epochs == 0: os.makedirs( os.path.join(args.out_dir, 'models', args.model_name + '_' + current_time, str(epoch))) torch.save( G_AB.module.state_dict(), os.path.join(args.out_dir, 'models', args.model_name + '_' + current_time, str(epoch), 'ab.pt')) torch.save( G_BA.module.state_dict(), os.path.join(args.out_dir, 'models', args.model_name + '_' + current_time, str(epoch), 'ba.pt')) torch.save( D_A.module.state_dict(), os.path.join(args.out_dir, 'models', args.model_name + '_' + current_time, str(epoch), 'da.pt')) torch.save( D_B.module.state_dict(), os.path.join(args.out_dir, 'models', args.model_name + '_' + current_time, str(epoch), 'db.pt')) torch.save( segmen_B.module.state_dict(), os.path.join(args.out_dir, 'models', args.model_name + '_' + current_time, str(epoch), 'semsg.pt')) lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step()
def train_from_mask(): #load best fitness binary masks mask_input_A2B=np.loadtxt("/cache/GA/txt/best_fitness_A2B.txt") mask_input_B2A=np.loadtxt("/cache/GA/txt/best_fitness_B2A.txt") cfg_mask_A2B=compute_layer_mask(mask_input_A2B,mask_chns) cfg_mask_B2A=compute_layer_mask(mask_input_B2A,mask_chns) netG_B2A = Generator(opt.output_nc, opt.input_nc) netG_A2B = Generator(opt.output_nc, opt.input_nc) model_A2B = Generator_Prune(cfg_mask_A2B) model_B2A = Generator_Prune(cfg_mask_B2A) netD_A = Discriminator(opt.input_nc) netD_B = Discriminator(opt.output_nc) netG_A2B.load_state_dict(torch.load('/cache/log/output/netG_A2B.pth')) netG_B2A.load_state_dict(torch.load('/cache/log/output/netG_B2A.pth')) netD_A.load_state_dict(torch.load('/cache/log/output/netD_A.pth')) netD_B.load_state_dict(torch.load('/cache/log/output/netD_B.pth')) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() layer_id_in_cfg=0 start_mask=torch.ones(3) end_mask=cfg_mask_A2B[layer_id_in_cfg] for [m0, m1] in zip(netG_A2B.modules(), model_A2B.modules()): if isinstance(m0, nn.Conv2d): idx0 = np.squeeze(np.argwhere(np.asarray(start_mask))) idx1 = np.squeeze(np.argwhere(np.asarray(end_mask))) print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() w1 = w1[idx1.tolist(), :, :, :].clone() m1.weight.data = w1.clone() m1.bias.data =m0.bias.data[idx1.tolist()].clone() layer_id_in_cfg += 1 start_mask = end_mask if layer_id_in_cfg < len(cfg_mask_A2B): # do not change in Final FC end_mask = cfg_mask_A2B[layer_id_in_cfg] print(layer_id_in_cfg) elif isinstance(m0, nn.ConvTranspose2d): print('Into ConvTranspose...') idx0 = np.squeeze(np.argwhere(np.asarray(start_mask))) idx1 = np.squeeze(np.argwhere(np.asarray(end_mask))) print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) w1 = m0.weight.data[idx0.tolist(),:, :, :].clone() w1 = w1[:,idx1.tolist(), :, :].clone() m1.weight.data = w1.clone() m1.bias.data =m0.bias.data[idx1.tolist()].clone() layer_id_in_cfg += 1 start_mask = end_mask if layer_id_in_cfg < len(cfg_mask_A2B): end_mask = cfg_mask_A2B[layer_id_in_cfg] layer_id_in_cfg=0 start_mask=torch.ones(3) end_mask=cfg_mask_B2A[layer_id_in_cfg] for [m0, m1] in zip(netG_B2A.modules(), model_B2A.modules()): if isinstance(m0, nn.Conv2d): idx0 = np.squeeze(np.argwhere(np.asarray(start_mask))) idx1 = np.squeeze(np.argwhere(np.asarray(end_mask))) print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() w1 = w1[idx1.tolist(), :, :, :].clone() m1.weight.data = w1.clone() m1.bias.data =m0.bias.data[idx1.tolist()].clone() layer_id_in_cfg += 1 start_mask = end_mask if layer_id_in_cfg < len(cfg_mask_B2A): end_mask = cfg_mask_B2A[layer_id_in_cfg] print(layer_id_in_cfg) elif isinstance(m0, nn.ConvTranspose2d): print('Into ConvTranspose...') idx0 = np.squeeze(np.argwhere(np.asarray(start_mask))) idx1 = np.squeeze(np.argwhere(np.asarray(end_mask))) print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size)) w1 = m0.weight.data[idx0.tolist(),:, :, :].clone() w1 = w1[:,idx1.tolist(), :, :].clone() m1.weight.data = w1.clone() m1.bias.data =m0.bias.data[idx1.tolist()].clone() layer_id_in_cfg += 1 start_mask = end_mask if layer_id_in_cfg < len(cfg_mask_B2A): end_mask = cfg_mask_B2A[layer_id_in_cfg] # Dataset loader netD_A=torch.nn.DataParallel(netD_A).cuda() netD_B=torch.nn.DataParallel(netD_B).cuda() model_A2B=torch.nn.DataParallel(model_A2B).cuda() model_B2A=torch.nn.DataParallel(model_B2A).cuda() Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() lamda_loss_ID=5.0 lamda_loss_G=1.0 lamda_loss_cycle=10.0 optimizer_G = torch.optim.Adam(itertools.chain(model_A2B.parameters(), model_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), transforms.RandomCrop(opt.size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True,mode='train'), batch_size=opt.batchSize, shuffle=True,drop_last=True) for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # Set model input real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) ###### Generators A2B and B2A ###### optimizer_G.zero_grad() # Identity loss # G_A2B(B) should equal B if real B is fed same_B = model_A2B(real_B) loss_identity_B = criterion_identity(same_B, real_B)*lamda_loss_ID #initial 5.0 # G_B2A(A) should equal A if real A is fed same_A = model_B2A(real_A) loss_identity_A = criterion_identity(same_A, real_A)*lamda_loss_ID #initial 5.0 # GAN loss fake_B = model_A2B(real_A) pred_fake = netD_B(fake_B) loss_GAN_A2B = criterion_GAN(pred_fake, target_real)*lamda_loss_G #initial 1.0 fake_A = model_B2A(real_B) pred_fake = netD_A(fake_A) loss_GAN_B2A = criterion_GAN(pred_fake, target_real)*lamda_loss_G #initial 1.0 # Cycle loss recovered_A = model_B2A(fake_B) loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*lamda_loss_cycle #initial 10.0 recovered_B = model_A2B(fake_A) loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*lamda_loss_cycle #initial 10.0 # Total loss loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB loss_G.backward() optimizer_G.step() ###### Discriminator A ###### optimizer_D_A.zero_grad() # Real loss pred_real = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_A = fake_A_buffer.push_and_pop(fake_A) pred_fake = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_A = (loss_D_real + loss_D_fake)*0.5 loss_D_A.backward() optimizer_D_A.step() ################################### ###### Discriminator B ###### optimizer_D_B.zero_grad() # Real loss pred_real = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_B = fake_B_buffer.push_and_pop(fake_B) pred_fake = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_B = (loss_D_real + loss_D_fake)*0.5 loss_D_B.backward() optimizer_D_B.step() print("epoch:%d Loss G:%4f LossID_A:%4f LossID_B:%4f Loss_G_A2B:%4f Loss_G_B2A:%4f Loss_Cycle_ABA:%4f Loss_Cycle_BAB:%4f "%(epoch,loss_G,loss_identity_A, loss_identity_B, loss_GAN_A2B, loss_GAN_B2A, loss_cycle_ABA, loss_cycle_BAB)) # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() if epoch%20==0: # Save models checkpoints torch.save(model_A2B.module.state_dict(), '/cache/log/output/A2B_%d.pth'%(epoch)) torch.save(model_B2A.module.state_dict(), '/cache/log/output/B2A_%d.pth'%(epoch))
# Lossess criterion_GAN = torch.nn.MSELoss() # Adversarial Loss criterion_cycle = torch.nn.L1Loss() # Cyclic consistency loss # Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam( netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam( netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) # real target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) # fake fake_A_buffer = ReplayBuffer()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--epoch', type=int, default=0, help='starting epoch') parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') parser.add_argument('--batchSize', type=int, default=1, help='size of the batches') parser.add_argument('--dataroot', type=str, default='datasets/data/', help='root directory of the dataset') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') parser.add_argument( '--decay_epoch', type=int, default=100, help='epoch to start linearly decaying the learning rate to 0') parser.add_argument('--size', type=int, default=256, help='size of the data crop (squared assumed)') parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') parser.add_argument('--cuda', action='store_true', help='use GPU computation') parser.add_argument( '--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') opt = parser.parse_args() print(opt) if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) ###### Definition of variables ###### # Networks netG_A2B = Generator(opt.input_nc, opt.output_nc) netG_B2A = Generator(opt.output_nc, opt.input_nc) netD_A = Discriminator(opt.input_nc) netD_B = Discriminator(opt.output_nc) if opt.cuda: netG_A2B.cuda() netG_B2A.cuda() netD_A.cuda() netD_B.cuda() netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() # Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Dataset loader transforms_ = [ transforms.Resize(int(opt.size * 1.12), Image.BICUBIC), transforms.RandomCrop(opt.size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) # Loss plot logger = Logger(opt.n_epochs, len(dataloader)) ################################### ###### Training ###### for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # Set model input real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) ###### Generators A2B and B2A ###### optimizer_G.zero_grad() # Identity loss # G_A2B(B) should equal B if real B is fed same_B = netG_A2B(real_B) loss_identity_B = criterion_identity(same_B, real_B) * 5.0 # G_B2A(A) should equal A if real A is fed same_A = netG_B2A(real_A) loss_identity_A = criterion_identity(same_A, real_A) * 5.0 # GAN loss fake_B = netG_A2B(real_A) pred_fake = netD_B(fake_B) loss_GAN_A2B = criterion_GAN(pred_fake, target_real) fake_A = netG_B2A(real_B) pred_fake = netD_A(fake_A) loss_GAN_B2A = criterion_GAN(pred_fake, target_real) # Cycle loss recovered_A = netG_B2A(fake_B) loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0 recovered_B = netG_A2B(fake_A) loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0 # Total loss loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB loss_G.backward() optimizer_G.step() ################################### ###### Discriminator A ###### optimizer_D_A.zero_grad() # Real loss pred_real = netD_A(real_A) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_A = fake_A_buffer.push_and_pop(fake_A) pred_fake = netD_A(fake_A.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A.backward() optimizer_D_A.step() ################################### ###### Discriminator B ###### optimizer_D_B.zero_grad() # Real loss pred_real = netD_B(real_B) loss_D_real = criterion_GAN(pred_real, target_real) # Fake loss fake_B = fake_B_buffer.push_and_pop(fake_B) pred_fake = netD_B(fake_B.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B.backward() optimizer_D_B.step() ################################### # Progress report (http://localhost:8097) logger.log( { 'loss_G': loss_G, 'loss_G_identity': (loss_identity_A + loss_identity_B), 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B) }, images={ 'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B }) # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() # Save models checkpoints torch.save(netG_A2B.state_dict(), 'output/netG_A2B.pth') torch.save(netG_B2A.state_dict(), 'output/netG_B2A.pth') torch.save(netD_A.state_dict(), 'output/netD_A.pth') torch.save(netD_B.state_dict(), 'output/netD_B.pth')
def train(self): num_channels = self.config.NUM_CHANNELS use_cuda = self.config.USE_CUDA lr = self.config.LEARNING_RATE # Networks netG_A2B = Generator(num_channels) netG_B2A = Generator(num_channels) netD_A = Discriminator(num_channels) netD_B = Discriminator(num_channels) #netG_A2B = Generator_BN(num_channels) #netG_B2A = Generator_BN(num_channels) #netD_A = Discriminator_BN(num_channels) #netD_B = Discriminator_BN(num_channels) if use_cuda: netG_A2B.cuda() netG_B2A.cuda() netD_A.cuda() netD_B.cuda() netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) criterion_GAN = torch.nn.BCELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999)) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(self.config.EPOCH, 0, self.config.EPOCH//2).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(self.config.EPOCH, 0, self.config.EPOCH//2).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(self.config.EPOCH, 0, self.config.EPOCH//2).step) # Inputs & targets memory allocation #Tensor = LongTensor if use_cuda else torch.Tensor batch_size = self.config.BATCH_SIZE height, width, channels = self.config.INPUT_SHAPE input_A = FloatTensor(batch_size, channels, height, width) input_B = FloatTensor(batch_size, channels, height, width) target_real = Variable(FloatTensor(batch_size).fill_(1.0), requires_grad=False) target_fake = Variable(FloatTensor(batch_size).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() transforms_ = [transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] dataloader = DataLoader(ImageDataset(self.config.DATA_DIR, self.config.DATASET_A, self.config.DATASET_B, transforms_=transforms_, unaligned=True), batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True) # Loss plot logger = Logger(self.config.EPOCH, len(dataloader)) now = datetime.datetime.now() datetime_sequence = "{0}{1:02d}{2:02d}_{3:02}{4:02d}".format(str(now.year)[-2:], now.month, now.day , now.hour, now.minute) output_name_1 = self.config.DATASET_A + "2" + self.config.DATASET_B output_name_2 = self.config.DATASET_B + "2" + self.config.DATASET_A experiment_dir = os.path.join(self.config.RESULT_DIR, datetime_sequence) sample_output_dir_1 = os.path.join(experiment_dir, "sample", output_name_1) sample_output_dir_2 = os.path.join(experiment_dir, "sample", output_name_2) weights_output_dir_1 = os.path.join(experiment_dir, "weights", output_name_1) weights_output_dir_2 = os.path.join(experiment_dir, "weights", output_name_2) weights_output_dir_resume = os.path.join(experiment_dir, "weights", "resume") os.makedirs(sample_output_dir_1, exist_ok=True) os.makedirs(sample_output_dir_2, exist_ok=True) os.makedirs(weights_output_dir_1, exist_ok=True) os.makedirs(weights_output_dir_2, exist_ok=True) os.makedirs(weights_output_dir_resume, exist_ok=True) counter = 0 for epoch in range(self.config.EPOCH): """ logger.loss_df.to_csv(os.path.join(experiment_dir, self.config.DATASET_A + "_" + self.config.DATASET_B + ".csv"), index=False) """ if epoch % 100 == 0: torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netG_A2B.pth')) torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netG_B2A.pth')) torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netD_A.pth')) torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netD_B.pth')) for i, batch in enumerate(dataloader): # Set model input real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) ###### Generators A2B and B2A ###### optimizer_G.zero_grad() # GAN loss fake_B = netG_A2B(real_A) pred_fake_B = netD_B(fake_B) loss_GAN_A2B = criterion_GAN(pred_fake_B, target_real) fake_A = netG_B2A(real_B) pred_fake_A = netD_A(fake_A) loss_GAN_B2A = criterion_GAN(pred_fake_A, target_real) # Cycle loss recovered_A = netG_B2A(fake_B) loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0 recovered_B = netG_A2B(fake_A) loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0 # Total loss loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB loss_G.backward() optimizer_G.step() ################################### ###### Discriminator A ###### optimizer_D_A.zero_grad() # Real loss pred_A = netD_A(real_A) loss_D_real = criterion_GAN(pred_A, target_real) # Fake loss fake_A_ = fake_A_buffer.push_and_pop(fake_A) pred_fake = netD_A(fake_A_.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A.backward() optimizer_D_A.step() ################################### ###### Discriminator B ###### optimizer_D_B.zero_grad() # Real loss pred_B = netD_B(real_B) loss_D_real = criterion_GAN(pred_B, target_real) # Fake loss fake_B_ = fake_B_buffer.push_and_pop(fake_B) pred_fake = netD_B(fake_B_.detach()) loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B.backward() optimizer_D_B.step() # Progress report (http://localhost:8097) logger.log({'loss_G': loss_G, 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)}, images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B}) if counter % 500 == 0: real_A_sample = real_A.cpu().detach().numpy()[0] pred_A_sample = fake_A.cpu().detach().numpy()[0] real_B_sample = real_B.cpu().detach().numpy()[0] pred_B_sample = fake_B.cpu().detach().numpy()[0] combine_sample_1 = np.concatenate([real_A_sample, pred_B_sample], axis=2) combine_sample_2 = np.concatenate([real_B_sample, pred_A_sample], axis=2) file_1 = "{0}_{1}.jpg".format(epoch, counter) output_sample_image(os.path.join(sample_output_dir_1, file_1), combine_sample_1) file_2 = "{0}_{1}.jpg".format(epoch, counter) output_sample_image(os.path.join(sample_output_dir_2, file_2), combine_sample_2) counter += 1 # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netG_A2B.pth')) torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netG_B2A.pth')) torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netD_A.pth')) torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netD_B.pth'))
from utils import ReplayBuffer, LambdaLR, sample_images #load the args args = TrainOptions().parse() # Calculate output of size discriminator (PatchGAN) patch = (1, args.img_height//(2**args.n_D_layers) - 2 , args.img_width//(2**args.n_D_layers) - 2) # Initialize generator and discriminator G__AB, D__B, G__BA, D__A = Create_nets(args) # Loss functions criterion_GAN, criterion_cycle, criterion_identity = Get_loss_func(args) # Optimizers optimizer_G, optimizer_D_B, optimizer_D_A = Get_optimizers(args, G__AB, G__BA, D__B, D__A ) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step) # Configure dataloaders train_dataloader,test_dataloader,_ = Get_dataloader(args) # Buffers of previously generated samples fake_Y_A_buffer = ReplayBuffer() fake_X_B_buffer = ReplayBuffer() # ---------- # Training # ----------
netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) # Lossess criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() # Optimizers & LR schedulers optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) lambda_LR=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch) lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_LR.step) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda_LR.step) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda_LR.step) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size) input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False) target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False) fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Dataset loader
def train_seg_model(args): # model model = None if args.model_name == "UNet": model = UNet(n_channels=args.in_channels, n_classes=args.class_num) elif args.model_name == "PSP": model = pspnet.PSPNet(n_classes=19, input_size=(512, 512)) model.load_pretrained_model( model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel") model.classification = nn.Conv2d(512, args.class_num, kernel_size=1) else: raise AssertionError("Unknow modle: {}".format(args.model_name)) model = nn.DataParallel(model) model.cuda() # optimizer optimizer = None if args.optim_name == "Adam": optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1.0e-3) elif args.optim_name == "SGD": optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.init_lr, momentum=0.9, weight_decay=0.0005) else: raise AssertionError("Unknow optimizer: {}".format(args.optim_name)) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(args.maxepoch, 0, 0).step) # dataloader train_data_dir = os.path.join(args.data_dir, args.tumor_type, "train") train_dloader = gen_dloader(train_data_dir, args.batch_size, mode="train", normalize=args.normalize, tumor_type=args.tumor_type) test_data_dir = os.path.join(args.data_dir, args.tumor_type, "val") val_dloader = gen_dloader(test_data_dir, args.batch_size, mode="val", normalize=args.normalize, tumor_type=args.tumor_type) # training save_model_dir = os.path.join(args.model_dir, args.tumor_type, args.session) if not os.path.exists(save_model_dir): os.makedirs(save_model_dir) best_dice = 0.0 for epoch in np.arange(0, args.maxepoch): print('Epoch {}/{}'.format(epoch + 1, args.maxepoch)) print('-' * 10) since = time.time() for phase in ['train', 'val']: if phase == 'train': dloader = train_dloader scheduler.step() for param_group in optimizer.param_groups: print("Current LR: {:.8f}".format(param_group['lr'])) model.train() # Set model to training mode else: dloader = val_dloader model.eval() # Set model to evaluate mode metrics = defaultdict(float) epoch_samples = 0 for batch_ind, (imgs, masks) in enumerate(dloader): inputs = Variable(imgs.cuda()) masks = Variable(masks.cuda()) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss = calc_loss(outputs, masks, metrics, bce_weight=args.bce_weight) if phase == 'train': loss.backward() optimizer.step() # statistics epoch_samples += inputs.size(0) print_metrics(metrics, epoch_samples, phase) epoch_dice = metrics['dice'] / epoch_samples # deep copy the model if phase == 'val' and (epoch_dice > best_dice or epoch > args.maxepoch - 5): best_dice = epoch_dice best_model = copy.deepcopy(model.state_dict()) best_model_name = "-".join([ args.model_name, "{:03d}-{:.3f}.pth".format(epoch, best_dice) ]) torch.save(best_model, os.path.join(save_model_dir, best_model_name)) time_elapsed = time.time() - since print('Epoch {:2d} takes {:.0f}m {:.0f}s'.format( epoch, time_elapsed // 60, time_elapsed % 60)) print( "================================================================================" ) print("Training finished...")
netD_A = Discriminator(3).to(device) netD_B = Discriminator(3).to(device) netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) # optimizers and learning rate schedulers optimizer_G = Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=opt.lr, betas=BETAS) optimizer_D_en = Adam(netD_A.parameters(), lr=opt.lr, betas=BETAS) optimizer_D_zh = Adam(netD_B.parameters(), lr=opt.lr, betas=BETAS) lr_scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR( opt.n_epochs, 0, DECAY_EPOCH).step) lr_scheduler_D_en = lr_scheduler.LambdaLR(optimizer_D_en, lr_lambda=LambdaLR( opt.n_epochs, 0, DECAY_EPOCH).step) lr_scheduler_D_zh = lr_scheduler.LambdaLR(optimizer_D_zh, lr_lambda=LambdaLR( opt.n_epochs, 0, DECAY_EPOCH).step) train()