def test_transformer(config, netG, train_iterators, monitor, param_file):

    netG_A2B = netG['netG_A2B']

    train_iterator_src, train_iterator_trg = train_iterators

    # Load boundary image to get Variable shapes
    bod_map_A = train_iterator_src.next()[0]
    bod_map_B = train_iterator_trg.next()[0]
    real_bod_map_A = nn.Variable(bod_map_A.shape)
    real_bod_map_B = nn.Variable(bod_map_B.shape)
    real_bod_map_A.persistent, real_bod_map_B.persistent = True, True

    ################### Graph Construction ####################
    # Generator
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            fake_bod_map_B = netG_A2B(
                real_bod_map_A, test=True,
                norm_type=config["norm_type"])  # (1, 15, 64, 64)
    fake_bod_map_B.persistent = True

    # load parameters of networks
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            nn.load_parameters(param_file)

    monitor_vis = nm.MonitorImage('result',
                                  monitor,
                                  interval=config["test"]["vis_interval"],
                                  num_images=1,
                                  normalize_method=lambda x: x)

    # Test
    i = 0
    iter_per_epoch = train_iterator_src.size // config["test"]["batch_size"] + 1

    if config["num_test"]:
        num_test = config["num_test"]
    else:
        num_test = train_iterator_src.size

    for _ in range(iter_per_epoch):
        bod_map_A = train_iterator_src.next()[0]
        bod_map_B = train_iterator_trg.next()[0]
        real_bod_map_A.d, real_bod_map_B.d = bod_map_A, bod_map_B

        # Generate fake images
        fake_bod_map_B.forward(clear_buffer=True)

        i += 1

        images_to_visualize = [
            real_bod_map_A.d, fake_bod_map_B.d, real_bod_map_B.d
        ]
        visuals = combine_images(images_to_visualize)
        monitor_vis.add(i, visuals)

        if i > num_test:
            break
Exemple #2
0
def get_monitors(config, loss_flags, loss_var_dict, test=False):

    log_root_dir = config.monitor_params.monitor_path
    log_dir = os.path.join(log_root_dir, get_current_time())

    # if additional information is given, add it
    if "info" in config.monitor_params:
        info = config.monitor_params.info
        log_dir = f'{log_dir}_{info}'

    master_monitor_misc = nm.Monitor(log_dir)
    monitor_vis = nm.MonitorImage('images', master_monitor_misc,
                                  interval=1, num_images=4,
                                  normalize_method=lambda x: x)
    if test:
        # when inference, returns the visualization monitor only
        return monitor_vis

    interval = config.monitor_params.monitor_freq
    monitoring_var_dict_gen = dict()
    monitoring_var_dict_dis = dict()

    if loss_flags.use_perceptual_loss:
        monitoring_var_dict_gen.update(
            {'perceptual_loss': loss_var_dict['perceptual_loss']})

    if loss_flags.use_gan_loss:
        monitoring_var_dict_gen.update(
            {'gan_loss_gen': loss_var_dict['gan_loss_gen']})

    if loss_flags.use_gan_loss:
        monitoring_var_dict_dis.update(
            {'gan_loss_dis': loss_var_dict['gan_loss_dis']})

    if loss_flags.use_feature_matching_loss:
        monitoring_var_dict_gen.update(
            {'feature_matching_loss': loss_var_dict['feature_matching_loss']})

    if loss_flags.use_equivariance_value_loss:
        monitoring_var_dict_gen.update(
            {'equivariance_value_loss': loss_var_dict['equivariance_value_loss']})

    if loss_flags.use_equivariance_jacobian_loss:
        monitoring_var_dict_gen.update(
            {'equivariance_jacobian_loss': loss_var_dict['equivariance_jacobian_loss']})

    monitoring_var_dict_gen.update(
        {'total_loss_gen': loss_var_dict['total_loss_gen']})

    master_monitor_gen = nm.Monitor(log_dir)
    master_monitor_dis = nm.Monitor(log_dir)

    monitors_gen = MonitorManager(monitoring_var_dict_gen,
                                  master_monitor_gen, interval=interval)
    monitors_dis = MonitorManager(monitoring_var_dict_dis,
                                  master_monitor_dis, interval=interval)
    monitor_time = nm.MonitorTimeElapsed('time_training',
                                         master_monitor_misc, interval=interval)

    return monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir
Exemple #3
0
def test(config, netG, train_iterator, monitor, param_file):
    # Load image and boundary image to get Variable shapes
    img, bod_map, bod_map_resize = train_iterator.next()

    real_img = nn.Variable(img.shape)
    real_bod_map = nn.Variable(bod_map.shape)
    real_bod_map_resize = nn.Variable(bod_map_resize.shape)

    ################### Graph Construction ####################
    # Generator
    with nn.parameter_scope('netG_decoder'):
        fake_img = netG(real_bod_map, test=False)
    fake_img.persistent = True

    # load parameters of networks
    with nn.parameter_scope('netG_decoder'):
        nn.load_parameters(param_file)

    monitor_vis = nm.MonitorImage('result',
                                  monitor,
                                  interval=config["test"]["vis_interval"],
                                  num_images=4,
                                  normalize_method=lambda x: x)

    # Test
    i = 0
    iter_per_epoch = train_iterator.size // config["test"]["batch_size"] + 1

    if config["num_test"]:
        num_test = config["num_test"]
    else:
        num_test = train_iterator.size

    for _ in range(iter_per_epoch):
        img, bod_map, bod_map_resize = train_iterator.next()

        real_img.d = img
        real_bod_map.d = bod_map
        real_bod_map_resize.d = bod_map_resize

        # Generate fake image
        fake_img.forward(clear_buffer=True)

        i += 1

        images_to_visualize = [real_bod_map_resize.d, fake_img.d, img]
        visuals = combine_images(images_to_visualize)
        monitor_vis.add(i, visuals)

        if i > num_test:
            break
