예제 #1
0
from model.resfcn256 import ResFCN256

from tools.WLP300dataset import PRNetDataset, ToTensor, ToNormalize, ToResize
from tools.prnet_loss import WeightMaskLoss, INFO

from config.config import FLAGS

from utils.utils import save_image, test_data_preprocess, make_all_grids, make_grid
from utils.losses import SSIM

from torchvision import transforms
from torch.utils.data import DataLoader

# Set random seem for reproducibility
manualSeed = 5
INFO("Random Seed", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)


def main(data_dir):
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((256, 256)),
예제 #2
0
def main(data_dir):
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((256, 256)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256(resolution_input=256,
                      resolution_output=256,
                      channel=3,
                      size=16)

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(FLAGS["device"])

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])

    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_G, stat_list = [], []
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            optimizer.zero_grad()
            uv_map_result = model(origin)
            loss_g = bce_loss(uv_map_result, uv_map)
            loss_g.backward()
            loss_list_G.append(loss_g.item())
            optimizer.step()

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(" {} [BCE ({})]".format(ep, loss_list_G[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'prnet': model.state_dict(),
                'Loss': loss_list_G,
                'start_epoch': ep
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()
예제 #3
0
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([ToTensor(),
                                                        ToNormalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])]))

    wlp300_dataloader = DataLoader(dataset=wlp300, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=4)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256()

    # Load the pre-trained weight
    if FLAGS['resume'] and os.path.exists(os.path.join(FLAGS['images'], "latest.pth")):
        state = torch.load(os.path.join(FLAGS['images'], "latest.pth"))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO("Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to("cuda")

    optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS["lr"], betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])

    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        Loss_list, Stat_list = [], []
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            uv_map_result = model(origin)

            # Loss & ssim stat.
            logit = loss(uv_map_result, uv_map)
            stat_logit = stat_loss(uv_map_result, uv_map)

            # Record Loss.
            Loss_list.append(logit.item())
            Stat_list.append(stat_logit.item())

            # Update.
            optimizer.zero_grad()
            logit.backward()
            optimizer.step()
            bar.set_description(" {} [Loss(Paper)] {} [SSIM({})] {}".format(ep, Loss_list[-1], FLAGS["gauss_kernel"], Stat_list[-1]))

            # Record Training information in Tensorboard.
            if origin_img is None and uv_map_gt is None:
                origin_img, uv_map_gt = origin, uv_map
            uv_map_predicted = uv_map_result

            writer.add_scalar("Original Loss", Loss_list[-1], FLAGS["summary_step"])
            writer.add_scalar("SSIM Loss", Stat_list[-1], FLAGS["summary_step"])

            grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted)

            writer.add_image('original', grid_1, FLAGS["summary_step"])
            writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            writer.add_graph(model, uv_map)

        if ep % FLAGS["save_interval"] == 0:
            with torch.no_grad():
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(origin), test_data_preprocess(gt_uv_map)

                # origin, gt_uv_map = transform_img(origin), transform_img(gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image([origin.cpu(), gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                           os.path.join(FLAGS['images'], str(ep) + '.png'), nrow=1, normalize=True)

            # Save model
            state = {
                'prnet': model.state_dict(),
                'Loss': Loss_list,
                'start_epoch': ep,
            }
            torch.save(state, os.path.join(FLAGS['images'], 'latest.pth'))

            scheduler.step()

    writer.close()
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((416, 416)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256(resolution_input=416,
                      resolution_output=416,
                      channel=3,
                      size=16)
    discriminator = Discriminator()

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(FLAGS["device"])
    discriminator.to(FLAGS["device"])

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])

    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_G, stat_list = [], []
        loss_list_D = []
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            optimizer.zero_grad()
            uv_map_result = model(origin)

            # Update D
            optimizer_D.zero_grad()
            fake_detach = uv_map_result.detach()
            d_fake = discriminator(fake_detach)
            d_real = discriminator(uv_map)
            retain_graph = False
            if FLAGS['gan_type'] == 'GAN':
                loss_d = bce_loss(d_real, d_fake)
            elif FLAGS['gan_type'].find('WGAN') >= 0:
                loss_d = (d_fake - d_real).mean()
                if FLAGS['gan_type'].find('GP') >= 0:
                    epsilon = torch.rand(fake_detach.shape[0]).view(
                        -1, 1, 1, 1)
                    epsilon = epsilon.to(fake_detach.device)
                    hat = fake_detach.mul(1 - epsilon) + uv_map.mul(epsilon)
                    hat.requires_grad = True
                    d_hat = discriminator(hat)
                    gradients = torch.autograd.grad(outputs=d_hat.sum(),
                                                    inputs=hat,
                                                    retain_graph=True,
                                                    create_graph=True,
                                                    only_inputs=True)[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_norm = gradients.norm(2, dim=1)
                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                    loss_d += gradient_penalty
            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
            elif FLAGS['gan_type'] == 'RGAN':
                better_real = d_real - d_fake.mean(dim=0, keepdim=True)
                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
                loss_d = bce_loss(better_real, better_fake)
                retain_graph = True

            if discriminator.training:
                loss_list_D.append(loss_d.item())
                loss_d.backward(retain_graph=retain_graph)
                optimizer_D.step()

                if 'WGAN' in FLAGS['gan_type']:
                    for p in discriminator.parameters():
                        p.data.clamp_(-1, 1)

            # Update G
            d_fake_bp = discriminator(
                uv_map_result)  # for backpropagation, use fake as it is
            if FLAGS['gan_type'] == 'GAN':
                label_real = torch.ones_like(d_fake_bp)
                loss_g = bce_loss(d_fake_bp, label_real)
            elif FLAGS['gan_type'].find('WGAN') >= 0:
                loss_g = -d_fake_bp.mean()
            elif FLAGS['gan_type'] == 'RGAN':
                better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
                better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
                loss_g = bce_loss(better_fake, better_real)

            loss_g.backward()
            loss_list_G.append(loss_g.item())
            optimizer.step()

            stat_logit = stat_loss(uv_map_result, uv_map)
            stat_list.append(stat_logit.item())
            #bar.set_description(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(ep, loss_list_G[-1], loss_list_D[-1],FLAGS["gauss_kernel"], stat_list[-1]))
            # Record Training information in Tensorboard.
            """
            if origin_img is None and uv_map_gt is None:
                origin_img, uv_map_gt = origin, uv_map
            uv_map_predicted = uv_map_result

            writer.add_scalar("Original Loss", loss_list_G[-1], FLAGS["summary_step"])
            writer.add_scalar("D Loss", loss_list_D[-1], FLAGS["summary_step"])
            writer.add_scalar("SSIM Loss", stat_list[-1], FLAGS["summary_step"])

            grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted)

            writer.add_image('original', grid_1, FLAGS["summary_step"])
            writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            writer.add_graph(model, uv_map)
            """

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(" {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(
                    ep, loss_list_G[-1], loss_list_D[-1],
                    FLAGS["gauss_kernel"], stat_list[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'prnet': model.state_dict(),
                'Loss': loss_list_G,
                'start_epoch': ep,
                'Loss_D': loss_list_D,
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()

    writer.close()
예제 #5
0
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=4)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256()
    discriminator = Discriminator1()

    # Load the pre-trained weight
    if FLAGS['resume'] and os.path.exists(
            os.path.join(FLAGS['images'], "latest.pth")):
        state = torch.load(os.path.join(FLAGS['images'], "latest.pth"))
        model.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to("cuda")
    discriminator.to("cuda")

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to("cuda")

    #Loss function for adversarial
    Tensor = torch.cuda.FloatTensor
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        Loss_list, Stat_list = [], []
        Loss_list_D = []
        for i, sample in enumerate(bar):
            #drop the last batch
            if i == 111:
                break
            # the dimension of the uv_map is 16 3 256 256 but the last sample is 10 3 256 256
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])
            # generate fake label
            fake_label = Variable(torch.zeros([16, 512])).cuda()
            # generate real lable
            real_label = Variable(torch.ones([16, 512])).cuda()
            # Sample noise as generator input
            #z = Variable(Tensor(np.random.normal(0, 1,([16, 3, 256, 256]))))

            # Inference.
            uv_map_result = model(origin)

            # Loss & ssim stat.
            # Loss measures generator's ability to fool the discriminator
            logit = bce_loss(discriminator(uv_map_result), real_label)

            stat_logit = stat_loss(uv_map_result, uv_map)

            # Measure discriminator's ability to classify real from generated samples
            real_loss = bce_loss(discriminator(uv_map), real_label)
            fake_loss = bce_loss(
                discriminator(uv_map_result).detach(), fake_label)
            d_loss = (real_loss + fake_loss) / 2

            # Record Loss.
            Loss_list.append(logit.item())
            Loss_list_D.append(d_loss.item())
            Stat_list.append(stat_logit.item())

            # Update.

            optimizer.zero_grad()
            logit.backward()
            optimizer.step()
            bar.set_description(
                " {} [Loss(Paper)] {} [Loss(D)] {} [SSIM({})] {}".format(
                    ep, Loss_list[-1], Loss_list_D[-1], FLAGS["gauss_kernel"],
                    Stat_list[-1]))

            optimizer_D.zero_grad()
            d_loss.backward(retain_graph=True)
            optimizer_D.step()
            # Record Training information in Tensorboard.
            if origin_img is None and uv_map_gt is None:
                origin_img, uv_map_gt = origin, uv_map
            uv_map_predicted = uv_map_result

            writer.add_scalar("Original Loss", Loss_list[-1],
                              FLAGS["summary_step"])
            writer.add_scalar("D Loss", Loss_list_D[-1], FLAGS["summary_step"])
            writer.add_scalar("SSIM Loss", Stat_list[-1],
                              FLAGS["summary_step"])

            grid_1, grid_2, grid_3 = make_grid(
                origin_img, normalize=True), make_grid(uv_map_gt), make_grid(
                    uv_map_predicted)

            writer.add_image('original', grid_1, FLAGS["summary_step"])
            writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            writer.add_graph(model, uv_map)

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'prnet': model.state_dict(),
                'Loss': Loss_list,
                'start_epoch': ep,
                'Loss_D': Loss_list_D,
            }
            torch.save(state, os.path.join(FLAGS['images'], 'latest.pth'))

            scheduler.step()

    writer.close()
