from model.resfcn256 import ResFCN256 from tools.WLP300dataset import PRNetDataset, ToTensor, ToNormalize, ToResize from tools.prnet_loss import WeightMaskLoss, INFO from config.config import FLAGS from utils.utils import save_image, test_data_preprocess, make_all_grids, make_grid from utils.losses import SSIM from torchvision import transforms from torch.utils.data import DataLoader # Set random seem for reproducibility manualSeed = 5 INFO("Random Seed", manualSeed) random.seed(manualSeed) torch.manual_seed(manualSeed) def main(data_dir): origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ ToTensor(), ToResize((256, 256)),
def main(data_dir): origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ ToTensor(), ToResize((256, 256)), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ])) wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=1) # 2) Intermediate Processing. transform_img = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ]) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] model = ResFCN256(resolution_input=256, resolution_output=256, channel=3, size=16) # Load the pre-trained weight if FLAGS['resume'] != "" and os.path.exists( os.path.join(FLAGS['pretrained'], FLAGS['resume'])): state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume'])) model.load_state_dict(state['prnet']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO( "Pre-trained weight cannot load successfully, train from scratch!") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(FLAGS["device"]) optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) bce_loss = torch.nn.BCEWithLogitsLoss() bce_loss.to(FLAGS["device"]) #Loss function for adversarial for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) loss_list_G, stat_list = [], [] for i, sample in enumerate(bar): uv_map, origin = sample['uv_map'].to( FLAGS['device']), sample['origin'].to(FLAGS['device']) # Inference. optimizer.zero_grad() uv_map_result = model(origin) loss_g = bce_loss(uv_map_result, uv_map) loss_g.backward() loss_list_G.append(loss_g.item()) optimizer.step() if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): print(" {} [BCE ({})]".format(ep, loss_list_G[-1])) origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = np.load("./test_data/test_obama.npy") origin, gt_uv_map = test_data_preprocess( origin), test_data_preprocess(gt_uv_map) origin, gt_uv_map = transform_img(origin), transform_img( gt_uv_map) origin_in = origin.unsqueeze_(0).cuda() pred_uv_map = model(origin_in).detach().cpu() save_image( [origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True) # Save model print("Save model") state = { 'prnet': model.state_dict(), 'Loss': loss_list_G, 'start_epoch': ep } torch.save(state, os.path.join(FLAGS['images'], '{}.pth'.format(ep))) scheduler.step()
def main(data_dir): # 0) Tensoboard Writer. writer = SummaryWriter(FLAGS['summary_path']) origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ToTensor(), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])])) wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=4) # 2) Intermediate Processing. transform_img = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ]) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] model = ResFCN256() # Load the pre-trained weight if FLAGS['resume'] and os.path.exists(os.path.join(FLAGS['images'], "latest.pth")): state = torch.load(os.path.join(FLAGS['images'], "latest.pth")) model.load_state_dict(state['prnet']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO("Pre-trained weight cannot load successfully, train from scratch!") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to("cuda") optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) Loss_list, Stat_list = [], [] for i, sample in enumerate(bar): uv_map, origin = sample['uv_map'].to(FLAGS['device']), sample['origin'].to(FLAGS['device']) # Inference. uv_map_result = model(origin) # Loss & ssim stat. logit = loss(uv_map_result, uv_map) stat_logit = stat_loss(uv_map_result, uv_map) # Record Loss. Loss_list.append(logit.item()) Stat_list.append(stat_logit.item()) # Update. optimizer.zero_grad() logit.backward() optimizer.step() bar.set_description(" {} [Loss(Paper)] {} [SSIM({})] {}".format(ep, Loss_list[-1], FLAGS["gauss_kernel"], Stat_list[-1])) # Record Training information in Tensorboard. if origin_img is None and uv_map_gt is None: origin_img, uv_map_gt = origin, uv_map uv_map_predicted = uv_map_result writer.add_scalar("Original Loss", Loss_list[-1], FLAGS["summary_step"]) writer.add_scalar("SSIM Loss", Stat_list[-1], FLAGS["summary_step"]) grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted) writer.add_image('original', grid_1, FLAGS["summary_step"]) writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"]) writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"]) writer.add_graph(model, uv_map) if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = np.load("./test_data/test_obama.npy") origin, gt_uv_map = test_data_preprocess(origin), test_data_preprocess(gt_uv_map) # origin, gt_uv_map = transform_img(origin), transform_img(gt_uv_map) origin_in = origin.unsqueeze_(0).cuda() pred_uv_map = model(origin_in).detach().cpu() save_image([origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True) # Save model state = { 'prnet': model.state_dict(), 'Loss': Loss_list, 'start_epoch': ep, } torch.save(state, os.path.join(FLAGS['images'], 'latest.pth')) scheduler.step() writer.close()
def main(data_dir): # 0) Tensoboard Writer. writer = SummaryWriter(FLAGS['summary_path']) origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ ToTensor(), ToResize((416, 416)), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ])) wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=1) # 2) Intermediate Processing. transform_img = transforms.Compose([ transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ]) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] model = ResFCN256(resolution_input=416, resolution_output=416, channel=3, size=16) discriminator = Discriminator() # Load the pre-trained weight if FLAGS['resume'] != "" and os.path.exists( os.path.join(FLAGS['pretrained'], FLAGS['resume'])): state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume'])) model.load_state_dict(state['prnet']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO( "Pre-trained weight cannot load successfully, train from scratch!") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(FLAGS["device"]) discriminator.to(FLAGS["device"]) optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"]) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) bce_loss = torch.nn.BCEWithLogitsLoss() bce_loss.to(FLAGS["device"]) #Loss function for adversarial for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) loss_list_G, stat_list = [], [] loss_list_D = [] for i, sample in enumerate(bar): uv_map, origin = sample['uv_map'].to( FLAGS['device']), sample['origin'].to(FLAGS['device']) # Inference. optimizer.zero_grad() uv_map_result = model(origin) # Update D optimizer_D.zero_grad() fake_detach = uv_map_result.detach() d_fake = discriminator(fake_detach) d_real = discriminator(uv_map) retain_graph = False if FLAGS['gan_type'] == 'GAN': loss_d = bce_loss(d_real, d_fake) elif FLAGS['gan_type'].find('WGAN') >= 0: loss_d = (d_fake - d_real).mean() if FLAGS['gan_type'].find('GP') >= 0: epsilon = torch.rand(fake_detach.shape[0]).view( -1, 1, 1, 1) epsilon = epsilon.to(fake_detach.device) hat = fake_detach.mul(1 - epsilon) + uv_map.mul(epsilon) hat.requires_grad = True d_hat = discriminator(hat) gradients = torch.autograd.grad(outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks elif FLAGS['gan_type'] == 'RGAN': better_real = d_real - d_fake.mean(dim=0, keepdim=True) better_fake = d_fake - d_real.mean(dim=0, keepdim=True) loss_d = bce_loss(better_real, better_fake) retain_graph = True if discriminator.training: loss_list_D.append(loss_d.item()) loss_d.backward(retain_graph=retain_graph) optimizer_D.step() if 'WGAN' in FLAGS['gan_type']: for p in discriminator.parameters(): p.data.clamp_(-1, 1) # Update G d_fake_bp = discriminator( uv_map_result) # for backpropagation, use fake as it is if FLAGS['gan_type'] == 'GAN': label_real = torch.ones_like(d_fake_bp) loss_g = bce_loss(d_fake_bp, label_real) elif FLAGS['gan_type'].find('WGAN') >= 0: loss_g = -d_fake_bp.mean() elif FLAGS['gan_type'] == 'RGAN': better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) loss_g = bce_loss(better_fake, better_real) loss_g.backward() loss_list_G.append(loss_g.item()) optimizer.step() stat_logit = stat_loss(uv_map_result, uv_map) stat_list.append(stat_logit.item()) #bar.set_description(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(ep, loss_list_G[-1], loss_list_D[-1],FLAGS["gauss_kernel"], stat_list[-1])) # Record Training information in Tensorboard. """ if origin_img is None and uv_map_gt is None: origin_img, uv_map_gt = origin, uv_map uv_map_predicted = uv_map_result writer.add_scalar("Original Loss", loss_list_G[-1], FLAGS["summary_step"]) writer.add_scalar("D Loss", loss_list_D[-1], FLAGS["summary_step"]) writer.add_scalar("SSIM Loss", stat_list[-1], FLAGS["summary_step"]) grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted) writer.add_image('original', grid_1, FLAGS["summary_step"]) writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"]) writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"]) writer.add_graph(model, uv_map) """ if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): print(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format( ep, loss_list_G[-1], loss_list_D[-1], FLAGS["gauss_kernel"], stat_list[-1])) origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = np.load("./test_data/test_obama.npy") origin, gt_uv_map = test_data_preprocess( origin), test_data_preprocess(gt_uv_map) origin, gt_uv_map = transform_img(origin), transform_img( gt_uv_map) origin_in = origin.unsqueeze_(0).cuda() pred_uv_map = model(origin_in).detach().cpu() save_image( [origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True) # Save model print("Save model") state = { 'prnet': model.state_dict(), 'Loss': loss_list_G, 'start_epoch': ep, 'Loss_D': loss_list_D, } torch.save(state, os.path.join(FLAGS['images'], '{}.pth'.format(ep))) scheduler.step() writer.close()
def main(data_dir): # 0) Tensoboard Writer. writer = SummaryWriter(FLAGS['summary_path']) origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ ToTensor(), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ])) wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=4) # 2) Intermediate Processing. transform_img = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ]) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] model = ResFCN256() discriminator = Discriminator1() # Load the pre-trained weight if FLAGS['resume'] and os.path.exists( os.path.join(FLAGS['images'], "latest.pth")): state = torch.load(os.path.join(FLAGS['images'], "latest.pth")) model.load_state_dict(state['prnet']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO( "Pre-trained weight cannot load successfully, train from scratch!") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to("cuda") discriminator.to("cuda") optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"]) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) bce_loss = torch.nn.BCEWithLogitsLoss() bce_loss.to("cuda") #Loss function for adversarial Tensor = torch.cuda.FloatTensor for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) Loss_list, Stat_list = [], [] Loss_list_D = [] for i, sample in enumerate(bar): #drop the last batch if i == 111: break # the dimension of the uv_map is 16 3 256 256 but the last sample is 10 3 256 256 uv_map, origin = sample['uv_map'].to( FLAGS['device']), sample['origin'].to(FLAGS['device']) # generate fake label fake_label = Variable(torch.zeros([16, 512])).cuda() # generate real lable real_label = Variable(torch.ones([16, 512])).cuda() # Sample noise as generator input #z = Variable(Tensor(np.random.normal(0, 1,([16, 3, 256, 256])))) # Inference. uv_map_result = model(origin) # Loss & ssim stat. # Loss measures generator's ability to fool the discriminator logit = bce_loss(discriminator(uv_map_result), real_label) stat_logit = stat_loss(uv_map_result, uv_map) # Measure discriminator's ability to classify real from generated samples real_loss = bce_loss(discriminator(uv_map), real_label) fake_loss = bce_loss( discriminator(uv_map_result).detach(), fake_label) d_loss = (real_loss + fake_loss) / 2 # Record Loss. Loss_list.append(logit.item()) Loss_list_D.append(d_loss.item()) Stat_list.append(stat_logit.item()) # Update. optimizer.zero_grad() logit.backward() optimizer.step() bar.set_description( " {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format( ep, Loss_list[-1], Loss_list_D[-1], FLAGS["gauss_kernel"], Stat_list[-1])) optimizer_D.zero_grad() d_loss.backward(retain_graph=True) optimizer_D.step() # Record Training information in Tensorboard. if origin_img is None and uv_map_gt is None: origin_img, uv_map_gt = origin, uv_map uv_map_predicted = uv_map_result writer.add_scalar("Original Loss", Loss_list[-1], FLAGS["summary_step"]) writer.add_scalar("D Loss", Loss_list_D[-1], FLAGS["summary_step"]) writer.add_scalar("SSIM Loss", Stat_list[-1], FLAGS["summary_step"]) grid_1, grid_2, grid_3 = make_grid( origin_img, normalize=True), make_grid(uv_map_gt), make_grid( uv_map_predicted) writer.add_image('original', grid_1, FLAGS["summary_step"]) writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"]) writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"]) writer.add_graph(model, uv_map) if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = np.load("./test_data/test_obama.npy") origin, gt_uv_map = test_data_preprocess( origin), test_data_preprocess(gt_uv_map) origin, gt_uv_map = transform_img(origin), transform_img( gt_uv_map) origin_in = origin.unsqueeze_(0).cuda() pred_uv_map = model(origin_in).detach().cpu() save_image( [origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True) # Save model print("Save model") state = { 'prnet': model.state_dict(), 'Loss': Loss_list, 'start_epoch': ep, 'Loss_D': Loss_list_D, } torch.save(state, os.path.join(FLAGS['images'], 'latest.pth')) scheduler.step() writer.close()
def main(data_dir): origin_img, uv_map_gt, uv_map_predicted = None, None, None if not os.path.exists(FLAGS['images']): os.mkdir(FLAGS['images']) # 1) Create Dataset of 300_WLP & Dataloader. wlp300 = PRNetDataset(root_dir=data_dir, transform=transforms.Compose([ ToTensor(), ToResize((416, 416)), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ])) wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=1) # 2) Intermediate Processing. transform_img = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ]) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] g_x = ResFCN256(resolution_input=416, resolution_output=416, channel=3, size=16) g_y = ResFCN256(resolution_input=416, resolution_output=416, channel=3, size=16) d_x = Discriminator() d_y = Discriminator() # Load the pre-trained weight if FLAGS['resume'] != "" and os.path.exists( os.path.join(FLAGS['pretrained'], FLAGS['resume'])): state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume'])) try: g_x.load_state_dict(state['g_x']) g_y.load_state_dict(state['g_y']) d_x.load_state_dict(state['d_x']) d_y.load_state_dict(state['d_y']) except Exception: g_x.load_state_dict(state['prnet']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO( "Pre-trained weight cannot load successfully, train from scratch!") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) g_x.to(FLAGS["device"]) g_y.to(FLAGS["device"]) d_x.to(FLAGS["device"]) d_y.to(FLAGS["device"]) optimizer_g = torch.optim.Adam(itertools.chain(g_x.parameters(), g_y.parameters()), lr=FLAGS["lr"], betas=(0.5, 0.999)) optimizer_d = torch.optim.Adam(itertools.chain(d_x.parameters(), d_y.parameters()), lr=FLAGS["lr"]) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99) stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) bce_loss = torch.nn.BCEWithLogitsLoss() bce_loss.to(FLAGS["device"]) l1_loss = nn.L1Loss().to(FLAGS["device"]) lambda_X = 10 lambda_Y = 10 #Loss function for adversarial for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) loss_list_cycle_x = [] loss_list_cycle_y = [] loss_list_d_x = [] loss_list_d_y = [] real_label = torch.ones(FLAGS['batch_size']) fake_label = torch.zeros(FLAGS['batch_size']) for i, sample in enumerate(bar): real_y, real_x = sample['uv_map'].to( FLAGS['device']), sample['origin'].to(FLAGS['device']) # x -> y' -> x^ optimizer_g.zero_grad() fake_y = g_x(real_x) prediction = d_x(fake_y) loss_g_x = bce_loss(prediction, real_label) x_hat = g_y(fake_y) loss_cycle_x = l1_loss(x_hat, real_x) * lambda_X loss_x = loss_g_x + loss_cycle_x loss_x.backward(retain_graph=True) optimizer_g.step() loss_list_cycle_x.append(loss_x.item()) # y -> x' -> y^ optimizer_g.zero_grad() fake_x = g_y(real_y) prediction = d_y(fake_x) loss_g_y = bce_loss(prediction, real_label) y_hat = g_x(fake_x) loss_cycle_y = l1_loss(y_hat, real_y) * lambda_Y loss_y = loss_g_y + loss_cycle_y loss_y.backward(retain_graph=True) optimizer_g.step() loss_list_cycle_y.append(loss_y.item()) # d_x optimizer_d.zero_grad() pred_real = d_x(real_y) loss_d_x_real = bce_loss(pred_real, real_label) pred_fake = d_x(fake_y) loss_d_x_fake = bce_loss(pred_fake, fake_label) loss_d_x = (loss_d_x_real + loss_d_x_fake) * 0.5 loss_d_x.backward() loss_list_d_x.append(loss_d_x.item()) optimizer_d.step() if 'WGAN' in FLAGS['gan_type']: for p in d_x.parameters(): p.data.clamp_(-1, 1) # d_y optimizer_d.zero_grad() pred_real = d_y(real_x) loss_d_y_real = bce_loss(pred_real, real_label) pred_fake = d_y(fake_x) loss_d_y_fake = bce_loss(pred_fake, fake_label) loss_d_y = (loss_d_y_real + loss_d_y_fake) * 0.5 loss_d_y.backward() loss_list_d_y.append(loss_d_y.item()) optimizer_d.step() if 'WGAN' in FLAGS['gan_type']: for p in d_y.parameters(): p.data.clamp_(-1, 1) if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): print( " {} [Loss_G_X] {} [Loss_G_Y] {} [Loss_D_X] {} [Loss_D_Y] {}" .format(ep, loss_list_g_x[-1], loss_list_g_y[-1], loss_list_d_x[-1], loss_list_d_y[-1])) origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = np.load("./test_data/test_obama.npy") origin, gt_uv_map = test_data_preprocess( origin), test_data_preprocess(gt_uv_map) origin, gt_uv_map = transform_img(origin), transform_img( gt_uv_map) origin_in = origin.unsqueeze_(0).cuda() pred_uv_map = g_x(origin_in).detach().cpu() save_image( [origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True) # Save model print("Save model") state = { 'g_x': g_x.state_dict(), 'g_y': g_y.state_dict(), 'd_x': d_x.state_dict(), 'd_y': d_y.state_dict(), 'start_epoch': ep, } torch.save(state, os.path.join(FLAGS['images'], '{}.pth'.format(ep))) scheduler.step()
def main(data_dir): # 0) Tensoboard Writer. writer = SummaryWriter(FLAGS['summary_path']) origin_img, uv_map_gt, uv_map_predicted = None, None, None # 1) Create Dataset of 300_WLP. train_data_dir = [ '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_AFW_HELEN_LFPW', '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_AFW_HELEN_LFPW_Flip', '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_IBUG_Src_Flip' ] wlp300 = PRNetDataset(root_dir=train_data_dir, transform=transforms.Compose([ ToTensor(), ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"]) ])) # 2) Create DataLoader. wlp300_dataloader = DataLoaderX(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=4) # 3) Create PRNet model. start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch'] model = ResFCN256() #GPU if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to("cuda") #Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999)) # scheduler_MultiStepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer,[11],gamma=0.5, last_epoch=-1) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-6, verbose=False) # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) # scheduler_StepLR = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5, last_epoch=-1) #apex混合精度训练 # from apex import amp # model , optimizer = amp.initialize(model,optimizer,opt_level="O1",verbosity=0) #Loss stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"]) loss = WeightMaskLoss(mask_path=FLAGS["mask_path"]) # Load the pre-trained weight if FLAGS['resume'] and os.path.exists( os.path.join(FLAGS['model_path'], "latest.pth")): state = torch.load(os.path.join( FLAGS['model_path'], "latest.pth")) #这是个字典,keys: ['prnet', 'Loss', 'start_epoch'] model.load_state_dict(state['prnet']) optimizer.load_state_dict(state['optimizer']) # amp.load_state_dict(state['amp']) start_epoch = state['start_epoch'] INFO("Load the pre-trained weight! Start from Epoch", start_epoch) else: start_epoch = 0 INFO( "Pre-trained weight cannot load successfully, train from scratch!") #Tensorboard model_input = torch.rand(FLAGS['batch_size'], 3, 256, 256) writer.add_graph = (model, model_input) nme_mean = 999 for ep in range(start_epoch, target_epoch): bar = tqdm(wlp300_dataloader) Loss_list, Stat_list = deque(maxlen=len(bar)), deque(maxlen=len(bar)) model.train() for i, sample in enumerate(bar): uv_map, origin = sample['uv_map'].to( FLAGS['device']), sample['origin'].to(FLAGS['device']) # Inference. uv_map_result = model(origin) # Loss & ssim stat. logit_loss = loss(uv_map_result, uv_map) stat_logit = stat_loss(uv_map_result, uv_map) # Record Loss. Loss_list.append(logit_loss.item()) Stat_list.append(stat_logit.item()) # Update. optimizer.zero_grad() logit_loss.backward() # with amp.scale_loss(logit_loss,optimizer) as scaled_loss: # scaled_loss.backward() optimizer.step() lr = optimizer.param_groups[0]['lr'] bar.set_description( " {} lr {} [Loss(Paper)] {:.5f} [SSIM({})] {:.5f}".format( ep, lr, np.mean(Loss_list), FLAGS["gauss_kernel"], np.mean(Stat_list))) # Record Training information in Tensorboard. # if origin_img is None and uv_map_gt is None: # origin_img, uv_map_gt = origin, uv_map # uv_map_predicted = uv_map_result #写入Tensorboard # FLAGS["summary_step"] += 1 # if FLAGS["summary_step"] % 500 ==0: # writer.add_scalar("Original Loss", Loss_list[-1], FLAGS["summary_step"]) # writer.add_scalar("SSIM Loss", Stat_list[-1], FLAGS["summary_step"]) # grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted) # writer.add_image('original', grid_1, FLAGS["summary_step"]) # writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"]) # writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"]) # writer.add_graph(model, uv_map) #每个epoch过后将Loss写入Tensorboard loss_mean = np.mean(Loss_list) writer.add_scalar("Original Loss", loss_mean, ep) lr = optimizer.param_groups[0]['lr'] writer.add_scalar("lr", lr, ep) # scheduler_StepLR.step() scheduler.step(loss_mean) del Loss_list del Stat_list #Test && Cal AFLW2000's NME model.eval() if ep % FLAGS["save_interval"] == 0: with torch.no_grad(): nme_mean = cal_aflw2000_nme( model, '/home/beitadoge/Data/PRNet_PyTorch_Data/AFLW2000') print("NME IS {}".format(nme_mean)) writer.add_scalar("Aflw2000_nme", nme_mean, ep) origin = cv2.imread("./test_data/obama_origin.jpg") gt_uv_map = cv2.imread("./test_data/obama_uv_posmap.jpg") origin, gt_uv_map = test_data_preprocess( origin), test_data_preprocess(gt_uv_map) origin_in = F.normalize(origin, FLAGS["normalize_mean"], FLAGS["normalize_std"], False).unsqueeze_(0) pred_uv_map = model(origin_in).detach().cpu() save_image( [origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map], os.path.join(FLAGS['model_path'], str(ep) + '.png'), nrow=1, normalize=True) # # Save model # state = { # 'prnet': model.state_dict(), # 'Loss': Loss_list, # 'start_epoch': ep, # } # torch.save(checkpoint, os.path.join(FLAGS['model_path'], 'epoch{}.pth'.format(ep))) checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), # 'amp': amp.state_dict(), 'start_epoch': ep } torch.save(checkpoint, os.path.join(FLAGS['model_path'], 'lastest.pth')) # adjust_learning_rate(lr , ep, optimizer) # scheduler.step(nme_mean) writer.close()