def test(args): transform = transforms.Compose( [transforms.Resize((args.crop_height,args.crop_width)), transforms.ToTensor(), # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset_dirs = utils.get_testdata_link(args.dataset_dir) # a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform) # b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform) a_test_data = ListDataSet('/media/l/新加卷1/city/data/river/train_256_9w.lst', transform=transform) b_test_data = ListDataSet('/media/l/新加卷/city/jinan_z3.lst', transform=transform) a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) utils.print_networks([Gab, Gba], ['Gab', 'Gba']) try: ckpt = utils.load_checkpoint('/media/l/新加卷/city/project/cycleGAN-PyTorch/checkpoints/horse2zebra/latest.ckpt') Gab.load_state_dict(ckpt['Gab']) Gba.load_state_dict(ckpt['Gba']) except: print(' [*] No checkpoint!') res = [] for i in range(3): """ run """ a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True) b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True) a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test]) Gab.eval() Gba.eval() with torch.no_grad(): a_fake_test = Gab(b_real_test) b_fake_test = Gba(a_real_test) a_recon_test = Gab(b_fake_test) b_recon_test = Gba(a_fake_test) res.append(a_real_test) res.append(b_fake_test) res.append(b_real_test) res.append(a_fake_test) pic = (torch.cat(res, dim=0).data + 1) / 2.0 if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) torchvision.utils.save_image(pic, args.results_dir + '/sample.png', nrow=2)
def test(args, epoch): transform = transforms.Compose( [transforms.Resize((args.crop_height, args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) dataset_dirs = utils.get_testdata_link(args.dataset_dir) a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform) b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform) a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral) Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral) utils.print_networks([Gab,Gba], ['Gab','Gba']) ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_path)) Gab.load_state_dict(ckpt['Gab']) Gba.load_state_dict(ckpt['Gba']) """ run """ a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True) b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True) a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test]) Gab.eval() Gba.eval() with torch.no_grad(): a_fake_test = Gab(b_real_test) b_fake_test = Gba(a_real_test) a_recon_test = Gab(b_fake_test) b_recon_test = Gba(a_fake_test) # Calculate ssim loss gray = kornia.color.RgbToGrayscale() m = kornia.losses.SSIM(11, 'mean') ba_ssim = m(gray((a_real_test + 1) / 2.0), gray((b_fake_test + 1) / 2.0)) ab_ssim = m(gray((b_real_test + 1) / 2.0), gray((a_fake_test + 1) / 2.0)) pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0 if not os.path.isdir(args.results_path): os.makedirs(args.results_path) torchvision.utils.save_image(pic, args.results_path+'/sample_' + str(epoch) + '_' + str(1 - 2*round(ba_ssim.item(), 4)) + '_' + str(1 - 2*round(ab_ssim.item(), 4)) + '.jpg', nrow=args.batch_size)
def test(args): transform = transforms.Compose( [transforms.Resize((args.crop_height, args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) dataset_dirs = utils.get_testdata_link(args.dataset_dir) a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform) b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform) a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) utils.print_networks([Gab, Gba], ['Gab', 'Gba']) try: # ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir)) ckpt = utils.load_checkpoint('%s' % args.checkpoint) Gab.load_state_dict(ckpt['Gab']) Gba.load_state_dict(ckpt['Gba']) except: print(' [*] No checkpoint!') """ run """ a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True) b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True) a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test]) Gab.eval() Gba.eval() with torch.no_grad(): a_fake_test = Gab(b_real_test) b_fake_test = Gba(a_real_test) a_recon_test = Gab(b_fake_test) b_recon_test = Gba(a_fake_test) pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0 if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) torchvision.utils.save_image(pic, args.results_dir + '/sample.jpg', nrow=3)
def define_load_gen(args, one_direction=True, gen_name='Gba'): ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_path)) if one_direction: G = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral) G.load_state_dict(ckpt[gen_name]) G.eval() return G else: Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral) Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral) Gab.load_state_dict(ckpt['Gab']) Gba.load_state_dict(ckpt['Gba']) Gab.eval() Gba.eval() return Gab, Gba
def __init__(self, args): if args.dataset == 'voc2012': self.n_channels = 21 elif args.dataset == 'cityscapes': self.n_channels = 20 elif args.dataset == 'acdc': self.n_channels = 4 # Define the network self.Gsi = define_Gen( input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) # for image to segmentation utils.print_networks([self.Gsi], ['Gsi']) self.CE = nn.CrossEntropyLoss() self.activation_softmax = nn.Softmax2d() self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(), lr=args.lr, betas=(0.9, 0.999)) ### writer for tensorboard self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised') self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset) self.args = args if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) try: ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir)) self.start_epoch = ckpt['epoch'] self.Gsi.load_state_dict(ckpt['Gsi']) self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer']) self.best_iou = ckpt['best_iou'] except: print(' [*] No checkpoint!') self.start_epoch = 0 self.best_iou = -100
def __init__(self, args): # Define the network ##################################################### self.Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) self.Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) self.Da = define_Dis(input_nc=3, ndf=args.ndf, netD=args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) self.Db = define_Dis(input_nc=3, ndf=args.ndf, netD=args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) utils.print_networks([self.Gab, self.Gba, self.Da, self.Db], ['Gab', 'Gba', 'Da', 'Db']) # Define Loss criterias self.MSE = nn.MSELoss() self.L1 = nn.L1Loss() # Optimizers ##################################################### self.g_optimizer = torch.optim.Adam(itertools.chain( self.Gab.parameters(), self.Gba.parameters()), lr=args.lr, betas=(0.5, 0.999)) self.d_optimizer = torch.optim.Adam(itertools.chain( self.Da.parameters(), self.Db.parameters()), lr=args.lr, betas=(0.5, 0.999)) self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) # Try loading checkpoint ##################################################### if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) try: ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir)) self.start_epoch = ckpt['epoch'] self.Da.load_state_dict(ckpt['Da']) self.Db.load_state_dict(ckpt['Db']) self.Gab.load_state_dict(ckpt['Gab']) self.Gba.load_state_dict(ckpt['Gba']) self.d_optimizer.load_state_dict(ckpt['d_optimizer']) self.g_optimizer.load_state_dict(ckpt['g_optimizer']) except: print(' [*] No checkpoint!') self.start_epoch = 0
def validation(args): ### For selecting the number of channels if args.dataset == 'voc2012': n_channels = 21 elif args.dataset == 'cityscapes': n_channels = 20 elif args.dataset == 'acdc': n_channels = 4 transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset) ## let the choice of dataset configurable if args.dataset == 'voc2012': val_set = VOCDataset(root_path=root, name='val', ratio=0.5, transformation=transform, augmentation=None) elif args.dataset == 'cityscapes': val_set = CityscapesDataset(root_path=root_cityscapes, name='val', ratio=0.5, transformation=transform, augmentation=None) elif args.dataset == 'acdc': val_set = ACDCDataset(root_path=root_acdc, name='val', ratio=0.5, transformation=transform, augmentation=None) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) Gsi = define_Gen(input_nc=3, output_nc=n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm, use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids) Gis = define_Gen(input_nc=n_channels, output_nc=3, ngf=args.ngf, netG='deeplab', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) ### best_iou best_iou = 0 ### Interpolation interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True) ### Softmax activation activation_softmax = nn.Softmax2d() activation_tanh = nn.Tanh() if(args.model == 'supervised_model'): ### loading the checkpoint try: ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir)) Gsi.load_state_dict(ckpt['Gsi']) best_iou = ckpt['best_iou'] except: print(' [*] No checkpoint!') ### run Gsi.eval() for i, (image_test, real_segmentation, image_name) in enumerate(val_loader): image_test = utils.cuda(image_test, args.gpu_ids) seg_map = Gsi(image_test) seg_map = interp(seg_map) seg_map = activation_softmax(seg_map) prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy() ### To convert from 22 --> 1 channel for j in range(prediction.shape[0]): new_img = prediction[j] ### Taking a particular image from the batch new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image ### Now the new_img is PIL.Image new_img.save(os.path.join(args.validation_dir+'/supervised/'+image_name[j]+'.png')) print('Epoch-', str(i+1), ' Done!') print('The iou of the resulting segment maps: ', str(best_iou)) elif(args.model == 'semisupervised_cycleGAN'): ### loading the checkpoint try: ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) Gsi.load_state_dict(ckpt['Gsi']) Gis.load_state_dict(ckpt['Gis']) best_iou = ckpt['best_iou'] except: print(' [*] No checkpoint!') ### run Gsi.eval() for i, (image_test, real_segmentation, image_name) in enumerate(val_loader): image_test, real_segmentation = utils.cuda([image_test, real_segmentation], args.gpu_ids) seg_map = Gsi(image_test) seg_map = interp(seg_map) seg_map = activation_softmax(seg_map) fake_img = Gis(seg_map).detach() fake_img = interp(fake_img) fake_img = activation_tanh(fake_img) fake_img_from_labels = Gis(make_one_hot(real_segmentation, args.dataset, args.gpu_ids).float()).detach() fake_img_from_labels = interp(fake_img_from_labels) fake_img_from_labels = activation_tanh(fake_img_from_labels) fake_label_regenerated = Gsi(fake_img_from_labels).detach() fake_label_regenerated = interp(fake_label_regenerated) fake_label_regenerated = activation_softmax(fake_label_regenerated) prediction = seg_map.data.max(1)[1].squeeze_(1).cpu().numpy() ### To convert from 22 --> 1 channel fake_regenerated_label = fake_label_regenerated.data.max(1)[1].squeeze_(1).cpu().numpy() fake_img = fake_img.cpu() fake_img_from_labels = fake_img_from_labels.cpu() ### Now i am going to revert back the transformation on these images if args.dataset == 'voc2012' or args.dataset == 'cityscapes': trans_mean = [0.5, 0.5, 0.5] trans_std = [0.5, 0.5, 0.5] for k in range(3): fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k]) fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k]) elif args.dataset == 'acdc': trans_mean = [0.5] trans_std = [0.5] for k in range(1): fake_img[:, k, :, :] = ((fake_img[:, k, :, :] * trans_std[k]) + trans_mean[k]) fake_img_from_labels[:, k, :, :] = ((fake_img_from_labels[:, k, :, :] * trans_std[k]) + trans_mean[k]) for j in range(prediction.shape[0]): new_img = prediction[j] ### Taking a particular image from the batch new_img = utils.colorize_mask(new_img, args.dataset) ### So as to convert it back to a paletted image regen_label = fake_regenerated_label[j] regen_label = utils.colorize_mask(regen_label, args.dataset) ### Now the new_img is PIL.Image new_img.save(os.path.join(args.validation_dir+'/unsupervised/generated_labels/'+image_name[j]+'.png')) regen_label.save(os.path.join(args.validation_dir+'/unsupervised/regenerated_labels/'+image_name[j]+'.png')) torchvision.utils.save_image(fake_img[j], os.path.join(args.validation_dir+'/unsupervised/regenerated_image/'+image_name[j]+'.jpg')) torchvision.utils.save_image(fake_img_from_labels[j], os.path.join(args.validation_dir+'/unsupervised/image_from_labels/'+image_name[j]+'.jpg')) print('Epoch-', str(i+1), ' Done!') print('The iou of the resulting segment maps: ', str(best_iou))
args.checkpoint_path = args.checkpoint_dir + args.identifier args.results_path = args.results_dir + args.identifier args.gpu_ids = [] for i in range(torch.cuda.device_count()): args.gpu_ids.append(i) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # - if one_direction: G = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral=args.spectral) else: Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral=args.spectral) Gba = define_Gen(input_nc=3,
def __init__(self, args): if args.dataset == 'voc2012': self.n_channels = 21 elif args.dataset == 'cityscapes': self.n_channels = 20 elif args.dataset == 'acdc': self.n_channels = 4 # Define the network self.Gsi = define_Gen( input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) # for image to segmentation ### Now we put in the pretrained weights in Gsi ### These will only be used in the case of VOC and cityscapes if args.dataset != 'acdc': saved_state_dict = torch.load(pretrained_loc) new_params = self.Gsi.state_dict().copy() for name, param in new_params.items(): # print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) # print('copy {}'.format(name)) # self.Gsi.load_state_dict(new_params) utils.print_networks([self.Gsi], ['Gsi']) ###Defining an interpolation function so as to match the output of network to feature map size self.interp = nn.Upsample(size=(args.crop_height, args.crop_width), mode='bilinear', align_corners=True) self.interp_val = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True) self.CE = nn.CrossEntropyLoss() self.activation_softmax = nn.Softmax2d() self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(), lr=args.lr, betas=(0.9, 0.999)) ### writer for tensorboard self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised') self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset) self.args = args if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) try: ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir)) self.start_epoch = ckpt['epoch'] self.Gsi.load_state_dict(ckpt['Gsi']) self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer']) self.best_iou = ckpt['best_iou'] except: print(' [*] No checkpoint!') self.start_epoch = 0 self.best_iou = -100
def __init__(self, args): if args.dataset == 'voc2012': self.n_channels = 21 elif args.dataset == 'cityscapes': self.n_channels = 20 elif args.dataset == 'acdc': self.n_channels = 4 # Define the network ##################################################### # for segmentaion to image self.Gis = define_Gen(input_nc=self.n_channels, output_nc=3, ngf=args.ngf, netG='deeplab', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) # for image to segmentation self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) self.Di = define_Dis(input_nc=3, ndf=args.ndf, netD='pixel', n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) self.Ds = define_Dis( input_nc=self.n_channels, ndf=args.ndf, netD='pixel', n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) # for voc 2012, there are 21 classes self.old_Gis = define_Gen(input_nc=self.n_channels, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) self.old_Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) self.old_Di = define_Dis(input_nc=3, ndf=args.ndf, netD='pixel', n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) ### To put the pretrained weights in Gis and Gsi # if args.dataset != 'acdc': # saved_state_dict = torch.load(pretrained_loc) # new_params_Gsi = self.Gsi.state_dict().copy() # # new_params_Gis = self.Gis.state_dict().copy() # for name, param in new_params_Gsi.items(): # # print(name) # if name in saved_state_dict and param.size() == saved_state_dict[name].size(): # new_params_Gsi[name].copy_(saved_state_dict[name]) # # print('copy {}'.format(name)) # self.Gsi.load_state_dict(new_params_Gsi) # for name, param in new_params_Gis.items(): # # print(name) # if name in saved_state_dict and param.size() == saved_state_dict[name].size(): # new_params_Gis[name].copy_(saved_state_dict[name]) # # print('copy {}'.format(name)) # # self.Gis.load_state_dict(new_params_Gis) ### This is just so as to get pretrained methods for the case of Gis if args.dataset == 'voc2012': try: ckpt_for_Arnab_loss = utils.load_checkpoint( './ckpt_for_Arnab_loss.ckpt') self.old_Gis.load_state_dict(ckpt_for_Arnab_loss['Gis']) self.old_Gsi.load_state_dict(ckpt_for_Arnab_loss['Gsi']) except: print( '**There is an error in loading the ckpt_for_Arnab_loss**') utils.print_networks([self.Gsi], ['Gsi']) utils.print_networks([self.Gis, self.Gsi, self.Di, self.Ds], ['Gis', 'Gsi', 'Di', 'Ds']) self.args = args ### interpolation self.interp = nn.Upsample((args.crop_height, args.crop_width), mode='bilinear', align_corners=True) self.MSE = nn.MSELoss() self.L1 = nn.L1Loss() self.CE = nn.CrossEntropyLoss() self.activation_softmax = nn.Softmax2d() self.activation_tanh = nn.Tanh() self.activation_sigmoid = nn.Sigmoid() ### Tensorboard writer self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper') self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset) ### For adding gaussian noise self.gauss_noise = utils.GaussianNoise(sigma=0.2) # Optimizers ##################################################### self.g_optimizer = torch.optim.Adam(itertools.chain( self.Gis.parameters(), self.Gsi.parameters()), lr=args.lr, betas=(0.5, 0.999)) self.d_optimizer = torch.optim.Adam(itertools.chain( self.Di.parameters(), self.Ds.parameters()), lr=args.lr, betas=(0.5, 0.999)) self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) # Try loading checkpoint ##################################################### if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) try: ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) self.start_epoch = ckpt['epoch'] self.Di.load_state_dict(ckpt['Di']) self.Ds.load_state_dict(ckpt['Ds']) self.Gis.load_state_dict(ckpt['Gis']) self.Gsi.load_state_dict(ckpt['Gsi']) self.d_optimizer.load_state_dict(ckpt['d_optimizer']) self.g_optimizer.load_state_dict(ckpt['g_optimizer']) self.best_iou = ckpt['best_iou'] except: print(' [*] No checkpoint!') self.start_epoch = 0 self.best_iou = -100
def __init__(self, args): if args.dataset == 'voc2012': self.n_channels = 21 elif args.dataset == 'cityscapes': self.n_channels = 20 elif args.dataset == 'acdc': self.n_channels = 4 # Define the network ##################################################### # for segmentaion to image self.Gis = define_Gen(input_nc=self.n_channels, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) # for image to segmentation self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) self.Di = define_Dis(input_nc=3, ndf=args.ndf, netD='pixel', n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) self.Ds = define_Dis( input_nc=1, ndf=args.ndf, netD='pixel', n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids) # for voc 2012, there are 21 classes utils.print_networks([self.Gis, self.Gsi, self.Di, self.Ds], ['Gis', 'Gsi', 'Di', 'Ds']) self.args = args self.MSE = nn.MSELoss() self.L1 = nn.L1Loss() self.CE = nn.CrossEntropyLoss() self.activation_softmax = nn.Softmax2d() ### Tensorboard writer self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper') self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset) ### For adding gaussian noise self.gauss_noise = utils.GaussianNoise(sigma=0.2) # Optimizers ##################################################### self.g_optimizer = torch.optim.Adam(itertools.chain( self.Gis.parameters(), self.Gsi.parameters()), lr=args.lr, betas=(0.5, 0.999)) self.d_optimizer = torch.optim.Adam(itertools.chain( self.Di.parameters(), self.Ds.parameters()), lr=args.lr, betas=(0.5, 0.999)) self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step) # Try loading checkpoint ##################################################### if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) try: ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) self.start_epoch = ckpt['epoch'] self.Di.load_state_dict(ckpt['Di']) self.Ds.load_state_dict(ckpt['Ds']) self.Gis.load_state_dict(ckpt['Gis']) self.Gsi.load_state_dict(ckpt['Gsi']) self.d_optimizer.load_state_dict(ckpt['d_optimizer']) self.g_optimizer.load_state_dict(ckpt['g_optimizer']) self.best_iou = ckpt['best_iou'] except: print(' [*] No checkpoint!') self.start_epoch = 0 self.best_iou = -100
def test(args): ### For selecting the number of channels if args.dataset == 'voc2012': n_channels = 21 elif args.dataset == 'cityscapes': n_channels = 20 elif args.dataset == 'acdc': n_channels = 4 transform = get_transformation((args.crop_height, args.crop_width), resize=True, dataset=args.dataset) ## let the choice of dataset configurable if args.dataset == 'voc2012': test_set = VOCDataset(root_path=root, name='test', ratio=0.5, transformation=transform, augmentation=None) elif args.dataset == 'cityscapes': test_set = CityscapesDataset(root_path=root_cityscapes, name='test', ratio=0.5, transformation=transform, augmentation=None) elif args.dataset == 'acdc': test_set = ACDCDataset(root_path=root_acdc, name='test', ratio=0.5, transformation=transform, augmentation=None) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) Gsi = define_Gen(input_nc=3, output_nc=n_channels, ngf=args.ngf, netG='resnet_9blocks_softmax', norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) ### activation_softmax activation_softmax = nn.Softmax2d() if (args.model == 'supervised_model'): ### loading the checkpoint try: ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir)) Gsi.load_state_dict(ckpt['Gsi']) except: print(' [*] No checkpoint!') ### run Gsi.eval() for i, (image_test, image_name) in enumerate(test_loader): image_test = utils.cuda(image_test, args.gpu_ids) seg_map = Gsi(image_test) seg_map = activation_softmax(seg_map) prediction = seg_map.data.max(1)[1].squeeze_(1).squeeze_( 0).cpu().numpy() ### To convert from 22 --> 1 channel for j in range(prediction.shape[0]): new_img = prediction[ j] ### Taking a particular image from the batch new_img = utils.colorize_mask( new_img, args.dataset ) ### So as to convert it back to a paletted image ### Now the new_img is PIL.Image new_img.save( os.path.join(args.results_dir + '/supervised/' + image_name[j] + '.png')) print('Epoch-', str(i + 1), ' Done!') elif (args.model == 'semisupervised_cycleGAN'): ### loading the checkpoint try: ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' % (args.checkpoint_dir)) Gsi.load_state_dict(ckpt['Gsi']) except: print(' [*] No checkpoint!') ### run Gsi.eval() for i, (image_test, image_name) in enumerate(test_loader): image_test = utils.cuda(image_test, args.gpu_ids) seg_map = Gsi(image_test) seg_map = activation_softmax(seg_map) prediction = seg_map.data.max(1)[1].squeeze_(1).squeeze_( 0).cpu().numpy() ### To convert from 22 --> 1 channel for j in range(prediction.shape[0]): new_img = prediction[ j] ### Taking a particular image from the batch new_img = utils.colorize_mask( new_img, args.dataset ) ### So as to convert it back to a paletted image ### Now the new_img is PIL.Image new_img.save( os.path.join(args.results_dir + '/unsupervised/' + image_name[j] + '.png')) print('Epoch-', str(i + 1), ' Done!')
def gen_samples(args, epoch): transform = transforms.Compose([ transforms.Resize((args.crop_height, args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) if args.specific_samples: dataset_dirs = utils.get_sampledata_link(args.dataset_dir) a_test_data = dsets.ImageFolder(dataset_dirs['sampleA'], transform=transform) b_test_data = dsets.ImageFolder(dataset_dirs['sampleB'], transform=transform) else: dataset_dirs = utils.get_testdata_link(args.dataset_dir) a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform) b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform) a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=False, num_workers=4) b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=False, num_workers=4) Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral=args.spectral) Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral=args.spectral) utils.print_networks([Gab, Gba], ['Gab', 'Gba']) ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_path)) Gab.load_state_dict(ckpt['Gab']) Gba.load_state_dict(ckpt['Gba']) ab_ssims = [] ba_ssims = [] a_names = [] b_names = [] """ run """ for i, (a_real_test, b_real_test) in enumerate(zip(a_test_loader, b_test_loader)): a_fnames = a_test_loader.dataset.samples[i * 16:i * 16 + 16] b_fnames = b_test_loader.dataset.samples[i * 16:i * 16 + 16] a_real_test = Variable(a_real_test[0], requires_grad=True) b_real_test = Variable(b_real_test[0], requires_grad=True) a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test]) gray = kornia.color.RgbToGrayscale() m = kornia.losses.SSIM(11, 'mean') Gab.eval() Gba.eval() with torch.no_grad(): a_fake_test = Gab(b_real_test) b_fake_test = Gba(a_real_test) a_recon_test = Gab(b_fake_test) b_recon_test = Gba(a_fake_test) # Calculate ssim loss b = a_real_test.size(0) for j in range(min(args.batch_size, b)): a_real = a_real_test[j].unsqueeze(0) b_fake = b_fake_test[j].unsqueeze(0) a_recon = a_recon_test[j].unsqueeze(0) b_real = b_real_test[j].unsqueeze(0) a_fake = a_fake_test[j].unsqueeze(0) b_recon = b_recon_test[j].unsqueeze(0) ba_ssim = m(gray((a_real + 1) / 2.0), gray((b_fake + 1) / 2.0)) ab_ssim = m(gray((b_real + 1) / 2.0), gray((a_fake + 1) / 2.0)) ab_ssims.append(ab_ssim.item()) ba_ssims.append(ba_ssim.item()) pic = (torch.cat([ a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test ], dim=0).data + 1) / 2.0 path = args.results_path + '/b_fake/' image_path = path + a_fnames[j][0].split('/')[-1] if not os.path.isdir(path): os.makedirs(path) torchvision.utils.save_image((b_fake.data + 1) / 2.0, image_path) a_names.append(a_fnames[j][0].split('/')[-1]) path = args.results_path + '/a_recon/' image_path = path + a_fnames[j][0].split('/')[-1] if not os.path.isdir(path): os.makedirs(path) torchvision.utils.save_image((a_recon.data + 1) / 2.0, image_path) path = args.results_path + '/a_fake/' image_path = path + b_fnames[j][0].split('/')[-1] if not os.path.isdir(path): os.makedirs(path) torchvision.utils.save_image((a_fake.data + 1) / 2.0, image_path) b_names.append(b_fnames[j][0].split('/')[-1]) path = args.results_path + '/b_recon/' image_path = path + b_fnames[j][0].split('/')[-1] if not os.path.isdir(path): os.makedirs(path) torchvision.utils.save_image((b_recon.data + 1) / 2.0, image_path) df1 = pd.DataFrame(list(zip(a_names, ba_ssims)), columns=['Name', 'SSIM_A_to_B']) df2 = pd.DataFrame(list(zip(b_names, ab_ssims)), columns=['Name', 'SSIM_B_to_A']) df1.to_csv(args.results_path + '/b_fake/' + 'SSIM_A_to_B.csv') df2.to_csv(args.results_path + '/a_fake/' + 'SSIM_B_to_A.csv')
def test(args): transform = transforms.Compose([ transforms.Resize((args.crop_height, args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset_dirs = utils.get_testdata_link(args.dataset_dir) a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform) b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform) a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4) Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) utils.print_networks([Gab, Gba], ['Gab', 'Gba']) try: ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir)) if len(args.gpu_ids) > 0: print(args.gpu_ids) Gab.load_state_dict(ckpt['Gab']) Gba.load_state_dict(ckpt['Gba']) else: new_state_dict = {} for k, v in ckpt['Gab'].items(): name = k.replace('module.', '') new_state_dict[name] = v Gab.load_state_dict(new_state_dict) new_state_dict = {} for k, v in ckpt['Gba'].items(): name = k.replace('module.', '') new_state_dict[name] = v Gba.load_state_dict(new_state_dict) except: print(' [*] No checkpoint!') """ run """ a_it = iter(a_test_loader) b_it = iter(b_test_loader) for i in range(len(a_test_loader)): try: a_real_test = Variable(next(a_it)[0], requires_grad=True) except: a_it = iter(a_test_loader) try: b_real_test = Variable(next(b_it)[0], requires_grad=True) except: b_it = iter(b_test_loader) print(a_real_test) return a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test]) print(a_real_test.shape) Gab.eval() Gba.eval() with torch.no_grad(): a_fake_test = Gab(b_real_test) b_fake_test = Gba(a_real_test) a_recon_test = Gab(b_fake_test) b_recon_test = Gba(a_fake_test) pic = (torch.cat([ a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test ], dim=0).data + 1) / 2.0 if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) path = str.format("{}/sample-{}.jpg", args.results_dir, i) torchvision.utils.save_image(pic, path, nrow=3)