def train(config, netG, netD, solver_netG, solver_netD, train_iterator,
          monitor):

    if config["train"][
            "feature_loss"] and config["train"]["feature_loss"]["lambda"] > 0:
        print(
            f'Applying VGG feature Loss, weight: {config["train"]["feature_loss"]["lambda"]}.'
        )
        with_feature_loss = True
    else:
        with_feature_loss = False

    # Load image and boundary image to get Variable shapes
    img, bod_map, bod_map_resize = train_iterator.next()

    real_img = nn.Variable(img.shape)
    real_bod_map = nn.Variable(bod_map.shape)
    real_bod_map_resize = nn.Variable(bod_map_resize.shape)

    ################### Graph Construction ####################
    # Generator
    with nn.parameter_scope('netG_decoder'):
        fake_img = netG(real_bod_map, test=False)
    fake_img.persistent = True

    fake_img_unlinked = fake_img.get_unlinked_variable()

    # Discriminator
    with nn.parameter_scope('netD_decoder'):
        pred_fake = netD(F.concatenate(real_bod_map_resize,
                                       fake_img_unlinked,
                                       axis=1),
                         test=False)
        pred_real = netD(F.concatenate(real_bod_map_resize, real_img, axis=1),
                         test=False)
    real_target = F.constant(1, pred_fake.shape)
    fake_target = F.constant(0, pred_real.shape)

    ################### Loss Definition ####################
    # for Generator
    gan_loss_G = gan_loss(pred_fake, real_target)
    gan_loss_G.persistent = True

    weight_L1 = config["train"]["weight_L1"]
    L1_loss = recon_loss(fake_img_unlinked, real_img)
    L1_loss.persistent = True
    loss_netG = gan_loss_G + weight_L1 * L1_loss

    if with_feature_loss:
        feature_loss = vgg16_perceptual_loss(127.5 * (fake_img_unlinked + 1.),
                                             127.5 * (real_img + 1.))
        feature_loss.persistent = True
        loss_netG += feature_loss * config["train"]["feature_loss"]["lambda"]

    # for Discriminator
    loss_netD = (gan_loss(pred_real, real_target) +
                 gan_loss(pred_fake, fake_target)) * 0.5

    ################### Setting Solvers ####################
    # for Generator
    with nn.parameter_scope('netG_decoder'):
        solver_netG.set_parameters(nn.get_parameters())

    # for Discrimintar
    with nn.parameter_scope('netD_decoder'):
        solver_netD.set_parameters(nn.get_parameters())

    ################### Create Monitors ####################
    interval = config["monitor"]["interval"]
    monitors_G_dict = {
        'loss_netG': loss_netG,
        'loss_gan': gan_loss_G,
        'L1_loss': L1_loss
    }

    if with_feature_loss:
        monitors_G_dict.update({'vgg_feature_loss': feature_loss})

    monitors_G = MonitorManager(monitors_G_dict, monitor, interval=interval)

    monitors_D_dict = {'loss_netD': loss_netD}
    monitors_D = MonitorManager(monitors_D_dict, monitor, interval=interval)

    monitor_time = nm.MonitorTimeElapsed('time_training',
                                         monitor,
                                         interval=interval)
    monitor_vis = nm.MonitorImage('result',
                                  monitor,
                                  interval=1,
                                  num_images=4,
                                  normalize_method=lambda x: x)

    # Dump training information
    with open(os.path.join(monitor._save_path, "training_info.yaml"),
              "w",
              encoding="utf-8") as f:
        f.write(yaml.dump(config))

    # Training
    epoch = config["train"]["epochs"]
    i = 0
    lr_decay_start_at = config["train"]["lr_decay_start_at"]
    iter_per_epoch = train_iterator.size // config["train"]["batch_size"] + 1
    for e in range(epoch):
        logger.info(f'Epoch = {e} / {epoch}')
        train_iterator._reset()  # rewind the iterator
        if e > lr_decay_start_at:
            decay_coeff = 1.0 - max(0, e - lr_decay_start_at) / 50.
            lr_decayed = config["train"]["lr"] * decay_coeff
            print(f"learning rate decayed to {lr_decayed}")
            solver_netG.set_learning_rate(lr_decayed)
            solver_netD.set_learning_rate(lr_decayed)

        for _ in range(iter_per_epoch):
            img, bod_map, bod_map_resize = train_iterator.next()
            # bod_map_noize = np.random.random_sample(bod_map.shape) * 0.01
            # bod_map_resize_noize = np.random.random_sample(bod_map_resize.shape) * 0.01

            real_img.d = img
            real_bod_map.d = bod_map  # + bod_map_noize
            real_bod_map_resize.d = bod_map_resize  # + bod_map_resize_noize

            # Generate fake image
            fake_img.forward(clear_no_need_grad=True)

            # Update Discriminator
            solver_netD.zero_grad()
            solver_netG.zero_grad()
            loss_netD.forward(clear_no_need_grad=True)
            loss_netD.backward(clear_buffer=True)
            solver_netD.update()

            # Update Generator
            solver_netD.zero_grad()
            solver_netG.zero_grad()
            fake_img_unlinked.grad.zero()
            loss_netG.forward(clear_no_need_grad=True)
            loss_netG.backward(clear_buffer=True)
            fake_img.backward(grad=None)
            solver_netG.update()

            # Monitors
            monitor_time.add(i)
            monitors_G.add(i)
            monitors_D.add(i)

            i += 1

        images_to_visualize = [real_bod_map_resize.d, fake_img.d, img]
        visuals = combine_images(images_to_visualize)
        monitor_vis.add(i, visuals)

        if e % config["monitor"]["save_interval"] == 0 or e == epoch - 1:
            # Save parameters of networks
            netG_save_path = os.path.join(monitor._save_path,
                                          f'netG_decoder_{e}.h5')
            with nn.parameter_scope('netG_decoder'):
                nn.save_parameters(netG_save_path)
            netD_save_path = os.path.join(monitor._save_path,
                                          f'netD_decoder_{e}.h5')
            with nn.parameter_scope('netD_decoder'):
                nn.save_parameters(netD_save_path)
