def train(net, dataloader, epoch, opt, criterion): net.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() for i, (imgs, cls_ids) in enumerate(dataloader): data_time.update(time.time() - end) imgs, cls_ids = imgs.to(device), cls_ids.to(device) opt.zero_grad() masks = generate_mask() masks = masks.to(device) if np.random.rand() < mask_rate: pred = net(imgs * (1 - masks)) else: pred = net(imgs) loss = criterion(pred, cls_ids) loss.backward() opt.step() #measure prec1, prec5 = accuracy(pred, cls_ids, topk=(1, 5)) losses.update(loss.item(), imgs.size(0)) top1.update(prec1[0], imgs.size(0)) top5.update(prec5[0], imgs.size(0)) batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: logger.info('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate(net, dataloader, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() net.eval() with torch.no_grad(): end = time.time() for i, (imgs, cls_ids) in enumerate(dataloader): imgs, cls_ids = imgs.to(device), cls_ids.to(device) masks = generate_mask() masks = masks.to(device) pred = net(imgs * (1 - masks)) loss = criterion(pred, cls_ids) #measure prec1, prec5 = accuracy(pred, cls_ids, topk=(1, 5)) losses.update(loss.item(), imgs.size(0)) top1.update(prec1[0], imgs.size(0)) top5.update(prec5[0], imgs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: logger.info( 'Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) return top1.avg, top5.avg
def train(netG, netD, GANLoss, ReconLoss, DLoss, NLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ netG.to(device) netD.to(device) batch_time = AverageMeter() data_time = AverageMeter() losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "whole_loss":AverageMeter(), "d_loss":AverageMeter(), 'n_loss': AverageMeter()} netG.train() netD.train() end = time.time() for i, (imgs, masks, gray) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['mine'] # masks = masks['random_free_form'] # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() imgs, masks, gray = imgs.to(device), masks.to(device), gray.to(device) # print(imgs.shape) masks = 1 - masks / 255.0 # masks = masks / 255.0 # 1 for masks, areas with holes # print(masks.min(), masks.max()) imgs = (imgs / 127.5 - 1) gray = (gray / 127.5 - 1) # mask is 1 on masked region coarse_imgs, refined, mixed = netG(gray, masks) # coarse_imgs, mixed = netG(imgs, masks) # coarse_imgs, mixed, attention = netG(imgs, masks) #print(attention.size(), ) # complete_imgs = mixed * masks + imgs * (1 - masks) complete_imgs = mixed # * masks + imgs * (1 - masks) # print(imgs.cpu().detach().max(), imgs.cpu().detach().min(), mixed.cpu().detach().max(), mixed.cpu().detach().min(), masks.cpu().detach().max(), masks.cpu().detach().min(), complete_imgs.cpu().detach().max(), complete_imgs.cpu().detach().min()) pos_imgs = imgs neg_imgs = complete_imgs # pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) # neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) d_loss.backward(retain_graph=True) optD.step() # Optimize Generator optD.zero_grad(), netD.zero_grad() optG.zero_grad(), netG.zero_grad() pred_neg = netD(neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, mixed, masks) n_loss = NLoss(coarse_imgs, refined, mixed, imgs) # whole_loss = r_loss + n_loss whole_loss = g_loss + r_loss + n_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['n_loss'].update(n_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward() optG.step() # print('w?', imgs.min(), imgs.max()) # Update time recorder batch_time.update(time.time() - end) # print(((imgs+1)*127.5).min(), ((imgs+1)*127.5).max()) if (i+1) % config.SUMMARY_FREQ == 0: # Logger logging # "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, \t New Loss: {n_loss.val:.4f}" logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f}, \t New Loss: {n_loss.val:.4f}, \t D Loss: {d_loss.val:.4f}" .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], n_loss=losses['n_loss'])) # , n_loss=losses['n_loss'])) # Tensorboard logger for scaler and images # info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item()} info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item(), "GANLoss":g_loss.item(), "DLoss":d_loss.item()} for tag, value in info_terms.items(): tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i) for tag, value in losses.items(): tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i) def img2photo(imgs): # return ((imgs+1)*127.5).detach().cpu().numpy() return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy() # info = { 'train/ori_imgs':img2photo(imgs), # 'train/coarse_imgs':img2photo(coarse_imgs), # 'train/mixed':img2photo(mixed), # 'train/comp_imgs':img2photo(complete_imgs), info = { 'train/whole_imgs':img2photo(torch.cat([imgs * (1 - masks) + masks, refined, imgs * masks, complete_imgs, imgs], dim=3)) } for tag, images in info.items(): tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i) if (i+1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None: with torch.no_grad(): validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, val_datas , epoch, device, batch_n=i) netG.train() # netD.train() end = time.time()
def pretrainD(netG, netD, GANLoss, ReconLoss, DLoss, NLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ logger.info("Pretraining D epoch %d"%epoch) netG.to(device) netD.to(device) batch_time = AverageMeter() data_time = AverageMeter() losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "whole_loss":AverageMeter(), "d_loss":AverageMeter(), 'n_loss': AverageMeter()} netG.train() netD.train() end = time.time() for i, (imgs, masks, gray) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['mine'] # masks = masks['random_free_form'] # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() imgs, masks, gray = imgs.to(device), masks.to(device), gray.to(device) # print(imgs.shape) masks = 1 - masks / 255.0 # masks = masks / 255.0 # 1 for masks, areas with holes # print(masks.min(), masks.max()) imgs = (imgs / 127.5 - 1) gray = (gray / 127.5 - 1) # mask is 1 on masked region coarse_imgs, refined, mixed = netG(gray, masks) # coarse_imgs, mixed = netG(imgs, masks) # coarse_imgs, mixed, attention = netG(imgs, masks) #print(attention.size(), ) # complete_imgs = mixed * masks + imgs * (1 - masks) complete_imgs = mixed # * masks + imgs * (1 - masks) # print(imgs.cpu().detach().max(), imgs.cpu().detach().min(), mixed.cpu().detach().max(), mixed.cpu().detach().min(), masks.cpu().detach().max(), masks.cpu().detach().min(), complete_imgs.cpu().detach().max(), complete_imgs.cpu().detach().min()) pos_imgs = imgs neg_imgs = complete_imgs # pos_imgs = torch.cat([imgs], dim=1) # pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) # neg_imgs = torch.cat([complete_imgs], dim=1) # neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) d_loss.backward(retain_graph=True) optD.step() batch_time.update(time.time() - end) if (i+1) % config.SUMMARY_FREQ == 0: logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, \t D Loss: {d_loss.val:.4f}" .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, d_loss=losses['d_loss']))
def train(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=device, val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ netG.to(device) netD.to(device) batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), 'd_loss': AverageMeter() } netG.train() netD.train() end = time.time() for i, (imgs, masks) in enumerate(dataloader): data_time.update(time.time() - end) #masks = masks['random_free_form'] # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 250) # mask is 1 on masked region coarse_imgs, recon_imgs = netG(imgs, masks) #print(attention.size(), ) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) d_loss.backward(retain_graph=True) optD.step() # Optimize Generator optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad() pred_neg = netD(neg_imgs) #pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward() optG.step() # Update time recorder batch_time.update(time.time() - end) '''
def validate(nets, loss_terms, opts, dataloader, epoch, network_type, devices=(cuda0, cuda1), batch_n="whole_test_show"): """ validate phase """ netD, netG = nets["netD"], nets["netG"] ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms[ 'ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms[ "GANLoss"], loss_terms["StyleLoss"] optG, optD = opts['optG'], opts['optD'] device0, device1 = devices netG.to(device0) netD.to(device0) netG.eval() netD.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "p_loss": AverageMeter(), "s_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), "d_loss": AverageMeter() } netG.train() netD.train() end = time.time() val_save_dir = os.path.join( result_dir, "val_{}_{}".format( epoch, batch_n if isinstance(batch_n, str) else batch_n + 1)) val_save_real_dir = os.path.join(val_save_dir, "real") val_save_gen_dir = os.path.join(val_save_dir, "gen") val_save_comp_dir = os.path.join(val_save_dir, "comp") for size in SIZES_TAGS: if not os.path.exists(os.path.join(val_save_real_dir, size)): os.makedirs(os.path.join(val_save_real_dir, size)) if not os.path.exists(os.path.join(val_save_gen_dir, size)): os.makedirs(os.path.join(val_save_gen_dir, size)) if not os.path.exists(os.path.join(val_save_comp_dir, size)): os.makedirs(os.path.join(val_save_comp_dir, size)) info = {} t = 0 for i, (ori_imgs, ori_masks) in enumerate(dataloader): data_time.update(time.time() - end) pre_imgs = ori_imgs pre_complete_imgs = (pre_imgs / 127.5 - 1) for s_i, size in enumerate(TRAIN_SIZES): masks = ori_masks['val'] masks = F.interpolate(masks, size) masks = (masks > 0).type(torch.FloatTensor) imgs = F.interpolate(ori_imgs, size) if imgs.size(1) != 3: print(t, imgs.size()) pre_inter_imgs = F.interpolate(pre_complete_imgs, size) imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to( device0), masks.to(device0), pre_complete_imgs.to( device0), pre_inter_imgs.to(device0) # masks = (masks > 0).type(torch.FloatTensor) # imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region # forward if network_type == 'l2h_unet': recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs, size) elif network_type == 'l2h_gated': recon_imgs = netG(imgs, masks, pre_inter_imgs) elif network_type == 'sa_gated': recon_imgs, _ = netG(imgs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat( [imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [recon_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks) imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to( device1), complete_imgs.to(device1) p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs) s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss( imgs, complete_imgs) p_loss, s_loss = p_loss.to(device0), s_loss.to(device0) imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to( device0), complete_imgs.to(device0) whole_loss = r_loss + p_loss # g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['p_loss'].update(p_loss.item(), imgs.size(0)) losses['s_loss'].update(s_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) pre_complete_imgs = complete_imgs # Update time recorder batch_time.update(time.time() - end) # Logger logging # if t < config.STATIC_VIEW_SIZE: print(i, size) real_img = img2photo(imgs) gen_img = img2photo(recon_imgs) comp_img = img2photo(complete_imgs) real_img = Image.fromarray(real_img[0].astype(np.uint8)) gen_img = Image.fromarray(gen_img[0].astype(np.uint8)) comp_img = Image.fromarray(comp_img[0].astype(np.uint8)) real_img.save( os.path.join(val_save_real_dir, SIZES_TAGS[s_i], "{}.png".format(i))) gen_img.save( os.path.join(val_save_gen_dir, SIZES_TAGS[s_i], "{}.png".format(i))) comp_img.save( os.path.join(val_save_comp_dir, SIZES_TAGS[s_i], "{}.png".format(i))) end = time.time()
def validate(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), batch_n="whole"): """ validate phase """ netG, netD = nets["netG"], nets["netD"] GANLoss, ReconLoss, L1ReconLoss, DLoss = loss_terms["GANLoss"], loss_terms["ReconLoss"], loss_terms["L1ReconLoss"], loss_terms["DLoss"] optG, optD = opts["optG"], opts["optD"] device0, device1 = devices[0], devices[1] netG.to(device0) netD.to(device0) # maskNetD.to(device1) netG.eval() netD.eval() # maskNetD.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "r_ex_loss":AverageMeter(), "whole_loss":AverageMeter(), 'd_loss':AverageMeter(), 'mask_d_loss':AverageMeter(), 'mask_rec_loss':AverageMeter(),'mask_whole_loss':AverageMeter()} end = time.time() val_save_dir = os.path.join(result_dir, "val_{}_{}".format(epoch, batch_n+1)) val_save_real_dir = os.path.join(val_save_dir, "real") val_save_gen_dir = os.path.join(val_save_dir, "gen") val_save_inf_dir = os.path.join(val_save_dir, "inf") if not os.path.exists(val_save_real_dir): os.makedirs(val_save_real_dir) os.makedirs(val_save_gen_dir) os.makedirs(val_save_inf_dir) info = {} for i, data in enumerate(dataloader): data_time.update(time.time() - end, 1) imgs, img_exs, masks = data masks = masks['val'] #masks = (masks > 0).type(torch.FloatTensor) imgs, img_exs, masks = imgs.to(device0), img_exs.to(device0), masks.to(device0) imgs = (imgs / 127.5 - 1) img_exs = (img_exs / 127.5 - 1) # mask is 1 on masked region # forward coarse_imgs, recon_imgs, recon_ex_imgs = netG(imgs, img_exs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) #mask_pos_neg_imgs = torch.cat([imgs, complete_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) # # Mask Gan # mask_pos_neg_imgs = mask_pos_neg_imgs.to(device1) # mask_pred_pos_neg = maskNetD(mask_pos_neg_imgs) # mask_pred_pos, mask_pred_neg = torch.chunk(mask_pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) r_ex_loss = L1ReconLoss(img_exs, recon_ex_imgs) whole_loss = g_loss + r_loss + r_ex_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['r_ex_loss'].update(r_ex_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) # masks = masks.to(device1) # mask_d_loss = DLoss(mask_pred_pos*masks + (1-masks), mask_pred_neg*masks + (1-masks)) # mask_rec_loss = L1ReconLoss(mask_pred_neg, masks) # mask_whole_loss = mask_rec_loss # masks = masks.to(device0) # losses['mask_d_loss'].update(mask_d_loss.item(), imgs.size(0)) # losses['mask_rec_loss'].update(mask_rec_loss.item(), imgs.size(0)) # losses['mask_whole_loss'].update(mask_whole_loss.item(), imgs.size(0)) # Update time recorder batch_time.update(time.time() - end, 1) # Logger logging if (i+1) < config.STATIC_VIEW_SIZE: def img2photo(imgs): return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy() # info = { 'val/ori_imgs':img2photo(imgs), # 'val/coarse_imgs':img2photo(coarse_imgs), # 'val/recon_imgs':img2photo(recon_imgs), # 'val/comp_imgs':img2photo(complete_imgs), info['val/whole_imgs/{}'.format(i)] = {"img":img2photo(torch.cat([imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs], dim=3)), } else: logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t Ex Recon Loss:{r_ex_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], \ r_loss=losses['r_loss'], r_ex_loss=losses['r_ex_loss'], g_loss=losses['g_loss'], d_loss=losses['d_loss'])) for tag, value in losses.items(): tensorboardlogger.scalar_summary('val/avg_'+tag, value.avg, epoch*len(dataloader)+i) j = 0 for tag, datas in info.items(): images = datas["img"] h, w = images.shape[1], images.shape[2] // 5 for kv, val_img in enumerate(images): real_img = val_img[:,(3*w):(4*w),:] gen_img = val_img[:,(4*w):(5*w),:] real_img = Image.fromarray(real_img.astype(np.uint8)) gen_img = Image.fromarray(gen_img.astype(np.uint8)) #pkl.dump({datas[term][kv] for term in datas if term != "img"}, open(os.path.join(val_save_inf_dir, "{}.png".format(j)), 'wb')) real_img.save(os.path.join(val_save_real_dir, "{}.png".format(j))) gen_img.save(os.path.join(val_save_gen_dir, "{}.png".format(j))) j += 1 tensorboardlogger.image_summary(tag, images, epoch) path1, path2 = val_save_real_dir, val_save_gen_dir fid_score = metrics['fid']([path1, path2], cuda=False) ssim_score = metrics['ssim']([path1, path2]) tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch*len(dataloader)+i) tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch*len(dataloader)+i) break end = time.time()
def validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=cuda0, batch_n="whole"): """ validate phase """ netG.to(device) netD.to(device) netG.eval() netD.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), "d_loss": AverageMeter() } netG.train() netD.train() end = time.time() val_save_dir = os.path.join( result_dir, "val_{}_{}".format( epoch, batch_n if isinstance(batch_n, str) else batch_n + 1)) val_save_real_dir = os.path.join(val_save_dir, "real") val_save_gen_dir = os.path.join(val_save_dir, "gen") val_save_raw_dir = os.path.join(val_save_dir, 'raw') # val_save_inf_dir = os.path.join(val_save_dir, "inf") if not os.path.exists(val_save_real_dir): os.makedirs(val_save_real_dir) os.makedirs(val_save_gen_dir) os.makedirs(val_save_raw_dir) info = {} for i, (imgs, masks, mean, std) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['random_free_form'] #masks = (masks > 0).type(torch.FloatTensor) imgs, masks = imgs.to(device), masks.to(device) # imgs = (imgs / 127.5 - 1) # mask is 1 on masked region # forward coarse_imgs, recon_imgs = netG(imgs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) # Update time recorder batch_time.update(time.time() - end) # Logger logging if i + 1 < config.STATIC_VIEW_SIZE: def img2photo(imgs): return (imgs * 255).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() # return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy() # info = { 'val/ori_imgs':img2photo(imgs), # 'val/coarse_imgs':img2photo(coarse_imgs), # 'val/recon_imgs':img2photo(recon_imgs), # 'val/comp_imgs':img2photo(complete_imgs), info['val/whole_imgs/{}'.format(i)] = img2photo( torch.cat( [((imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]) * (1 - masks) + masks, ((coarse_imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]), ((recon_imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]), ((imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0]), ((complete_imgs * std.cpu().numpy()[0]) + mean.cpu().numpy()[0])], dim=3)) else: logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'])) # vis.line([[r_loss.item(), (1 / config.GAN_LOSS_ALPHA) * g_loss.item(), d_loss.item()]], # epoch*len(dataloader)+i, win='validation_loss', update='append') j = 0 for tag, images in info.items(): h, w = images.shape[1], images.shape[2] // 5 for val_img in images: raw_img = val_img[:, 0:w, :] # raw_img = ((raw_img - np.min(raw_img)) / np.max(raw_img)) * 255 real_img = val_img[:, (3 * w):(4 * w), :] # real_img = ((real_img - np.min(real_img)) / np.max(real_img)) * 255 gen_img = val_img[:, (4 * w):, :] # gen_img = ((gen_img - np.min(gen_img)) / np.max(gen_img)) * 255 cv2.imwrite( os.path.join(val_save_real_dir, "{}.png".format(j)), real_img) cv2.imwrite( os.path.join(val_save_gen_dir, "{}.png".format(j)), gen_img) cv2.imwrite( os.path.join(val_save_raw_dir, "{}.png".format(j)), raw_img) j += 1 # tensorboardlogger.image_summary(tag, images, epoch) path1, path2 = val_save_real_dir, val_save_gen_dir fid_score = metrics['fid']([path1, path2], cuda=False) ssim_score = metrics['ssim']([path1, path2]) # vis.line([[fid_score.item(),ssim_score.item()]], [epoch*len(dataloader)+i], win='validation_metric', update='append') # tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch*len(dataloader)+i) # tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch*len(dataloader)+i) break end = time.time() # vis.line([[losses['r_loss'].out(), (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), losses['d_loss'].out()]], # [epoch], win='validation_loss', update='append') wandb.log({ "val_r_loss": losses['r_loss'].out(), "val_g_loss": (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), "val_d_loss": losses['d_loss'].out() })
def validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=cuda0): """ validate phase """ netG.eval() netD.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), "d_loss": AverageMeter() } netG.train() netD.train() end = time.time() for i, (imgs, masks) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['random_free_form'] #masks = (masks > 0).type(torch.FloatTensor) imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region # forward coarse_imgs, recon_imgs = netG(imgs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) # Update time recorder batch_time.update(time.time() - end) # Logger logging logger.info("Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" .format(epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'])) if i * config.BATCH_SIZE < config.STATIC_VIEW_SIZE: def img2photo(imgs): return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() info = { 'val/ori_imgs': img2photo(imgs), 'val/coarse_imgs': img2photo(coarse_imgs), 'val/recon_imgs': img2photo(recon_imgs), 'val/comp_imgs': img2photo(complete_imgs), 'val/whole_imgs': img2photo( torch.cat([imgs, coarse_imgs, recon_imgs, complete_imgs], dim=3)) } for tag, images in info.items(): tensorboardlogger.image_summary(tag, images, i) end = time.time()
def train(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ netG, netD = nets["netG"], nets["netD"] GANLoss, ReconLoss, L1ReconLoss, DLoss = loss_terms["GANLoss"], loss_terms["ReconLoss"], loss_terms["L1ReconLoss"], loss_terms["DLoss"] optG, optD = opts["optG"], opts["optD"] device0, device1 = devices[0], devices[1] netG.to(device0) netD.to(device0) # maskNetD.to(device1) batch_time = AverageMeter() data_time = AverageMeter() losses = {"g_loss":AverageMeter(), "r_loss":AverageMeter(), "r_ex_loss":AverageMeter(), "whole_loss":AverageMeter(), 'd_loss':AverageMeter(),} # 'mask_d_loss':AverageMeter(), 'mask_rec_loss':AverageMeter(),'mask_whole_loss':AverageMeter()} netG.train() netD.train() # maskNetD.train() end = time.time() for i, data in enumerate(dataloader): data_time.update(time.time() - end) imgs, img_exs, masks = data masks = masks['random_free_form'] # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() imgs, img_exs, masks = imgs.to(device0), img_exs.to(device0), masks.to(device0) imgs = (imgs / 127.5 - 1) img_exs = (img_exs / 127.5 - 1) # mask is 1 on masked region coarse_imgs, recon_imgs, recon_ex_imgs = netG(imgs, img_exs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat([complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) #mask_pos_neg_imgs = torch.cat([imgs, complete_imgs], dim=0) # Discriminator Loss pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) d_loss.backward(retain_graph=True) optD.step() # Mask Discriminator Loss # mask_pos_neg_imgs = mask_pos_neg_imgs.to(device1) # masks = masks.to(device1) # mask_pred_pos_neg = maskNetD(mask_pos_neg_imgs) # mask_pred_pos, mask_pred_neg = torch.chunk(mask_pred_pos_neg, 2, dim=0) # mask_d_loss = DLoss(mask_pred_pos*masks , mask_pred_neg*masks ) # mask_rec_loss = L1ReconLoss(mask_pred_neg, masks, masks=masks) losses['d_loss'].update(d_loss.item(), imgs.size(0)) # losses['mask_d_loss'].update(mask_d_loss.item(), imgs.size(0)) # losses['mask_rec_loss'].update(mask_rec_loss.item(), imgs.size(0)) # mask_whole_loss = mask_rec_loss # losses['mask_whole_loss'].update(mask_whole_loss.item(), imgs.size(0)) # mask_whole_loss.backward(retain_graph=True) # optMD.step() # Optimize Generator # masks = masks.to(device0) optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad(),# optMD.zero_grad(), maskNetD.zero_grad() pred_neg = netD(neg_imgs) #pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) r_ex_loss = L1ReconLoss(img_exs, recon_ex_imgs) whole_loss = g_loss + r_loss + r_ex_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['r_ex_loss'].update(r_ex_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward() optG.step() # Update time recorder batch_time.update(time.time() - end) if (i+1) % config.SUMMARY_FREQ == 0: # Logger logging logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f}, \t Ex Recon Loss:{r_ex_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, " \ .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time \ ,whole_loss=losses['whole_loss'], r_loss=losses['r_loss'], r_ex_loss=losses['r_ex_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'])) # Tensorboard logger for scaler and images info_terms = {'WGLoss':whole_loss.item(), 'ReconLoss':r_loss.item(), "GANLoss":g_loss.item(), "DLoss":d_loss.item(), } for tag, value in info_terms.items(): tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i) for tag, value in losses.items(): tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i) def img2photo(imgs): return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy() info = { 'train/whole_imgs':img2photo(torch.cat([imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs], dim=3)) } for tag, images in info.items(): tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i) if (i+1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None: validate(nets, loss_terms, opts, val_datas , epoch, devices, batch_n=i) netG.train() netD.train() #maskNetD.train() end = time.time()
def train(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=cuda0): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), 'd_loss': AverageMeter() } netG.train() netD.train() end = time.time() for i, (imgs, masks) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['random_free_form'] #masks = (masks > 0).type(torch.FloatTensor)# #print(len([i for i in masks.numpy().flatten() if i != 0])) # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region coarse_imgs, recon_imgs = netG(imgs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) #print(pred_pos.size()) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) d_loss.backward(retain_graph=True) optD.step() # Optimize Generator optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad() pred_neg = netD(neg_imgs) #pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward() optG.step() # Update time recorder batch_time.update(time.time() - end) if i % config.SUMMARY_FREQ == 0: # Logger logging logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \ .format(epoch, i, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'])) # Tensorboard logger for scaler and images info_terms = { 'WGLoss': whole_loss.item(), 'ReconLoss': r_loss.item(), "GANLoss": g_loss.item(), "DLoss": d_loss.item() } for tag, value in info_terms.items(): tensorboardlogger.scalar_summary(tag, value, epoch * len(dataloader) + i) for tag, value in losses.items(): tensorboardlogger.scalar_summary('avg_' + tag, value.avg, epoch * len(dataloader) + i) def img2photo(imgs): return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() info = { 'train/ori_imgs': img2photo(imgs), 'train/coarse_imgs': img2photo(coarse_imgs), 'train/recon_imgs': img2photo(recon_imgs), 'train/comp_imgs': img2photo(complete_imgs), 'train/whole_imgs': img2photo( torch.cat([imgs, coarse_imgs, recon_imgs, complete_imgs], dim=3)) } for tag, images in info.items(): tensorboardlogger.image_summary(tag, images, i) end = time.time()
def validate(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), batch_n="whole"): """ validate phase """ netD, netG, netSR = nets["netD"], nets["netG"], nets["netSR"] ReconLoss, DLoss, PercLoss, GANLoss, StyleLoss = loss_terms[ 'ReconLoss'], loss_terms['DLoss'], loss_terms["PercLoss"], loss_terms[ "GANLoss"], loss_terms["StyleLoss"] optG, optD = opts['optG'], opts['optD'] device0, device1 = devices netG.to(device0) netD.to(device0) netSR.to(device0) netG.eval() netD.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "p_loss": AverageMeter(), "s_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), "d_loss": AverageMeter() } netG.train() netD.train() end = time.time() val_save_dir = os.path.join( result_dir, "val_{}_{}".format( epoch, batch_n if isinstance(batch_n, str) else batch_n + 1)) val_save_real_dir = os.path.join(val_save_dir, "real") val_save_gen_dir = os.path.join(val_save_dir, "gen") val_save_inf_dir = os.path.join(val_save_dir, "inf") if not os.path.exists(val_save_real_dir): os.makedirs(val_save_real_dir) os.makedirs(val_save_gen_dir) os.makedirs(val_save_inf_dir) info = {} for i, (ori_imgs, ori_masks) in enumerate(dataloader): data_time.update(time.time() - end) pre_imgs = ori_imgs pre_complete_imgs = (pre_imgs / 127.5 - 1) pre_complete_imgs = pre_complete_imgs * ( 1 - ori_masks['val']) + ori_masks['val'] pre_inter_imgs = F.interpolate(pre_complete_imgs, TRAIN_SIZES[0]) for s_j, size in enumerate(TRAIN_SIZES): masks = ori_masks['val'] masks = F.interpolate(masks, size) masks = (masks > 0).type(torch.FloatTensor) imgs = F.interpolate(ori_imgs, size) if s_j == 0: pre_inter_imgs = F.interpolate(pre_complete_imgs, size) else: pre_complete_imgs = (pre_complete_imgs + 1) * 127.5 pre_inter_imgs = netSR(pre_complete_imgs, 2) pre_inter_imgs = (pre_inter_imgs / 127.5 - 1) #upsampled_imgs = pre_inter_imgs imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to( device0), masks.to(device0), pre_complete_imgs.to( device0), pre_inter_imgs.to(device0) #masks = (masks > 0).type(torch.FloatTensor) upsampled_imgs = pre_inter_imgs #imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region # forward recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs, size) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat( [imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks) imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to( device1), complete_imgs.to(device1) p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs) #s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs) #p_loss, s_loss = p_loss.to(device0), s_loss.to(device0) p_loss = p_loss.to(device0) imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to( device0), complete_imgs.to(device0) whole_loss = r_loss + p_loss + g_loss #+ s_loss#g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['p_loss'].update(p_loss.item(), imgs.size(0)) losses['s_loss'].update(0, imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) pre_complete_imgs = complete_imgs # Update time recorder batch_time.update(time.time() - end) # Logger logging if i + 1 < config.STATIC_VIEW_SIZE: def img2photo(imgs): return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() # info = { 'val/ori_imgs':img2photo(imgs), # 'val/coarse_imgs':img2photo(coarse_imgs), # 'val/recon_imgs':img2photo(recon_imgs), # 'val/comp_imgs':img2photo(complete_imgs), info['val/{}whole_imgs/{}'.format(size, i)] = img2photo( torch.cat([ imgs * (1 - masks), upsampled_imgs, recon_imgs, imgs, complete_imgs ], dim=3)) else: logger.info("Validation Epoch {0}, [{1}/{2}]: Size:{size}, Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f},\t Perc Loss:{p_loss.val:.4f},\tStyle Loss:{s_loss.val:.4f}" .format(epoch, i+1, len(dataloader),size=size, batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], p_loss=losses['p_loss'], s_loss=losses['s_loss'])) j = 0 for size in SIZES_TAGS: if not os.path.exists(os.path.join(val_save_real_dir, size)): os.makedirs(os.path.join(val_save_real_dir, size)) os.makedirs(os.path.join(val_save_gen_dir, size)) for tag, images in info.items(): h, w = images.shape[1], images.shape[2] // 5 s_i = 0 for i_, s in enumerate(TRAIN_SIZES): if "{}".format(s) in tag: size_tag = "{}".format(s) s_i = i_ break for val_img in images: real_img = val_img[:, (3 * w):(4 * w), :] gen_img = val_img[:, (4 * w):, :] real_img = Image.fromarray(real_img.astype(np.uint8)) gen_img = Image.fromarray(gen_img.astype(np.uint8)) real_img.save( os.path.join(val_save_real_dir, SIZES_TAGS[s_i], "{}_{}.png".format(size_tag, j))) gen_img.save( os.path.join(val_save_gen_dir, SIZES_TAGS[s_i], "{}_{}.png".format(size_tag, j))) j += 1 tensorboardlogger.image_summary(tag, images, epoch) path1, path2 = os.path.join( val_save_real_dir, SIZES_TAGS[len(SIZES_TAGS) - 1]), os.path.join( val_save_gen_dir, SIZES_TAGS[len(SIZES_TAGS) - 1]) fid_score = metrics['fid']([path1, path2], cuda=False) ssim_score = metrics['ssim']([path1, path2]) tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch * len(dataloader) + i) tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch * len(dataloader) + i) break end = time.time() saved_model = { 'epoch': epoch + 1, 'netG_state_dict': netG.to(cpu0).state_dict(), 'netD_state_dict': netD.to(cpu0).state_dict(), # 'optG' : optG.state_dict(), # 'optD' : optD.state_dict() } torch.save(saved_model, '{}/latest_ckpt.pth.tar'.format(log_dir, epoch + 1))
def train(nets, loss_terms, opts, dataloader, epoch, devices=(cuda0, cuda1), val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ netD, netG, netSR = nets["netD"], nets["netG"], nets["netSR"] ReconLoss, DLoss, GANLoss, PercLoss, StyleLoss = loss_terms[ 'ReconLoss'], loss_terms['DLoss'], loss_terms['GANLoss'], loss_terms[ "PercLoss"], loss_terms["StyleLoss"] optG, optD = opts['optG'], opts['optD'] device0, device1 = devices netG.to(device0) netD.to(device0) netSR.to(device0) batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "s_loss": AverageMeter(), 'p_loss': AverageMeter(), "whole_loss": AverageMeter(), 'd_loss': AverageMeter() } netG.train() netD.train() end = time.time() for i, (ori_imgs, ori_masks) in enumerate(dataloader): ff_mask, rect_mask = ori_masks['random_free_form'], ori_masks[ 'random_bbox'] if np.random.rand() < 0.3: ori_masks = rect_mask else: ori_masks = ff_mask # Optimize Discriminator # mask is 1 on masked region pre_complete_imgs = ori_imgs pre_complete_imgs = (pre_complete_imgs / 127.5 - 1) pre_complete_imgs = pre_complete_imgs * (1 - ori_masks) + ori_masks pre_complete_imgs = F.interpolate(pre_complete_imgs, TRAIN_SIZES[0]) for s_j, size in enumerate(TRAIN_SIZES): data_time.update(time.time() - end) optD.zero_grad(), netD.zero_grad(), netG.zero_grad( ), optG.zero_grad() #Reshape masks = F.interpolate(ori_masks, size) masks = (masks > 0).type(torch.FloatTensor) imgs = F.interpolate(ori_imgs, size) if s_j == 0: pre_inter_imgs = F.interpolate(pre_complete_imgs, size) else: pre_complete_imgs = (pre_complete_imgs + 1) * 127.5 pre_inter_imgs = netSR(pre_complete_imgs, 2) pre_inter_imgs = (pre_inter_imgs / 127.5 - 1) imgs, masks, pre_complete_imgs, pre_inter_imgs = imgs.to( device0), masks.to(device0), pre_complete_imgs.to( device0), pre_inter_imgs.to(device0) imgs = (imgs / 127.5 - 1) upsampled_imgs = pre_inter_imgs recon_imgs = netG(imgs, masks, pre_complete_imgs, pre_inter_imgs, size) #print(attention.size(), ) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat( [imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) #print(size) if i % 3: d_loss.backward(retain_graph=True) optD.step() # Optimize Generator optD.zero_grad(), netD.zero_grad(), optG.zero_grad( ), netG.zero_grad() pred_neg = netD(neg_imgs) #pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, recon_imgs, recon_imgs, masks) imgs, recon_imgs, complete_imgs = imgs.to(device1), recon_imgs.to( device1), complete_imgs.to(device1) p_loss = PercLoss(imgs, recon_imgs) + PercLoss(imgs, complete_imgs) #s_loss = StyleLoss(imgs, recon_imgs) + StyleLoss(imgs, complete_imgs) #p_loss, s_loss = p_loss.to(device0), s_loss.to(device0) p_loss = p_loss.to(device0) imgs, recon_imgs, complete_imgs = imgs.to(device0), recon_imgs.to( device0), complete_imgs.to(device0) whole_loss = r_loss + p_loss + g_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['p_loss'].update(p_loss.item(), imgs.size(0)) losses['s_loss'].update(0, imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward(retain_graph=True) optG.step() pre_complete_imgs = complete_imgs # Update time recorder batch_time.update(time.time() - end) if (i + 1) % config.SUMMARY_FREQ == 0: # Logger logging logger.info("Epoch {0}, [{1}/{2}]:Size:{size} Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}, \t Perc Loss:{p_loss.val:.4f}, \t Style Loss:{s_loss.val:.4f}" \ .format(epoch, i+1, len(dataloader), size=size, batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'], p_loss=losses['p_loss'], s_loss=losses['s_loss'])) # Tensorboard logger for scaler and images info_terms = { '{}WGLoss'.format(size): whole_loss.item(), '{}ReconLoss'.format(size): r_loss.item(), "{}GANLoss".format(size): g_loss.item(), "{}DLoss".format(size): d_loss.item(), "{}PercLoss".format(size): p_loss.item() } for tag, value in info_terms.items(): tensorboardlogger.scalar_summary( tag, value, epoch * len(dataloader) + i) for tag, value in losses.items(): tensorboardlogger.scalar_summary( 'avg_' + tag, value.avg, epoch * len(dataloader) + i) def img2photo(imgs): return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() # info = { 'train/ori_imgs':img2photo(imgs), # 'train/coarse_imgs':img2photo(coarse_imgs), # 'train/recon_imgs':img2photo(recon_imgs), # 'train/comp_imgs':img2photo(complete_imgs), info = { 'train/{}whole_imgs'.format(size): img2photo( torch.cat([ imgs * (1 - masks), upsampled_imgs, recon_imgs, imgs, complete_imgs ], dim=3)) } for tag, images in info.items(): tensorboardlogger.image_summary( tag, images, epoch * len(dataloader) + i) end = time.time() if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None: validate(nets, loss_terms, opts, val_datas, epoch, devices, batch_n=i) netG.train() netD.train() netG.to(device0) netD.to(device0)
def train(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ netG.to(device) netD.to(device) batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), 'd_loss': AverageMeter() } netG.train() netD.train() end = time.time() for i, (imgs, masks) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['random_free_form'] # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() guide = [] transform = transforms.Compose([transforms.ToPILImage()]) for k in range(imgs.shape[0]): im = transform(imgs[k]) im = np.array(im) # cv2.imwrite('test.jpg', im) im = cv2.Canny(image=im, threshold1=20, threshold2=220) # cv2.imwrite('test1.jpg', im) # exit(1) guide.append(im) guide = torch.FloatTensor(guide) guide = guide[:, None, :, :] imgs, masks, guide = imgs.to(device), masks.to(device), guide.to( device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region guide = guide / 255.0 coarse_imgs, recon_imgs, attention = netG(imgs, masks, guide) # print(attention.size(), ) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) d_loss.backward(retain_graph=True) optD.step() # Optimize Generator optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad() pred_neg = netD(neg_imgs) # pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward() optG.step() # Update time recorder batch_time.update(time.time() - end) if (i + 1) % config.SUMMARY_FREQ == 0: # Logger logging logger.info( "Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \ .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ , g_loss=losses['g_loss'], d_loss=losses['d_loss'])) # Tensorboard logger for scaler and images info_terms = { 'WGLoss': whole_loss.item(), 'ReconLoss': r_loss.item(), "GANLoss": g_loss.item(), "DLoss": d_loss.item() } for tag, value in info_terms.items(): tensorboardlogger.scalar_summary(tag, value, epoch * len(dataloader) + i) for tag, value in losses.items(): tensorboardlogger.scalar_summary('avg_' + tag, value.avg, epoch * len(dataloader) + i) def img2photo(imgs): return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() # info = { 'train/ori_imgs':img2photo(imgs), # 'train/coarse_imgs':img2photo(coarse_imgs), # 'train/recon_imgs':img2photo(recon_imgs), # 'train/comp_imgs':img2photo(complete_imgs), info = { 'train/whole_imgs': img2photo( torch.cat([ imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs ], dim=3)) } for tag, images in info.items(): tensorboardlogger.image_summary(tag, images, epoch * len(dataloader) + i) if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None: validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, val_datas, epoch, device, batch_n=i) netG.train() netD.train() end = time.time()
def train(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ netG.to(device) netD.to(device) batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), 'd_loss': AverageMeter() } netG.train() netD.train() end = time.time() for i, (imgs, masks) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['random_free_form'] #traditional inpainting for item in range(len(imgs)): img = np.array(transforms.ToPILImage()(imgs[item])) mask = np.array(transforms.ToPILImage()(masks[item])) res = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA) res = transforms.ToTensor()(res) res = (res * 255) / 127.5 - 1 if item: traditional_inpaint = torch.cat((traditional_inpaint, res)) else: traditional_inpaint = res traditional_inpaint = torch.reshape(traditional_inpaint, (config.BATCH_SIZE, 3, 256, 256)) traditional_inpaint = traditional_inpaint.to(device) # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region #print(type(masks)) #print(type(guidence)) #exit() coarse_imgs, recon_imgs_with_weight = netG(imgs, masks) recon_imgs = recon_imgs_with_weight[:, 0:3, :, :] weight_layer = (recon_imgs_with_weight[:, 3:, :, :] + 1.0) / 2 recon_imgs = weight_layer * recon_imgs + ( 1 - weight_layer) * traditional_inpaint #print(attention.size(), ) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) d_loss.backward(retain_graph=True) optD.step() # Optimize Generator optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad() pred_neg = netD(neg_imgs) #pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward() optG.step() # Update time recorder batch_time.update(time.time() - end) if (i + 1) % config.SUMMARY_FREQ == 0: # Logger logging logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \ .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'])) # Tensorboard logger for scaler and images info_terms = { 'WGLoss': whole_loss.item(), 'ReconLoss': r_loss.item(), "GANLoss": g_loss.item(), "DLoss": d_loss.item() } for tag, value in info_terms.items(): tensorboardlogger.scalar_summary(tag, value, epoch * len(dataloader) + i) for tag, value in losses.items(): tensorboardlogger.scalar_summary('avg_' + tag, value.avg, epoch * len(dataloader) + i) def img2photo(imgs): return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() # info = { 'train/ori_imgs':img2photo(imgs), # 'train/coarse_imgs':img2photo(coarse_imgs), # 'train/recon_imgs':img2photo(recon_imgs), # 'train/comp_imgs':img2photo(complete_imgs), info = { 'train/whole_imgs': img2photo( torch.cat([ imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs ], dim=3)) } for tag, images in info.items(): tensorboardlogger.image_summary(tag, images, epoch * len(dataloader) + i) if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None: validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, val_datas, epoch, device, batch_n=i) netG.train() netD.train() end = time.time()
def validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=cuda0, batch_n="whole"): """ validate phase """ netG.to(device) netD.to(device) netG.eval() netD.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), "d_loss": AverageMeter() } netG.train() netD.train() end = time.time() val_save_dir = os.path.join( result_dir, "val_{}_{}".format( epoch, batch_n if isinstance(batch_n, str) else batch_n + 1)) val_save_real_dir = os.path.join(val_save_dir, "real") val_save_gen_dir = os.path.join(val_save_dir, "gen") val_save_inf_dir = os.path.join(val_save_dir, "inf") if not os.path.exists(val_save_real_dir): os.makedirs(val_save_real_dir) os.makedirs(val_save_gen_dir) os.makedirs(val_save_inf_dir) info = {} for i, (imgs, masks) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['val'] # masks = (masks > 0).type(torch.FloatTensor) imgs, masks = imgs.to(device), masks.to(device) imgs = (imgs / 127.5 - 1) # mask is 1 on masked region # forward coarse_imgs, recon_imgs, attention = netG.forward(imgs, masks) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) # Update time recorder batch_time.update(time.time() - end) # Logger logging if i + 1 < config.STATIC_VIEW_SIZE: def img2photo(imgs): return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() # info = { 'val/ori_imgs':img2photo(imgs), # 'val/coarse_imgs':img2photo(coarse_imgs), # 'val/recon_imgs':img2photo(recon_imgs), # 'val/comp_imgs':img2photo(complete_imgs), info['val/whole_imgs/{}'.format(i)] = img2photo( torch.cat([ imgs * (1 - masks), coarse_imgs, recon_imgs, imgs, complete_imgs ], dim=3)) else: logger.info( "Validation Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f},\t Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" .format(epoch, i + 1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ , g_loss=losses['g_loss'], d_loss=losses['d_loss'])) j = 0 for tag, images in info.items(): h, w = images.shape[1], images.shape[2] // 5 for val_img in images: real_img = val_img[:, (3 * w):(4 * w), :] gen_img = val_img[:, (4 * w):, :] real_img = Image.fromarray(real_img.astype(np.uint8)) gen_img = Image.fromarray(gen_img.astype(np.uint8)) real_img.save( os.path.join(val_save_real_dir, "{}.png".format(j))) gen_img.save( os.path.join(val_save_gen_dir, "{}.png".format(j))) j += 1 tensorboardlogger.image_summary(tag, images, epoch) path1, path2 = val_save_real_dir, val_save_gen_dir fid_score = metrics['fid']([path1, path2], cuda=False) ssim_score = metrics['ssim']([path1, path2]) tensorboardlogger.scalar_summary('val/fid', fid_score.item(), epoch * len(dataloader) + i) tensorboardlogger.scalar_summary('val/ssim', ssim_score.item(), epoch * len(dataloader) + i) break end = time.time()
def train(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, dataloader, epoch, device=cuda0, val_datas=None): """ Train Phase, for training and spectral normalization patch gan in Free-Form Image Inpainting with Gated Convolution (snpgan) """ # wandb.watch(netG, netD) netG.to(device) netD.to(device) batch_time = AverageMeter() data_time = AverageMeter() losses = { "g_loss": AverageMeter(), "r_loss": AverageMeter(), "whole_loss": AverageMeter(), 'd_loss': AverageMeter() } netG.train() netD.train() end = time.time() for i, (imgs, masks, mean, std) in enumerate(dataloader): data_time.update(time.time() - end) masks = masks['random_free_form'] # Optimize Discriminator optD.zero_grad(), netD.zero_grad(), netG.zero_grad(), optG.zero_grad() imgs, masks = imgs.to(device), masks.to(device) # imgs = (imgs / 127.5 - 1) # mask is 1 on masked region coarse_imgs, recon_imgs = netG(imgs, masks) #print(attention.size(), ) complete_imgs = recon_imgs * masks + imgs * (1 - masks) pos_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1) neg_imgs = torch.cat( [complete_imgs, masks, torch.full_like(masks, 1.)], dim=1) pos_neg_imgs = torch.cat([pos_imgs, neg_imgs], dim=0) pred_pos_neg = netD(pos_neg_imgs) pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) d_loss = DLoss(pred_pos, pred_neg) losses['d_loss'].update(d_loss.item(), imgs.size(0)) d_loss.backward(retain_graph=True) optD.step() # Optimize Generator optD.zero_grad(), netD.zero_grad(), optG.zero_grad(), netG.zero_grad() pred_neg = netD(neg_imgs) #pred_pos, pred_neg = torch.chunk(pred_pos_neg, 2, dim=0) g_loss = GANLoss(pred_neg) r_loss = ReconLoss(imgs, coarse_imgs, recon_imgs, masks) whole_loss = g_loss + r_loss # Update the recorder for losses losses['g_loss'].update(g_loss.item(), imgs.size(0)) losses['r_loss'].update(r_loss.item(), imgs.size(0)) losses['whole_loss'].update(whole_loss.item(), imgs.size(0)) whole_loss.backward() optG.step() # Update time recorder batch_time.update(time.time() - end) if (i + 1) % config.SUMMARY_FREQ == 0: # Logger logging logger.info("Epoch {0}, [{1}/{2}]: Batch Time:{batch_time.val:.4f},\t Data Time:{data_time.val:.4f}, Whole Gen Loss:{whole_loss.val:.4f}\t," "Recon Loss:{r_loss.val:.4f},\t GAN Loss:{g_loss.val:.4f},\t D Loss:{d_loss.val:.4f}" \ .format(epoch, i+1, len(dataloader), batch_time=batch_time, data_time=data_time, whole_loss=losses['whole_loss'], r_loss=losses['r_loss'] \ ,g_loss=losses['g_loss'], d_loss=losses['d_loss'])) # Tensorboard logger for scaler and images info_terms = { 'ReconLoss': losses['r_loss'], "GANLoss": losses['g_loss'], "DLoss": d_loss.item() } # vis.line([[r_loss.item(), (1 / config.GAN_LOSS_ALPHA) * g_loss.item(), d_loss.item()]], # [epoch*len(dataloader)+i], win='train_loss', update='append') # for tag, value in info_terms.items(): # tensorboardlogger.scalar_summary(tag, value, epoch*len(dataloader)+i) # # for tag, value in losses.items(): # tensorboardlogger.scalar_summary('avg_'+tag, value.avg, epoch*len(dataloader)+i) def img2photo(imgs): # return (((imgs*0.263)+0.472)*255).transpose(1,2).transpose(2,3).detach().cpu().numpy() return ((imgs + 1) * 127.5).transpose(1, 2).transpose( 2, 3).detach().cpu().numpy() # info = { 'train/ori_imgs':img2photo(imgs), # 'train/coarse_imgs':img2photo(coarse_imgs), # 'train/recon_imgs':img2photo(recon_imgs), # 'train/comp_imgs':img2photo(complete_imgs), info = { 'train/whole_imgs': img2photo( torch.cat([ imgs * (1 - masks) + masks, coarse_imgs, recon_imgs, imgs, complete_imgs ], dim=3)) } # for tag, images in info.items(): # tensorboardlogger.image_summary(tag, images, epoch*len(dataloader)+i) if (i + 1) % config.VAL_SUMMARY_FREQ == 0 and val_datas is not None: validate(netG, netD, GANLoss, ReconLoss, DLoss, optG, optD, val_datas, epoch, device, batch_n=i) netG.train() netD.train() end = time.time() # vis.line([[losses['r_loss'].out(), (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), losses['d_loss'].out()]], # [epoch], win='train_loss', update='append') wandb.log({ "train_r_loss": losses['r_loss'].out(), "train_g_loss": (1 / config.GAN_LOSS_ALPHA) * losses['g_loss'].out(), "train_d_loss": losses['d_loss'].out() })