def Validate_Skip(model_dir=None, syn_dir=None, out_dir=None):
    train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_dir + 'train/',
                                                    validation_split=2, training_syn=True)
    val_dl = DataLoader(val_dataset, batch_size=32, shuffle=True)
    validation_len = len(val_dl)
    sfsnet_model = SkipNet()
    sfsnet_model.load_state_dict(torch.load(model_dir))
    suffix = 'Val'

    wandb.init(tensorboard=True)
    for bix, data in enumerate(val_dl):
        print("This is the {}th bix".format(bix))
        if bix % 5 == 0:
            out_dir_cur = out_dir + str(bix)
            if not os.path.exists(out_dir_cur):
                os.system("mkdir "+ out_dir_cur)
            out_dir_cur += '/'
            albedo, normal, mask, sh, face = data
            predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon = sfsnet_model(face)
            wandb_log_images(wandb, predicted_normal, mask, suffix + ' Predicted Normal', bix,
                             suffix + ' Predicted Normal', path= out_dir_cur + 'predicted_normal.png')
            wandb_log_images(wandb, predicted_albedo, mask, suffix + ' Predicted Albedo', bix,
                             suffix + ' Predicted Albedo', path= out_dir_cur + 'predicted_albedo.png')
            wandb_log_images(wandb, out_shading, mask, suffix + ' Predicted Shading', bix,
                             suffix + ' Predicted Shading', path= out_dir_cur + 'predicted_shading.png')
            wandb_log_images(wandb, out_recon, mask, suffix + ' Out recon', bix,
                             suffix + ' Out recon', path= out_dir_cur + 'out_recon.png')
            print("We finished the logging process at the {}th bix".format(bix))
示例#2
0
def test(model_dir=None, img_path=None, out_dir=None, lighting_pos=None):
    base_model = SkipNet()
    base_model.load_state_dict(torch.load(model_dir))
    transform_one = transforms.Compose([
        transforms.Resize(128),
    ])
    transform_two = transforms.Compose([transforms.ToTensor()])
    img = transform_one(Image.open(img_path))
    img = np.array(img)
    img = transform_two(img)
    img = torch.reshape(img, [1, 3, 128, 128])
    suffix = 'test'
    predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon = base_model(
        img)
    wandb.init(tensorboard=True)
    wandb_log_images(wandb, predicted_normal, None, suffix + "Predicted Normal", \
                     0, suffix + "Predicted Normal", path= out_dir + "_predicted_normal_new.png")
    wandb_log_images(wandb, predicted_albedo, None, suffix + "Predicted Alebdo", \
                     0, suffix + "Predicted Albedo", path= out_dir + "_predicted albedo_new.png")
    wandb_log_images(wandb, out_shading, None, suffix + "Out Shading", \
                     0, suffix + "Out Shading", path= out_dir + "_out shading_new.png")
    wandb_log_images(wandb, out_recon, None, suffix + "Out Recon", \
                     0, suffix + "Out Recon", path= out_dir + "reconstruction_-0.5.png")
