Exemplo n.º 1
0
def Validate_Skip(model_dir=None, syn_dir=None, out_dir=None):
    train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_dir + 'train/',
                                                    validation_split=2, training_syn=True)
    val_dl = DataLoader(val_dataset, batch_size=32, shuffle=True)
    validation_len = len(val_dl)
    sfsnet_model = SkipNet()
    sfsnet_model.load_state_dict(torch.load(model_dir))
    suffix = 'Val'

    wandb.init(tensorboard=True)
    for bix, data in enumerate(val_dl):
        print("This is the {}th bix".format(bix))
        if bix % 5 == 0:
            out_dir_cur = out_dir + str(bix)
            if not os.path.exists(out_dir_cur):
                os.system("mkdir "+ out_dir_cur)
            out_dir_cur += '/'
            albedo, normal, mask, sh, face = data
            predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon = sfsnet_model(face)
            wandb_log_images(wandb, predicted_normal, mask, suffix + ' Predicted Normal', bix,
                             suffix + ' Predicted Normal', path= out_dir_cur + 'predicted_normal.png')
            wandb_log_images(wandb, predicted_albedo, mask, suffix + ' Predicted Albedo', bix,
                             suffix + ' Predicted Albedo', path= out_dir_cur + 'predicted_albedo.png')
            wandb_log_images(wandb, out_shading, mask, suffix + ' Predicted Shading', bix,
                             suffix + ' Predicted Shading', path= out_dir_cur + 'predicted_shading.png')
            wandb_log_images(wandb, out_recon, mask, suffix + ' Out recon', bix,
                             suffix + ' Out recon', path= out_dir_cur + 'out_recon.png')
            print("We finished the logging process at the {}th bix".format(bix))
Exemplo n.º 2
0
def test(model_dir=None, img_path=None, out_dir=None, lighting_pos=None):
    base_model = SkipNet()
    base_model.load_state_dict(torch.load(model_dir))
    transform_one = transforms.Compose([
        transforms.Resize(128),
    ])
    transform_two = transforms.Compose([transforms.ToTensor()])
    img = transform_one(Image.open(img_path))
    img = np.array(img)
    img = transform_two(img)
    img = torch.reshape(img, [1, 3, 128, 128])
    suffix = 'test'
    predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon = base_model(
        img)
    wandb.init(tensorboard=True)
    wandb_log_images(wandb, predicted_normal, None, suffix + "Predicted Normal", \
                     0, suffix + "Predicted Normal", path= out_dir + "_predicted_normal_new.png")
    wandb_log_images(wandb, predicted_albedo, None, suffix + "Predicted Alebdo", \
                     0, suffix + "Predicted Albedo", path= out_dir + "_predicted albedo_new.png")
    wandb_log_images(wandb, out_shading, None, suffix + "Out Shading", \
                     0, suffix + "Out Shading", path= out_dir + "_out shading_new.png")
    wandb_log_images(wandb, out_recon, None, suffix + "Out Recon", \
                     0, suffix + "Out Recon", path= out_dir + "reconstruction_-0.5.png")