Exemple #5
0
def animate(args):

    # get context
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    logger.setLevel(logging.ERROR)  # to supress minor messages

    if not args.config:
        assert not args.params, "pretrained weights file is given, but corresponding config file is not. Please give both."
        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/voxceleb_trained_info.yaml"
        )
        args.config = 'voxceleb_trained_info.yaml'

        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/pretrained_fomm_params.h5"
        )

    config = read_yaml(args.config)

    dataset_params = config.dataset_params
    model_params = config.model_params

    if args.detailed:
        vis_params = config.visualizer_params
        visualizer = Visualizer(**vis_params)

    if not args.params:
        assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters."
        param_file = os.path.join(config.log_dir, config.saved_parameters)
    else:
        param_file = args.params
    print(f"Loading {param_file} for image animation...")
    nn.load_parameters(param_file)

    bs, h, w, c = [1] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving_initial = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    filename = args.driving

    # process repeated until all the test data is used
    driving_video = read_video(
        filename, dataset_params.frame_shape)  # (#frames, h, w, 3)
    driving_video = np.transpose(driving_video,
                                 (0, 3, 1, 2))  # (#frames, 3, h, w)

    source_img = imread(args.source, channel_first=True,
                        size=(256, 256)) / 255.
    source_img = source_img[:3]

    source.d = np.expand_dims(source_img, 0)
    driving_initial.d = driving_video[0][:3, ]

    with nn.parameter_scope("kp_detector"):
        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=True,
                                    comm=False)
        persistent_all(kp_source)

    with nn.parameter_scope("kp_detector"):
        kp_driving_initial = detect_keypoint(driving_initial,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=True,
                                             comm=False)
        persistent_all(kp_driving_initial)

    with nn.parameter_scope("kp_detector"):
        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=True,
                                     comm=False)
        persistent_all(kp_driving)

    if args.adapt_movement_scale:
        nn.forward_all([
            kp_source["value"], kp_source["jacobian"],
            kp_driving_initial["value"], kp_driving_initial["jacobian"]
        ])
        source_area = ConvexHull(kp_source['value'].d[0]).volume
        driving_area = ConvexHull(kp_driving_initial['value'].d[0]).volume
        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
    else:
        adapt_movement_scale = 1

    kp_norm = adjust_kp(kp_source=unlink_all(kp_source),
                        kp_driving=kp_driving,
                        kp_driving_initial=unlink_all(kp_driving_initial),
                        adapt_movement_scale=adapt_movement_scale,
                        use_relative_movement=args.unuse_relative_movement,
                        use_relative_jacobian=args.unuse_relative_jacobian)
    persistent_all(kp_norm)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=unlink_all(kp_source),
                                              kp_driving=kp_norm,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=True,
                                              comm=False)

    if not args.full and 'sparse_deformed' in generated:
        del generated['sparse_deformed']  # remove needless info

    persistent_all(generated)

    generated['kp_driving'] = kp_driving
    generated['kp_source'] = kp_source
    generated['kp_norm'] = kp_norm

    # generated contains these values;
    # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
    # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4))  # (bs, num_kp + 1, c, h, w)
    # 'occlusion_map': <Variable((bs, 1, h/4, w/4))
    # 'deformed': <Variable((bs, c, h, w))
    # 'prediction': <Variable((bs, c, h, w))

    mode = "arbitrary"
    if "log_dir" in config:
        result_dir = os.path.join(args.out_dir,
                                  os.path.basename(config.log_dir), f"{mode}")
    else:
        result_dir = os.path.join(args.out_dir, "test_result", f"{mode}")

    # create an empty directory to save generated results
    _ = nm.Monitor(result_dir)

    # load the header images.
    header = imread("imgs/header_combined.png", channel_first=True)
    generated_images = list()

    # compute these in advance and reuse
    nn.forward_all([kp_source["value"], kp_source["jacobian"]],
                   clear_buffer=True)
    nn.forward_all(
        [kp_driving_initial["value"], kp_driving_initial["jacobian"]],
        clear_buffer=True)

    num_of_driving_frames = driving_video.shape[0]

    for frame_idx in tqdm(range(num_of_driving_frames)):
        driving.d = driving_video[frame_idx][:3, ]
        nn.forward_all([generated["prediction"], generated["deformed"]],
                       clear_buffer=True)

        if args.detailed:
            # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map
            visualization = visualizer.visualize(source=source.d,
                                                 driving=driving.d,
                                                 out=generated)
            if args.full:
                visualization = reshape_result(visualization)  # (H, W, C)
            combined_image = visualization.transpose(2, 0, 1)  # (C, H, W)

        elif args.only_generated:
            combined_image = np.clip(generated["prediction"].d[0], 0.0, 1.0)
            combined_image = (255 * combined_image).astype(
                np.uint8)  # (C, H, W)

        else:
            # visualize source, driving, and generated image
            driving_fake = np.concatenate([
                np.clip(driving.d[0], 0.0, 1.0),
                np.clip(generated["prediction"].d[0], 0.0, 1.0)
            ],
                                          axis=2)
            header_source = np.concatenate([
                np.clip(header / 255., 0.0, 1.0),
                np.clip(source.d[0], 0.0, 1.0)
            ],
                                           axis=2)
            combined_image = np.concatenate([header_source, driving_fake],
                                            axis=1)
            combined_image = (255 * combined_image).astype(np.uint8)

        generated_images.append(combined_image)

    # once each video is generated, save it.
    output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4"
    output_filename = f"{os.path.basename(args.source)}_by_{output_filename}"
    output_filename = output_filename.replace("#", "_")
    if args.output_png:
        monitor_vis = nm.MonitorImage(output_filename,
                                      nm.Monitor(result_dir),
                                      interval=1,
                                      num_images=1,
                                      normalize_method=lambda x: x)
        for frame_idx, img in enumerate(generated_images):
            monitor_vis.add(frame_idx, img)
    else:
        generated_images = [_.transpose(1, 2, 0) for _ in generated_images]
        # you might need to change ffmpeg_params according to your environment.
        mimsave(f'{os.path.join(result_dir, output_filename)}',
                generated_images,
                fps=args.fps,
                ffmpeg_params=[
                    "-pix_fmt", "yuv420p", "-vcodec", "libx264", "-f", "mp4",
                    "-q", "0"
                ])

    return
