def eval(args, model, data_loader): device = get_device(args) data_loader = tqdm(data_loader) model.eval() model = model.to(device) fake_dir = osp.join(args.save, 'fake_result') real_dir = osp.join(args.save, 'real_result') create_dir(real_dir) create_dir(fake_dir) for i, sample in enumerate(data_loader): imgs = sample['image'].to(device) maps = sample['map'].to(device) im_name = sample['im_name'] with torch.no_grad(): fakes = model(imgs) batch_size = imgs.size(0) for b in range(batch_size): file_name = osp.split(im_name[b])[-1].split('.')[0] real_file = osp.join(real_dir, f'{file_name}.tif') fake_file = osp.join(fake_dir, f'{file_name}.tif') from_std_tensor_save_image(filename=real_file, data=maps[b].cpu()) from_std_tensor_save_image(filename=fake_file, data=fakes[b].cpu()) pass pass fid = fid_score(real_path=real_dir, fake_path=fake_dir, gpu=str(args.gpu)) print(f'===> fid score:{fid:.4f}') return fid
def get_E(args): norm_layer = get_norm_layer(norm_type=args.norm) netE = Encoder(args.output_nc, args.feat_num, args.ngf, args.n_downsample_global, norm_layer) netE.apply(weights_init) print(netE) netE = nn.DataParallel(netE).to(get_device(args)) return netE
def get_G(args, input_nc=None): if input_nc is None: input_nc = args.label_nc if args.use_instance: input_nc += 1 if args.feat_num > 0: input_nc += args.feat_num norm_layer = get_norm_layer(norm_type=args.norm) if args.netG == 'global': netG = GlobalGenerator(input_nc, args.output_nc, args.ngf, args.n_downsample_global, args.n_blocks_global, norm_layer) elif args.netG == 'local': netG = LocalEnhancer(input_nc, args.output_nc, args.ngf, args.n_downsample_global, args.n_blocks_global, args.n_local_enhancers, args.n_blocks_local, norm_layer) elif args.netG == 'encoder': netG = Encoder(input_nc, args.output_nc, args.ngf, args.n_downsample_global, norm_layer) else: raise ('generator not implemented!') print(netG) netG.apply(weights_init) netG = nn.DataParallel(netG).to(get_device(args)) return netG
def __init__(self, args): super(VGGLoss, self).__init__() assert args.vgg_type in ('vgg16', 'vgg19') vgg = Vgg16 if args.vgg_type == 'vgg16' else Vgg19 self.vgg = nn.DataParallel(vgg()).to(get_device(args)) # self.vgg = vgg().to(get_device(args)) self.vgg.eval() self.criterion = nn.DataParallel(nn.L1Loss()) # self.criterion = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] print( f'===> {self.__class__.__name__} | vgg:{args.vgg_type} | loss:{self.criterion}' )
def get_D(args, input_nc=None): if input_nc is None: #input_nc = args.label_nc + args.output_nc input_nc = args.input_nc + args.output_nc if args.use_instance: input_nc += 1 norm_layer = get_norm_layer(norm_type=args.norm) netD = MultiscaleDiscriminator(input_nc, args.ndf, args.n_layers_D, norm_layer, args.use_lsgan, args.num_D, args.use_ganFeat_loss) print(netD) netD.apply(weights_init) netD = nn.DataParallel(netD).to(get_device(args)) return netD
def __init__(self, args): self.content_layer = args.content_layer device = get_device(args) self.vgg = nn.DataParallel(Vgg16()).to(device) self.vgg.eval() self.mse = nn.DataParallel(nn.MSELoss()) self.mse_sum = nn.DataParallel(nn.MSELoss(reduction='sum')) style_image = Image.open(args.style_image).convert('RGB') _, transform = get_transform(args) style_image = transform(style_image).repeat(args.batch_size, 1, 1, 1).to(device) with torch.no_grad(): self.style_features = self.vgg(style_image) self.style_gram = [gram(fmap) for fmap in self.style_features] pass
model_saver = ModelSaver(save_path=args.save, name_list=[ f'{style_name}', f'{style_name}_{args.optimizer}', f'{style_name}_{args.scheduler}' ]) criterion = PerceptualLoss(args) model = ImageTransformNet() model = model_accelerate(args, model) model_saver.load(f'{style_name}', model=model) optimizer = Adam(model.parameters(), lr=args.lr) model_saver.load(f'{style_name}_{args.optimizer}', model=optimizer) epoch_now = len(logger.get_data(key='loss')) device = get_device(args) #### data_loader = get_dataloader_from_dir(args) data_loader = tqdm(data_loader) model.eval() with torch.no_grad(): # counter = 10 for i, (imgs, path) in enumerate(data_loader): imgs = imgs.to(device) # counter -= 1 # if counter < 0: # break y_hat = model(imgs) for index in range(y_hat.size(0)): # data = torch.cat([y_hat[index], imgs[index]], dim=2)
def train(args, get_dataloader_func=get_pix2pix_maps_dataloader): logger = Logger(save_path=args.save, json_name='img2map') model_saver = ModelSaver(save_path=args.save, name_list=[ 'G', 'D', 'E', 'G_optimizer', 'D_optimizer', 'E_optimizer', 'G_scheduler', 'D_scheduler', 'E_scheduler' ]) visualizer = Visualizer( keys=['image', 'encode_feature', 'fake', 'label', 'instance']) sw = SummaryWriter(args.tensorboard_path) G = get_G(args) D = get_D(args) model_saver.load('G', G) model_saver.load('D', D) # fid = get_fid(args) # logger.log(key='FID', data=fid) # logger.save_log() # logger.visualize() G_optimizer = Adam(G.parameters(), lr=args.G_lr, betas=(args.beta1, 0.999)) D_optimizer = Adam(D.parameters(), lr=args.D_lr, betas=(args.beta1, 0.999)) model_saver.load('G_optimizer', G_optimizer) model_saver.load('D_optimizer', D_optimizer) G_scheduler = get_hinge_scheduler(args, G_optimizer) D_scheduler = get_hinge_scheduler(args, D_optimizer) model_saver.load('G_scheduler', G_scheduler) model_saver.load('D_scheduler', D_scheduler) device = get_device(args) GANLoss = get_GANLoss(args) if args.use_ganFeat_loss: DFLoss = get_DFLoss(args) if args.use_vgg_loss: VGGLoss = get_VGGLoss(args) if args.use_low_level_loss: LLLoss = get_low_level_loss(args) epoch_now = len(logger.get_data('G_loss')) for epoch in range(epoch_now, args.epochs): G_loss_list = [] D_loss_list = [] data_loader = get_dataloader_func(args, train=True) data_loader = tqdm(data_loader) for step, sample in enumerate(data_loader): imgs = sample['image'].to(device) maps = sample['map'].to(device) # print(smasks.shape) # train the Discriminator D_optimizer.zero_grad() reals_maps = torch.cat([imgs.float(), maps.float()], dim=1) fakes = G(imgs).detach() fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1) D_real_outs = D(reals_maps) D_real_loss = GANLoss(D_real_outs, True) D_fake_outs = D(fakes_maps) D_fake_loss = GANLoss(D_fake_outs, False) D_loss = 0.5 * (D_real_loss + D_fake_loss) D_loss = D_loss.mean() D_loss.backward() D_loss = D_loss.item() D_optimizer.step() # train generator and encoder G_optimizer.zero_grad() fakes = G(imgs) fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1) D_fake_outs = D(fakes_maps) gan_loss = GANLoss(D_fake_outs, True) G_loss = 0 G_loss += gan_loss gan_loss = gan_loss.mean().item() if args.use_vgg_loss: vgg_loss = VGGLoss(fakes, imgs) G_loss += args.lambda_feat * vgg_loss vgg_loss = vgg_loss.mean().item() else: vgg_loss = 0. if args.use_ganFeat_loss: df_loss = DFLoss(D_fake_outs, D_real_outs) G_loss += args.lambda_feat * df_loss df_loss = df_loss.mean().item() else: df_loss = 0. if args.use_low_level_loss: ll_loss = LLLoss(fakes, maps) G_loss += args.lambda_feat * ll_loss ll_loss = ll_loss.mean().item() else: ll_loss = 0. G_loss = G_loss.mean() G_loss.backward() G_loss = G_loss.item() G_optimizer.step() data_loader.write( f'Epochs:{epoch} | Dloss:{D_loss:.6f} | Gloss:{G_loss:.6f}' f'| GANloss:{gan_loss:.6f} | VGGloss:{vgg_loss:.6f} | DFloss:{df_loss:.6f} ' f'| LLloss:{ll_loss:.6f} | lr:{get_lr(G_optimizer):.8f}') G_loss_list.append(G_loss) D_loss_list.append(D_loss) # display if args.display and step % args.display == 0: visualizer.display(transforms.ToPILImage()(imgs[0].cpu()), 'image') visualizer.display(transforms.ToPILImage()(fakes[0].cpu()), 'fake') visualizer.display(transforms.ToPILImage()(maps[0].cpu()), 'label') # tensorboard log if args.tensorboard_log and step % args.tensorboard_log == 0: total_steps = epoch * len(data_loader) + step sw.add_scalar('Loss/G', G_loss, total_steps) sw.add_scalar('Loss/D', D_loss, total_steps) sw.add_scalar('Loss/gan', gan_loss, total_steps) sw.add_scalar('Loss/vgg', vgg_loss, total_steps) sw.add_scalar('Loss/df', df_loss, total_steps) sw.add_scalar('Loss/ll', ll_loss, total_steps) sw.add_scalar('LR/G', get_lr(G_optimizer), total_steps) sw.add_scalar('LR/D', get_lr(D_optimizer), total_steps) sw.add_image('img/real', imgs[0].cpu(), step) sw.add_image('img/fake', fakes[0].cpu(), step) sw.add_image('visual/label', maps[0].cpu(), step) D_scheduler.step(epoch) G_scheduler.step(epoch) if epoch % 10 == 0 or epoch == args.epochs: fid = eval(args, model=G, data_loader=get_dataloader_func(args, train=False)) logger.log(key='FID', data=fid) if fid > logger.get_max(key='FID'): model_saver.save(f'G_{fid:.4f}', G) model_saver.save(f'D_{fid:.4f}', D) logger.log(key='D_loss', data=sum(D_loss_list) / float(len(D_loss_list))) logger.log(key='G_loss', data=sum(G_loss_list) / float(len(G_loss_list))) logger.save_log() logger.visualize() model_saver.save('G', G) model_saver.save('D', D) model_saver.save('G_optimizer', G_optimizer) model_saver.save('D_optimizer', D_optimizer) model_saver.save('G_scheduler', G_scheduler) model_saver.save('D_scheduler', D_scheduler)
def eval_fidiou(args, model_G, model_seg, data_loader): device = get_device(args) data_loader = tqdm(data_loader) model_G.eval() model_seg.eval() model_G = model_G.to(device) model_seg = model_seg.to(device) label_preds = [] label_targets = [] real_seg_dir = osp.join(args.save, 'real_seg') real_dir = osp.join(args.save, 'real_result') A_dir = osp.join(args.save, 'real_source') seg_dir = osp.join(args.save, 'seg_result') fake_dir = osp.join(args.save, 'fake_result') create_dir(real_dir) create_dir(real_seg_dir) create_dir(A_dir) create_dir(seg_dir) create_dir(fake_dir) for i, sample in enumerate(data_loader): inputs, labels = sample['A_seg'], sample['seg'].squeeze(dim=1) inputs = inputs.cuda() if args.gpu else inputs labels = labels.cuda() if args.gpu else labels imgs = sample['A'].to(device) maps = sample['B'].to(device) outputs, feature_map = model_seg(inputs) bs, n_class, h, w = outputs.shape outs = outputs.data.cpu().numpy() pred = outs.transpose(0, 2, 3, 1).reshape(-1, n_class).argmax(axis=1).reshape( bs, h, w) target = labels.cpu().numpy().reshape(bs, h, w) label_preds.append(pred) label_targets.append(target) # seg_ret = pred2gray(outputs).unsqueeze(1).type(torch.FloatTensor).to(device) # bs*1*h*w feature_map = feature_map.detach() imgs_plus = torch.cat((imgs, feature_map), 1) fakes = model_G(imgs_plus).detach() batch_size = inputs.size(0) im_name = sample['A_paths'] for b in range(batch_size): file_name = osp.split(im_name[b])[0].split( os.sep)[-2] + '_' + osp.split(im_name[b])[0].split( os.sep)[-1] + '_' + osp.split(im_name[b])[-1].split('.')[0] real_file = osp.join(real_dir, f'{file_name}.tif') real_seg_file = osp.join(real_seg_dir, f'{file_name}.tif') A_file = osp.join(A_dir, f'{file_name}.tif') seg_file = osp.join(seg_dir, f'{file_name}.tif') fake_file = osp.join(fake_dir, f'{file_name}.tif') from_std_tensor_save_image(filename=real_file, data=sample['B'][b].cpu()) from_std_tensor_save_image(filename=A_file, data=sample['A'][b].cpu()) from_std_tensor_save_image(filename=fake_file, data=fakes[b].cpu()) tmpimg = sample['seg'][b].data.cpu().numpy() tmpimg = gray2rgb(tmpimg) tmpimg = Image.fromarray(tmpimg) tmpimg.save(fp=real_seg_file) tmpimg = gray2rgb(pred[b]) tmpimg = Image.fromarray(tmpimg) tmpimg.save(fp=seg_file) fid = fid_score(real_path=real_dir, fake_path=fake_dir, gpu=str(args.gpu)) print(f'===> fid score:{fid:.4f}') iou = None from src.pix2pixHD.eval_iou import label_accuracy_score _, _, iou, _, _ = label_accuracy_score(label_targets, label_preds, n_class) print(f'===> iou score:{iou:.4f}') model_seg.train() model_G.train() return fid, iou
def train(args, get_dataloader_func=get_pix2pix_maps_dataloader): with open(os.path.join(args.save, 'args.json'), 'w') as f: json.dump(vars(args), f) logger = Logger(save_path=args.save, json_name='img2map_seg') epoch_now = len(logger.get_data('FOCAL_loss')) model_saver = ModelSaver( save_path=args.save, name_list=[ 'G', 'D', 'G_optimizer', 'D_optimizer', 'G_scheduler', 'D_scheduler', 'DLV3P', "DLV3P_global_optimizer", "DLV3P_backbone_optimizer", "DLV3P_global_scheduler", "DLV3P_backbone_scheduler", 'best_G', 'best_D', 'best_G_optimizer', 'best_D_optimizer', 'best_G_scheduler', 'best_D_scheduler', 'best_DLV3P', "best_DLV3P_global_optimizer", "best_DLV3P_backbone_optimizer", "best_DLV3P_global_scheduler", "best_DLV3P_backbone_scheduler" ]) sw = SummaryWriter(args.tensorboard_path) G = get_G(args, input_nc=3 + 256) # 3+256,256为分割网络输出featuremap的通道数 D = get_D(args) model_saver.load('G', G) model_saver.load('D', D) cfg = Configuration() cfg.MODEL_NUM_CLASSES = args.label_nc DLV3P = deeplabv3plus(cfg) if args.gpu: # DLV3P=nn.DataParallel(DLV3P) DLV3P = DLV3P.cuda() model_saver.load('DLV3P', DLV3P) G_optimizer = Adam(G.parameters(), lr=args.G_lr, betas=(args.beta1, 0.999)) D_optimizer = Adam(D.parameters(), lr=args.D_lr, betas=(args.beta1, 0.999)) seg_global_params, seg_backbone_params = DLV3P.get_paras() DLV3P_global_optimizer = torch.optim.Adam([{ 'params': seg_global_params, 'initial_lr': args.seg_lr_global }], lr=args.seg_lr_global, betas=(args.beta1, 0.999)) DLV3P_backbone_optimizer = torch.optim.Adam( [{ 'params': seg_backbone_params, 'initial_lr': args.seg_lr_backbone }], lr=args.seg_lr_backbone, betas=(args.beta1, 0.999)) model_saver.load('G_optimizer', G_optimizer) model_saver.load('D_optimizer', D_optimizer) model_saver.load('DLV3P_global_optimizer', DLV3P_global_optimizer) model_saver.load('DLV3P_backbone_optimizer', DLV3P_backbone_optimizer) G_scheduler = get_hinge_scheduler(args, G_optimizer) D_scheduler = get_hinge_scheduler(args, D_optimizer) DLV3P_global_scheduler = torch.optim.lr_scheduler.LambdaLR( DLV3P_global_optimizer, lr_lambda=lambda epoch: (1 - epoch / args.epochs)**0.9, last_epoch=epoch_now) DLV3P_backbone_scheduler = torch.optim.lr_scheduler.LambdaLR( DLV3P_backbone_optimizer, lr_lambda=lambda epoch: (1 - epoch / args.epochs)**0.9, last_epoch=epoch_now) model_saver.load('G_scheduler', G_scheduler) model_saver.load('D_scheduler', D_scheduler) model_saver.load('DLV3P_global_scheduler', DLV3P_global_scheduler) model_saver.load('DLV3P_backbone_scheduler', DLV3P_backbone_scheduler) D_scheduler.step(epoch_now) # 调整lr便于finetrain G_scheduler.step(epoch_now) device = get_device(args) GANLoss = get_GANLoss(args) if args.use_ganFeat_loss: DFLoss = get_DFLoss(args) if args.use_vgg_loss: VGGLoss = get_VGGLoss(args) if args.use_low_level_loss: LLLoss = get_low_level_loss(args) # CE_loss=nn.CrossEntropyLoss(ignore_index=255) LVS_loss = lovasz_softmax data_loader_focal = get_dataloader_func(args, train=True) data_loader_focal = tqdm(data_loader_focal) alpha = label_nums(data_loader_focal, label_num=args.label_nc) # alpha = [1,1,1,1,1] tmp_min = min(alpha) assert tmp_min > 0 for i in range(len(alpha)): alpha[i] = tmp_min / alpha[i] if args.focal_alpha_revise: assert len(args.focal_alpha_revise) == len(alpha) for i in range(len(alpha)): alpha[i] = alpha[i] * args.focal_alpha_revise[i] print(alpha) FOCAL_loss = FocalLoss(gamma=2, alpha=alpha) if epoch_now == args.epochs: print('get final models') iou = eval_fidiou(args, model_G=G, model_seg=DLV3P, data_loader=get_pix2pix_maps_dataloader(args, train=False)) logger.log(key='iou', data=iou) # if iou < logger.get_max(key='FID'): # model_saver.save(f'DLV3P_{iou:.4f}', DLV3P) sw.add_scalar('eval/iou', iou, epoch_now) for epoch in range(epoch_now, args.epochs): G_loss_list = [] D_loss_list = [] # CE_loss_list = [] LVS_loss_list = [] FOCAL_loss_list = [] data_loader = get_dataloader_func( args, train=True, flag=(2 if args._usefakelen else 0)) # flag=2:使用虚假的数据长度 data_loader = tqdm(data_loader) for step, sample in enumerate(data_loader): # 先训练deeplabv3+ imgs_seg = sample['A_seg'].to( device) # (shape: (batch_size, 3, img_h, img_w)) label_imgs = sample['seg'].type(torch.LongTensor).to( device) # (shape: (batch_size, img_h, img_w)) # print(label_imgs) # print(label_imgs.max()) # print(label_imgs.min()) # imgs_show=sample['A'].to(device) # (shape: (batch_size, 3, img_h, img_w)) # maps_show= sample['B'].to(device) # (shape: (batch_size, 3, img_h, img_w)) outputs, feature_map = DLV3P( imgs_seg) # (shape: (batch_size, num_classes, img_h, img_w)) # feature_map=feature_map.detach() # compute the loss: # ce_loss = CE_loss(outputs, label_imgs) # ce_loss_value = ce_loss.data.cpu().numpy() soft_outputs = torch.nn.functional.softmax(outputs, dim=1) lvs_loss = LVS_loss(soft_outputs, label_imgs, ignore=255) lvs_loss_value = lvs_loss.data.cpu().numpy() focal_loss = FOCAL_loss(outputs, label_imgs) focal_loss_value = focal_loss.data.cpu().numpy() seg_loss = (focal_loss + lvs_loss) * 0.5 # optimization step: # DLV3P_global_optimizer.zero_grad() # (reset gradients) # DLV3P_backbone_optimizer.zero_grad() # seg_loss.backward() # (compute gradients) # DLV3P_global_optimizer.step() # (perform optimization step) # DLV3P_backbone_optimizer.step() # 然后训练GAN imgs = sample['A'].to(device) maps = sample['B'].to(device) # feature_map=feature_map.detach() imgs_plus = torch.cat((imgs, feature_map), 1) # bs*(3+256)*h*w # train the Discriminator D_optimizer.zero_grad() reals_maps = torch.cat([imgs.float(), maps.float()], dim=1) fakes = G(imgs_plus).detach() fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1) D_real_outs = D(reals_maps) D_real_loss = GANLoss(D_real_outs, True) D_fake_outs = D(fakes_maps) D_fake_loss = GANLoss(D_fake_outs, False) D_loss = 0.5 * (D_real_loss + D_fake_loss) D_loss = D_loss.mean() D_loss.backward() D_loss = D_loss.item() D_optimizer.step() # train generator and encoder # G_optimizer.zero_grad() fakes = G(imgs_plus) fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1) D_fake_outs = D(fakes_maps) gan_loss = GANLoss(D_fake_outs, True) G_loss = 0 G_loss += gan_loss gan_loss = gan_loss.mean().item() if args.use_vgg_loss: vgg_loss = VGGLoss(fakes, imgs) G_loss += args.lambda_feat * vgg_loss vgg_loss = vgg_loss.mean().item() else: vgg_loss = 0. if args.use_ganFeat_loss: df_loss = DFLoss(D_fake_outs, D_real_outs) G_loss += args.lambda_feat * df_loss df_loss = df_loss.mean().item() else: df_loss = 0. if args.use_low_level_loss: ll_loss = LLLoss(fakes, maps) G_loss += args.lambda_feat * ll_loss ll_loss = ll_loss.mean().item() else: ll_loss = 0. G_loss = G_loss.mean() G_seg_loss = args._1002arg_GANloss_alpha * G_loss + args._1002arg_segloss_alpha * seg_loss G_loss = G_loss.item() seg_loss = seg_loss.item() G_optimizer.zero_grad() DLV3P_global_optimizer.zero_grad() # (reset gradients) DLV3P_backbone_optimizer.zero_grad() G_seg_loss.backward() G_optimizer.step() DLV3P_global_optimizer.step() # (perform optimization step) DLV3P_backbone_optimizer.step() data_loader.write( f'Epochs:{epoch} | Dloss:{D_loss:.6f} | Gloss:{G_loss:.6f}' f'| GANloss:{gan_loss:.6f} | VGGloss:{vgg_loss:.6f} | DFloss:{df_loss:.6f} ' f'| LLloss:{ll_loss:.6f} | lr_gan:{get_lr(G_optimizer):.8f}' f'| FOCAL_loss:{focal_loss_value:.6f}|LVS_loss:{lvs_loss_value:.6f} ' f'| lr_global:{get_lr(DLV3P_global_optimizer):.8f}| lr_backbone:{get_lr(DLV3P_backbone_optimizer):.8f}' ) G_loss_list.append(G_loss) D_loss_list.append(D_loss) # CE_loss_list.append(ce_loss_value) LVS_loss_list.append(lvs_loss_value) FOCAL_loss_list.append(focal_loss_value) # tensorboard log if args.tensorboard_log and step % args.tensorboard_log == 0: # defalut is 5 total_steps = epoch * len(data_loader) + step sw.add_scalar('Loss1/G', G_loss, total_steps) sw.add_scalar('Loss1/seg', seg_loss, total_steps) sw.add_scalar('Loss1/G_seg', G_seg_loss, total_steps) sw.add_scalar('Loss/G', G_loss, total_steps) sw.add_scalar('Loss/D', D_loss, total_steps) sw.add_scalar('Loss/gan', gan_loss, total_steps) sw.add_scalar('Loss/vgg', vgg_loss, total_steps) sw.add_scalar('Loss/df', df_loss, total_steps) sw.add_scalar('Loss/ll', ll_loss, total_steps) # sw.add_scalar('Loss/CE', ce_loss_value, total_steps) sw.add_scalar('Loss/LVS', lvs_loss_value, total_steps) sw.add_scalar('Loss/focal', focal_loss_value, total_steps) sw.add_scalar('LR/G', get_lr(G_optimizer), total_steps) sw.add_scalar('LR/D', get_lr(D_optimizer), total_steps) sw.add_scalar('LR/global_seg', get_lr(DLV3P_global_optimizer), total_steps) sw.add_scalar('LR/backbone_seg', get_lr(DLV3P_backbone_optimizer), total_steps) sw.add_image('img2/realA', tensor2im(imgs.data), total_steps, dataformats='HWC') sw.add_image('img2/fakeB', tensor2im(fakes.data), total_steps, dataformats='HWC') sw.add_image('img2/realB', tensor2im(maps.data), total_steps, dataformats='HWC') tmpsegmap = pred2gray(outputs) tmpsegmap = tmpsegmap[0].data.numpy() tmpsegmap = gray2rgb(tmpsegmap) sw.add_image('img2/fake_segB', tmpsegmap, total_steps, dataformats='HWC') tmpsegmap = label_imgs[0].data.cpu().numpy() tmpsegmap = gray2rgb(tmpsegmap) sw.add_image('img2/real_segB', tmpsegmap, total_steps, dataformats='HWC') D_scheduler.step(epoch) G_scheduler.step(epoch) DLV3P_global_scheduler.step() DLV3P_backbone_scheduler.step() if epoch % args._val_frequency == 0 or epoch == (args.epochs - 1): import copy args2 = copy.deepcopy(args) args2.batch_size = args.batch_size_eval fid, iou = eval_fidiou(args, model_G=G, model_seg=DLV3P, data_loader=get_pix2pix_maps_dataloader( args2, train=False)) if fid < logger.get_min(key='FID') == 0: # if epoch >= 80 : # model_saver.save('best_G', G) # model_saver.save('best_D', D) # model_saver.save('best_DLV3P', DLV3P) # model_saver.save('best_G_optimizer', G_optimizer) # model_saver.save('best_D_optimizer', D_optimizer) # model_saver.save('best_DLV3P_global_optimizer', DLV3P_global_optimizer) # model_saver.save('best_DLV3P_backbone_optimizer', DLV3P_backbone_optimizer) # model_saver.save('best_G_scheduler', G_scheduler) # model_saver.save('best_D_scheduler', D_scheduler) # model_saver.save('best_DLV3P_global_scheduler', DLV3P_global_scheduler) # model_saver.save('best_DLV3P_backbone_scheduler', DLV3P_backbone_scheduler) model_saver.save(f'best_G_{epoch}', G) model_saver.save(f'best_D_{epoch}', D) model_saver.save(f'best_DLV3P_{epoch}', DLV3P) model_saver.save(f'best_G_optimizer_{epoch}', G_optimizer) model_saver.save(f'best_D_optimizer_{epoch}', D_optimizer) model_saver.save(f'best_DLV3P_global_optimizer_{epoch}', DLV3P_global_optimizer) model_saver.save(f'best_DLV3P_backbone_optimizer_{epoch}', DLV3P_backbone_optimizer) model_saver.save(f'best_G_scheduler_{epoch}', G_scheduler) model_saver.save(f'best_D_scheduler_{epoch}', D_scheduler) model_saver.save(f'best_DLV3P_global_scheduler_{epoch}', DLV3P_global_scheduler) model_saver.save(f'best_DLV3P_backbone_scheduler_{epoch}', DLV3P_backbone_scheduler) logger.log(key='FID', data=fid) logger.log(key='iou', data=iou) sw.add_scalar('eval/fid', fid, epoch) sw.add_scalar('eval/iou', iou, epoch) logger.log(key='D_loss', data=sum(D_loss_list) / float(len(D_loss_list))) logger.log(key='G_loss', data=sum(G_loss_list) / float(len(G_loss_list))) # logger.log(key='CE_loss', data=sum(CE_loss_list) / float(len(CE_loss_list))) logger.log(key='LVS_loss', data=sum(LVS_loss_list) / float(len(LVS_loss_list))) logger.log(key='FOCAL_loss', data=sum(FOCAL_loss_list) / float(len(FOCAL_loss_list))) logger.save_log() # logger.visualize() model_saver.save('G', G) model_saver.save('D', D) model_saver.save('DLV3P', DLV3P) model_saver.save('G_optimizer', G_optimizer) model_saver.save('D_optimizer', D_optimizer) model_saver.save('DLV3P_global_optimizer', DLV3P_global_optimizer) model_saver.save('DLV3P_backbone_optimizer', DLV3P_backbone_optimizer) model_saver.save('G_scheduler', G_scheduler) model_saver.save('D_scheduler', D_scheduler) model_saver.save('DLV3P_global_scheduler', DLV3P_global_scheduler) model_saver.save('DLV3P_backbone_scheduler', DLV3P_backbone_scheduler)
def train(args, get_dataloader_func=get_cityscapes_dataloader): logger = Logger(save_path=args.save, json_name='seg2img') model_saver = ModelSaver(save_path=args.save, name_list=['G', 'D', 'E', 'G_optimizer', 'D_optimizer', 'E_optimizer', 'G_scheduler', 'D_scheduler', 'E_scheduler']) visualizer = Visualizer(keys=['image', 'encode_feature', 'fake', 'label', 'instance']) sw = SummaryWriter(args.tensorboard_path) G = get_G(args) D = get_D(args) E = get_E(args) model_saver.load('G', G) model_saver.load('D', D) model_saver.load('E', E) G_optimizer = Adam(G.parameters(), lr=args.G_lr, betas=(args.beta1, 0.999)) D_optimizer = Adam(D.parameters(), lr=args.D_lr, betas=(args.beta1, 0.999)) E_optimizer = Adam(E.parameters(), lr=args.E_lr, betas=(args.beta1, 0.999)) model_saver.load('G_optimizer', G_optimizer) model_saver.load('D_optimizer', D_optimizer) model_saver.load('E_optimizer', E_optimizer) G_scheduler = get_hinge_scheduler(args, G_optimizer) D_scheduler = get_hinge_scheduler(args, D_optimizer) E_scheduler = get_hinge_scheduler(args, E_optimizer) model_saver.load('G_scheduler', G_scheduler) model_saver.load('D_scheduler', D_scheduler) model_saver.load('E_scheduler', E_scheduler) device = get_device(args) GANLoss = get_GANLoss(args) if args.use_ganFeat_loss: DFLoss = get_DFLoss(args) if args.use_vgg_loss: VGGLoss = get_VGGLoss(args) epoch_now = len(logger.get_data('G_loss')) for epoch in range(epoch_now, args.epochs): G_loss_list = [] D_loss_list = [] data_loader = get_dataloader_func(args, train=True) data_loader = tqdm(data_loader) for step, sample in enumerate(data_loader): imgs = sample['image'].to(device) instances = sample['instance'].to(device) labels = sample['label'].to(device) smasks = sample['smask'].to(device) # print(smasks.shape) instances_edge = get_edges(instances) one_hot_labels = label_to_one_hot(smasks.long(), n_class=args.label_nc) # Encoder out encode_features = E(imgs, instances) # train the Discriminator D_optimizer.zero_grad() labels_instE_encodeF = torch.cat([one_hot_labels.float(), instances_edge.float(), encode_features.float()], dim=1) fakes = G(labels_instE_encodeF).detach() labels_instE_realimgs = torch.cat([one_hot_labels.float(), instances_edge.float(), imgs.float()], dim=1) D_real_outs = D(labels_instE_realimgs) D_real_loss = GANLoss(D_real_outs, True) labels_instE_fakeimgs = torch.cat([one_hot_labels.float(), instances_edge.float(), fakes.float()], dim=1) D_fake_outs = D(labels_instE_fakeimgs) D_fake_loss = GANLoss(D_fake_outs, False) D_loss = 0.5 * (D_real_loss + D_fake_loss) D_loss = D_loss.mean() D_loss.backward() D_loss = D_loss.item() D_optimizer.step() # train generator and encoder G_optimizer.zero_grad() E_optimizer.zero_grad() fakes = G(labels_instE_encodeF) labels_instE_fakeimgs = torch.cat([one_hot_labels.float(), instances_edge.float(), fakes.float()], dim=1) D_fake_outs = D(labels_instE_fakeimgs) gan_loss = GANLoss(D_fake_outs, True) G_loss = 0 G_loss += gan_loss gan_loss = gan_loss.mean().item() if args.use_vgg_loss: vgg_loss = VGGLoss(fakes, imgs) G_loss += args.lambda_feat * vgg_loss vgg_loss = vgg_loss.mean().item() else: vgg_loss = 0. if args.use_ganFeat_loss: df_loss = DFLoss(D_fake_outs, D_real_outs) G_loss += args.lambda_feat * df_loss df_loss = df_loss.mean().item() else: df_loss = 0. G_loss = G_loss.mean() G_loss.backward() G_loss = G_loss.item() G_optimizer.step() E_optimizer.step() data_loader.write(f'Epochs:{epoch} | Dloss:{D_loss:.6f} | Gloss:{G_loss:.6f}' f'| GANloss:{gan_loss:.6f} | VGGloss:{vgg_loss:.6f} ' f'| DFloss:{df_loss:.6f} | lr:{get_lr(G_optimizer):.8f}') G_loss_list.append(G_loss) D_loss_list.append(D_loss) # display if args.display and step % args.display == 0: visualizer.display(transforms.ToPILImage()(encode_features[0].cpu()), 'encode_feature') visualizer.display(transforms.ToPILImage()(imgs[0].cpu()), 'image') visualizer.display(transforms.ToPILImage()(fakes[0].cpu()), 'fake') visualizer.display(transforms.ToPILImage()(labels[0].cpu() * 15), 'label') visualizer.display(transforms.ToPILImage()(instances[0].cpu() * 15), 'instance') # tensorboard log if args.tensorboard_log and step % args.tensorboard_log == 0: total_steps = epoch * len(data_loader) + step sw.add_scalar('Loss/G', G_loss, total_steps) sw.add_scalar('Loss/D', D_loss, total_steps) sw.add_scalar('Loss/gan', gan_loss, total_steps) sw.add_scalar('Loss/vgg', vgg_loss, total_steps) sw.add_scalar('Loss/df', df_loss, total_steps) sw.add_scalar('LR/G', get_lr(G_optimizer), total_steps) sw.add_scalar('LR/D', get_lr(D_optimizer), total_steps) sw.add_scalar('LR/E', get_lr(E_optimizer), total_steps) sw.add_image('img/real', imgs[0].cpu(), step) sw.add_image('img/fake', fakes[0].cpu(), step) sw.add_image('visual/encode_feature', encode_features[0].cpu(), step) sw.add_image('visual/instance', instances[0].cpu(), step) sw.add_image('visual/label', labels[0].cpu(), step) D_scheduler.step(epoch) G_scheduler.step(epoch) E_scheduler.step(epoch) logger.log(key='D_loss', data=sum(D_loss_list) / float(len(D_loss_list))) logger.log(key='G_loss', data=sum(G_loss_list) / float(len(G_loss_list))) logger.save_log() logger.visualize() model_saver.save('G', G) model_saver.save('D', D) model_saver.save('E', E) model_saver.save('G_optimizer', G_optimizer) model_saver.save('D_optimizer', D_optimizer) model_saver.save('E_optimizer', E_optimizer) model_saver.save('G_scheduler', G_scheduler) model_saver.save('D_scheduler', D_scheduler) model_saver.save('E_scheduler', E_scheduler)
def eval(args, model, data_loader, model_seg=None): device = get_device(args) data_loader = tqdm(data_loader) model.eval() model = model.to(device) if not (model_seg is None): model_seg.eval() model_seg = model_seg.to(device) seg_dir = osp.join(args.save, 'seg_result') label_preds = [] label_targets = [] fake_dir = osp.join(args.save, 'fake_result') real_dir = osp.join(args.save, 'real_result') A_dir = osp.join(args.save, 'real_source') seg_dir = osp.join(args.save, 'seg_result') create_dir(real_dir) create_dir(fake_dir) create_dir(A_dir) create_dir(seg_dir) for i, sample in enumerate(data_loader): # imgs = sample['image'].to(device) # maps = sample['map'].to(device) # im_name = sample['im_name'] imgs = sample['A'].to(device) maps = sample['B'].to(device) im_name = sample['A_paths'] with torch.no_grad(): if model_seg is None: fakes = model(imgs) else: outputs, feature_map = model_seg(imgs) input_2 = F.upsample(feature_map, size=(64, 64), mode="bilinear") # BS*256*64*64 input_3 = F.upsample(feature_map, size=(128, 128), mode="bilinear") # BS*256*128*128 fakes = model(imgs, input_2, input_3) # 以下为计算iou的准备 bs, n_class, h, w = outputs.shape outs = outputs.data.cpu().numpy() pred = outs.transpose(0, 2, 3, 1).reshape( -1, n_class).argmax(axis=1).reshape(bs, h, w) target = sample['seg'].cpu().numpy().reshape(bs, h, w) label_preds.append(pred) label_targets.append(target) batch_size = imgs.size(0) if not (model_seg is None): from src.pix2pixHD.myutils import pred2gray outputs = pred2gray(outputs) for b in range(batch_size): file_name = osp.split(im_name[b])[0].split( os.sep)[-2] + '_' + osp.split(im_name[b])[0].split( os.sep)[-1] + '_' + osp.split(im_name[b])[-1].split('.')[0] real_file = osp.join(real_dir, f'{file_name}.tif') fake_file = osp.join(fake_dir, f'{file_name}.tif') A_file = osp.join(A_dir, f'{file_name}.tif') if not (model_seg is None): seg_file = osp.join(seg_dir, f'{file_name}.tif') # if not(model_seg is None): # seg_file = osp.join(seg_dir, f'{file_name}.tif') # from_std_tensor_save_image(filename=seg_file, data=torch.unsqueeze(outputs[b],0).cpu()) from_std_tensor_save_image(filename=real_file, data=maps[b].cpu()) from_std_tensor_save_image(filename=fake_file, data=fakes[b].cpu()) from_std_tensor_save_image(filename=A_file, data=imgs[b].cpu()) if not (model_seg is None): from_std_tensor_save_image(filename=seg_file, data=outputs[b].cpu()) pass pass fid = fid_score(real_path=real_dir, fake_path=fake_dir, gpu=str(args.gpu)) print(f'===> fid score:{fid:.4f}') iou = None if not (model_seg is None): from src.pix2pixHD.eval_iou import label_accuracy_score _, _, iou, _, _ = label_accuracy_score(label_targets, label_preds, n_class) model.train() return fid, iou