def train(self, args): # For transforming the input image transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize((args.load_height, args.load_width)), transforms.RandomCrop((args.crop_height, args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset_dirs = utils.get_traindata_link(args.dataset_dir) # Pytorch dataloader a_loader = torch.utils.data.DataLoader(dsets.ImageFolder( dataset_dirs['trainA'], transform=transform), batch_size=args.batch_size, shuffle=True, num_workers=4) b_loader = torch.utils.data.DataLoader(dsets.ImageFolder( dataset_dirs['trainB'], transform=transform), batch_size=args.batch_size, shuffle=True, num_workers=4) a_fake_sample = utils.Sample_from_Pool() b_fake_sample = utils.Sample_from_Pool() max_len = max(len(a_loader), len(b_loader)) steps = 0 for epoch in range(self.start_epoch, args.epochs): lr = self.g_optimizer.param_groups[0]['lr'] print('learning rate = %.7f' % lr) a_it = iter(a_loader) b_it = iter(b_loader) for i in range(max_len): try: a_real = next(a_it)[0] except: a_it = iter(a_loader) try: b_real = next(b_it)[0] except: b_it = iter(b_loader) # Generator Computations ################################################## set_grad([self.Da, self.Db], False) self.g_optimizer.zero_grad() a_real = Variable(a_real) b_real = Variable(b_real) a_real, b_real = utils.cuda([a_real, b_real]) # Forward pass through generators ################################################## a_fake = self.Gab(b_real) b_fake = self.Gba(a_real) a_recon = self.Gab(b_fake) b_recon = self.Gba(a_fake) # Identity losses ################################################### a_idt = self.Gab(a_real) b_idt = self.Gba(b_real) # lamda = 1.75E+12 lamda = args.lamda * args.idt_coef a_idt_loss = self.L1(a_idt, a_real) * lamda b_idt_loss = self.L1(b_idt, b_real) * lamda # a_real_features = vgg.get_features(a_real) # b_real_features = vgg.get_features(b_real) # a_fake_features = vgg.get_features(a_fake) # b_fake_features = vgg.get_features(b_fake) # Content losses # content_loss_weight = 1.50 # content_loss_weight = 1 # a_content_loss = vgg.get_content_loss(b_fake_features, a_real_features) * content_loss_weight # b_content_loss = vgg.get_content_loss(a_fake_features, b_real_features) * content_loss_weight # style losse # style_loss_weight = 3.00E+05 # style_loss_weight = 1 # a_style_loss = vgg.get_style_loss(a_fake_features, a_real_features) * style_loss_weight # b_style_loss = vgg.get_style_loss(b_fake_features, b_real_features) * style_loss_weight # Adversarial losses ################################################### a_fake_dis = self.Da(a_fake) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_fake_dis.size()))) # gen_loss_weight = 4.50E+08 gen_loss_weight = 1 a_gen_loss = self.MSE(a_fake_dis, real_label) * gen_loss_weight b_gen_loss = self.MSE(b_fake_dis, real_label) * gen_loss_weight # Cycle consistency losses ################################################### a_cycle_loss = self.L1(a_recon, a_real) * args.lamda b_cycle_loss = self.L1(b_recon, b_real) * args.lamda # lamda = 3.50E+12 # a_cycle_loss = self.L1(a_recon, a_real) * lamda # b_cycle_loss = self.L1(b_recon, b_real) * lamda # gen_loss = a_gen_loss + b_gen_loss +\ # a_cycle_loss + b_cycle_loss +\ # a_style_loss + b_style_loss +\ # a_content_loss + b_content_loss +\ # a_idt_loss + b_idt_loss # # Total generators losses # ################################################### gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss # # Update generators # ################################################### gen_loss.backward() self.g_optimizer.step() # # # Discriminator Computations ################################################# set_grad([self.Da, self.Db], True) self.d_optimizer.zero_grad() # Sample from history of generated images ################################################# a_fake = Variable( torch.Tensor( a_fake_sample([a_fake.cpu().data.numpy()])[0])) b_fake = Variable( torch.Tensor( b_fake_sample([b_fake.cpu().data.numpy()])[0])) a_fake, b_fake = utils.cuda([a_fake, b_fake]) # Forward pass through discriminators ################################################# a_real_dis = self.Da(a_real) a_fake_dis = self.Da(a_fake) b_real_dis = self.Db(b_real) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_real_dis.size()))) fake_label = utils.cuda( Variable(torch.zeros(a_fake_dis.size()))) # Discriminator losses ################################################## a_dis_real_loss = self.MSE(a_real_dis, real_label) a_dis_fake_loss = self.MSE(a_fake_dis, fake_label) b_dis_real_loss = self.MSE(b_real_dis, real_label) b_dis_fake_loss = self.MSE(b_fake_dis, fake_label) # Total discriminators losses a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5 b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5 # Update discriminators ################################################## a_dis_loss.backward() b_dis_loss.backward() self.d_optimizer.step() steps += 1 if steps % print_msg == 0: print( "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" % (epoch, i + 1, max(len(a_loader), len(b_loader)), gen_loss, a_dis_loss + b_dis_loss)) # Override the latest checkpoint ####################################################### utils.save_checkpoint( { 'epoch': epoch + 1, 'Da': self.Da.state_dict(), 'Db': self.Db.state_dict(), 'Gab': self.Gab.state_dict(), 'Gba': self.Gba.state_dict(), 'd_optimizer': self.d_optimizer.state_dict(), 'g_optimizer': self.g_optimizer.state_dict() }, '%s/latest.ckpt' % (args.checkpoint_dir)) # Update learning rates ######################## self.g_lr_scheduler.step() self.d_lr_scheduler.step()
def train(self, args): transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset) # let the choice of dataset configurable if self.args.dataset == 'voc2012': labeled_set = VOCDataset(root_path=root, name='label', ratio=0.2, transformation=transform, augmentation=None) unlabeled_set = VOCDataset(root_path=root, name='unlabel', ratio=0.2, transformation=transform, augmentation=None) val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=transform, augmentation=None) elif self.args.dataset == 'cityscapes': labeled_set = CityscapesDataset(root_path=root_cityscapes, name='label', ratio=0.5, transformation=transform, augmentation=None) unlabeled_set = CityscapesDataset(root_path=root_cityscapes, name='unlabel', ratio=0.5, transformation=transform, augmentation=None) val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform, augmentation=None) elif self.args.dataset == 'acdc': labeled_set = ACDCDataset(root_path=root_acdc, name='label', ratio=0.5, transformation=transform, augmentation=None) unlabeled_set = ACDCDataset(root_path=root_acdc, name='unlabel', ratio=0.5, transformation=transform, augmentation=None) val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform, augmentation=None) ''' https://discuss.pytorch.org/t/about-the-relation-between-batch-size-and-length-of-data-loader/10510 ^^ The reason for using drop_last=True so as to obtain an even size of all the batches and deleting the last batch with less images ''' labeled_loader = DataLoader(labeled_set, batch_size=args.batch_size, shuffle=True, drop_last=True) unlabeled_loader = DataLoader(unlabeled_set, batch_size=args.batch_size, shuffle=True, drop_last=True) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, drop_last=True) new_img_fake_sample = utils.Sample_from_Pool() img_fake_sample = utils.Sample_from_Pool() gt_fake_sample = utils.Sample_from_Pool() img_dis_loss, gt_dis_loss, unsupervisedloss, fullsupervisedloss = 0, 0, 0, 0 ### Variable to regulate the frequency of update between Discriminators and Generators counter = 0 for epoch in range(self.start_epoch, args.epochs): lr = self.g_optimizer.param_groups[0]['lr'] print('learning rate = %.7f' % lr) self.Gsi.train() self.Gis.train() # if (epoch+1)%10 == 0: # args.lamda_img = args.lamda_img + 0.08 # args.lamda_gt = args.lamda_gt + 0.04 for i, ((l_img, l_gt, _), (unl_img, _, _)) in enumerate(zip(labeled_loader, unlabeled_loader)): # step step = epoch * min(len(labeled_loader), len(unlabeled_loader)) + i + 1 l_img, unl_img, l_gt = utils.cuda([l_img, unl_img, l_gt], args.gpu_ids) # Generator Computations ################################################## set_grad([self.Di, self.Ds, self.old_Di], False) set_grad([self.old_Gsi, self.old_Gis], False) self.g_optimizer.zero_grad() # Forward pass through generators ################################################## fake_img = self.Gis( make_one_hot(l_gt, args.dataset, args.gpu_ids).float()) fake_gt = self.Gsi(unl_img.float()) ### having 21 channels lab_gt = self.Gsi(l_img) ### having 21 channels ### Getting the outputs of the model to correct dimensions fake_img = self.interp(fake_img) fake_gt = self.interp(fake_gt) lab_gt = self.interp(lab_gt) # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0) ### will get into no channels # fake_gt = fake_gt.unsqueeze(1) ### will get into 1 channel only # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids) lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1)) ### Again applying activations lab_gt = self.activation_softmax(lab_gt) fake_gt = self.activation_softmax(fake_gt) # fake_gt = fake_gt.data.max(1)[1].squeeze_(1).squeeze_(0) # fake_gt = fake_gt.unsqueeze(1) # fake_gt = make_one_hot(fake_gt, args.dataset, args.gpu_ids) # fake_img = self.activation_tanh(fake_img) recon_img = self.Gis(fake_gt.float()) recon_lab_img = self.Gis(lab_gt.float()) recon_gt = self.Gsi(fake_img.float()) ### Getting the outputs of the model to correct dimensions recon_img = self.interp(recon_img) recon_lab_img = self.interp(recon_lab_img) recon_gt = self.interp(recon_gt) ### This is for the case of the new loss between the recon_img from resnet and deeplab network resnet_fake_gt = self.old_Gsi(unl_img.float()) resnet_lab_gt = self.old_Gsi(l_img) resnet_lab_gt = self.activation_softmax(resnet_lab_gt) resnet_fake_gt = self.activation_softmax(resnet_fake_gt) resnet_recon_img = self.old_Gis(resnet_fake_gt.float()) resnet_recon_lab_img = self.old_Gis(resnet_lab_gt.float()) ## Applying the tanh activations # recon_img = self.activation_tanh(recon_img) # recon_lab_img = self.activation_tanh(recon_lab_img) # Adversarial losses ################################################### fake_img_dis = self.Di(fake_img) resnet_fake_img_dis = self.old_Di(recon_img) ### For passing different type of input to Ds fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_( 1).squeeze_(0) fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1) fake_gt_discriminator = make_one_hot(fake_gt_discriminator, args.dataset, args.gpu_ids) fake_gt_dis = self.Ds(fake_gt_discriminator.float()) # lab_gt_dis = self.Ds(lab_gt) real_label_gt = utils.cuda( Variable(torch.ones(fake_gt_dis.size())), args.gpu_ids) real_label_img = utils.cuda( Variable(torch.ones(fake_img_dis.size())), args.gpu_ids) # here is much better to have a cross entropy loss for classification. img_gen_loss = self.MSE(fake_img_dis, real_label_img) gt_gen_loss = self.MSE(fake_gt_dis, real_label_gt) # gt_label_gen_loss = self.MSE(lab_gt_dis, real_label) # Cycle consistency losses ################################################### resnet_img_cycle_loss = self.MSE(resnet_fake_img_dis, real_label_img) # img_cycle_loss = self.L1(recon_img, unl_img) # img_cycle_loss_perceptual = perceptual_loss(recon_img, unl_img, args.gpu_ids) gt_cycle_loss = self.CE(recon_gt, l_gt.squeeze(1)) # lab_img_cycle_loss = self.L1(recon_lab_img, l_img) * args.lamda # Total generators losses ################################################### # lab_loss_CE = self.CE(lab_gt, l_gt.squeeze(1)) lab_loss_MSE = self.L1(fake_img, l_img) # lab_loss_perceptual = perceptual_loss(fake_img, l_img, args.gpu_ids) fullsupervisedloss = args.lab_CE_weight * lab_loss_CE + args.lab_MSE_weight * lab_loss_MSE unsupervisedloss = args.adversarial_weight * ( img_gen_loss + gt_gen_loss ) + resnet_img_cycle_loss + gt_cycle_loss * args.lamda_gt gen_loss = fullsupervisedloss + unsupervisedloss # Update generators ################################################### gen_loss.backward() self.g_optimizer.step() if counter % 1 == 0: # Discriminator Computations ################################################# set_grad([self.Di, self.Ds, self.old_Di], True) self.d_optimizer.zero_grad() # Sample from history of generated images ################################################# if torch.rand(1) < 0.0: fake_img = self.gauss_noise(fake_img.cpu()) fake_gt = self.gauss_noise(fake_gt.cpu()) recon_img = Variable( torch.Tensor( new_img_fake_sample([recon_img.cpu().data.numpy() ])[0])) fake_img = Variable( torch.Tensor( img_fake_sample([fake_img.cpu().data.numpy()])[0])) # lab_gt = Variable(torch.Tensor(gt_fake_sample([lab_gt.cpu().data.numpy()])[0])) fake_gt = Variable( torch.Tensor( gt_fake_sample([fake_gt.cpu().data.numpy()])[0])) recon_img, fake_img, fake_gt = utils.cuda( [recon_img, fake_img, fake_gt], args.gpu_ids) # Forward pass through discriminators ################################################# unl_img_dis = self.Di(unl_img) fake_img_dis = self.Di(fake_img) resnet_recon_img_dis = self.old_Di(resnet_recon_img) resnet_fake_img_dis = self.old_Di(recon_img) # lab_gt_dis = self.Ds(lab_gt) l_gt = make_one_hot(l_gt, args.dataset, args.gpu_ids) real_gt_dis = self.Ds(l_gt.float()) fake_gt_discriminator = fake_gt.data.max(1)[1].squeeze_( 1).squeeze_(0) fake_gt_discriminator = fake_gt_discriminator.unsqueeze(1) fake_gt_discriminator = make_one_hot( fake_gt_discriminator, args.dataset, args.gpu_ids) fake_gt_dis = self.Ds(fake_gt_discriminator.float()) real_label_img = utils.cuda( Variable(torch.ones(unl_img_dis.size())), args.gpu_ids) fake_label_img = utils.cuda( Variable(torch.zeros(fake_img_dis.size())), args.gpu_ids) real_label_gt = utils.cuda( Variable(torch.ones(real_gt_dis.size())), args.gpu_ids) fake_label_gt = utils.cuda( Variable(torch.zeros(fake_gt_dis.size())), args.gpu_ids) # Discriminator losses ################################################## img_dis_real_loss = self.MSE(unl_img_dis, real_label_img) img_dis_fake_loss = self.MSE(fake_img_dis, fake_label_img) gt_dis_real_loss = self.MSE(real_gt_dis, real_label_gt) gt_dis_fake_loss = self.MSE(fake_gt_dis, fake_label_gt) # lab_gt_dis_fake_loss = self.MSE(lab_gt_dis, fake_label) cycle_img_dis_real_loss = self.MSE(resnet_recon_img_dis, real_label_img) cycle_img_dis_fake_loss = self.MSE(resnet_fake_img_dis, fake_label_img) # Total discriminators losses img_dis_loss = (img_dis_real_loss + img_dis_fake_loss) * 0.5 gt_dis_loss = (gt_dis_real_loss + gt_dis_fake_loss) * 0.5 # lab_gt_dis_loss = (gt_dis_real_loss + lab_gt_dis_fake_loss)*0.33 cycle_img_dis_loss = cycle_img_dis_real_loss + cycle_img_dis_fake_loss # Update discriminators ################################################## discriminator_loss = args.discriminator_weight * ( img_dis_loss + gt_dis_loss) + cycle_img_dis_loss discriminator_loss.backward() # lab_gt_dis_loss.backward() self.d_optimizer.step() print( "Epoch: (%3d) (%5d/%5d) | Dis Loss:%.2e | Unlab Gen Loss:%.2e | Lab Gen loss:%.2e" % (epoch, i + 1, min(len(labeled_loader), len(unlabeled_loader)), img_dis_loss + gt_dis_loss, unsupervisedloss, fullsupervisedloss)) self.writer_semisuper.add_scalars( 'Dis Loss', { 'img_dis_loss': img_dis_loss, 'gt_dis_loss': gt_dis_loss, 'cycle_img_dis_loss': cycle_img_dis_loss }, len(labeled_loader) * epoch + i) self.writer_semisuper.add_scalars( 'Unlabelled Loss', { 'img_gen_loss': img_gen_loss, 'gt_gen_loss': gt_gen_loss, 'img_cycle_loss': resnet_img_cycle_loss, 'gt_cycle_loss': gt_cycle_loss }, len(labeled_loader) * epoch + i) self.writer_semisuper.add_scalars( 'Labelled Loss', { 'lab_loss_CE': lab_loss_CE, 'lab_loss_MSE': lab_loss_MSE }, len(labeled_loader) * epoch + i) counter += 1 ### For getting the mean IoU self.Gsi.eval() self.Gis.eval() with torch.no_grad(): for i, (val_img, val_gt, _) in enumerate(val_loader): val_img, val_gt = utils.cuda([val_img, val_gt], args.gpu_ids) outputs = self.Gsi(val_img) outputs = self.interp(outputs) outputs = self.activation_softmax(outputs) pred = outputs.data.max(1)[1].cpu().numpy() gt = val_gt.squeeze().data.cpu().numpy() self.running_metrics_val.update(gt, pred) score, class_iou = self.running_metrics_val.get_scores() self.running_metrics_val.reset() print('The mIoU for the epoch is: ', score["Mean IoU : \t"]) ### For displaying the images generated by generator on tensorboard using validation images val_image, val_gt, _ = iter(val_loader).next() val_image, val_gt = utils.cuda([val_image, val_gt], args.gpu_ids) with torch.no_grad(): fake_label = self.Gsi(val_image).detach() fake_label = self.interp(fake_label) fake_label = self.activation_softmax(fake_label) fake_label = fake_label.data.max(1)[1].squeeze_(1).squeeze_(0) fake_label = fake_label.unsqueeze(1) fake_label = make_one_hot(fake_label, args.dataset, args.gpu_ids) fake_img = self.Gis(fake_label).detach() fake_img = self.interp(fake_img) # fake_img = self.activation_tanh(fake_img) fake_img_from_labels = self.Gis( make_one_hot(val_gt, args.dataset, args.gpu_ids).float()).detach() fake_img_from_labels = self.interp(fake_img_from_labels) # fake_img_from_labels = self.activation_tanh(fake_img_from_labels) fake_label_regenerated = self.Gsi( fake_img_from_labels).detach() fake_label_regenerated = self.interp(fake_label_regenerated) fake_label_regenerated = self.activation_softmax( fake_label_regenerated) fake_prediction_label = fake_label.data.max(1)[1].squeeze_( 1).cpu().numpy() fake_regenerated_label = fake_label_regenerated.data.max( 1)[1].squeeze_(1).cpu().numpy() val_gt = val_gt.cpu() fake_img = fake_img.cpu() fake_img_from_labels = fake_img_from_labels.cpu() ### Now i am going to revert back the transformation on these images if self.args.dataset == 'voc2012' or self.args.dataset == 'cityscapes': trans_mean = [0.5, 0.5, 0.5] trans_std = [0.5, 0.5, 0.5] for i in range(3): fake_img[:, i, :, :] = ( (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i]) fake_img_from_labels[:, i, :, :] = ( (fake_img_from_labels[:, i, :, :] * trans_std[i]) + trans_mean[i]) elif self.args.dataset == 'acdc': trans_mean = [0.5] trans_std = [0.5] for i in range(1): fake_img[:, i, :, :] = ( (fake_img[:, i, :, :] * trans_std[i]) + trans_mean[i]) fake_img_from_labels[:, i, :, :] = ( (fake_img_from_labels[:, i, :, :] * trans_std[i]) + trans_mean[i]) ### display_tensor is the final tensor that will be displayed on tensorboard display_tensor_label = torch.zeros([ fake_label.shape[0], 3, fake_label.shape[2], fake_label.shape[3] ]) display_tensor_gt = torch.zeros( [val_gt.shape[0], 3, val_gt.shape[2], val_gt.shape[3]]) display_tensor_regen_label = torch.zeros([ fake_label_regenerated.shape[0], 3, fake_label_regenerated.shape[2], fake_label_regenerated.shape[3] ]) for i in range(fake_prediction_label.shape[0]): new_img_label = fake_prediction_label[i] new_img_label = utils.colorize_mask( new_img_label, self.args.dataset ) ### So this is the generated image in PIL.Image format img_tensor_label = utils.PIL_to_tensor(new_img_label, self.args.dataset) display_tensor_label[i, :, :, :] = img_tensor_label display_tensor_gt[i, :, :, :] = val_gt[i] regen_label = fake_regenerated_label[i] regen_label = utils.colorize_mask(regen_label, self.args.dataset) regen_tensor_label = utils.PIL_to_tensor( regen_label, self.args.dataset) display_tensor_regen_label[i, :, :, :] = regen_tensor_label self.writer_semisuper.add_image( 'Generated segmented image: ', torchvision.utils.make_grid(display_tensor_label, nrow=2, normalize=True), epoch) self.writer_semisuper.add_image( 'Generated image back from segmentation: ', torchvision.utils.make_grid(fake_img, nrow=2, normalize=True), epoch) self.writer_semisuper.add_image( 'Ground truth for the image: ', torchvision.utils.make_grid(display_tensor_gt, nrow=2, normalize=True), epoch) self.writer_semisuper.add_image( 'Image generated from val labels: ', torchvision.utils.make_grid(fake_img_from_labels, nrow=2, normalize=True), epoch) self.writer_semisuper.add_image( 'Labels generated back from the cycle: ', torchvision.utils.make_grid(display_tensor_regen_label, nrow=2, normalize=True), epoch) if score["Mean IoU : \t"] >= self.best_iou: self.best_iou = score["Mean IoU : \t"] # Override the latest checkpoint ####################################################### utils.save_checkpoint( { 'epoch': epoch + 1, 'Di': self.Di.state_dict(), 'Ds': self.Ds.state_dict(), 'Gis': self.Gis.state_dict(), 'Gsi': self.Gsi.state_dict(), 'd_optimizer': self.d_optimizer.state_dict(), 'g_optimizer': self.g_optimizer.state_dict(), 'best_iou': self.best_iou, 'class_iou': class_iou }, '%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) # Update learning rates ######################## self.g_lr_scheduler.step() self.d_lr_scheduler.step() self.writer_semisuper.close()
def train(self, args): # Test input transform_test = transforms.Compose([ transforms.Resize((args.crop_height, args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset_dirs = utils.get_testdata_link(args.dataset_dir) a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform_test) b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform_test) a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) # For transforming the input image transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize((args.load_height, args.load_width)), transforms.RandomCrop((args.crop_height, args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset_dirs = utils.get_traindata_link(args.dataset_dir) # Pytorch dataloader a_loader = torch.utils.data.DataLoader(dsets.ImageFolder( dataset_dirs['trainA'], transform=transform), batch_size=args.batch_size, shuffle=True, num_workers=4) b_loader = torch.utils.data.DataLoader(dsets.ImageFolder( dataset_dirs['trainB'], transform=transform), batch_size=args.batch_size, shuffle=True, num_workers=4) a_fake_sample = utils.Sample_from_Pool() b_fake_sample = utils.Sample_from_Pool() for epoch in range(self.start_epoch, args.epochs): lr = self.g_optimizer.param_groups[0]['lr'] print('learning rate = %.7f' % lr) for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)): # step step = epoch * min(len(a_loader), len(b_loader)) + i + 1 # Generator Computations ################################################## set_grad([self.Da, self.Db], False) self.g_optimizer.zero_grad() a_real = Variable(a_real[0]) b_real = Variable(b_real[0]) a_real, b_real = utils.cuda([a_real, b_real]) # Forward pass through generators ################################################## a_fake = self.Gab(b_real) b_fake = self.Gba(a_real) a_recon = self.Gab(b_fake) b_recon = self.Gba(a_fake) a_idt = self.Gab(a_real) b_idt = self.Gba(b_real) # Identity losses ################################################### a_idt_loss = self.L1(a_idt, a_real) * args.lamda * args.idt_coef b_idt_loss = self.L1(b_idt, b_real) * args.lamda * args.idt_coef # Adversarial losses ################################################### a_fake_dis = self.Da(a_fake) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_fake_dis.size()))) a_gen_loss = self.MSE(a_fake_dis, real_label) b_gen_loss = self.MSE(b_fake_dis, real_label) # Cycle consistency losses ################################################### a_cycle_loss = self.L1(a_recon, a_real) * args.lamda b_cycle_loss = self.L1(b_recon, b_real) * args.lamda # Total generators losses ################################################### gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss # Update generators ################################################### gen_loss.backward() self.g_optimizer.step() # Discriminator Computations ################################################# set_grad([self.Da, self.Db], True) self.d_optimizer.zero_grad() # Sample from history of generated images ################################################# a_fake = Variable( torch.Tensor( a_fake_sample([a_fake.cpu().data.numpy()])[0])) b_fake = Variable( torch.Tensor( b_fake_sample([b_fake.cpu().data.numpy()])[0])) a_fake, b_fake = utils.cuda([a_fake, b_fake]) # Forward pass through discriminators ################################################# a_real_dis = self.Da(a_real) a_fake_dis = self.Da(a_fake) b_real_dis = self.Db(b_real) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_real_dis.size()))) fake_label = utils.cuda( Variable(torch.zeros(a_fake_dis.size()))) # Discriminator losses ################################################## a_dis_real_loss = self.MSE(a_real_dis, real_label) a_dis_fake_loss = self.MSE(a_fake_dis, fake_label) b_dis_real_loss = self.MSE(b_real_dis, real_label) b_dis_fake_loss = self.MSE(b_fake_dis, fake_label) # Total discriminators losses a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5 b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5 # Update discriminators ################################################## a_dis_loss.backward() b_dis_loss.backward() self.d_optimizer.step() print( "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" % (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss, a_dis_loss + b_dis_loss)) # Override the latest checkpoint ####################################################### utils.save_checkpoint( { 'epoch': epoch + 1, 'Da': self.Da.state_dict(), 'Db': self.Db.state_dict(), 'Gab': self.Gab.state_dict(), 'Gba': self.Gba.state_dict(), 'd_optimizer': self.d_optimizer.state_dict(), 'g_optimizer': self.g_optimizer.state_dict() }, '%s/latest.ckpt' % (args.checkpoint_dir)) # Save image current : ####################################################################### """ run """ a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True) b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True) a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test]) self.Gab.eval() self.Gba.eval() with torch.no_grad(): a_fake_test = self.Gab(b_real_test) b_fake_test = self.Gba(a_real_test) a_recon_test = self.Gab(b_fake_test) b_recon_test = self.Gba(a_fake_test) pic = (torch.cat([ a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test ], dim=0).data + 1) / 2.0 if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) torchvision.utils.save_image(pic, args.results_dir + '/sample_{}.jpg'.format(epoch), nrow=3) self.Gab.train() self.Gba.train() # Update learning rates ######################## self.g_lr_scheduler.step() self.d_lr_scheduler.step()
def train(self, args): # For transforming the input image transform = transforms.Compose([ # [transforms.RandomHorizontalFlip(), transforms.Resize((480, 1440)), # transforms.RandomCrop((args.crop_height,args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset_dirs = utils.get_traindata_link(args.dataset_dir) # Pytorch dataloader dataset = torch.utils.data.DataLoader(dsets.ImageFolder( '/train_merged/', transform=transform), batch_size=args.batch_size, shuffle=True, num_workers=4) # a_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainA'], transform=transform), # batch_size=args.batch_size, shuffle=True, num_workers=4) # b_loader = torch.utils.data.DataLoader(dsets.ImageFolder(dataset_dirs['trainB'], transform=transform), # batch_size=args.batch_size, shuffle=True, num_workers=4) a_fake_sample = utils.Sample_from_Pool() b_fake_sample = utils.Sample_from_Pool() for epoch in range(self.start_epoch, args.epochs): lr = self.g_optimizer.param_groups[0]['lr'] print('learning rate = %.7f' % lr) for i, x in enumerate(dataset): # step step = epoch * len(dataset) + i + 1 # Generator Computations ################################################## set_grad([self.Da, self.Db], False) self.g_optimizer.zero_grad() x = Variable(x[0]) x = utils.cuda([x])[0] shape = x.shape a_real, b_real = x[:, :, :, shape[3] // 2:], x[:, :, :, shape[3] // 2:] # Forward pass through generators ################################################## a_fake = self.Gab(b_real) b_fake = self.Gba(a_real) a_recon = self.Gab(b_fake) b_recon = self.Gba(a_fake) a_idt = self.Gab(a_real) b_idt = self.Gba(b_real) # Identity losses ################################################### a_idt_loss = self.L1(a_idt, a_real) * args.lamda * args.idt_coef b_idt_loss = self.L1(b_idt, b_real) * args.lamda * args.idt_coef # Adversarial losses ################################################### a_fake_dis = self.Da(a_fake) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_fake_dis.size()))) a_gen_loss = self.MSE(a_fake_dis, real_label) b_gen_loss = self.MSE(b_fake_dis, real_label) # Cycle consistency losses ################################################### a_cycle_loss = self.L1(a_recon, a_real) * args.lamda b_cycle_loss = self.L1(b_recon, b_real) * args.lamda # Total generators losses ################################################### gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss # Update generators ################################################### gen_loss.backward() self.g_optimizer.step() # Discriminator Computations ################################################# set_grad([self.Da, self.Db], True) self.d_optimizer.zero_grad() # Sample from history of generated images ################################################# a_fake = Variable( torch.Tensor( a_fake_sample([a_fake.cpu().data.numpy()])[0])) b_fake = Variable( torch.Tensor( b_fake_sample([b_fake.cpu().data.numpy()])[0])) a_fake, b_fake = utils.cuda([a_fake, b_fake]) # Forward pass through discriminators ################################################# a_real_dis = self.Da(a_real) a_fake_dis = self.Da(a_fake) b_real_dis = self.Db(b_real) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_real_dis.size()))) fake_label = utils.cuda( Variable(torch.zeros(a_fake_dis.size()))) # Discriminator losses ################################################## a_dis_real_loss = self.MSE(a_real_dis, real_label) a_dis_fake_loss = self.MSE(a_fake_dis, fake_label) b_dis_real_loss = self.MSE(b_real_dis, real_label) b_dis_fake_loss = self.MSE(b_fake_dis, fake_label) # Total discriminators losses a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5 b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5 # Update discriminators ################################################## a_dis_loss.backward() b_dis_loss.backward() self.d_optimizer.step() print( "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" % (epoch, i + 1, len(dataset), gen_loss, a_dis_loss + b_dis_loss)) # Override the latest checkpoint ####################################################### utils.save_checkpoint( { 'epoch': epoch + 1, 'Da': self.Da.state_dict(), 'Db': self.Db.state_dict(), 'Gab': self.Gab.state_dict(), 'Gba': self.Gba.state_dict(), 'd_optimizer': self.d_optimizer.state_dict(), 'g_optimizer': self.g_optimizer.state_dict() }, '%s/latest.ckpt' % (args.checkpoint_dir)) # Update learning rates ######################## self.g_lr_scheduler.step() self.d_lr_scheduler.step()
def train(self, args): # For transforming the input image transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), # transforms.Resize((args.load_height,args.load_width)), transforms.RandomCrop((args.crop_height, args.crop_width)), transforms.ToTensor(), # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset_dirs = utils.get_traindata_link(args.dataset_dir) dataset_a = ListDataSet( '/media/l/新加卷1/city/data/river/train_256_9w.lst', transform=transform) dataset_b = ListDataSet('/media/l/新加卷/city/jinan_z3.lst', transform=transform) # Pytorch dataloader a_loader = torch.utils.data.DataLoader(dataset_a, batch_size=args.batch_size, shuffle=True, num_workers=4) b_loader = torch.utils.data.DataLoader(dataset_b, batch_size=args.batch_size, shuffle=True, num_workers=4) a_fake_sample = utils.Sample_from_Pool() b_fake_sample = utils.Sample_from_Pool() for epoch in range(self.start_epoch, args.epochs): lr = self.g_optimizer.param_groups[0]['lr'] print('learning rate = %.7f' % lr) for i, (a_real, b_real) in enumerate(zip(a_loader, b_loader)): # step step = epoch * min(len(a_loader), len(b_loader)) + i + 1 # Generator Computations ################################################## set_grad([self.Da, self.Db], False) self.g_optimizer.zero_grad() a_real = Variable(a_real[0]) b_real = Variable(b_real[0]) a_real, b_real = utils.cuda([a_real, b_real]) # Forward pass through generators ################################################## a_fake = self.Gab(b_real) b_fake = self.Gba(a_real) a_recon = self.Gab(b_fake) b_recon = self.Gba(a_fake) a_idt = self.Gab(a_real) b_idt = self.Gba(b_real) # Identity losses ################################################### a_idt_loss = self.L1(a_idt, a_real) * args.lamda * args.idt_coef b_idt_loss = self.L1(b_idt, b_real) * args.lamda * args.idt_coef # Adversarial losses ################################################### a_fake_dis = self.Da(a_fake) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_fake_dis.size()))) a_gen_loss = self.MSE(a_fake_dis, real_label) b_gen_loss = self.MSE(b_fake_dis, real_label) # Cycle consistency losses ################################################### a_cycle_loss = self.L1(a_recon, a_real) * args.lamda b_cycle_loss = self.L1(b_recon, b_real) * args.lamda # Total generators losses ################################################### gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss # Update generators ################################################### gen_loss.backward() self.g_optimizer.step() # Discriminator Computations ################################################# set_grad([self.Da, self.Db], True) self.d_optimizer.zero_grad() # Sample from history of generated images ################################################# a_fake = Variable( torch.Tensor( a_fake_sample([a_fake.cpu().data.numpy()])[0])) b_fake = Variable( torch.Tensor( b_fake_sample([b_fake.cpu().data.numpy()])[0])) a_fake, b_fake = utils.cuda([a_fake, b_fake]) # Forward pass through discriminators ################################################# a_real_dis = self.Da(a_real) a_fake_dis = self.Da(a_fake) b_real_dis = self.Db(b_real) b_fake_dis = self.Db(b_fake) real_label = utils.cuda(Variable(torch.ones( a_real_dis.size()))) fake_label = utils.cuda( Variable(torch.zeros(a_fake_dis.size()))) # Discriminator losses ################################################## a_dis_real_loss = self.MSE(a_real_dis, real_label) a_dis_fake_loss = self.MSE(a_fake_dis, fake_label) b_dis_real_loss = self.MSE(b_real_dis, real_label) b_dis_fake_loss = self.MSE(b_fake_dis, fake_label) # Total discriminators losses a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5 b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5 # Update discriminators ################################################## a_dis_loss.backward() b_dis_loss.backward() self.d_optimizer.step() print( "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.4f | Dis Loss:%.4f" % (epoch, i + 1, min(len(a_loader), len(b_loader)), gen_loss, a_dis_loss + b_dis_loss)) # Override the latest checkpoint ####################################################### utils.save_checkpoint( { 'epoch': epoch + 1, 'Da': self.Da.state_dict(), 'Db': self.Db.state_dict(), 'Gab': self.Gab.state_dict(), 'Gba': self.Gba.state_dict(), 'd_optimizer': self.d_optimizer.state_dict(), 'g_optimizer': self.g_optimizer.state_dict(), }, '%s/latest.ckpt' % (args.checkpoint_dir)) # Update learning rates ######################## self.g_lr_scheduler.step() self.d_lr_scheduler.step()