Exemplo n.º 3
0
def thirdStageTraining(syn_data,
                       celeb_data,
                       batch_size=16,
                       num_epochs=20,
                       log_path=None,
                       use_cuda=True,
                       lr=0.0025,
                       weight_decay=0.005):

    train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_data +
                                                    'train/',
                                                    read_from_csv=None,
                                                    validation_split=10)
    test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_data + 'test/',
                                         read_from_csv=None,
                                         validation_split=0)

    model_checkpoint_dir = log_path + 'checkpoints/'
    out_images_dir = log_path + 'out_images/'
    out_syn_images_dir = out_images_dir + 'syn/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    prev_SkipNet_model = SkipNet()
    prev_SkipNet_model.load_state_dict(
        torch.load(
            '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl'
        ))
    prev_SkipNet_model.to(device)

    os.system("mkdir -p {}".format(model_checkpoint_dir))
    os.system("mkdir -p {}".format(out_syn_images_dir + 'train/'))
    os.system("mkdir -p {}".format(out_syn_images_dir + 'val/'))
    os.system("mkdir -p {}".format(out_syn_images_dir + 'test/'))

    normal_loss = nn.L1Loss()
    albedo_loss = nn.L1Loss()
    sh_loss = nn.MSELoss()
    recon_loss = nn.L1Loss()
    c_recon_loss = nn.L1Loss()
    c_sh_loss = nn.MSELoss()
    c_albedo_loss = nn.L1Loss()
    c_normal_loss = nn.L1Loss()

    if use_cuda:
        normal_loss = normal_loss.cuda()
        albedo_loss = albedo_loss.cuda()
        sh_loss = sh_loss.cuda()
        recon_loss = recon_loss.cuda()
        c_recon_loss = c_recon_loss.cuda()
        c_sh_loss = c_sh_loss.cuda()
        c_albedo_loss = c_albedo_loss.cuda()
        c_normal_loss = c_normal_loss.cuda()

    lamda_recon = 0.5
    lamda_normal = 0.5
    lamda_albedo = 0.5
    lamda_sh = 0.1

    wandb.init(tensorboard=True)
    for epoch in range(1, num_epochs + 1):
        tloss = 0
        nloss = 0
        aloss = 0
        shloss = 0
        rloss = 0
        c_tloss = 0
        c_nloss = 0
        c_aloss = 0
        c_shloss = 0
        c_reconloss = 0
        predicted_normal = None
        predicted_albedo = None
        out_shading = None
        out_recon = None
        mask = None
        face = None
        normal = None
        albedo = None
        c_predicted_normal = None
        c_predicted_albedo = None
        c_out_shading = None
        c_out_recon = None
        c_face = None

        sfsnet_model = SfsNetPipeline()
        if epoch > 1:
            sfsnet_model.load_state_dict(
                torch.load(
                    '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Syn.pkl'
                ))
        model_parameters = sfsnet_model.parameters()
        optimizer = torch.optim.Adam(model_parameters,
                                     lr=lr,
                                     weight_decay=weight_decay)
        sfsnet_model.to(device)

        syn_train_dl = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
        syn_val_dl = DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False)
        syn_test_dl = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=True)

        syn_train_len = len(syn_train_dl)

        if epoch == 0:
            print("Synthetic dataset: Train data:", len(syn_train_dl),
                  ' Val data: ', len(syn_val_dl), ' Test data: ',
                  len(syn_test_dl))
        #Initiate iterators
        syn_train_iter = iter(syn_train_dl)
        syn_count = 0
        celeb_count = 0
        #Until we process all synthetic and celebA data
        while True:
            #Get and train on synthetic data
            data = next(syn_train_iter, None)
            if data is not None:
                syn_count += 1
                albedo, normal, mask, sh, face = data
                if use_cuda:
                    albedo = albedo.cuda()
                    normal = normal.cuda()
                    mask = mask.cuda()
                    sh = sh.cuda()
                    face = face.cuda()

                face = apply_mask(face, mask)

                predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon = sfsnet_model(
                    face)

                current_normal_loss = normal_loss(predicted_normal, normal)
                current_albedo_loss = albedo_loss(predicted_albedo, albedo)
                current_sh_loss = sh_loss(predicted_sh, sh)
                current_recon_loss = recon_loss(out_recon, de_norm(face))

                total_loss = lamda_sh * current_sh_loss + lamda_normal * current_normal_loss + \
                             lamda_albedo * current_albedo_loss + lamda_recon * current_recon_loss

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                tloss += total_loss.item()
                nloss += current_normal_loss.item()
                aloss += current_albedo_loss.item()
                shloss += current_sh_loss.item()
                rloss += current_recon_loss.item()

                print(
                    "Epoch {}/20, synthetic data {}/{}. Synthetic total loss: {}, normal_loss: {}, albedo_loss: {}, sh_loss: {}, recon_loss: {}"
                    .format(epoch, syn_count, syn_train_len, total_loss,
                            current_normal_loss, current_albedo_loss,
                            current_sh_loss, current_recon_loss))
            elif data is None:
                break

        file_name = out_syn_images_dir + 'train/' + 'train_' + str(epoch)
        wandb_log_images(wandb,
                         predicted_normal,
                         mask,
                         'Train Predicted Normal',
                         epoch,
                         'Train Predicted Normal',
                         path=file_name + '_predicted_normal.png')
        wandb_log_images(wandb,
                         predicted_albedo,
                         mask,
                         'Train Predicted Albedo',
                         epoch,
                         'Train Predicted Albedo',
                         path=file_name + '_predicted_albedo.png')
        wandb_log_images(wandb,
                         out_shading,
                         mask,
                         'Train Predicted Shading',
                         epoch,
                         'Train Predicted Shading',
                         path=file_name + '_predicted_shading.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         out_recon,
                         mask,
                         'Train Recon',
                         epoch,
                         'Train Recon',
                         path=file_name + '_predicted_face.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         face,
                         mask,
                         'Train Ground Truth',
                         epoch,
                         'Train Ground Truth',
                         path=file_name + '_gt_face.png')
        wandb_log_images(wandb,
                         normal,
                         mask,
                         'Train Ground Truth Normal',
                         epoch,
                         'Train Ground Truth Normal',
                         path=file_name + '_gt_normal.png')
        wandb_log_images(wandb,
                         albedo,
                         mask,
                         'Train Ground Truth Albedo',
                         epoch,
                         'Train Ground Truth Albedo',
                         path=file_name + '_gt_albedo.png')

        torch.save(
            sfsnet_model.state_dict(),
            "/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Syn.pkl"
        )