def test(encoder_config, transformer_config, decoder_config, encoder_netG,
         transformer_netG, decoder_netG, src_celeb_name, trg_celeb_name,
         test_iterator, monitor, encoder_param_file, transformer_param_file,
         decoder_param_file):
    # prepare nn.Variable
    real_img = nn.Variable((1, 3, 256, 256))
    real_bod_map = nn.Variable((1, 15, 64, 64))
    real_bod_map_resize = nn.Variable((1, 15, 256, 256))

    # encoder
    with nn.parameter_scope(encoder_config["model_name"]):
        _, preds = encoder_netG(
            real_img,
            batch_stat=False,
            planes=encoder_config["model"]["planes"],
            output_nc=encoder_config["model"]["output_nc"],
            num_stacks=encoder_config["model"]["num_stacks"],
            activation=encoder_config["model"]["activation"],
        )
    preds.persistent = True
    preds_unlinked = preds.get_unlinked_variable()

    # load parameters of networks
    with nn.parameter_scope(encoder_config["model_name"]):
        nn.load_parameters(encoder_param_file)

    # transformer
    # Generator
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            fake_bod_map = transformer_netG(
                preds, test=True, norm_type=transformer_config["norm_type"])
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            nn.load_parameters(transformer_param_file)

    fake_bod_map.persistent = True
    fake_bod_map_unlinked = fake_bod_map.get_unlinked_variable()

    # decoder
    with nn.parameter_scope('netG_decoder'):
        fake_img = decoder_netG(fake_bod_map_unlinked, test=True)
    fake_img.persistent = True

    # load parameters of networks
    with nn.parameter_scope('netG_decoder'):
        nn.load_parameters(decoder_param_file)

    monitor_vis = nm.MonitorImage('result',
                                  monitor,
                                  interval=1,
                                  num_images=1,
                                  normalize_method=lambda x: x)

    # test
    num_test_batches = test_iterator.size
    for i in range(num_test_batches):
        _real_img, _, _real_bod_map_resize = test_iterator.next()

        real_img.d = _real_img
        real_bod_map_resize.d = _real_bod_map_resize

        # Generator
        preds.forward(clear_no_need_grad=True)
        fake_bod_map.forward(clear_no_need_grad=True)
        fake_img.forward(clear_no_need_grad=True)

        images_to_visualize = [
            real_img.d, preds.d, fake_bod_map.d, fake_img.d,
            real_bod_map_resize.d
        ]
        visuals = combine_images(images_to_visualize)
        monitor_vis.add(i, visuals)
