Beispiel #1
0
def init_nnabla(conf=None, ext_name=None, device_id=None, type_config=None):
    import nnabla as nn
    from nnabla.ext_utils import get_extension_context
    from .comm import CommunicatorWrapper
    if conf is None:
        conf = AttrDict()
    if ext_name is not None:
        conf.ext_name = ext_name
    if device_id is not None:
        conf.device_id = device_id
    if type_config is not None:
        conf.type_config = type_config

    # set context
    ctx = get_extension_context(ext_name=conf.ext_name,
                                device_id=conf.device_id,
                                type_config=conf.type_config)

    # init communicator
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    # disable outputs from logger except rank==0
    if comm.rank > 0:
        from nnabla import logger
        import logging

        logger.setLevel(logging.ERROR)

    return comm
Beispiel #2
0
def init_nnabla(ctx_config):
    import nnabla as nn
    from nnabla.ext_utils import get_extension_context
    from comm import CommunicatorWrapper

    # set context
    ctx = get_extension_context(**ctx_config)

    # init communicator
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    # disable outputs from logger except rank==0
    if comm.rank > 0:
        from nnabla import logger
        import logging

        logger.setLevel(logging.ERROR)

    return comm
Beispiel #3
0
def animate(args):

    # get context
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    logger.setLevel(logging.ERROR)  # to supress minor messages

    if not args.config:
        assert not args.params, "pretrained weights file is given, but corresponding config file is not. Please give both."
        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/voxceleb_trained_info.yaml"
        )
        args.config = 'voxceleb_trained_info.yaml'

        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/pretrained_fomm_params.h5"
        )

    config = read_yaml(args.config)

    dataset_params = config.dataset_params
    model_params = config.model_params

    if args.detailed:
        vis_params = config.visualizer_params
        visualizer = Visualizer(**vis_params)

    if not args.params:
        assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters."
        param_file = os.path.join(config.log_dir, config.saved_parameters)
    else:
        param_file = args.params
    print(f"Loading {param_file} for image animation...")
    nn.load_parameters(param_file)

    bs, h, w, c = [1] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving_initial = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    filename = args.driving

    # process repeated until all the test data is used
    driving_video = read_video(
        filename, dataset_params.frame_shape)  # (#frames, h, w, 3)
    driving_video = np.transpose(driving_video,
                                 (0, 3, 1, 2))  # (#frames, 3, h, w)

    source_img = imread(args.source, channel_first=True,
                        size=(256, 256)) / 255.
    source_img = source_img[:3]

    source.d = np.expand_dims(source_img, 0)
    driving_initial.d = driving_video[0][:3, ]

    with nn.parameter_scope("kp_detector"):
        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=True,
                                    comm=False)
        persistent_all(kp_source)

    with nn.parameter_scope("kp_detector"):
        kp_driving_initial = detect_keypoint(driving_initial,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=True,
                                             comm=False)
        persistent_all(kp_driving_initial)

    with nn.parameter_scope("kp_detector"):
        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=True,
                                     comm=False)
        persistent_all(kp_driving)

    if args.adapt_movement_scale:
        nn.forward_all([
            kp_source["value"], kp_source["jacobian"],
            kp_driving_initial["value"], kp_driving_initial["jacobian"]
        ])
        source_area = ConvexHull(kp_source['value'].d[0]).volume
        driving_area = ConvexHull(kp_driving_initial['value'].d[0]).volume
        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
    else:
        adapt_movement_scale = 1

    kp_norm = adjust_kp(kp_source=unlink_all(kp_source),
                        kp_driving=kp_driving,
                        kp_driving_initial=unlink_all(kp_driving_initial),
                        adapt_movement_scale=adapt_movement_scale,
                        use_relative_movement=args.unuse_relative_movement,
                        use_relative_jacobian=args.unuse_relative_jacobian)
    persistent_all(kp_norm)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=unlink_all(kp_source),
                                              kp_driving=kp_norm,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=True,
                                              comm=False)

    if not args.full and 'sparse_deformed' in generated:
        del generated['sparse_deformed']  # remove needless info

    persistent_all(generated)

    generated['kp_driving'] = kp_driving
    generated['kp_source'] = kp_source
    generated['kp_norm'] = kp_norm

    # generated contains these values;
    # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
    # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4))  # (bs, num_kp + 1, c, h, w)
    # 'occlusion_map': <Variable((bs, 1, h/4, w/4))
    # 'deformed': <Variable((bs, c, h, w))
    # 'prediction': <Variable((bs, c, h, w))

    mode = "arbitrary"
    if "log_dir" in config:
        result_dir = os.path.join(args.out_dir,
                                  os.path.basename(config.log_dir), f"{mode}")
    else:
        result_dir = os.path.join(args.out_dir, "test_result", f"{mode}")

    # create an empty directory to save generated results
    _ = nm.Monitor(result_dir)

    # load the header images.
    header = imread("imgs/header_combined.png", channel_first=True)
    generated_images = list()

    # compute these in advance and reuse
    nn.forward_all([kp_source["value"], kp_source["jacobian"]],
                   clear_buffer=True)
    nn.forward_all(
        [kp_driving_initial["value"], kp_driving_initial["jacobian"]],
        clear_buffer=True)

    num_of_driving_frames = driving_video.shape[0]

    for frame_idx in tqdm(range(num_of_driving_frames)):
        driving.d = driving_video[frame_idx][:3, ]
        nn.forward_all([generated["prediction"], generated["deformed"]],
                       clear_buffer=True)

        if args.detailed:
            # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map
            visualization = visualizer.visualize(source=source.d,
                                                 driving=driving.d,
                                                 out=generated)
            if args.full:
                visualization = reshape_result(visualization)  # (H, W, C)
            combined_image = visualization.transpose(2, 0, 1)  # (C, H, W)

        elif args.only_generated:
            combined_image = np.clip(generated["prediction"].d[0], 0.0, 1.0)
            combined_image = (255 * combined_image).astype(
                np.uint8)  # (C, H, W)

        else:
            # visualize source, driving, and generated image
            driving_fake = np.concatenate([
                np.clip(driving.d[0], 0.0, 1.0),
                np.clip(generated["prediction"].d[0], 0.0, 1.0)
            ],
                                          axis=2)
            header_source = np.concatenate([
                np.clip(header / 255., 0.0, 1.0),
                np.clip(source.d[0], 0.0, 1.0)
            ],
                                           axis=2)
            combined_image = np.concatenate([header_source, driving_fake],
                                            axis=1)
            combined_image = (255 * combined_image).astype(np.uint8)

        generated_images.append(combined_image)

    # once each video is generated, save it.
    output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4"
    output_filename = f"{os.path.basename(args.source)}_by_{output_filename}"
    output_filename = output_filename.replace("#", "_")
    if args.output_png:
        monitor_vis = nm.MonitorImage(output_filename,
                                      nm.Monitor(result_dir),
                                      interval=1,
                                      num_images=1,
                                      normalize_method=lambda x: x)
        for frame_idx, img in enumerate(generated_images):
            monitor_vis.add(frame_idx, img)
    else:
        generated_images = [_.transpose(1, 2, 0) for _ in generated_images]
        # you might need to change ffmpeg_params according to your environment.
        mimsave(f'{os.path.join(result_dir, output_filename)}',
                generated_images,
                fps=args.fps,
                ffmpeg_params=[
                    "-pix_fmt", "yuv420p", "-vcodec", "libx264", "-f", "mp4",
                    "-q", "0"
                ])

    return
