Exemple #1
0
def main():

    args = get_args()

    config = read_yaml(args.config_path)

    nn.set_auto_forward(True)
    comm = init_nnabla(ext_name="cuda", device_id='0', type_config='float')

    if args.save_results_dir != '':
        config.log.save_results_dir = args.save_results_dir

    config.data.color_perturb = True if (
        args.data_perturb == 'color' or args.data_perturb == 'both') else False
    config.data.occ_perturb = True if (
        args.data_perturb == 'occ' or args.data_perturb == 'both') else False

    if comm is None or comm.rank == 0:
        if config.data.color_perturb:
            print('Applying color perturbation to the dataset')
        if config.data.occ_perturb:
            print('Applying occlusion perturbation to the dataset')
        if not config.data.color_perturb and not config.data.occ_perturb:
            print('No perturbation will be applied to the data')

    train_nerf(config, comm, args.model, args.dataset)
Exemple #2
0
    def parse(self, args=''):
        if args == '':
            opt = self.parser.parse_args()
        else:
            opt = self.parser.parse_args(args)

        # Load training config
        if opt.train_config is not None:
            opt.train_config = read_yaml(opt.train_config)
            cfg = opt.train_config
            opt.dataset = cfg.dataset.dataset
            opt.arch = cfg.model.arch
            opt.num_layers = cfg.model.num_layers
            opt.pretrained_model_dir = cfg.model.pretrained_model_dir
            opt.batch_size = cfg.train.batch_size
            opt.num_epochs = cfg.train.num_epochs
            lr_cfg = cfg.learning_rate_config
            if 'epochs' not in lr_cfg or lr_cfg.epochs is None:
                lr_cfg.epochs = cfg.train.num_epochs
            mp_cfg = cfg.mixed_precision
            opt.mixed_precision = mp_cfg.mixed_precision
            opt.channel_last = mp_cfg.channel_last
            opt.loss_scaling = mp_cfg.loss_scaling
            opt.use_dynamic_loss_scaling = mp_cfg.use_dynamic_loss_scaling

        opt.gpus_str = opt.gpus
        opt.gpus = [int(gpu) for gpu in opt.gpus.split(',')]
        opt.gpus = [i for i in range(len(opt.gpus))
                    ] if opt.gpus[0] >= 0 else [-1]
        opt.test_scales = [float(i) for i in opt.test_scales.split(',')]

        opt.fix_res = not opt.keep_res
        print(
            'Fix size testing.' if opt.fix_res else 'Keep resolution testing.')

        if opt.head_conv == -1:  # init default head_conv
            opt.head_conv = 256 if 'dlav0' in opt.arch else 64
        opt.down_ratio = 4
        opt.pad = 31
        opt.num_stacks = 1

        opt.exp_dir = os.path.join(opt.root_output_dir, "exp", opt.task)
        if opt.save_dir is None:
            opt.save_dir = os.path.join(opt.exp_dir, opt.exp_id)
        opt.debug_dir = os.path.join(opt.save_dir, 'debug')
        print('The output will be saved to ', opt.save_dir)
        os.makedirs(opt.save_dir, exist_ok=True)

        return opt
Exemple #3
0
        trainer.train(epoch)

        if epoch % config['val']['interval'] == 0 and val_loader != None:
            trainer.validate(epoch)

        if comm.rank == 0:
            if epoch % config['train']['save_param_step_interval'] == 0 or epoch == config['train']['num_epochs']-1:
                trainer.save_checkpoint(
                    config['model']['saved_models_dir'], epoch, pixelcnn=args.pixelcnn_prior)


if __name__ == '__main__':

    parser = make_parser()
    args = parser.parse_args()
    config = read_yaml(os.path.join('configs', '{}.yaml'.format(args.data)))
    ctx = get_extension_context(
        config['extension_module'], device_id=config['device_id'])
    nn.set_auto_forward(True)

    if args.data == 'mnist':
        data_iterator = mnist_iterator
    elif args.data == 'imagenet':
        data_iterator = imagenet_iterator
    elif args.data == 'cifar10':
        data_iterator = cifar10_iterator
    else:
        print('Dataset not recognized')
        exit(1)

    comm = CommunicatorWrapper(ctx)