def reconstruct(args):

    # get context
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    logger.setLevel(logging.ERROR)  # to supress minor messages

    config = read_yaml(args.config)

    dataset_params = config.dataset_params
    model_params = config.model_params

    if args.detailed:
        vis_params = config.visualizer_params
        visualizer = Visualizer(**vis_params)

    if not args.params:
        assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters."
        param_file = os.path.join(
            config.log_dir, config.saved_parameters)
    else:
        param_file = args.params
    nn.load_parameters(param_file)

    bs, h, w, c = [1] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving_initial = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    with nn.parameter_scope("kp_detector"):
        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=True, comm=False)
        persistent_all(kp_source)

    with nn.parameter_scope("kp_detector"):
        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=True, comm=False)
        persistent_all(kp_driving)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=unlink_all(kp_source),
                                              kp_driving=kp_driving,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=True, comm=False)

    if not args.full and 'sparse_deformed' in generated:
        del generated['sparse_deformed']  # remove needless info

    persistent_all(generated)

    generated['kp_driving'] = kp_driving
    generated['kp_source'] = kp_source

    # generated contains these values;
    # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
    # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4))  # (bs, num_kp + 1, c, h, w)
    # 'occlusion_map': <Variable((bs, 1, h/4, w/4))
    # 'deformed': <Variable((bs, c, h, w))
    # 'prediction': <Variable((bs, c, h, w))

    mode = "reconstruction"
    if "log_dir" in config:
        result_dir = os.path.join(args.out_dir, os.path.basename(config.log_dir), f"{mode}")
    else:
        result_dir = os.path.join(args.out_dir, "test_result", f"{mode}")

    # create an empty directory to save generated results
    _ = nm.Monitor(result_dir)
    if args.eval:
        os.makedirs(os.path.join(result_dir, "png"), exist_ok=True)

    # load the header images.
    header = imread("imgs/header_combined.png", channel_first=True)

    filenames = sorted(glob.glob(os.path.join(
        dataset_params.root_dir, "test", "*")))
    recon_loss_list = list()

    for filename in tqdm(filenames):
        # process repeated until all the test data is used
        driving_video = read_video(
            filename, dataset_params.frame_shape)  # (#frames, h, w, 3)
        driving_video = np.transpose(
            driving_video, (0, 3, 1, 2))  # (#frames, 3, h, w)

        generated_images = list()
        source_img = driving_video[0]

        source.d = np.expand_dims(source_img, 0)
        driving_initial.d = driving_video[0]

        # compute these in advance and reuse
        nn.forward_all(
            [kp_source["value"], kp_source["jacobian"]], clear_buffer=True)

        num_of_driving_frames = driving_video.shape[0]

        for frame_idx in tqdm(range(num_of_driving_frames)):
            driving.d = driving_video[frame_idx]
            nn.forward_all([generated["prediction"],
                            generated["deformed"]], clear_buffer=True)

            if args.detailed:
                # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map
                visualization = visualizer.visualize(
                    source=source.d, driving=driving.d, out=generated)
                if args.full:
                    visualization = reshape_result(visualization)  # (H, W, C)
                combined_image = visualization.transpose(2, 0, 1)  # (C, H, W)

            elif args.only_generated:
                combined_image = np.clip(
                    generated["prediction"].d[0], 0.0, 1.0)
                combined_image = (
                    255*combined_image).astype(np.uint8)  # (C, H, W)

            else:
                # visualize source, driving, and generated image
                driving_fake = np.concatenate([np.clip(driving.d[0], 0.0, 1.0),
                                               np.clip(generated["prediction"].d[0], 0.0, 1.0)], axis=2)
                header_source = np.concatenate([np.clip(header / 255., 0.0, 1.0),
                                                np.clip(source.d[0], 0.0, 1.0)], axis=2)
                combined_image = np.concatenate(
                    [header_source, driving_fake], axis=1)
                combined_image = (255*combined_image).astype(np.uint8)

            generated_images.append(combined_image)
            # compute L1 distance per frame.
            recon_loss_list.append(
                np.mean(np.abs(generated["prediction"].d[0] - driving.d[0])))

        # post process only for reconstruction evaluation.
        if args.eval:
            # crop generated images region only.
            if args.only_generated:
                eval_images = generated_images
            elif args.full:
                eval_images = [_[:, :h, 4*w:5*w] for _ in generated_images]
            elif args.detailed:
                assert generated_images[0].shape == (c, h, 5*w)
                eval_images = [_[:, :, 3*w:4*w] for _ in generated_images]
            else:
                eval_images = [_[:, h:, w:] for _ in generated_images]
            # place them horizontally and save for evaluation.
            image_for_eval = np.concatenate(
                eval_images, axis=2).transpose(1, 2, 0)
            imsave(os.path.join(result_dir, "png", f"{os.path.basename(filename)}.png"),
                   image_for_eval)

        # once each video is generated, save it.
        output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4"
        if args.output_png:
            monitor_vis = nm.MonitorImage(output_filename, nm.Monitor(result_dir),
                                          interval=1, num_images=1,
                                          normalize_method=lambda x: x)
            for frame_idx, img in enumerate(generated_images):
                monitor_vis.add(frame_idx, img)
        else:
            generated_images = [_.transpose(1, 2, 0) for _ in generated_images]
            # you might need to change ffmpeg_params according to your environment.
            mimsave(f'{os.path.join(result_dir, output_filename)}', generated_images,
                    fps=args.fps,
                    ffmpeg_params=["-pix_fmt", "yuv420p",
                                   "-vcodec", "libx264",
                                   "-f", "mp4",
                                   "-q", "0"])
    print(f"Reconstruction loss: {np.mean(recon_loss_list)}")

    return