Exemplo n.º 4
0
from utils import wandb_log_images, random_split
import torch
import os
import glob
import wandb

#def predict_celeb(celeb_path=None, out_dir=None):
#     return True

if __name__ == "__main__":
    Size_for_Image = 128
    suffix = 'celeba'
    check_dir = "/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl"
    Celeb_path = "/home/hd8t/data/CelebA-HQ/original/"
    out_dir = "/home/hd8t/xiangyu.yin/results/metadata/out_images/celeba/"
    sfsnet_model = SkipNet()
    sfsnet_model.load_state_dict(torch.load("/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl"))
    face = []
    name = []
    for img in glob.glob(Celeb_path + "*.png"):
        n_suffix = img.split('/')[-1]
        face.append(img)
        name.append(n_suffix.split('.')[0])
    datasize = len(face)
    validation_count = int(2 * datasize / 100)
    train_count = datasize - validation_count
    transform = transforms.Compose([
        transforms.Resize(Size_for_Image),
        transforms.ToTensor()
    ])
    full_dataset = CelebDataset(face, name, transform)
Exemplo n.º 5
0
def FirstStage_Training(syn_path=None, model_dir=None):

    learning_rate = 0.00125
    weight_decay = 0.0001
    if torch.cuda.is_available():
        use_cuda = True
    nums_epoch = 20
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_path + 'train/',
                                                    validation_split=2, training_syn=True)
    test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_path + 'test/', validation_split=0,
                                         training_syn=True)

    normal_loss = nn.L1Loss()
    albedo_loss = nn.L1Loss()
    lighting_loss = nn.MSELoss()
    recon_loss = nn.L1Loss()

    if use_cuda:
        normal_loss = normal_loss.cuda()
        albedo_loss = albedo_loss.cuda()
        lighting_loss = lighting_loss.cuda()
        recon_loss = recon_loss.cuda()

    lambda_normal = 0.5
    lambda_albedo = 0.5
    lambda_sh = 0.1
    lambda_recon = 0.5

    # wandb.init(tensorboard=True)
    for epoch in range(nums_epoch):

        syn_train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
        syn_val_dl = DataLoader(val_dataset, batch_size=32, shuffle=True)
        syn_test_dl = DataLoader(test_dataset, batch_size=32, shuffle=True)
        print('Synthetic dataset: Train data: ', len(syn_train_dl), ' Val data: ', len(syn_val_dl), ' Test data: ',
              len(syn_test_dl))
        len_syn_train = len(syn_train_dl)
        t_loss = 0
        n_loss = 0
        a_loss = 0
        sh_loss = 0
        r_loss = 0
        sfsnet_model = SkipNet()
        if epoch > 0:
            sfsnet_model.load_state_dict(torch.load(model_dir + "Skip_First" + ".pkl"))
        sfsnet_model.to(device)
        parameters = sfsnet_model.parameters()
        optimizer = torch.optim.Adam(parameters, lr=learning_rate, weight_decay=weight_decay)

        for bix, data in enumerate(syn_train_dl):
            albedo, normal, mask, sh, face = data
            print(albedo.shape)
            print(normal.shape)
            print(mask.shape)
            print(face.shape)
            print(sh.shape)
            if use_cuda:
                albedo = albedo.cuda()
                normal = normal.cuda()
                mask = mask.cuda()
                sh = sh.cuda()
                face = face.cuda()
            print('True')

            predicted_normal, predicted_albedo, predicted_sh, produced_shading, produced_recon = sfsnet_model(face)
            current_normal_loss = normal_loss(predicted_normal, normal)
            current_albedo_loss = albedo_loss(predicted_albedo, albedo)
            current_sh_loss = lighting_loss(predicted_sh, sh)
            current_recon_loss = recon_loss(produced_recon, face)

            total_loss = lambda_normal * current_normal_loss + lambda_albedo * current_albedo_loss + \
                         lambda_sh * current_sh_loss + lambda_recon * current_recon_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            t_loss += total_loss.item()
            a_loss += current_albedo_loss.item()
            n_loss += current_normal_loss.item()
            sh_loss += current_sh_loss.item()
            r_loss += current_recon_loss.item()

            print('Epoch: {} - Total Loss : {}, Normal Loss: {}, Albedo Loss: {}, SH Loss:{}, Recon Loss:{}'.format(
                epoch, \
                total_loss, current_albedo_loss, current_normal_loss, current_sh_loss, current_recon_loss))
            print('This is {} / {} of training dataline'.format(bix, (len(syn_train_dl) - 1)))

        torch.save(sfsnet_model.state_dict(), model_dir + "Skip_First" + ".pkl")