Exemple #4
0
    parser.add_argument('--g-n-scales',
                        type=int,
                        default=1,
                        help='Number of generator resolution stacks')
    parser.add_argument("--d-n-scales",
                        type=int,
                        default=2,
                        help='Number of layers of discriminator pyramids')

    return parser


if __name__ == '__main__':

    parser = make_parser()
    config = read_yaml(os.path.join('configs', 'gender.yaml'))
    args = parser.parse_args()
    config.nnabla_context.device_id = args.device_id
    config.gender_faces.data_dir = args.data_root
    config.train.save_path = args.save_path
    config.train.batch_size = args.batch_size
    config.model.g_n_scales = args.g_n_scales
    config.model.d_n_scales = args.d_n_scales

    # nn.set_auto_forward(True)

    ctx = get_extension_context(config.nnabla_context.ext_name)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(ctx)

    image_shape = tuple(x * config.model.g_n_scales
Exemple #5
0
                        help='Only for style mixing: Batch size for style B')
    parser.add_argument('--use_tf_weights', action='store_true', default=False,
                        help='Use TF trained weights converted to NNabla')

    parser.add_argument('--img_path', type=str,
                        default='',
                        help='Image path for latent space projection')

    return parser


if __name__ == '__main__':

    parser = make_parser()
    args = parser.parse_args()
    config = read_yaml(os.path.join('configs', f'{args.data}.yaml'))
    ctx = get_extension_context(args.extension_module)
    nn.set_auto_forward(args.auto_forward or args.test)

    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(ctx)

    monitor = None
    if comm is not None:
        if comm.rank == 0:
            monitor = Monitor(args.monitor_path)
            start_time = time.time()

    few_shot_config = None
    if args.few_shot is not None:
        few_shot_config = read_yaml(os.path.join(
Exemple #6
0
                        help='Only for style mixing: Batch size for style B')
    parser.add_argument('--use_tf_weights', action='store_true', default=False,
                        help='Use TF trained weights converted to NNabla')

    parser.add_argument('--img_path', type=str,
                        default='/FFHQ/images1024x1024/00000/00399.png',
                        help='Image path for latent space projection')

    return parser


if __name__ == '__main__':

    parser = make_parser()
    args = parser.parse_args()
    config = read_yaml(os.path.join('configs', f'{args.data}.yaml'))
    ctx = get_extension_context(args.extension_module)
    nn.set_auto_forward(args.auto_forward or args.test)

    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(ctx)

    monitor = None
    if comm is not None:
        if comm.rank == 0:
            monitor = Monitor(args.monitor_path)
            start_time = time.time()

    if args.train:
        style_gan = Train(monitor, config, args, comm)
    if args.test:
Exemple #7
0
def main(**kwargs):
    # set training args
    args = AttrDict(kwargs)

    assert os.path.exists(
        args.config
    ), f"{args.config} is not found. Please make sure the config file exists."
    conf = read_yaml(args.config)

    comm = init_nnabla(ext_name="cudnn",
                       device_id=args.device_id,
                       type_config="float",
                       random_pseed=True)
    if args.sampling_interval is None:
        args.sampling_interval = 1

    use_timesteps = list(
        range(0, conf.num_diffusion_timesteps, args.sampling_interval))
    if use_timesteps[-1] != conf.num_diffusion_timesteps - 1:
        # The last step should be included always.
        use_timesteps.append(conf.num_diffusion_timesteps - 1)

    # setup model variance type
    model_var_type = ModelVarType.FIXED_SMALL
    if "model_var_type" in conf:
        model_var_type = ModelVarType.get_vartype_from_key(conf.model_var_type)

    model = Model(beta_strategy=conf.beta_strategy,
                  use_timesteps=use_timesteps,
                  model_var_type=model_var_type,
                  num_diffusion_timesteps=conf.num_diffusion_timesteps,
                  attention_num_heads=conf.num_attention_heads,
                  attention_resolutions=conf.attention_resolutions,
                  scale_shift_norm=conf.ssn,
                  base_channels=conf.base_channels,
                  channel_mult=conf.channel_mult,
                  num_res_blocks=conf.num_res_blocks)

    # load parameters
    assert os.path.exists(
        args.h5
    ), f"{args.h5} is not found. Please make sure the h5 file exists."
    nn.parameter.load_parameters(args.h5)

    # Generate
    # sampling
    B = args.batch_size
    num_samples_per_iter = B * comm.n_procs
    num_iter = (args.samples + num_samples_per_iter -
                1) // num_samples_per_iter

    local_saved_cnt = 0
    for i in range(num_iter):
        logger.info(f"Generate samples {i + 1} / {num_iter}.")
        sample_out, _, x_starts = model.sample(shape=(B, ) +
                                               conf.image_shape[1:],
                                               dump_interval=1,
                                               use_ema=args.ema,
                                               progress=comm.rank == 0,
                                               use_ddim=args.ddim)

        # scale back to [0, 255]
        sample_out = (sample_out + 1) * 127.5

        if args.tiled:
            save_path = os.path.join(args.output_dir,
                                     f"gen_{local_saved_cnt}_{comm.rank}.png")
            save_tiled_image(sample_out.astype(np.uint8), save_path)
            local_saved_cnt += 1
        else:
            for b in range(B):
                save_path = os.path.join(
                    args.output_dir, f"gen_{local_saved_cnt}_{comm.rank}.png")
                imsave(save_path,
                       sample_out[b].astype(np.uint8),
                       channel_first=True)
                local_saved_cnt += 1

        # create video for x_starts
        if args.save_xstart:
            clips = []
            for i in range(len(x_starts)):
                xstart = x_starts[i][1]
                assert isinstance(xstart, np.ndarray)
                im = get_tiled_image(np.clip((xstart + 1) * 127.5, 0, 255),
                                     channel_last=False).astype(np.uint8)
                clips.append(im)

            clip = mp.ImageSequenceClip(clips, fps=5)
            clip.write_videofile(
                os.path.join(
                    args.output_dir,
                    f"pred_x0_along_time_{local_saved_cnt}_{comm.rank}.mp4"))
Exemple #8
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--output-filename',
                        '-o',
                        type=str,
                        default='video.gif',
                        help="name of an output file.")
    parser.add_argument('--output-static-filename',
                        '-os',
                        type=str,
                        default='video_static.gif',
                        help="name of an output file.")
    parser.add_argument('--config-path',
                        '-c',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='model and training configuration file')
    parser.add_argument('--weight-path',
                        '-w',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='path to pretrained NeRF parameters')
    parser.add_argument(
        '--model',
        type=str,
        choices=['wild', 'uncertainty', 'appearance', 'vanilla'],
        required=True,
        help='Select the model to train')

    parser.add_argument('--visualization-type',
                        '-v',
                        type=str,
                        choices=['zoom', '360-rotation', 'default'],
                        default='default-render-poses',
                        help='type of visualization')

    parser.add_argument(
        '--downscale',
        '-d',
        default=1,
        type=float,
        help="downsampling factor for the rendered images for faster inference"
    )

    parser.add_argument(
        '--num-images',
        '-n',
        default=120,
        type=int,
        help="Number of images to generate for the output video/gif")

    parser.add_argument("--fast",
                        help="Use Fast NeRF architecture",
                        action="store_true")

    args = parser.parse_args()

    use_transient = False
    use_embedding = False

    if args.model == 'wild':
        use_transient = True
        use_embedding = True
    elif args.model == 'uncertainty':
        use_transient = True
    elif args.model == 'appearance':
        use_embedding = True

    args = parser.parse_args()
    config = read_yaml(args.config_path)

    config.data.downscale = args.downscale

    nn.set_auto_forward(True)
    ctx = get_extension_context('cuda')
    nn.set_default_context(ctx)
    nn.load_parameters(args.weight_path)

    _, _, render_poses, hwf, _, _, near_plane, far_plane = get_data(config)
    height, width, focal_length = hwf
    print(
        f'Rendering with Height {height}, Width {width}, Focal Length: {focal_length}'
    )

    # mapping_net = MLP
    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    frames = []
    if use_transient:
        static_frames = []

    if args.visualization_type == '360-rotation':
        print('The 360 degree roation result will not work with LLFF data!')
        pbar = tqdm(np.linspace(0, 360, args.num_images, endpoint=False))
    elif args.visualization_type == 'zoom':
        pbar = tqdm(
            np.linspace(near_plane, far_plane, args.num_images,
                        endpoint=False))
    else:
        args.num_images = min(args.num_images, render_poses.shape[0])
        pbar = tqdm(
            np.arange(0, render_poses.shape[0],
                      render_poses.shape[0] // args.num_images))

    print(f'Rendering {args.num_images} poses...')

    for th in pbar:

        if args.visualization_type == '360-rotation':
            pose = nn.NdArray.from_numpy_array(pose_spherical(th, -30., 4.))
        elif args.visualization_type == 'zoom':
            pose = nn.NdArray.from_numpy_array(trans_t(th))
        else:
            pose = nn.NdArray.from_numpy_array(render_poses[th][:3, :4])
            # pose = nn.NdArray.from_numpy_array(render_poses[0][:3, :4])

        ray_directions, ray_origins = get_ray_bundle(height, width,
                                                     focal_length, pose)

        ray_directions = F.reshape(ray_directions, (-1, 3))
        ray_origins = F.reshape(ray_origins, (-1, 3))

        num_ray_batches = ray_directions.shape[
            0] // config.train.ray_batch_size + 1

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                embed_inp = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), 1, dtype=int))
                app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                   config.train.n_app)

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                embed_inp = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), th, dtype=int))
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

            static_rgb_map_fine_list, transient_rgb_map_fine_list = [], []

        rgb_map_fine_list = []

        for i in trange(num_ray_batches):
            if i != num_ray_batches - 1:
                ray_d, ray_o = ray_directions[i * config.train.ray_batch_size:(
                    i + 1) * config.train.ray_batch_size], ray_origins[
                        i * config.train.ray_batch_size:(i + 1) *
                        config.train.ray_batch_size]
            else:
                ray_d, ray_o = ray_directions[
                    i * config.train.ray_batch_size:, :], ray_origins[
                        i * config.train.ray_batch_size:, :]

            if use_transient:
                _, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, _, _, _ = forward_pass(
                    ray_d,
                    ray_o,
                    near_plane,
                    far_plane,
                    app_emb,
                    trans_emb,
                    encode_position_function,
                    encode_direction_function,
                    config,
                    use_transient,
                    hwf=hwf,
                    fast=args.fast)

                static_rgb_map_fine_list.append(static_rgb_map_fine)
                transient_rgb_map_fine_list.append(transient_rgb_map_fine)

            else:
                _, _, _, _, rgb_map_fine, _, _, _ = \
                    forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function,
                                 encode_direction_function, config, use_transient, hwf=hwf, fast=args.fast)
            rgb_map_fine_list.append(rgb_map_fine)

        rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
        rgb_map_fine = F.reshape(rgb_map_fine, (height, width, 3))

        if use_transient:
            static_rgb_map_fine = F.concatenate(*static_rgb_map_fine_list,
                                                axis=0)
            static_rgb_map_fine = F.reshape(static_rgb_map_fine,
                                            (height, width, 3))

        frames.append(
            (255 * np.clip(rgb_map_fine.data, 0, 1)).astype(np.uint8))
        if use_transient:
            static_frames.append(
                (255 * np.clip(static_rgb_map_fine.data, 0, 1)).astype(
                    np.uint8))

    imageio.mimwrite(args.output_filename, frames, fps=30)
    if use_transient:
        imageio.mimwrite(args.output_static_filename, static_frames, fps=30)