示例#3
0
def thirdStageTraining(syn_data,
                       celeb_data,
                       batch_size=16,
                       num_epochs=20,
                       log_path=None,
                       use_cuda=True,
                       lr=0.0025,
                       weight_decay=0.005):

    train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_data +
                                                    'train/',
                                                    read_from_csv=None,
                                                    validation_split=10)
    test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_data + 'test/',
                                         read_from_csv=None,
                                         validation_split=0)

    model_checkpoint_dir = log_path + 'checkpoints/'
    out_images_dir = log_path + 'out_images/'
    out_syn_images_dir = out_images_dir + 'syn/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    prev_SkipNet_model = SkipNet()
    prev_SkipNet_model.load_state_dict(
        torch.load(
            '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl'
        ))
    prev_SkipNet_model.to(device)

    os.system("mkdir -p {}".format(model_checkpoint_dir))
    os.system("mkdir -p {}".format(out_syn_images_dir + 'train/'))
    os.system("mkdir -p {}".format(out_syn_images_dir + 'val/'))
    os.system("mkdir -p {}".format(out_syn_images_dir + 'test/'))

    normal_loss = nn.L1Loss()
    albedo_loss = nn.L1Loss()
    sh_loss = nn.MSELoss()
    recon_loss = nn.L1Loss()
    c_recon_loss = nn.L1Loss()
    c_sh_loss = nn.MSELoss()
    c_albedo_loss = nn.L1Loss()
    c_normal_loss = nn.L1Loss()

    if use_cuda:
        normal_loss = normal_loss.cuda()
        albedo_loss = albedo_loss.cuda()
        sh_loss = sh_loss.cuda()
        recon_loss = recon_loss.cuda()
        c_recon_loss = c_recon_loss.cuda()
        c_sh_loss = c_sh_loss.cuda()
        c_albedo_loss = c_albedo_loss.cuda()
        c_normal_loss = c_normal_loss.cuda()

    lamda_recon = 0.5
    lamda_normal = 0.5
    lamda_albedo = 0.5
    lamda_sh = 0.1

    wandb.init(tensorboard=True)
    for epoch in range(1, num_epochs + 1):
        tloss = 0
        nloss = 0
        aloss = 0
        shloss = 0
        rloss = 0
        c_tloss = 0
        c_nloss = 0
        c_aloss = 0
        c_shloss = 0
        c_reconloss = 0
        predicted_normal = None
        predicted_albedo = None
        out_shading = None
        out_recon = None
        mask = None
        face = None
        normal = None
        albedo = None
        c_predicted_normal = None
        c_predicted_albedo = None
        c_out_shading = None
        c_out_recon = None
        c_face = None

        sfsnet_model = SfsNetPipeline()
        if epoch > 1:
            sfsnet_model.load_state_dict(
                torch.load(
                    '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Syn.pkl'
                ))
        model_parameters = sfsnet_model.parameters()
        optimizer = torch.optim.Adam(model_parameters,
                                     lr=lr,
                                     weight_decay=weight_decay)
        sfsnet_model.to(device)

        syn_train_dl = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
        syn_val_dl = DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False)
        syn_test_dl = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=True)

        syn_train_len = len(syn_train_dl)

        if epoch == 0:
            print("Synthetic dataset: Train data:", len(syn_train_dl),
                  ' Val data: ', len(syn_val_dl), ' Test data: ',
                  len(syn_test_dl))
        #Initiate iterators
        syn_train_iter = iter(syn_train_dl)
        syn_count = 0
        celeb_count = 0
        #Until we process all synthetic and celebA data
        while True:
            #Get and train on synthetic data
            data = next(syn_train_iter, None)
            if data is not None:
                syn_count += 1
                albedo, normal, mask, sh, face = data
                if use_cuda:
                    albedo = albedo.cuda()
                    normal = normal.cuda()
                    mask = mask.cuda()
                    sh = sh.cuda()
                    face = face.cuda()

                face = apply_mask(face, mask)

                predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon = sfsnet_model(
                    face)

                current_normal_loss = normal_loss(predicted_normal, normal)
                current_albedo_loss = albedo_loss(predicted_albedo, albedo)
                current_sh_loss = sh_loss(predicted_sh, sh)
                current_recon_loss = recon_loss(out_recon, de_norm(face))

                total_loss = lamda_sh * current_sh_loss + lamda_normal * current_normal_loss + \
                             lamda_albedo * current_albedo_loss + lamda_recon * current_recon_loss

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                tloss += total_loss.item()
                nloss += current_normal_loss.item()
                aloss += current_albedo_loss.item()
                shloss += current_sh_loss.item()
                rloss += current_recon_loss.item()

                print(
                    "Epoch {}/20, synthetic data {}/{}. Synthetic total loss: {}, normal_loss: {}, albedo_loss: {}, sh_loss: {}, recon_loss: {}"
                    .format(epoch, syn_count, syn_train_len, total_loss,
                            current_normal_loss, current_albedo_loss,
                            current_sh_loss, current_recon_loss))
            elif data is None:
                break

        file_name = out_syn_images_dir + 'train/' + 'train_' + str(epoch)
        wandb_log_images(wandb,
                         predicted_normal,
                         mask,
                         'Train Predicted Normal',
                         epoch,
                         'Train Predicted Normal',
                         path=file_name + '_predicted_normal.png')
        wandb_log_images(wandb,
                         predicted_albedo,
                         mask,
                         'Train Predicted Albedo',
                         epoch,
                         'Train Predicted Albedo',
                         path=file_name + '_predicted_albedo.png')
        wandb_log_images(wandb,
                         out_shading,
                         mask,
                         'Train Predicted Shading',
                         epoch,
                         'Train Predicted Shading',
                         path=file_name + '_predicted_shading.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         out_recon,
                         mask,
                         'Train Recon',
                         epoch,
                         'Train Recon',
                         path=file_name + '_predicted_face.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         face,
                         mask,
                         'Train Ground Truth',
                         epoch,
                         'Train Ground Truth',
                         path=file_name + '_gt_face.png')
        wandb_log_images(wandb,
                         normal,
                         mask,
                         'Train Ground Truth Normal',
                         epoch,
                         'Train Ground Truth Normal',
                         path=file_name + '_gt_normal.png')
        wandb_log_images(wandb,
                         albedo,
                         mask,
                         'Train Ground Truth Albedo',
                         epoch,
                         'Train Ground Truth Albedo',
                         path=file_name + '_gt_albedo.png')

        torch.save(
            sfsnet_model.state_dict(),
            "/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Syn.pkl"
        )