def train_transformer(config, netG, netD, solver_netG, solver_netD,
                      train_iterators, monitor):

    netG_A2B, netG_B2A = netG['netG_A2B'], netG['netG_B2A']
    netD_A, netD_B = netD['netD_A'], netD['netD_B']
    solver_netG_AB, solver_netG_BA = solver_netG['netG_A2B'], solver_netG[
        'netG_B2A']
    solver_netD_A, solver_netD_B = solver_netD['netD_A'], solver_netD['netD_B']

    train_iterator_src, train_iterator_trg = train_iterators

    if config["train"][
            "cycle_loss"] and config["train"]["cycle_loss"]["lambda"] > 0:
        print(
            f'Applying Cycle Loss, weight: {config["train"]["cycle_loss"]["lambda"]}.'
        )
        with_cycle_loss = True
    else:
        with_cycle_loss = False

    if config["train"][
            "shape_loss"] and config["train"]["shape_loss"]["lambda"] > 0:
        print(
            f'Applying Shape Loss using PCA, weight: {config["train"]["shape_loss"]["lambda"]}.'
        )
        with_shape_loss = True
    else:
        with_shape_loss = False

    # Load boundary image to get Variable shapes
    bod_map_A = train_iterator_src.next()[0]
    bod_map_B = train_iterator_trg.next()[0]
    real_bod_map_A = nn.Variable(bod_map_A.shape)
    real_bod_map_B = nn.Variable(bod_map_B.shape)
    real_bod_map_A.persistent, real_bod_map_B.persistent = True, True

    ################### Graph Construction ####################
    # Generator
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            fake_bod_map_B = netG_A2B(
                real_bod_map_A, test=False,
                norm_type=config["norm_type"])  # (1, 15, 64, 64)
        with nn.parameter_scope('netG_B2A'):
            fake_bod_map_A = netG_B2A(
                real_bod_map_B, test=False,
                norm_type=config["norm_type"])  # (1, 15, 64, 64)
    fake_bod_map_B.persistent, fake_bod_map_A.persistent = True, True

    fake_bod_map_B_unlinked = fake_bod_map_B.get_unlinked_variable()
    fake_bod_map_A_unlinked = fake_bod_map_A.get_unlinked_variable()

    # Reconstruct images if cycle loss is applied.
    if with_cycle_loss:
        with nn.parameter_scope('netG_transformer'):
            with nn.parameter_scope('netG_B2A'):
                recon_bod_map_A = netG_B2A(
                    fake_bod_map_B_unlinked,
                    test=False,
                    norm_type=config["norm_type"])  # (1, 15, 64, 64)
            with nn.parameter_scope('netG_A2B'):
                recon_bod_map_B = netG_A2B(
                    fake_bod_map_A_unlinked,
                    test=False,
                    norm_type=config["norm_type"])  # (1, 15, 64, 64)
        recon_bod_map_A.persistent, recon_bod_map_B.persistent = True, True

    # Discriminator
    with nn.parameter_scope('netD_transformer'):
        with nn.parameter_scope('netD_A'):
            pred_fake_A = netD_A(fake_bod_map_A_unlinked, test=False)
            pred_real_A = netD_A(real_bod_map_A, test=False)
        with nn.parameter_scope('netD_B'):
            pred_fake_B = netD_B(fake_bod_map_B_unlinked, test=False)
            pred_real_B = netD_B(real_bod_map_B, test=False)
    real_target = F.constant(1, pred_fake_A.shape)
    fake_target = F.constant(0, pred_real_A.shape)

    ################### Loss Definition ####################
    # Generator loss
    # LSGAN loss
    loss_gan_A = lsgan_loss(pred_fake_A, real_target)
    loss_gan_B = lsgan_loss(pred_fake_B, real_target)
    loss_gan_A.persistent, loss_gan_B.persistent = True, True
    loss_gan = loss_gan_A + loss_gan_B

    # Cycle loss
    if with_cycle_loss:
        loss_cycle_A = recon_loss(recon_bod_map_A, real_bod_map_A)
        loss_cycle_B = recon_loss(recon_bod_map_B, real_bod_map_B)
        loss_cycle_A.persistent, loss_cycle_B.persistent = True, True
        loss_cycle = loss_cycle_A + loss_cycle_B

    # Shape loss
    if with_shape_loss:
        with nn.parameter_scope("Align"):
            nn.load_parameters(
                config["train"]["shape_loss"]["align_param_path"])
            shape_bod_map_real_A = models.align_resnet(real_bod_map_A,
                                                       fix_parameters=True)
            shape_bod_map_fake_B = models.align_resnet(fake_bod_map_B_unlinked,
                                                       fix_parameters=True)

            shape_bod_map_real_B = models.align_resnet(real_bod_map_B,
                                                       fix_parameters=True)
            shape_bod_map_fake_A = models.align_resnet(fake_bod_map_A_unlinked,
                                                       fix_parameters=True)

        with nn.parameter_scope("PCA"):
            nn.load_parameters(config["train"]["shape_loss"]["PCA_param_path"])
            shape_bod_map_real_A = PF.affine(shape_bod_map_real_A,
                                             212,
                                             fix_parameters=True)
            shape_bod_map_real_A = shape_bod_map_real_A[:, :3]

            shape_bod_map_fake_B = PF.affine(shape_bod_map_fake_B,
                                             212,
                                             fix_parameters=True)
            shape_bod_map_fake_B = shape_bod_map_fake_B[:, :3]

            shape_bod_map_real_B = PF.affine(shape_bod_map_real_B,
                                             212,
                                             fix_parameters=True)
            shape_bod_map_real_B = shape_bod_map_real_B[:, :3]

            shape_bod_map_fake_A = PF.affine(shape_bod_map_fake_A,
                                             212,
                                             fix_parameters=True)
            shape_bod_map_fake_A = shape_bod_map_fake_A[:, :3]

        shape_bod_map_real_A.persistent, shape_bod_map_fake_A.persistent = True, True
        shape_bod_map_real_B.persistent, shape_bod_map_fake_B.persistent = True, True

        loss_shape_A = recon_loss(shape_bod_map_real_A, shape_bod_map_fake_B)
        loss_shape_B = recon_loss(shape_bod_map_real_B, shape_bod_map_fake_A)
        loss_shape_A.persistent, loss_shape_B.persistent = True, True
        loss_shape = loss_shape_A + loss_shape_B

    # Total Generator Loss
    loss_netG = loss_gan

    if with_cycle_loss:
        loss_netG += loss_cycle * config["train"]["cycle_loss"]["lambda"]

    if with_shape_loss:
        loss_netG += loss_shape * config["train"]["shape_loss"]["lambda"]

    # Discriminator loss
    loss_netD_A = lsgan_loss(pred_real_A, real_target) + \
        lsgan_loss(pred_fake_A, fake_target)
    loss_netD_B = lsgan_loss(pred_real_B, real_target) + \
        lsgan_loss(pred_fake_B, fake_target)
    loss_netD_A.persistent, loss_netD_B.persistent = True, True

    loss_netD = loss_netD_A + loss_netD_B

    ################### Setting Solvers ####################
    # Generator solver
    with nn.parameter_scope('netG_transformer'):
        with nn.parameter_scope('netG_A2B'):
            solver_netG_AB.set_parameters(nn.get_parameters())
        with nn.parameter_scope('netG_B2A'):
            solver_netG_BA.set_parameters(nn.get_parameters())

    # Discrimintar solver
    with nn.parameter_scope('netD_transformer'):
        with nn.parameter_scope('netD_A'):
            solver_netD_A.set_parameters(nn.get_parameters())
        with nn.parameter_scope('netD_B'):
            solver_netD_B.set_parameters(nn.get_parameters())

    ################### Create Monitors ####################
    interval = config["monitor"]["interval"]
    monitors_G_dict = {
        'loss_netG': loss_netG,
        'loss_gan_A': loss_gan_A,
        'loss_gan_B': loss_gan_B
    }

    if with_cycle_loss:
        monitors_G_dict.update({
            'loss_cycle_A': loss_cycle_A,
            'loss_cycle_B': loss_cycle_B
        })

    if with_shape_loss:
        monitors_G_dict.update({
            'loss_shape_A': loss_shape_A,
            'loss_shape_B': loss_shape_B
        })

    monitors_G = MonitorManager(monitors_G_dict, monitor, interval=interval)

    monitors_D_dict = {
        'loss_netD': loss_netD,
        'loss_netD_A': loss_netD_A,
        'loss_netD_B': loss_netD_B
    }
    monitors_D = MonitorManager(monitors_D_dict, monitor, interval=interval)

    monitor_time = nm.MonitorTimeElapsed('time_training',
                                         monitor,
                                         interval=interval)
    monitor_vis = nm.MonitorImage('result',
                                  monitor,
                                  interval=1,
                                  num_images=4,
                                  normalize_method=lambda x: x)

    # Dump training information
    with open(os.path.join(monitor._save_path, "training_info.yaml"),
              "w",
              encoding="utf-8") as f:
        f.write(yaml.dump(config))

    # Training
    epoch = config["train"]["epochs"]
    i = 0
    iter_per_epoch = train_iterator_src.size // config["train"][
        "batch_size"] + 1
    for e in range(epoch):
        logger.info(f'Epoch = {e} / {epoch}')
        train_iterator_src._reset()  # rewind the iterator
        train_iterator_trg._reset()  # rewind the iterator
        for _ in range(iter_per_epoch):
            bod_map_A = train_iterator_src.next()[0]
            bod_map_B = train_iterator_trg.next()[0]
            real_bod_map_A.d, real_bod_map_B.d = bod_map_A, bod_map_B

            # Generate fake image
            fake_bod_map_B.forward(clear_no_need_grad=True)
            fake_bod_map_A.forward(clear_no_need_grad=True)

            # Update Discriminator
            solver_netD_A.zero_grad()
            solver_netD_B.zero_grad()
            loss_netD.forward(clear_no_need_grad=True)
            loss_netD.backward(clear_buffer=True)
            if config["train"]["weight_decay"]:
                solver_netD_A.weight_decay(config["train"]["weight_decay"])
                solver_netD_B.weight_decay(config["train"]["weight_decay"])
            solver_netD_A.update()
            solver_netD_B.update()

            # Update Generator
            solver_netG_BA.zero_grad()
            solver_netG_AB.zero_grad()
            solver_netD_A.zero_grad()
            solver_netD_B.zero_grad()
            fake_bod_map_B_unlinked.grad.zero()
            fake_bod_map_A_unlinked.grad.zero()
            loss_netG.forward(clear_no_need_grad=True)
            loss_netG.backward(clear_buffer=True)
            fake_bod_map_B.backward(grad=None)
            fake_bod_map_A.backward(grad=None)
            solver_netG_AB.update()
            solver_netG_BA.update()

            # Monitors
            monitor_time.add(i)
            monitors_G.add(i)
            monitors_D.add(i)

            i += 1

        images_to_visualize = [
            real_bod_map_A.d, fake_bod_map_B.d, real_bod_map_B.d
        ]
        if with_cycle_loss:
            images_to_visualize.extend(
                [recon_bod_map_A.d, fake_bod_map_A.d, recon_bod_map_B.d])
        else:
            images_to_visualize.extend([fake_bod_map_A.d])
        visuals = combine_images(images_to_visualize)
        monitor_vis.add(i, visuals)

        if e % config["monitor"]["save_interval"] == 0 or e == epoch - 1:
            # Save parameters of networks
            netG_B2A_save_path = os.path.join(monitor._save_path,
                                              f'netG_transformer_B2A_{e}.h5')
            netG_A2B_save_path = os.path.join(monitor._save_path,
                                              f'netG_transformer_A2B_{e}.h5')
            with nn.parameter_scope('netG_transformer'):
                with nn.parameter_scope('netG_A2B'):
                    nn.save_parameters(netG_A2B_save_path)
                with nn.parameter_scope('netG_B2A'):
                    nn.save_parameters(netG_B2A_save_path)

            netD_A_save_path = os.path.join(monitor._save_path,
                                            f'netD_transformer_A_{e}.h5')
            netD_B_save_path = os.path.join(monitor._save_path,
                                            f'netD_transformer_B_{e}.h5')
            with nn.parameter_scope('netD_transformer'):
                with nn.parameter_scope('netD_A'):
                    nn.save_parameters(netD_A_save_path)
                with nn.parameter_scope('netD_B'):
                    nn.save_parameters(netD_B_save_path)
