Esempio n. 1
0
    # import ipdb; ipdb.set_trace()
    flm_params[i, :] = flame_param.astype('float32')
    if i == num_smpl_to_eval_on - 1:
        break

batch_size = 64
flame_decoder = overlay_visualizer.deca.flame.eval()

for run_idx in run_ids_1:
    # import ipdb; ipdb.set_trace()
    generator_1 = torch.nn.DataParallel(
        StyledGenerator(embedding_vocab_size=69158,
                        rendered_flame_ascondition=settings_for_runs[run_idx]
                        ['rendered_flame_as_condition'],
                        normal_maps_as_cond=settings_for_runs[run_idx]
                        ['normal_maps_as_cond'],
                        core_tensor_res=core_tensor_res,
                        w_truncation_factor=1.0,
                        apply_sqrt2_fac_in_eq_lin=settings_for_runs[run_idx]
                        ['apply_sqrt2_fac_in_eq_lin'],
                        n_mlp=8)).cuda()
    model_idx = settings_for_runs[run_idx]['model_idx']
    ckpt1 = torch.load(
        f'{cnst.output_root}checkpoint/{run_idx}/{model_idx}.model')
    generator_1.load_state_dict(ckpt1['generator_running'])
    generator_1 = generator_1.eval()

    # images = np.zeros((num_smpl_to_eval_on, 3, resolution, resolution)).astype('float32')
    pbar = tqdm.tqdm(range(0, num_smpl_to_eval_on, batch_size))
    pbar.set_description('Generating_images')
    flame_mesh_imgs = None
    mdl_id = 'mdl2_'
Esempio n. 2
0
if 'resolution' in other_params.files:
    resolution = other_params['resolution']

rendered_flame_as_condition = True
rows = 5
cols = 6
b_size = rows*cols
n_frames = 32
step_max = int(np.log2(resolution) - 2)  # starts from 4X4 hence the -2

torch.manual_seed(7)
generator = StyledGenerator(flame_dim=code_size,
                            embedding_vocab_size=69158,
                            rendered_flame_ascondition=rendered_flame_as_condition,
                            inst_norm=use_inst_norm,
                            normal_maps_as_cond=normal_maps_as_cond,
                            core_tensor_res=core_tensor_res,
                            use_styled_conv_stylegan2=True,
                            n_mlp=8).cuda()

# embeddings = generator.get_embddings()
generator = torch.nn.DataParallel(generator)
generator.load_state_dict(ckpt['generator_running'])
# generator.load_state_dict(ckpt['generator'])
# generator.eval()

log_dir = os.path.join(cnst.output_root, 'gif_smpls/FFHQ')

if random_background:
    torch.manual_seed(2)
    back_ground_noise = (torch.randn((3, 224, 224), dtype=torch.float32)*255).clamp(min=0, max=255).cuda()
Esempio n. 3
0
normalization_file_path = f'{cnst.output_root}FFHQ_dynamicfit_normalization_hawen_parms.npz'

if 'resolution' in other_params.files:
    resolution = other_params['resolution']

rendered_flame_as_condition = True
rows = 5
cols = 6
b_size = rows * cols
n_frames = 32
step_max = int(np.log2(resolution) - 2)  # starts from 4X4 hence the -2

torch.manual_seed(7)
generator = StyledGenerator(embedding_vocab_size=69158,
                            rendered_flame_ascondition=True,
                            normal_maps_as_cond=True,
                            core_tensor_res=4,
                            n_mlp=8,
                            apply_sqrt2_fac_in_eq_lin=True).cuda()

# embeddings = generator.get_embddings()
generator = torch.nn.DataParallel(generator)
generator.load_state_dict(ckpt['generator_running'])
# generator.load_state_dict(ckpt['generator'])
# generator.eval()

log_dir = os.path.join(cnst.output_root, 'gif_smpls/FFHQ')