示例#4
0
def predict_celeba(sfsnet_model,
                   dl,
                   use_cuda=False,
                   out_folder=None,
                   wandb=None,
                   suffix='celeba'):
    tloss = 0
    recon_loss = nn.L1Loss()
    if use_cuda:
        recon_loss = recon_loss.cuda()

    for bix, data in enumerate(dl):
        face = data
        if use_cuda:
            face = face.cuda()
        predicted_normal, predicted_albedo, predicted_sh, predicted_shading, predicted_face = sfsnet_model(
            face)
        print("we have computed the No.{} face decomposition.".format(bix))
        if bix % 10 == 0:
            if not os.path.exists(out_folder + str(bix)):
                os.system("mkdir " + out_folder + str(bix))
            file_name = out_folder + str(bix) + '/'
            predicted_normal = get_normal_in_range(predicted_normal)
            wandb_log_images(wandb,
                             predicted_normal,
                             None,
                             suffix + 'Predicted Normal',
                             bix,
                             suffix + ' Predicted Normal',
                             path=file_name + '_predicted_normal.png')
            wandb_log_images(wandb,
                             predicted_albedo,
                             None,
                             suffix + 'Predicted Albedo',
                             bix,
                             suffix + 'Predicted Albedo',
                             path=file_name + '_predicted_albedo.png')
            wandb_log_images(wandb,
                             predicted_shading,
                             None,
                             suffix + 'Predicted Shading',
                             bix,
                             suffix + 'Predicted_shading',
                             path=file_name + '_predicted_shading.png',
                             denormalize=False)
            wandb_log_images(wandb,
                             predicted_face,
                             None,
                             suffix + 'Predicted face',
                             path=file_name + '_predicted_face.png',
                             denormalize=False)
            wandb_log_images(wandb,
                             face,
                             None,
                             suffix + ' Ground Truth',
                             bix,
                             suffix + ' Ground Truth',
                             path=file_name + '_gt_face.png')

            total_loss = recon_loss(predicted_face, face)
            tloss += total_loss.item()

    print("tloss is equal to", tloss)