def train(config, train_iterator, valid_iterator, monitor):

    ################### Graph Construction ####################
    # Training graph
    img, htm = train_iterator.next()
    image = nn.Variable(img.shape)
    heatmap = nn.Variable(htm.shape)

    with nn.parameter_scope(config["model_name"]):
        preds = stacked_hourglass_net(
            image,
            batch_stat=True,
            planes=config["model"]["planes"],
            output_nc=config["model"]["output_nc"],
            num_stacks=config["model"]["num_stacks"],
            activation=config["model"]["activation"],
        )

    if config["finetune"]:
        os.path.isfile(config["finetune"]["param_path"]
                       ), "params file not found."
        with nn.parameter_scope(config["model_name"]):
            nn.load_parameters(config["finetune"]["param_path"])

    # Loss Definition
    if config["loss_name"] == 'mse':
        def loss_func(pred, target): return F.mean(
            F.squared_error(pred, target))
    elif config["loss_name"] == 'bce':
        def loss_func(pred, target): return F.mean(
            F.binary_cross_entropy(pred, target))
    else:
        raise NotImplementedError

    losses = []
    for pred in preds:
        loss_local = loss_func(pred, heatmap)
        loss_local.persistent = True
        losses.append(loss_local)

    loss = nn.Variable()
    loss.d = 0
    for loss_local in losses:
        loss += loss_local

    ################### Setting Solvers ####################
    solver = S.Adam(config["train"]["lr"])
    with nn.parameter_scope(config["model_name"]):
        solver.set_parameters(nn.get_parameters())

    # Validation graph
    img, htm = valid_iterator.next()
    val_image = nn.Variable(img.shape)
    val_heatmap = nn.Variable(htm.shape)

    with nn.parameter_scope(config["model_name"]):
        val_preds = stacked_hourglass_net(
            val_image,
            batch_stat=False,
            planes=config["model"]["planes"],
            output_nc=config["model"]["output_nc"],
            num_stacks=config["model"]["num_stacks"],
            activation=config["model"]["activation"],
        )

    for i in range(len(val_preds)):
        val_preds[i].persistent = True

    # Loss Definition
    val_losses = []
    for pred in val_preds:
        loss_local = loss_func(pred, val_heatmap)
        loss_local.persistent = True
        val_losses.append(loss_local)

    val_loss = nn.Variable()
    val_loss.d = 0
    for loss_local in val_losses:
        val_loss += loss_local

    num_train_batches = train_iterator.size // train_iterator.batch_size + 1
    num_valid_batches = valid_iterator.size // valid_iterator.batch_size + 1

    ################### Create Monitors ####################
    monitors_train_dict = {'loss_total': loss}

    for i in range(len(losses)):
        monitors_train_dict.update({f'loss_{i}': losses[i]})

    monitors_val_dict = {'val_loss_total': val_loss}

    for i in range(len(val_losses)):
        monitors_val_dict.update({f'val_loss_{i}': val_losses[i]})

    monitors_train = MonitorManager(
        monitors_train_dict, monitor, interval=config["monitor"]["interval"]*num_train_batches)
    monitors_val = MonitorManager(
        monitors_val_dict, monitor, interval=config["monitor"]["interval"]*num_valid_batches)
    monitor_time = nm.MonitorTimeElapsed(
        'time', monitor, interval=config["monitor"]["interval"]*num_train_batches)
    monitor_vis = nm.MonitorImage(
        'result', monitor, interval=1, num_images=4, normalize_method=lambda x: x)
    monitor_vis_val = nm.MonitorImage(
        'result_val', monitor, interval=1, num_images=4, normalize_method=lambda x: x)

    os.mkdir(os.path.join(monitor._save_path, 'model'))

    # Dump training information
    with open(os.path.join(monitor._save_path, "training_info.yaml"), "w", encoding="utf-8") as f:
        f.write(yaml.dump(config))

    # Training
    best_epoch = 0
    best_val_loss = np.inf

    for e in range(config["train"]["epochs"]):
        watch_val_loss = 0

        # training loop
        for i in range(num_train_batches):
            image.d, heatmap.d = train_iterator.next()

            solver.zero_grad()
            loss.forward()
            loss.backward(clear_buffer=True)
            solver.weight_decay(config["train"]["weight_decay"])
            solver.update()

            monitors_train.add(e*num_train_batches + i)
            monitor_time.add(e*num_train_batches + i)

        # validation loop
        for i in range(num_valid_batches):
            val_image.d, val_heatmap.d = valid_iterator.next()
            val_loss.forward(clear_buffer=True)
            monitors_val.add(e*num_valid_batches + i)

            watch_val_loss += val_loss.d.copy()

        watch_val_loss /= num_valid_batches

        # visualization
        visuals = combine_images([image.d, preds[0].d, preds[1].d, heatmap.d])
        monitor_vis.add(e, visuals)

        visuals_val = combine_images(
            [val_image.d, val_preds[0].d, val_preds[1].d, val_heatmap.d])
        monitor_vis_val.add(e, visuals_val)

        # update best result and save weights if updated
        if best_val_loss > watch_val_loss or e % config["monitor"]["save_interval"] == 0:
            best_val_loss = watch_val_loss
            best_epoch = e
            save_path = os.path.join(
                monitor._save_path, 'model/model_epoch-{}.h5'.format(e))
            with nn.parameter_scope(config["model_name"]):
                nn.save_parameters(save_path)

    # save the last parameters as well
    save_path = os.path.join(
        monitor._save_path, 'model/model_epoch-{}.h5'.format(e))
    with nn.parameter_scope(config["model_name"]):
        nn.save_parameters(save_path)

    logger.info(f'Best Epoch: {best_epoch}.')