def reconstruct(args):

    # get context
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    logger.setLevel(logging.ERROR)  # to supress minor messages

    config = read_yaml(args.config)

    dataset_params = config.dataset_params
    model_params = config.model_params

    if args.detailed:
        vis_params = config.visualizer_params
        visualizer = Visualizer(**vis_params)

    if not args.params:
        assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters."
        param_file = os.path.join(
            config.log_dir, config.saved_parameters)
    else:
        param_file = args.params
    nn.load_parameters(param_file)

    bs, h, w, c = [1] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving_initial = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    with nn.parameter_scope("kp_detector"):
        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=True, comm=False)
        persistent_all(kp_source)

    with nn.parameter_scope("kp_detector"):
        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=True, comm=False)
        persistent_all(kp_driving)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=unlink_all(kp_source),
                                              kp_driving=kp_driving,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=True, comm=False)

    if not args.full and 'sparse_deformed' in generated:
        del generated['sparse_deformed']  # remove needless info

    persistent_all(generated)

    generated['kp_driving'] = kp_driving
    generated['kp_source'] = kp_source

    # generated contains these values;
    # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
    # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4))  # (bs, num_kp + 1, c, h, w)
    # 'occlusion_map': <Variable((bs, 1, h/4, w/4))
    # 'deformed': <Variable((bs, c, h, w))
    # 'prediction': <Variable((bs, c, h, w))

    mode = "reconstruction"
    if "log_dir" in config:
        result_dir = os.path.join(args.out_dir, os.path.basename(config.log_dir), f"{mode}")
    else:
        result_dir = os.path.join(args.out_dir, "test_result", f"{mode}")

    # create an empty directory to save generated results
    _ = nm.Monitor(result_dir)
    if args.eval:
        os.makedirs(os.path.join(result_dir, "png"), exist_ok=True)

    # load the header images.
    header = imread("imgs/header_combined.png", channel_first=True)

    filenames = sorted(glob.glob(os.path.join(
        dataset_params.root_dir, "test", "*")))
    recon_loss_list = list()

    for filename in tqdm(filenames):
        # process repeated until all the test data is used
        driving_video = read_video(
            filename, dataset_params.frame_shape)  # (#frames, h, w, 3)
        driving_video = np.transpose(
            driving_video, (0, 3, 1, 2))  # (#frames, 3, h, w)

        generated_images = list()
        source_img = driving_video[0]

        source.d = np.expand_dims(source_img, 0)
        driving_initial.d = driving_video[0]

        # compute these in advance and reuse
        nn.forward_all(
            [kp_source["value"], kp_source["jacobian"]], clear_buffer=True)

        num_of_driving_frames = driving_video.shape[0]

        for frame_idx in tqdm(range(num_of_driving_frames)):
            driving.d = driving_video[frame_idx]
            nn.forward_all([generated["prediction"],
                            generated["deformed"]], clear_buffer=True)

            if args.detailed:
                # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map
                visualization = visualizer.visualize(
                    source=source.d, driving=driving.d, out=generated)
                if args.full:
                    visualization = reshape_result(visualization)  # (H, W, C)
                combined_image = visualization.transpose(2, 0, 1)  # (C, H, W)

            elif args.only_generated:
                combined_image = np.clip(
                    generated["prediction"].d[0], 0.0, 1.0)
                combined_image = (
                    255*combined_image).astype(np.uint8)  # (C, H, W)

            else:
                # visualize source, driving, and generated image
                driving_fake = np.concatenate([np.clip(driving.d[0], 0.0, 1.0),
                                               np.clip(generated["prediction"].d[0], 0.0, 1.0)], axis=2)
                header_source = np.concatenate([np.clip(header / 255., 0.0, 1.0),
                                                np.clip(source.d[0], 0.0, 1.0)], axis=2)
                combined_image = np.concatenate(
                    [header_source, driving_fake], axis=1)
                combined_image = (255*combined_image).astype(np.uint8)

            generated_images.append(combined_image)
            # compute L1 distance per frame.
            recon_loss_list.append(
                np.mean(np.abs(generated["prediction"].d[0] - driving.d[0])))

        # post process only for reconstruction evaluation.
        if args.eval:
            # crop generated images region only.
            if args.only_generated:
                eval_images = generated_images
            elif args.full:
                eval_images = [_[:, :h, 4*w:5*w] for _ in generated_images]
            elif args.detailed:
                assert generated_images[0].shape == (c, h, 5*w)
                eval_images = [_[:, :, 3*w:4*w] for _ in generated_images]
            else:
                eval_images = [_[:, h:, w:] for _ in generated_images]
            # place them horizontally and save for evaluation.
            image_for_eval = np.concatenate(
                eval_images, axis=2).transpose(1, 2, 0)
            imsave(os.path.join(result_dir, "png", f"{os.path.basename(filename)}.png"),
                   image_for_eval)

        # once each video is generated, save it.
        output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4"
        if args.output_png:
            monitor_vis = nm.MonitorImage(output_filename, nm.Monitor(result_dir),
                                          interval=1, num_images=1,
                                          normalize_method=lambda x: x)
            for frame_idx, img in enumerate(generated_images):
                monitor_vis.add(frame_idx, img)
        else:
            generated_images = [_.transpose(1, 2, 0) for _ in generated_images]
            # you might need to change ffmpeg_params according to your environment.
            mimsave(f'{os.path.join(result_dir, output_filename)}', generated_images,
                    fps=args.fps,
                    ffmpeg_params=["-pix_fmt", "yuv420p",
                                   "-vcodec", "libx264",
                                   "-f", "mp4",
                                   "-q", "0"])
    print(f"Reconstruction loss: {np.mean(recon_loss_list)}")

    return