示例#5
0
def predict_sfsnet(sfs_net_model,
                   dl,
                   train_epoch_num=0,
                   use_cuda=False,
                   out_folder=None,
                   wandb=None,
                   suffix='Val'):
    # debugging flag to dump image

    fix_bix_dump = 0
    normal_loss = nn.L1Loss()
    albedo_loss = nn.L1Loss()
    sh_loss = nn.MSELoss()
    recon_loss = nn.L1Loss()

    lamda_recon = 0.5
    lamda_albedo = 0.5
    lamda_normal = 0.5
    lamda_sh = 0.1

    if use_cuda:
        normal_loss = normal_loss.cuda()
        albedo_loss = albedo_loss.cuda()
        sh_loss = sh_loss.cuda()
        recon_loss = recon_loss.cuda()

    tloss = 0  # total loss
    nloss = 0  # normal loss
    aloss = 0  # albedo loss
    shloss = 0  # SH loss
    rloss = 0  # Reconstruction loss

    for bix, data in enumerate(dl):
        albedo, normal, mask, sh, face = data
        if use_cuda:
            albedo = albedo.cuda()
            normal = normal.cuda()
            mask = mask.cuda()
            sh = sh.cuda()
            face = face.cuda()

        predicted_normal, predicted_albedo, predicted_sh, predicted_shading, predicted_face = sfs_net_model(
            face)

        # save predictions in log folder
        if not os.path.exists(out_folder + str(bix)):
            os.system("mkdir " + out_folder + str(bix))
        file_name = out_folder + str(bix) + "/"
        # log_images
        predicted_normal = get_normal_in_range(predicted_normal)
        gt_normal = get_normal_in_range(normal)

        wandb_log_images(wandb,
                         predicted_normal,
                         mask,
                         suffix + ' Predicted Normal',
                         train_epoch_num,
                         suffix + ' Predicted Normal',
                         path=file_name + 'predicted_normal.png')
        wandb_log_images(wandb,
                         predicted_albedo,
                         mask,
                         suffix + ' Predicted Albedo',
                         train_epoch_num,
                         suffix + ' Predicted Albedo',
                         path=file_name + 'predicted_albedo.png')
        wandb_log_images(wandb,
                         predicted_shading,
                         mask,
                         suffix + ' Predicted Shading',
                         train_epoch_num,
                         suffix + ' Predicted Shading',
                         path=file_name + 'predicted_shading.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         predicted_face,
                         mask,
                         suffix + ' Predicted face',
                         train_epoch_num,
                         suffix + ' Predicted face',
                         path=file_name + 'predicted_face.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         face,
                         mask,
                         suffix + ' Ground Truth',
                         train_epoch_num,
                         suffix + ' Ground Truth',
                         path=file_name + '_gt_face.png')
        wandb_log_images(wandb,
                         gt_normal,
                         mask,
                         suffix + ' Ground Truth Normal',
                         train_epoch_num,
                         suffix + ' Ground Normal',
                         path=file_name + 'gt_normal.png')
        wandb_log_images(wandb,
                         albedo,
                         mask,
                         suffix + ' Ground Truth Albedo',
                         train_epoch_num,
                         suffix + ' Ground Albedo',
                         path=file_name + 'gt_albedo.png')
        # Get face with real SH
        real_sh_face = sfs_net_model.get_face(sh, predicted_normal,
                                              predicted_albedo)
        wandb_log_images(wandb,
                         real_sh_face,
                         mask,
                         'Val Real SH Predicted Face',
                         train_epoch_num,
                         'Val Real SH Predicted Face',
                         path=file_name + 'real_sh_face.png')
        syn_face = sfs_net_model.get_face(sh, normal, albedo)
        wandb_log_images(wandb,
                         syn_face,
                         mask,
                         'Val Real SH GT Face',
                         train_epoch_num,
                         'Val Real SH GT Face',
                         path=file_name + '_yn_gt_face.png')

        # TODO
        # Dump SH as CSV or TXT file

        # Loss computation
        # Normal loss
        current_normal_loss = normal_loss(predicted_normal, normal)
        # Albedo loss
        current_albedo_loss = albedo_loss(predicted_albedo, albedo)
        # SH loss
        current_sh_loss = sh_loss(predicted_sh, sh)
        # Reconstruction loss
        current_recon_loss = recon_loss(predicted_face, face)

        total_loss = lamda_recon * current_recon_loss + lamda_normal * current_normal_loss \
                     + lamda_albedo * current_albedo_loss + lamda_sh * current_sh_loss

        # Logging for display and debugging purposes
        tloss += total_loss.item()
        nloss += current_normal_loss.item()
        aloss += current_albedo_loss.item()
        shloss += current_sh_loss.item()
        rloss += current_recon_loss.item()

    len_dl = len(dl)
    # wandb.log(
    #    {suffix + ' Total loss': tloss / len_dl, 'Val Albedo loss': aloss / len_dl, 'Val Normal loss': nloss / len_dl, \
    #     'Val SH loss': shloss / len_dl, 'Val Recon loss': rloss / len_dl}, step=train_epoch_num)

    # return average loss over dataset
    return tloss / len_dl, nloss / len_dl, aloss / len_dl, shloss / len_dl, rloss / len_dl