Exemple #9
0
def train():
    bs_train, bs_valid = args.train_batch_size, args.val_batch_size
    extension_module = args.context
    ctx = get_extension_context(
        extension_module, device_id=args.device_id, type_config=args.type_config
    )
    nn.set_default_context(ctx)

    if args.input:
        train_loader, val_loader, n_train_samples, n_val_samples = load_data(
            bs_train, bs_valid
        )

    else:
        train_data_source = data_source_cifar10(
            train=True, shuffle=True, label_shuffle=True
        )
        val_data_source = data_source_cifar10(train=False, shuffle=False)
        n_train_samples = len(train_data_source.labels)
        n_val_samples = len(val_data_source.labels)
        # Data Iterator
        train_loader = data_iterator(
            train_data_source, bs_train, None, False, False)
        val_loader = data_iterator(
            val_data_source, bs_valid, None, False, False)

        if args.shuffle_label:
            if not os.path.exists(args.output):
                os.makedirs(args.output)
            np.save(os.path.join(args.output, "x_train.npy"),
                    train_data_source.images)
            np.save(
                os.path.join(args.output, "y_shuffle_train.npy"),
                train_data_source.labels,
            )
            np.save(os.path.join(args.output, "y_train.npy"),
                    train_data_source.raw_label)
            np.save(os.path.join(args.output, "x_val.npy"),
                    val_data_source.images)
            np.save(os.path.join(args.output, "y_val.npy"),
                    val_data_source.labels)

    if args.model == "resnet23":
        model_prediction = resnet23_prediction
    elif args.model == "resnet56":
        model_prediction = resnet56_prediction
    prediction = functools.partial(
        model_prediction, ncls=10, nmaps=64, act=F.relu, seed=args.seed)

    # Create training graphs
    test = False
    image_train = nn.Variable((bs_train, 3, 32, 32))
    label_train = nn.Variable((bs_train, 1))
    pred_train, _ = prediction(image_train, test)

    loss_train = loss_function(pred_train, label_train)

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((bs_valid, 1))
    pred_valid, _ = prediction(image_valid, test)
    loss_val = loss_function(pred_valid, label_valid)

    for param in nn.get_parameters().values():
        param.grad.zero()

    cfg = read_yaml("./learning_rate.yaml")
    print(cfg)
    lr_sched = create_learning_rate_scheduler(cfg.learning_rate_config)
    solver = S.Momentum(momentum=0.9, lr=lr_sched.get_lr())
    solver.set_parameters(nn.get_parameters())
    start_point = 0

    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed

    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_err = MonitorSeries("Training error", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)
    monitor_vloss = MonitorSeries("Test loss", monitor, interval=1)

    # save_nnp
    contents = save_nnp({"x": image_valid}, {"y": pred_valid}, bs_valid)
    save.save(
        os.path.join(args.model_save_path,
                     (args.model+"_epoch0_result.nnp")), contents
    )

    train_iter = math.ceil(n_train_samples / bs_train)
    val_iter = math.ceil(n_val_samples / bs_valid)

    # Training-loop
    for i in range(start_point, args.train_epochs):
        lr_sched.set_epoch(i)
        solver.set_learning_rate(lr_sched.get_lr())
        print("Learning Rate: ", lr_sched.get_lr())
        # Validation
        ve = 0.0
        vloss = 0.0
        print("## Validation")
        for j in range(val_iter):
            image, label = val_loader.next()
            image_valid.d = image
            label_valid.d = label
            loss_val.forward()
            vloss += loss_val.data.data.copy() * bs_valid
            ve += categorical_error(pred_valid.d, label)
        ve /= args.val_iter
        vloss /= n_val_samples

        monitor_verr.add(i, ve)
        monitor_vloss.add(i, vloss)

        if int(i % args.model_save_interval) == 0:
            # save checkpoint file
            save_checkpoint(args.model_save_path, i, solver)

        # Forward/Zerograd/Backward
        print("## Training")
        e = 0.0
        loss = 0.0
        for k in range(train_iter):

            image, label = train_loader.next()
            image_train.d = image
            label_train.d = label
            loss_train.forward()
            solver.zero_grad()
            loss_train.backward()
            solver.update()
            e += categorical_error(pred_train.d, label_train.d)
            loss += loss_train.data.data.copy() * bs_train
        e /= train_iter
        loss /= n_train_samples

        e = categorical_error(pred_train.d, label_train.d)
        monitor_loss.add(i, loss)
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, "params_%06d.h5" %
                     (args.train_epochs))
    )

    # save_nnp_lastepoch
    contents = save_nnp({"x": image_valid}, {"y": pred_valid}, bs_valid)
    save.save(os.path.join(args.model_save_path,
              (args.model+"_result.nnp")), contents)