Beispiel #5
0
def train(args):

    # get context

    ctx = get_extension_context(args.context)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    config = read_yaml(args.config)

    if args.info:
        config.monitor_params.info = args.info

    if comm.size == 1:
        comm = None
    else:
        # disable outputs from logger except its rank = 0
        if comm.rank > 0:
            import logging
            logger.setLevel(logging.ERROR)

    test = False
    train_params = config.train_params
    dataset_params = config.dataset_params
    model_params = config.model_params

    loss_flags = get_loss_flags(train_params)

    start_epoch = 0

    rng = np.random.RandomState(device_id)
    data_iterator = frame_data_iterator(
        root_dir=dataset_params.root_dir,
        frame_shape=dataset_params.frame_shape,
        id_sampling=dataset_params.id_sampling,
        is_train=True,
        random_seed=rng,
        augmentation_params=dataset_params.augmentation_params,
        batch_size=train_params['batch_size'],
        shuffle=True,
        with_memory_cache=False,
        with_file_cache=False)

    if n_devices > 1:
        data_iterator = data_iterator.slice(rng=rng,
                                            num_of_slices=comm.size,
                                            slice_pos=comm.rank)
        # workaround not to use memory cache
        data_iterator._data_source._on_memory = False
        logger.info("Disabled on memory data cache.")

    bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    with nn.parameter_scope("kp_detector"):
        # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))}

        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=test,
                                    comm=comm)
        persistent_all(kp_source)

        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=test,
                                     comm=comm)
        persistent_all(kp_driving)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=kp_source,
                                              kp_driving=kp_driving,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=test,
                                              comm=comm)
        # generated is a dictionary containing;
        # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
        # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4))
        # 'occlusion_map': Variable((bs, 1, h/4, w/4))
        # 'deformed': Variable((bs, c, h, w))
        # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator.

    generated["prediction"].persistent = True

    pyramide_real = get_image_pyramid(driving, train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_real)

    pyramide_fake = get_image_pyramid(generated['prediction'],
                                      train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_fake)

    total_loss_G = None  # dammy. defined temporarily
    loss_var_dict = {}

    # perceptual loss using VGG19 (always applied)
    if loss_flags.use_perceptual_loss:
        logger.info("Use Perceptual Loss.")
        scales = train_params.scales
        weights = train_params.loss_weights.perceptual
        vgg_param_path = train_params.vgg_param_path
        percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales,
                                      weights, vgg_param_path)
        percep_loss.persistent = True
        loss_var_dict['perceptual_loss'] = percep_loss
        total_loss_G = percep_loss

    # (LS)GAN loss and feature matching loss
    if loss_flags.use_gan_loss:
        logger.info("Use GAN Loss.")
        with nn.parameter_scope("discriminator"):
            discriminator_maps_generated = multiscale_discriminator(
                pyramide_fake,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

            discriminator_maps_real = multiscale_discriminator(
                pyramide_real,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

        for v in discriminator_maps_generated["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_generated["prediction_map_1"].persistent = True

        for v in discriminator_maps_real["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_real["prediction_map_1"].persistent = True

        for i, scale in enumerate(model_params.discriminator_params.scales):
            key = f'prediction_map_{scale}'.replace('.', '-')
            lsgan_loss_weight = train_params.loss_weights.generator_gan
            # LSGAN loss for Generator
            if i == 0:
                gan_loss_gen = lsgan_loss(discriminator_maps_generated[key],
                                          lsgan_loss_weight)
            else:
                gan_loss_gen += lsgan_loss(discriminator_maps_generated[key],
                                           lsgan_loss_weight)
            # LSGAN loss for Discriminator
            if i == 0:
                gan_loss_dis = lsgan_loss(discriminator_maps_real[key],
                                          lsgan_loss_weight,
                                          discriminator_maps_generated[key])
            else:
                gan_loss_dis += lsgan_loss(discriminator_maps_real[key],
                                           lsgan_loss_weight,
                                           discriminator_maps_generated[key])
        gan_loss_dis.persistent = True
        loss_var_dict['gan_loss_dis'] = gan_loss_dis
        total_loss_D = gan_loss_dis
        total_loss_D.persistent = True

        gan_loss_gen.persistent = True
        loss_var_dict['gan_loss_gen'] = gan_loss_gen
        total_loss_G += gan_loss_gen

        if loss_flags.use_feature_matching_loss:
            logger.info("Use Feature Matching Loss.")
            fm_weights = train_params.loss_weights.feature_matching
            fm_loss = feature_matching_loss(discriminator_maps_real,
                                            discriminator_maps_generated,
                                            model_params, fm_weights)
            fm_loss.persistent = True
            loss_var_dict['feature_matching_loss'] = fm_loss
            total_loss_G += fm_loss

    # transform loss
    if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss:
        transform = Transform(bs, **config.train_params.transform_params)
        transformed_frame = transform.transform_frame(driving)

        with nn.parameter_scope("kp_detector"):
            transformed_kp = detect_keypoint(transformed_frame,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=test,
                                             comm=comm)
        persistent_all(transformed_kp)

        # Value loss part
        if loss_flags.use_equivariance_value_loss:
            logger.info("Use Equivariance Value Loss.")
            warped_kp_value = transform.warp_coordinates(
                transformed_kp['value'])
            eq_value_weight = train_params.loss_weights.equivariance_value

            eq_value_loss = equivariance_value_loss(kp_driving['value'],
                                                    warped_kp_value,
                                                    eq_value_weight)
            eq_value_loss.persistent = True
            loss_var_dict['equivariance_value_loss'] = eq_value_loss
            total_loss_G += eq_value_loss

        # jacobian loss part
        if loss_flags.use_equivariance_jacobian_loss:
            logger.info("Use Equivariance Jacobian Loss.")
            arithmetic_jacobian = transform.jacobian(transformed_kp['value'])
            eq_jac_weight = train_params.loss_weights.equivariance_jacobian
            eq_jac_loss = equivariance_jacobian_loss(
                kp_driving['jacobian'], arithmetic_jacobian,
                transformed_kp['jacobian'], eq_jac_weight)
            eq_jac_loss.persistent = True
            loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss
            total_loss_G += eq_jac_loss

    assert total_loss_G is not None
    total_loss_G.persistent = True
    loss_var_dict['total_loss_gen'] = total_loss_G

    # -------------------- Create Monitors --------------------
    monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors(
        config, loss_flags, loss_var_dict)

    if device_id == 0:
        # Dump training info .yaml
        _ = shutil.copy(args.config, log_dir)  # copy the config yaml
        training_info_yaml = os.path.join(log_dir, "training_info.yaml")
        os.rename(os.path.join(log_dir, os.path.basename(args.config)),
                  training_info_yaml)
        # then add additional information
        with open(training_info_yaml, "a", encoding="utf-8") as f:
            f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None")

    # -------------------- Solver Setup --------------------
    solvers = setup_solvers(train_params)
    solver_generator = solvers["generator"]
    solver_discriminator = solvers["discriminator"]
    solver_kp_detector = solvers["kp_detector"]

    # max epochs
    num_epochs = train_params['num_epochs']

    # iteration per epoch
    num_iter_per_epoch = data_iterator.size // bs
    # will be increased by num_repeat
    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        num_iter_per_epoch *= config.train_params.num_repeats

    # modify learning rate if current epoch exceeds the number defined in
    lr_decay_at_epochs = train_params['epoch_milestones']  # ex. [60, 90]
    gamma = 0.1  # decay rate

    # -------------------- For finetuning ---------------------
    if args.ft_params:
        assert os.path.isfile(args.ft_params)
        logger.info(f"load {args.ft_params} for finetuning.")
        nn.load_parameters(args.ft_params)
        start_epoch = int(
            os.path.splitext(os.path.basename(
                args.ft_params))[0].split("epoch_")[1])

        # set solver's state
        for name, solver in solvers.items():
            saved_states = os.path.join(
                os.path.dirname(args.ft_params),
                f"state_{name}_at_epoch_{start_epoch}.h5")
            solver.load_states(saved_states)

        start_epoch += 1
        logger.info(f"Resuming from epoch {start_epoch}.")

    logger.info(
        f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch."
    )

    for e in range(start_epoch, num_epochs):
        logger.info(f"Epoch: {e} / {num_epochs}.")
        data_iterator._reset()  # rewind the iterator at the beginning

        # learning rate scheduler
        if e in lr_decay_at_epochs:
            logger.info("Learning rate decayed.")
            learning_rate_decay(solvers, gamma=gamma)

        for i in range(num_iter_per_epoch):
            _driving, _source = data_iterator.next()
            source.d = _source
            driving.d = _driving

            # update generator and keypoint detector
            total_loss_G.forward()

            if device_id == 0:
                monitors_gen.add((e * num_iter_per_epoch + i) * n_devices)

            solver_generator.zero_grad()
            solver_kp_detector.zero_grad()

            callback = None
            if n_devices > 1:
                params = [x.grad for x in solver_generator.get_parameters().values()] + \
                         [x.grad for x in solver_kp_detector.get_parameters().values()]
                callback = comm.all_reduce_callback(params, 2 << 20)
            total_loss_G.backward(clear_buffer=True,
                                  communicator_callbacks=callback)

            solver_generator.update()
            solver_kp_detector.update()

            if loss_flags.use_gan_loss:
                # update discriminator

                total_loss_D.forward(clear_no_need_grad=True)
                if device_id == 0:
                    monitors_dis.add((e * num_iter_per_epoch + i) * n_devices)

                solver_discriminator.zero_grad()

                callback = None
                if n_devices > 1:
                    params = [
                        x.grad for x in
                        solver_discriminator.get_parameters().values()
                    ]
                    callback = comm.all_reduce_callback(params, 2 << 20)
                total_loss_D.backward(clear_buffer=True,
                                      communicator_callbacks=callback)

                solver_discriminator.update()

            if device_id == 0:
                monitor_time.add((e * num_iter_per_epoch + i) * n_devices)

            if device_id == 0 and (
                (e * num_iter_per_epoch + i) *
                    n_devices) % config.monitor_params.visualize_freq == 0:
                images_to_visualize = [
                    source.d, driving.d, generated["prediction"].d
                ]
                visuals = combine_images(images_to_visualize)
                monitor_vis.add((e * num_iter_per_epoch + i) * n_devices,
                                visuals)

        if device_id == 0:
            if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1:
                save_parameters(e, log_dir, solvers)

    return