if random_background:
    torch.manual_seed(2)
    back_ground_noise = (torch.randn(
        (3, 224, 224), dtype=torch.float32) * 255).clamp(min=0,
rendered_flame_as_condition = True
rows = 5
cols = 6
b_size = rows * cols
n_frames = 32
step_max = int(np.log2(resolution) - 2)  # starts from 4X4 hence the -2

torch.manual_seed(7)

parser = argparse.ArgumentParser(description='Progressive Growing of GANs')
args, dataset, flame_param_est = update_config(parser)

generator = StyledGenerator(
    embedding_vocab_size=args.embedding_vocab_size,
    rendered_flame_ascondition=args.rendered_flame_as_condition,
    normal_maps_as_cond=args.normal_maps_as_cond,
    core_tensor_res=args.core_tensor_res,
    n_mlp=args.nmlp_for_z_to_w).cuda()

# embeddings = generator.get_embddings()
generator = torch.nn.DataParallel(generator)
generator.load_state_dict(ckpt['generator_running'])
# generator.load_state_dict(ckpt['generator'])
# generator.eval()

log_dir = os.path.join(cnst.output_root, 'gif_smpls/FFHQ')

if random_background:
    torch.manual_seed(2)
    back_ground_noise = (torch.randn(
        (3, 224, 224), dtype=torch.float32) * 255).clamp(min=0,
Esempio n. 5
0
import argparse
from configurations import update_config
from my_utils.photometric_optimization.models import FLAME
from my_utils.photometric_optimization import util

###################################### Voca training Seq ######################################################
ignore_global_rotation = False
resolution = 256
run_idx = 29

parser = argparse.ArgumentParser(description='Progressive Growing of GANs')
args, dataset, flame_param_est = update_config(parser)

generator_1 = torch.nn.DataParallel(StyledGenerator(embedding_vocab_size=args.embedding_vocab_size,
                                                    rendered_flame_ascondition=args.rendered_flame_as_condition,
                                                    normal_maps_as_cond=args.normal_maps_as_cond,
                                                    core_tensor_res=args.core_tensor_res,
                                                    n_mlp=args.nmlp_for_z_to_w)).cuda()
model_idx = '026000_1'
ckpt1 = torch.load(f'{cnst.output_root}checkpoint/{run_idx}/{model_idx}.model')
generator_1.load_state_dict(ckpt1['generator_running'])
generator_1 = generator_1.eval()

seqs = np.load(cnst.voca_flame_seq_file)

if ignore_global_rotation:
    pose = np.hstack((seqs['frame_pose_params'][:, 0:3]*0, seqs['frame_pose_params'][:, 6:9]))
else:
    pose = np.hstack((seqs['frame_pose_params'][:, 0:3], seqs['frame_pose_params'][:, 6:9]))

num_frames = seqs['frame_exp_params'].shape[0]
Esempio n. 6
0
    if not args.debug:
        fid_computer = compute_fid.FidComputer(
            database_root_dir=cnst.ffhq_images_root_dir,
            true_img_stats_dir=cnst.true_img_stats_dir)
    else:
        fid_computer = None

    os.makedirs(f'{args.chk_pt_dir}/checkpoint/{str(args.run_id)}',
                exist_ok=True)
    os.makedirs(f'{args.chk_pt_dir}/sample/{str(args.run_id)}', exist_ok=True)

    generator = StyledGenerator(
        embedding_vocab_size=args.embedding_vocab_size,
        rendered_flame_ascondition=args.rendered_flame_as_condition,
        normal_maps_as_cond=args.normal_maps_as_cond,
        core_tensor_res=args.core_tensor_res,
        n_mlp=args.nmlp_for_z_to_w,
        apply_sqrt2_fac_in_eq_lin=args.apply_sqrt_in_eq_linear)

    # from my_utils.print_model_summary import summary
    # summary(generator, (3, 2, 2), device='cpu')

    from my_utils.graph_writer import graph_writer
    graph_writer.draw(
        generator, f'Style_gan_mdl_run_id{args.run_id}.png', (16, 38),
        torch.zeros((1,
                     int(3 * (int(args.normal_maps_as_cond) +
                              int(args.rendered_flame_as_condition))), 2, 2)))

    generator = nn.DataParallel(generator).cuda()