common_utils_path = os.path.abspath(
    os.path.join(os.path.dirname(__file__), '..', '..', 'utils'))
sys.path.append(common_utils_path)
from neu.yaml_wrapper import read_yaml

# --- utils ---
from EWC.utils import make_parser

# --- EWC ---
from EWC.EWC_loss import ElasticWeightConsolidation

if __name__ == "__main__":
    # [parameters]
    parser = make_parser()
    args = parser.parse_args()
    config = read_yaml(args.config)
    g_code_name = sys.argv[0]
    # [hyper parameters]
    g_gen_path = os.path.join(args.pre_trained_model,
                              'ffhq-slim-gen-256-config-e.h5')
    g_disc_path = os.path.join(args.pre_trained_model,
                               'ffhq-slim-disc-256-config-e-corrected.h5')
    g_apply_function_type_list = ['Convolution']
    # [calc core]
    ctx = get_extension_context(args.extension_module,
                                device_id=args.device_id)
    nn.set_default_context(ctx)
    # [load network]
    print('[{}] Load networks'.format(g_code_name))
    with nn.parameter_scope('Generator'):
        nn.load_parameters(g_gen_path)
Exemple #11
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--output-filename',
                        '-o',
                        type=str,
                        default='video.gif',
                        help="name of an output file.")
    parser.add_argument('--output-static-filename',
                        '-os',
                        type=str,
                        default='video_static.gif',
                        help="name of an output file.")
    parser.add_argument('--config-path',
                        '-c',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='model and training configuration file')
    parser.add_argument('--weight-path',
                        '-w',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='path to pretrained NeRF parameters')
    parser.add_argument(
        '--model',
        type=str,
        choices=['wild', 'uncertainty', 'appearance', 'vanilla'],
        required=True,
        help='Select the model to train')

    parser.add_argument('--visualization-type',
                        '-v',
                        type=str,
                        choices=['zoom', '360-rotation', 'default'],
                        default='default-render-poses',
                        help='type of visualization')

    parser.add_argument(
        '--downscale',
        '-d',
        default=1,
        type=float,
        help="downsampling factor for the rendered images for faster inference"
    )

    parser.add_argument(
        '--num-images',
        '-n',
        default=120,
        type=int,
        help="Number of images to generate for the output video/gif")

    args = parser.parse_args()

    nn.set_auto_forward(True)
    comm = init_nnabla(ext_name="cudnn")

    use_transient = False
    use_embedding = False

    if args.model == 'wild':
        use_transient = True
        use_embedding = True
    elif args.model == 'uncertainty':
        use_transient = True
    elif args.model == 'appearance':
        use_embedding = True

    args = parser.parse_args()
    config = read_yaml(args.config_path)

    config.data.downscale = args.downscale
    nn.load_parameters(args.weight_path)

    data_source = get_photo_tourism_dataiterator(config, 'test', comm)

    # Pose, Appearance index for generating novel views
    # as well as camera trajectory is hard-coded here.
    data_source.test_appearance_idx = 125
    pose_idx = 125
    dx = np.linspace(-0.2, 0.15, args.num_images // 3)
    dy = -0.15
    dz = np.linspace(0.1, 0.22, args.num_images // 3)

    embed_idx_list = list(data_source.poses_dict.keys())

    data_source.poses_test = np.tile(data_source.poses_dict[pose_idx],
                                     (args.num_images, 1, 1))
    for i in range(0, args.num_images // 3):
        data_source.poses_test[i, 0, 3] += dx[i]
        data_source.poses_test[i, 1, 3] += dy
    for i in range(args.num_images // 3, args.num_images // 2):
        data_source.poses_test[i, 0, 3] += dx[len(dx) - 1 - i]
        data_source.poses_test[i, 1, 3] += dy

    for i in range(args.num_images // 2, 5 * args.num_images // 6):
        data_source.poses_test[i, 2, 3] += dz[i - args.num_images // 2]
        data_source.poses_test[i, 1, 3] += dy
        data_source.poses_test[i, 0, 3] += dx[len(dx) // 2]

    for i in range(5 * args.num_images // 6, args.num_images):
        data_source.poses_test[i, 2, 3] += dz[args.num_images - 1 - i]
        data_source.poses_test[i, 1, 3] += dy
        data_source.poses_test[i, 0, 3] += dx[len(dx) // 2]

    # mapping_net = MLP
    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    frames = []
    if use_transient:
        static_frames = []

    pbar = tqdm(np.arange(0, data_source.poses_test.shape[0]))
    data_source._size = data_source.poses_test.shape[0]
    data_source.test_img_w = 400
    data_source.test_img_h = 400
    data_source.test_focal = data_source.test_img_w / 2 / np.tan(np.pi / 6)
    data_source.test_K = np.array(
        [[data_source.test_focal, 0, data_source.test_img_w / 2],
         [0, data_source.test_focal, data_source.test_img_h / 2], [0, 0, 1]])

    data_source._indexes = np.arange(0, data_source._size)

    di = data_iterator(data_source, batch_size=1)

    print(f'Rendering {args.num_images} poses...')

    a = [1, 128]
    alpha = np.linspace(0, 1, args.num_images)

    for th in pbar:

        rays, embed_inp = di.next()
        ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3])
        ray_directions = nn.NdArray.from_numpy_array(rays[0, :, 3:6])
        near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6])
        far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7])

        embed_inp = nn.NdArray.from_numpy_array(
            embed_inp[0, :config.train.chunksize_fine])
        image_shape = (data_source.test_img_w, data_source.test_img_h, 3)

        ray_directions = F.reshape(ray_directions, (-1, 3))
        ray_origins = F.reshape(ray_origins, (-1, 3))

        num_ray_batches = (ray_directions.shape[0] +
                           config.train.ray_batch_size -
                           1) // config.train.ray_batch_size

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                embed_inp_app = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), a[0], dtype=int))
                app_emb = PF.embed(embed_inp_app, config.train.n_vocab,
                                   config.train.n_app)

                embed_inp_app = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), a[1], dtype=int))
                app_emb_2 = PF.embed(embed_inp_app, config.train.n_vocab,
                                     config.train.n_app)

                app_emb = app_emb * alpha[th] + app_emb_2 * (1 - alpha[th])

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

            static_rgb_map_fine_list, transient_rgb_map_fine_list = [], []

        rgb_map_fine_list = []

        for i in trange(num_ray_batches):
            ray_d, ray_o = ray_directions[i * config.train.ray_batch_size:(
                i + 1) * config.train.ray_batch_size], ray_origins[
                    i * config.train.ray_batch_size:(i + 1) *
                    config.train.ray_batch_size]

            near_plane = near_plane_[i * config.train.ray_batch_size:(i + 1) *
                                     config.train.ray_batch_size]
            far_plane = far_plane_[i * config.train.ray_batch_size:(i + 1) *
                                   config.train.ray_batch_size]

            if use_transient:
                _, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, _, _, _ = forward_pass(
                    ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                    encode_position_function, encode_direction_function,
                    config, use_transient)

                static_rgb_map_fine_list.append(static_rgb_map_fine)
                transient_rgb_map_fine_list.append(transient_rgb_map_fine)

            else:
                _, _, _, _, rgb_map_fine, _, _, _ = \
                    forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                                 encode_position_function, encode_direction_function, config, use_transient)

            rgb_map_fine_list.append(rgb_map_fine)

        rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
        rgb_map_fine = F.reshape(rgb_map_fine, image_shape)

        if use_transient:
            static_rgb_map_fine = F.concatenate(*static_rgb_map_fine_list,
                                                axis=0)
            static_rgb_map_fine = F.reshape(static_rgb_map_fine, image_shape)

        frames.append(
            (255 * np.clip(rgb_map_fine.data, 0, 1)).astype(np.uint8))
        if use_transient:
            static_frames.append(
                (255 * np.clip(static_rgb_map_fine.data, 0, 1)).astype(
                    np.uint8))

    imageio.mimwrite(args.output_filename, frames, fps=30)
    imageio.mimwrite(args.output_static_filename, static_frames, fps=30)