def thirdStageTraining(syn_data,
                       celeb_data,
                       batch_size=8,
                       num_epochs=20,
                       log_path=None,
                       use_cuda=True,
                       lr=0.0025,
                       weight_decay=0.005):

    train_celeb_dataset, val_celeb_dataset = get_celeba_dataset(
        celeb_dir=celeb_data, validation_split=10)
    test_celeb_dataset, _ = get_celeba_dataset(celeb_dir=celeb_data,
                                               validation_split=0)

    model_checkpoint_dir = log_path + 'checkpoints/'
    out_images_dir = log_path + 'out_images/'
    out_celeb_images_dir = out_images_dir + 'celeb/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    prev_SkipNet_model = SkipNet()
    prev_SkipNet_model.load_state_dict(
        torch.load(
            '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl'
        ))
    prev_SkipNet_model.to(device)

    os.system("mkdir -p {}".format(model_checkpoint_dir))
    os.system("mkdir -p {}".format(out_celeb_images_dir + 'train/'))
    os.system("mkdir -p {}".format(out_celeb_images_dir + 'val/'))
    os.system("mkdir -p {}".format(out_celeb_images_dir + 'test/'))

    normal_loss = nn.L1Loss()
    albedo_loss = nn.L1Loss()
    sh_loss = nn.MSELoss()
    recon_loss = nn.L1Loss()
    c_recon_loss = nn.L1Loss()
    c_sh_loss = nn.MSELoss()
    c_albedo_loss = nn.L1Loss()
    c_normal_loss = nn.L1Loss()

    if use_cuda:
        normal_loss = normal_loss.cuda()
        albedo_loss = albedo_loss.cuda()
        sh_loss = sh_loss.cuda()
        recon_loss = recon_loss.cuda()
        c_recon_loss = c_recon_loss.cuda()
        c_sh_loss = c_sh_loss.cuda()
        c_albedo_loss = c_albedo_loss.cuda()
        c_normal_loss = c_normal_loss.cuda()

    lamda_recon = 0.5
    lamda_normal = 0.5
    lamda_albedo = 0.5
    lamda_sh = 0.1

    wandb.init(tensorboard=True)
    for epoch in range(1, num_epochs + 1):
        tloss = 0
        nloss = 0
        aloss = 0
        shloss = 0
        rloss = 0
        c_tloss = 0
        c_nloss = 0
        c_aloss = 0
        c_shloss = 0
        c_reconloss = 0
        predicted_normal = None
        predicted_albedo = None
        out_shading = None
        out_recon = None
        mask = None
        face = None
        normal = None
        albedo = None
        c_predicted_normal = None
        c_predicted_albedo = None
        c_out_shading = None
        c_out_recon = None
        c_face = None

        sfsnet_model = SfsNetPipeline()
        if epoch > 1:
            sfsnet_model.load_state_dict(
                torch.load(
                    '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Celeb.pkl'
                ))
        model_parameters = sfsnet_model.parameters()
        optimizer = torch.optim.Adam(model_parameters,
                                     lr=lr,
                                     weight_decay=weight_decay)
        sfsnet_model.to(device)

        celeb_train_dl = DataLoader(train_celeb_dataset,
                                    batch_size=batch_size,
                                    shuffle=True)
        celeb_val_dl = DataLoader(val_celeb_dataset,
                                  batch_size=batch_size,
                                  shuffle=False)
        celeb_test_dl = DataLoader(test_celeb_dataset,
                                   batch_size=batch_size,
                                   shuffle=True)

        celeb_train_len = len(celeb_train_dl)

        if epoch == 0:
            print("Celeb dataset: Train data:", len(celeb_train_dl),
                  ' Val data: ', len(celeb_val_dl), ' Test data: ',
                  len(celeb_test_dl))
        celeb_train_iter = iter(celeb_train_dl)
        celeb_count = 0
        #Until we process all synthetic and celebA data
        while True:
            c_data = next(celeb_train_iter, None)
            if c_data is not None:
                celeb_count += 1
                c_mask = None
                if use_cuda:
                    c_data = c_data.cuda()

                c_face = c_data
                prevc_normal, prevc_albedo, prevc_sh, prevc_shading, prec_recon = prev_SkipNet_model(
                    c_face)
                c_predicted_normal, c_predicted_albedo, c_predicted_sh, c_out_shading, c_out_recon = sfsnet_model(
                    c_face)

                c_current_normal_loss = c_normal_loss(c_predicted_normal,
                                                      prevc_normal)
                c_current_albedo_loss = c_albedo_loss(c_predicted_albedo,
                                                      prevc_albedo)
                c_current_sh_loss = c_sh_loss(c_predicted_sh, prevc_sh)
                c_current_recon_loss = c_recon_loss(c_out_recon,
                                                    de_norm(c_face))

                c_total_loss = lamda_sh * c_current_sh_loss + lamda_normal * c_current_normal_loss + lamda_albedo * c_current_albedo_loss +\
                              lamda_recon * c_current_recon_loss

                optimizer.zero_grad()
                c_total_loss.backward()
                optimizer.step()

                c_tloss += c_total_loss.item()
                c_nloss += c_current_normal_loss.item()
                c_aloss += c_current_albedo_loss.item()
                c_shloss += c_current_sh_loss.item()
                c_reconloss += c_current_recon_loss.item()
                print(
                    "Epoch {}/20, celeb data {}/{}. Celeb total loss: {}, normal_loss: {}, albedo_loss: {}, sh_loss: {}, recon_loss: {}"
                    .format(epoch, celeb_count, celeb_train_len, c_total_loss,
                            c_current_normal_loss, c_current_albedo_loss,
                            c_current_sh_loss, c_current_recon_loss))
            elif c_data is None:
                break

            # Log CelebA image
        file_name = out_celeb_images_dir + 'train/' + 'train_' + str(epoch)
        wandb_log_images(wandb,
                         c_predicted_normal,
                         None,
                         'Train CelebA Predicted Normal',
                         epoch,
                         'Train CelebA Predicted Normal',
                         path=file_name + '_c_predicted_normal.png')
        wandb_log_images(wandb,
                         c_predicted_albedo,
                         None,
                         'Train CelebA Predicted Albedo',
                         epoch,
                         'Train CelebA Predicted Albedo',
                         path=file_name + '_c_predicted_albedo.png')
        wandb_log_images(wandb,
                         c_out_shading,
                         None,
                         'Train CelebA Predicted Shading',
                         epoch,
                         'Train CelebA Predicted Shading',
                         path=file_name + '_c_predicted_shading.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         c_out_recon,
                         None,
                         'Train CelebA Recon',
                         epoch,
                         'Train CelebA Recon',
                         path=file_name + '_c_predicted_face.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         c_face,
                         None,
                         'Train CelebA Ground Truth',
                         epoch,
                         'Train CelebA Ground Truth',
                         path=file_name + '_c_gt_face.png')
        torch.save(
            sfsnet_model.state_dict(),
            "/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Celeb.pkl"
        )