def main(data_dir):
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    if not os.path.exists(FLAGS['images']):
        os.mkdir(FLAGS['images'])

    # 1) Create Dataset of 300_WLP & Dataloader.
    wlp300 = PRNetDataset(root_dir=data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToResize((416, 416)),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    wlp300_dataloader = DataLoader(dataset=wlp300,
                                   batch_size=FLAGS['batch_size'],
                                   shuffle=True,
                                   num_workers=1)

    # 2) Intermediate Processing.
    transform_img = transforms.Compose([
        #transforms.ToTensor(),
        transforms.Normalize(FLAGS["normalize_mean"], FLAGS["normalize_std"])
    ])

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    g_x = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    g_y = ResFCN256(resolution_input=416,
                    resolution_output=416,
                    channel=3,
                    size=16)
    d_x = Discriminator()
    d_y = Discriminator()

    # Load the pre-trained weight
    if FLAGS['resume'] != "" and os.path.exists(
            os.path.join(FLAGS['pretrained'], FLAGS['resume'])):
        state = torch.load(os.path.join(FLAGS['pretrained'], FLAGS['resume']))
        try:
            g_x.load_state_dict(state['g_x'])
            g_y.load_state_dict(state['g_y'])
            d_x.load_state_dict(state['d_x'])
            d_y.load_state_dict(state['d_y'])
        except Exception:
            g_x.load_state_dict(state['prnet'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    g_x.to(FLAGS["device"])
    g_y.to(FLAGS["device"])
    d_x.to(FLAGS["device"])
    d_y.to(FLAGS["device"])

    optimizer_g = torch.optim.Adam(itertools.chain(g_x.parameters(),
                                                   g_y.parameters()),
                                   lr=FLAGS["lr"],
                                   betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(itertools.chain(d_x.parameters(),
                                                   d_y.parameters()),
                                   lr=FLAGS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)

    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])
    bce_loss = torch.nn.BCEWithLogitsLoss()
    bce_loss.to(FLAGS["device"])
    l1_loss = nn.L1Loss().to(FLAGS["device"])
    lambda_X = 10
    lambda_Y = 10
    #Loss function for adversarial
    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        loss_list_cycle_x = []
        loss_list_cycle_y = []
        loss_list_d_x = []
        loss_list_d_y = []
        real_label = torch.ones(FLAGS['batch_size'])
        fake_label = torch.zeros(FLAGS['batch_size'])
        for i, sample in enumerate(bar):
            real_y, real_x = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])
            # x -> y' -> x^
            optimizer_g.zero_grad()
            fake_y = g_x(real_x)
            prediction = d_x(fake_y)
            loss_g_x = bce_loss(prediction, real_label)
            x_hat = g_y(fake_y)
            loss_cycle_x = l1_loss(x_hat, real_x) * lambda_X
            loss_x = loss_g_x + loss_cycle_x
            loss_x.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_x.append(loss_x.item())
            # y -> x' -> y^
            optimizer_g.zero_grad()
            fake_x = g_y(real_y)
            prediction = d_y(fake_x)
            loss_g_y = bce_loss(prediction, real_label)
            y_hat = g_x(fake_x)
            loss_cycle_y = l1_loss(y_hat, real_y) * lambda_Y
            loss_y = loss_g_y + loss_cycle_y
            loss_y.backward(retain_graph=True)
            optimizer_g.step()
            loss_list_cycle_y.append(loss_y.item())
            # d_x
            optimizer_d.zero_grad()
            pred_real = d_x(real_y)
            loss_d_x_real = bce_loss(pred_real, real_label)
            pred_fake = d_x(fake_y)
            loss_d_x_fake = bce_loss(pred_fake, fake_label)
            loss_d_x = (loss_d_x_real + loss_d_x_fake) * 0.5
            loss_d_x.backward()
            loss_list_d_x.append(loss_d_x.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_x.parameters():
                    p.data.clamp_(-1, 1)
            # d_y
            optimizer_d.zero_grad()
            pred_real = d_y(real_x)
            loss_d_y_real = bce_loss(pred_real, real_label)
            pred_fake = d_y(fake_x)
            loss_d_y_fake = bce_loss(pred_fake, fake_label)
            loss_d_y = (loss_d_y_real + loss_d_y_fake) * 0.5
            loss_d_y.backward()
            loss_list_d_y.append(loss_d_y.item())
            optimizer_d.step()
            if 'WGAN' in FLAGS['gan_type']:
                for p in d_y.parameters():
                    p.data.clamp_(-1, 1)

        if ep % FLAGS["save_interval"] == 0:

            with torch.no_grad():
                print(
                    " {} [Loss_G_X] {} [Loss_G_Y] {} [Loss_D_X] {} [Loss_D_Y] {}"
                    .format(ep, loss_list_g_x[-1], loss_list_g_y[-1],
                            loss_list_d_x[-1], loss_list_d_y[-1]))
                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = np.load("./test_data/test_obama.npy")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)

                origin, gt_uv_map = transform_img(origin), transform_img(
                    gt_uv_map)

                origin_in = origin.unsqueeze_(0).cuda()
                pred_uv_map = g_x(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['images'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # Save model
            print("Save model")
            state = {
                'g_x': g_x.state_dict(),
                'g_y': g_y.state_dict(),
                'd_x': d_x.state_dict(),
                'd_y': d_y.state_dict(),
                'start_epoch': ep,
            }
            torch.save(state, os.path.join(FLAGS['images'],
                                           '{}.pth'.format(ep)))

            scheduler.step()
예제 #7
0
def main(data_dir):
    # 0) Tensoboard Writer.
    writer = SummaryWriter(FLAGS['summary_path'])
    origin_img, uv_map_gt, uv_map_predicted = None, None, None

    # 1) Create Dataset of 300_WLP.
    train_data_dir = [
        '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_AFW_HELEN_LFPW',
        '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_AFW_HELEN_LFPW_Flip',
        '/home/beitadoge/Github/PRNet_PyTorch/Data/PRNet_PyTorch_Data/300WLP_IBUG_Src_Flip'
    ]
    wlp300 = PRNetDataset(root_dir=train_data_dir,
                          transform=transforms.Compose([
                              ToTensor(),
                              ToNormalize(FLAGS["normalize_mean"],
                                          FLAGS["normalize_std"])
                          ]))

    # 2) Create DataLoader.
    wlp300_dataloader = DataLoaderX(dataset=wlp300,
                                    batch_size=FLAGS['batch_size'],
                                    shuffle=True,
                                    num_workers=4)

    # 3) Create PRNet model.
    start_epoch, target_epoch = FLAGS['start_epoch'], FLAGS['target_epoch']
    model = ResFCN256()

    #GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to("cuda")

    #Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLAGS["lr"],
                                 betas=(0.5, 0.999))
    # scheduler_MultiStepLR = torch.optim.lr_scheduler.MultiStepLR(optimizer,[11],gamma=0.5, last_epoch=-1)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.1,
                                                           patience=5,
                                                           min_lr=1e-6,
                                                           verbose=False)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
    # scheduler_StepLR = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5, last_epoch=-1)

    #apex混合精度训练
    # from apex import amp
    # model , optimizer = amp.initialize(model,optimizer,opt_level="O1",verbosity=0)

    #Loss
    stat_loss = SSIM(mask_path=FLAGS["mask_path"], gauss=FLAGS["gauss_kernel"])
    loss = WeightMaskLoss(mask_path=FLAGS["mask_path"])

    # Load the pre-trained weight
    if FLAGS['resume'] and os.path.exists(
            os.path.join(FLAGS['model_path'], "latest.pth")):
        state = torch.load(os.path.join(
            FLAGS['model_path'],
            "latest.pth"))  #这是个字典,keys: ['prnet', 'Loss', 'start_epoch']
        model.load_state_dict(state['prnet'])
        optimizer.load_state_dict(state['optimizer'])
        # amp.load_state_dict(state['amp'])
        start_epoch = state['start_epoch']
        INFO("Load the pre-trained weight! Start from Epoch", start_epoch)
    else:
        start_epoch = 0
        INFO(
            "Pre-trained weight cannot load successfully, train from scratch!")

    #Tensorboard
    model_input = torch.rand(FLAGS['batch_size'], 3, 256, 256)
    writer.add_graph = (model, model_input)

    nme_mean = 999

    for ep in range(start_epoch, target_epoch):
        bar = tqdm(wlp300_dataloader)
        Loss_list, Stat_list = deque(maxlen=len(bar)), deque(maxlen=len(bar))

        model.train()
        for i, sample in enumerate(bar):
            uv_map, origin = sample['uv_map'].to(
                FLAGS['device']), sample['origin'].to(FLAGS['device'])

            # Inference.
            uv_map_result = model(origin)

            # Loss & ssim stat.
            logit_loss = loss(uv_map_result, uv_map)
            stat_logit = stat_loss(uv_map_result, uv_map)

            # Record Loss.
            Loss_list.append(logit_loss.item())
            Stat_list.append(stat_logit.item())

            # Update.
            optimizer.zero_grad()
            logit_loss.backward()
            # with amp.scale_loss(logit_loss,optimizer) as scaled_loss:
            #     scaled_loss.backward()
            optimizer.step()
            lr = optimizer.param_groups[0]['lr']
            bar.set_description(
                " {} lr {} [Loss(Paper)] {:.5f} [SSIM({})] {:.5f}".format(
                    ep, lr, np.mean(Loss_list), FLAGS["gauss_kernel"],
                    np.mean(Stat_list)))

            # Record Training information in Tensorboard.
            # if origin_img is None and uv_map_gt is None:
            #     origin_img, uv_map_gt = origin, uv_map
            # uv_map_predicted = uv_map_result

            #写入Tensorboard
            # FLAGS["summary_step"] += 1
            # if  FLAGS["summary_step"] % 500 ==0:
            #     writer.add_scalar("Original Loss", Loss_list[-1], FLAGS["summary_step"])
            #     writer.add_scalar("SSIM Loss", Stat_list[-1], FLAGS["summary_step"])

            #     grid_1, grid_2, grid_3 = make_grid(origin_img, normalize=True), make_grid(uv_map_gt), make_grid(uv_map_predicted)

            #     writer.add_image('original', grid_1, FLAGS["summary_step"])
            #     writer.add_image('gt_uv_map', grid_2, FLAGS["summary_step"])
            #     writer.add_image('predicted_uv_map', grid_3, FLAGS["summary_step"])
            #     writer.add_graph(model, uv_map)

        #每个epoch过后将Loss写入Tensorboard
        loss_mean = np.mean(Loss_list)
        writer.add_scalar("Original Loss", loss_mean, ep)

        lr = optimizer.param_groups[0]['lr']
        writer.add_scalar("lr", lr, ep)
        # scheduler_StepLR.step()
        scheduler.step(loss_mean)

        del Loss_list
        del Stat_list

        #Test && Cal AFLW2000's NME
        model.eval()
        if ep % FLAGS["save_interval"] == 0:
            with torch.no_grad():
                nme_mean = cal_aflw2000_nme(
                    model, '/home/beitadoge/Data/PRNet_PyTorch_Data/AFLW2000')
                print("NME IS {}".format(nme_mean))

                writer.add_scalar("Aflw2000_nme", nme_mean, ep)

                origin = cv2.imread("./test_data/obama_origin.jpg")
                gt_uv_map = cv2.imread("./test_data/obama_uv_posmap.jpg")
                origin, gt_uv_map = test_data_preprocess(
                    origin), test_data_preprocess(gt_uv_map)
                origin_in = F.normalize(origin, FLAGS["normalize_mean"],
                                        FLAGS["normalize_std"],
                                        False).unsqueeze_(0)
                pred_uv_map = model(origin_in).detach().cpu()

                save_image(
                    [origin.cpu(),
                     gt_uv_map.unsqueeze_(0).cpu(), pred_uv_map],
                    os.path.join(FLAGS['model_path'],
                                 str(ep) + '.png'),
                    nrow=1,
                    normalize=True)

            # # Save model
            # state = {
            #     'prnet': model.state_dict(),
            #     'Loss': Loss_list,
            #     'start_epoch': ep,
            # }
            # torch.save(checkpoint, os.path.join(FLAGS['model_path'], 'epoch{}.pth'.format(ep)))
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                # 'amp': amp.state_dict(),
                'start_epoch': ep
            }
            torch.save(checkpoint,
                       os.path.join(FLAGS['model_path'], 'lastest.pth'))

        # adjust_learning_rate(lr , ep, optimizer)
        # scheduler.step(nme_mean)

    writer.close()