def thirdStageTraining(syn_data,
                       celeb_data,
                       batch_size=8,
                       num_epochs=20,
                       log_path=None,
                       use_cuda=True,
                       lr=0.0025,
                       weight_decay=0.005):

    train_celeb_dataset, val_celeb_dataset = get_celeba_dataset(
        celeb_dir=celeb_data, validation_split=10)
    test_celeb_dataset, _ = get_celeba_dataset(celeb_dir=celeb_data,
                                               validation_split=0)

    model_checkpoint_dir = log_path + 'checkpoints/'
    out_images_dir = log_path + 'out_images/'
    out_celeb_images_dir = out_images_dir + 'celeb/'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    prev_SkipNet_model = SkipNet()
    prev_SkipNet_model.load_state_dict(
        torch.load(
            '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl'
        ))
    prev_SkipNet_model.to(device)

    os.system("mkdir -p {}".format(model_checkpoint_dir))
    os.system("mkdir -p {}".format(out_celeb_images_dir + 'train/'))
    os.system("mkdir -p {}".format(out_celeb_images_dir + 'val/'))
    os.system("mkdir -p {}".format(out_celeb_images_dir + 'test/'))

    normal_loss = nn.L1Loss()
    albedo_loss = nn.L1Loss()
    sh_loss = nn.MSELoss()
    recon_loss = nn.L1Loss()
    c_recon_loss = nn.L1Loss()
    c_sh_loss = nn.MSELoss()
    c_albedo_loss = nn.L1Loss()
    c_normal_loss = nn.L1Loss()

    if use_cuda:
        normal_loss = normal_loss.cuda()
        albedo_loss = albedo_loss.cuda()
        sh_loss = sh_loss.cuda()
        recon_loss = recon_loss.cuda()
        c_recon_loss = c_recon_loss.cuda()
        c_sh_loss = c_sh_loss.cuda()
        c_albedo_loss = c_albedo_loss.cuda()
        c_normal_loss = c_normal_loss.cuda()

    lamda_recon = 0.5
    lamda_normal = 0.5
    lamda_albedo = 0.5
    lamda_sh = 0.1

    wandb.init(tensorboard=True)
    for epoch in range(1, num_epochs + 1):
        tloss = 0
        nloss = 0
        aloss = 0
        shloss = 0
        rloss = 0
        c_tloss = 0
        c_nloss = 0
        c_aloss = 0
        c_shloss = 0
        c_reconloss = 0
        predicted_normal = None
        predicted_albedo = None
        out_shading = None
        out_recon = None
        mask = None
        face = None
        normal = None
        albedo = None
        c_predicted_normal = None
        c_predicted_albedo = None
        c_out_shading = None
        c_out_recon = None
        c_face = None

        sfsnet_model = SfsNetPipeline()
        if epoch > 1:
            sfsnet_model.load_state_dict(
                torch.load(
                    '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Celeb.pkl'
                ))
        model_parameters = sfsnet_model.parameters()
        optimizer = torch.optim.Adam(model_parameters,
                                     lr=lr,
                                     weight_decay=weight_decay)
        sfsnet_model.to(device)

        celeb_train_dl = DataLoader(train_celeb_dataset,
                                    batch_size=batch_size,
                                    shuffle=True)
        celeb_val_dl = DataLoader(val_celeb_dataset,
                                  batch_size=batch_size,
                                  shuffle=False)
        celeb_test_dl = DataLoader(test_celeb_dataset,
                                   batch_size=batch_size,
                                   shuffle=True)

        celeb_train_len = len(celeb_train_dl)

        if epoch == 0:
            print("Celeb dataset: Train data:", len(celeb_train_dl),
                  ' Val data: ', len(celeb_val_dl), ' Test data: ',
                  len(celeb_test_dl))
        celeb_train_iter = iter(celeb_train_dl)
        celeb_count = 0
        #Until we process all synthetic and celebA data
        while True:
            c_data = next(celeb_train_iter, None)
            if c_data is not None:
                celeb_count += 1
                c_mask = None
                if use_cuda:
                    c_data = c_data.cuda()

                c_face = c_data
                prevc_normal, prevc_albedo, prevc_sh, prevc_shading, prec_recon = prev_SkipNet_model(
                    c_face)
                c_predicted_normal, c_predicted_albedo, c_predicted_sh, c_out_shading, c_out_recon = sfsnet_model(
                    c_face)

                c_current_normal_loss = c_normal_loss(c_predicted_normal,
                                                      prevc_normal)
                c_current_albedo_loss = c_albedo_loss(c_predicted_albedo,
                                                      prevc_albedo)
                c_current_sh_loss = c_sh_loss(c_predicted_sh, prevc_sh)
                c_current_recon_loss = c_recon_loss(c_out_recon,
                                                    de_norm(c_face))

                c_total_loss = lamda_sh * c_current_sh_loss + lamda_normal * c_current_normal_loss + lamda_albedo * c_current_albedo_loss +\
                              lamda_recon * c_current_recon_loss

                optimizer.zero_grad()
                c_total_loss.backward()
                optimizer.step()

                c_tloss += c_total_loss.item()
                c_nloss += c_current_normal_loss.item()
                c_aloss += c_current_albedo_loss.item()
                c_shloss += c_current_sh_loss.item()
                c_reconloss += c_current_recon_loss.item()
                print(
                    "Epoch {}/20, celeb data {}/{}. Celeb total loss: {}, normal_loss: {}, albedo_loss: {}, sh_loss: {}, recon_loss: {}"
                    .format(epoch, celeb_count, celeb_train_len, c_total_loss,
                            c_current_normal_loss, c_current_albedo_loss,
                            c_current_sh_loss, c_current_recon_loss))
            elif c_data is None:
                break

            # Log CelebA image
        file_name = out_celeb_images_dir + 'train/' + 'train_' + str(epoch)
        wandb_log_images(wandb,
                         c_predicted_normal,
                         None,
                         'Train CelebA Predicted Normal',
                         epoch,
                         'Train CelebA Predicted Normal',
                         path=file_name + '_c_predicted_normal.png')
        wandb_log_images(wandb,
                         c_predicted_albedo,
                         None,
                         'Train CelebA Predicted Albedo',
                         epoch,
                         'Train CelebA Predicted Albedo',
                         path=file_name + '_c_predicted_albedo.png')
        wandb_log_images(wandb,
                         c_out_shading,
                         None,
                         'Train CelebA Predicted Shading',
                         epoch,
                         'Train CelebA Predicted Shading',
                         path=file_name + '_c_predicted_shading.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         c_out_recon,
                         None,
                         'Train CelebA Recon',
                         epoch,
                         'Train CelebA Recon',
                         path=file_name + '_c_predicted_face.png',
                         denormalize=False)
        wandb_log_images(wandb,
                         c_face,
                         None,
                         'Train CelebA Ground Truth',
                         epoch,
                         'Train CelebA Ground Truth',
                         path=file_name + '_c_gt_face.png')
        torch.save(
            sfsnet_model.state_dict(),
            "/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Celeb.pkl"
        )