Exemple #1
0
def train(args):

    # Variable size.
    bs, ch, h, w = args.batch_size, 3, args.loadSizeH, args.loadSizeW

    # Determine normalization method.
    if args.norm == "instance":
        norm_layer = functools.partial(PF.instance_normalization,
                                       fix_parameters=True,
                                       no_bias=True,
                                       no_scale=True)
    else:
        norm_layer = PF.batch_normalization

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(models.generator,
                                  input_nc=args.input_nc,
                                  output_nc=args.output_nc,
                                  ngf=args.ngf,
                                  norm_layer=norm_layer,
                                  use_dropout=False,
                                  n_blocks=9,
                                  padding_type='reflect')
    discriminator = functools.partial(models.discriminator,
                                      input_nc=args.output_nc,
                                      ndf=args.ndf,
                                      n_layers=args.n_layers_D,
                                      norm_layer=norm_layer,
                                      use_sigmoid=False)

    # --------------------- Computation Graphs --------------------

    # Input images and masks of both source / target domain
    x = nn.Variable([bs, ch, h, w], need_grad=False)
    a = nn.Variable([bs, 1, h, w], need_grad=False)

    y = nn.Variable([bs, ch, h, w], need_grad=False)
    b = nn.Variable([bs, 1, h, w], need_grad=False)

    # Apply image augmentation and get an unlinked variable
    xa_aug = image_augmentation(args, x, a)
    xa_aug.persistent = True
    xa_aug_unlinked = xa_aug.get_unlinked_variable()

    yb_aug = image_augmentation(args, y, b)
    yb_aug.persistent = True
    yb_aug_unlinked = yb_aug.get_unlinked_variable()

    # variables used for Image Pool
    x_history = nn.Variable([bs, ch, h, w])
    a_history = nn.Variable([bs, 1, h, w])
    y_history = nn.Variable([bs, ch, h, w])
    b_history = nn.Variable([bs, 1, h, w])

    # Generate Images (x -> y')
    with nn.parameter_scope("gen_x2y"):
        yb_fake = generator(xa_aug_unlinked)
    yb_fake.persistent = True
    yb_fake_unlinked = yb_fake.get_unlinked_variable()

    # Generate Images (y -> x')
    with nn.parameter_scope("gen_y2x"):
        xa_fake = generator(yb_aug_unlinked)
    xa_fake.persistent = True
    xa_fake_unlinked = xa_fake.get_unlinked_variable()

    # Reconstruct Images (y' -> x)
    with nn.parameter_scope("gen_y2x"):
        xa_recon = generator(yb_fake_unlinked)
    xa_recon.persistent = True

    # Reconstruct Images (x' -> y)
    with nn.parameter_scope("gen_x2y"):
        yb_recon = generator(xa_fake_unlinked)
    yb_recon.persistent = True

    # Use Discriminator on y' and x'
    with nn.parameter_scope("dis_y"):
        d_y_fake = discriminator(yb_fake_unlinked)
    d_y_fake.persistent = True

    with nn.parameter_scope("dis_x"):
        d_x_fake = discriminator(xa_fake_unlinked)
    d_x_fake.persistent = True

    # Use Discriminator on y and x
    with nn.parameter_scope("dis_y"):
        d_y_real = discriminator(yb_aug_unlinked)

    with nn.parameter_scope("dis_x"):
        d_x_real = discriminator(xa_aug_unlinked)

    # Identity Mapping (x -> x)
    with nn.parameter_scope("gen_y2x"):
        xa_idt = generator(xa_aug_unlinked)

    # Identity Mapping (y -> y)
    with nn.parameter_scope("gen_x2y"):
        yb_idt = generator(yb_aug_unlinked)

    # -------------------- Loss --------------------

    # (LS)GAN Loss (for Discriminator)
    loss_dis_x = (loss.lsgan_loss(d_y_fake, False) +
                  loss.lsgan_loss(d_y_real, True)) * 0.5
    loss_dis_y = (loss.lsgan_loss(d_x_fake, False) +
                  loss.lsgan_loss(d_x_real, True)) * 0.5
    loss_dis = loss_dis_x + loss_dis_y

    # Cycle Consistency Loss
    loss_cyc_x = args.lambda_cyc * loss.recon_loss(xa_recon, xa_aug_unlinked)
    loss_cyc_y = args.lambda_cyc * loss.recon_loss(yb_recon, yb_aug_unlinked)
    loss_cyc = loss_cyc_x + loss_cyc_y

    # Identity Mapping Loss
    loss_idt_x = args.lambda_idt * loss.recon_loss(xa_idt, xa_aug_unlinked)
    loss_idt_y = args.lambda_idt * loss.recon_loss(yb_idt, yb_aug_unlinked)
    loss_idt = loss_idt_x + loss_idt_y

    # Context Preserving Loss
    loss_ctx_x = args.lambda_ctx * \
        loss.context_preserving_loss(xa_aug_unlinked, yb_fake_unlinked)
    loss_ctx_y = args.lambda_ctx * \
        loss.context_preserving_loss(yb_aug_unlinked, xa_fake_unlinked)
    loss_ctx = loss_ctx_x + loss_ctx_y

    # (LS)GAN Loss (for Generator)
    d_loss_gen_x = loss.lsgan_loss(d_x_fake, True)
    d_loss_gen_y = loss.lsgan_loss(d_y_fake, True)
    d_loss_gen = d_loss_gen_x + d_loss_gen_y

    # Total Loss for Generator
    loss_gen = loss_cyc + loss_idt + loss_ctx + d_loss_gen

    # --------------------- Solvers --------------------

    # Initial learning rates
    G_lr = args.learning_rate_G
    #D_lr = args.learning_rate_D
    # As opposed to the description in the paper, D_lr is set the same as G_lr.
    D_lr = args.learning_rate_G

    # Define solvers
    solver_gen_x2y = S.Adam(G_lr, args.beta1, args.beta2)
    solver_gen_y2x = S.Adam(G_lr, args.beta1, args.beta2)
    solver_dis_x = S.Adam(D_lr, args.beta1, args.beta2)
    solver_dis_y = S.Adam(D_lr, args.beta1, args.beta2)

    # Set Parameters to each solver
    with nn.parameter_scope("gen_x2y"):
        solver_gen_x2y.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen_y2x"):
        solver_gen_y2x.set_parameters(nn.get_parameters())

    with nn.parameter_scope("dis_x"):
        solver_dis_x.set_parameters(nn.get_parameters())

    with nn.parameter_scope("dis_y"):
        solver_dis_y.set_parameters(nn.get_parameters())

    # create convenient functions manipulating Solvers
    def solvers_zero_grad():
        # Zeroing Gradients of all solvers
        solver_gen_x2y.zero_grad()
        solver_gen_y2x.zero_grad()
        solver_dis_x.zero_grad()
        solver_dis_y.zero_grad()

    def solvers_update_parameters(new_D_lr, new_G_lr):
        # Learning rate updater
        solver_gen_x2y.set_learning_rate(new_G_lr)
        solver_gen_y2x.set_learning_rate(new_G_lr)
        solver_dis_x.set_learning_rate(new_D_lr)
        solver_dis_y.set_learning_rate(new_D_lr)

    # -------------------- Data Iterators --------------------

    ds_train_A = insta_gan_data_source(args,
                                       train=True,
                                       domain="A",
                                       shuffle=True)
    di_train_A = insta_gan_data_iterator(ds_train_A, args.batch_size)

    ds_train_B = insta_gan_data_source(args,
                                       train=True,
                                       domain="B",
                                       shuffle=True)
    di_train_B = insta_gan_data_iterator(ds_train_B, args.batch_size)

    # -------------------- Monitors --------------------

    monitoring_targets_dis = {
        'discriminator_loss_x': loss_dis_x,
        'discriminator_loss_y': loss_dis_y
    }
    monitors_dis = Monitors(args, monitoring_targets_dis)

    monitoring_targets_gen = {
        'generator_loss_x': d_loss_gen_x,
        'generator_loss_y': d_loss_gen_y,
        'reconstruction_loss_x': loss_cyc_x,
        'reconstruction_loss_y': loss_cyc_y,
        'identity_mapping_loss_x': loss_idt_x,
        'identity_mapping_loss_y': loss_idt_y,
        'content_preserving_loss_x': loss_ctx_x,
        'content_preserving_loss_y': loss_ctx_y
    }
    monitors_gen = Monitors(args, monitoring_targets_gen)

    monitor_time = MonitorTimeElapsed("Training_time",
                                      Monitor(args.monitor_path),
                                      args.log_step)

    # Training loop
    epoch = 0
    n_images = max([ds_train_B.size, ds_train_A.size])
    print("{} images exist.".format(n_images))
    max_iter = args.max_epoch * n_images // args.batch_size
    decay_iter = args.max_epoch - args.lr_decay_start_epoch

    for i in range(max_iter):
        if i % (n_images // args.batch_size) == 0 and i > 0:
            # Learning Rate Decay
            epoch += 1
            print("epoch {}".format(epoch))
            if epoch >= args.lr_decay_start_epoch:
                new_D_lr = D_lr * \
                    (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) /
                     float(decay_iter - 1))
                new_G_lr = G_lr * \
                    (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) /
                     float(decay_iter - 1))
                solvers_update_parameters(new_D_lr, new_G_lr)
                print("Current learning rate for Discriminator: {}".format(
                    solver_dis_x.learning_rate()))
                print("Current learning rate for Generator: {}".format(
                    solver_gen_x2y.learning_rate()))

        # Get data
        x_data, a_data = di_train_A.next()
        y_data, b_data = di_train_B.next()
        x.d, a.d = x_data, a_data
        y.d, b.d = y_data, b_data

        solvers_zero_grad()

        # Image Augmentation
        nn.forward_all([xa_aug, yb_aug], clear_buffer=True)

        # Generate fake images
        nn.forward_all([xa_fake, yb_fake], clear_no_need_grad=True)

        # -------- Train Discriminator --------

        loss_dis.forward(clear_no_need_grad=True)
        monitors_dis.add(i)

        loss_dis.backward(clear_buffer=True)
        solver_dis_x.update()
        solver_dis_y.update()

        # -------- Train Generators --------

        # since the gradients computed above remain, reset to zero.
        xa_fake_unlinked.grad.zero()
        yb_fake_unlinked.grad.zero()
        solvers_zero_grad()

        loss_gen.forward(clear_no_need_grad=True)

        monitors_gen.add(i)
        monitor_time.add(i)

        loss_gen.backward(clear_buffer=True)
        xa_fake.backward(grad=None, clear_buffer=True)
        yb_fake.backward(grad=None, clear_buffer=True)
        solver_gen_x2y.update()
        solver_gen_y2x.update()

        if i % (n_images // args.batch_size) == 0:
            # save translation results after every epoch.
            save_images(args,
                        i,
                        xa_aug,
                        yb_fake,
                        domain="x",
                        reconstructed=xa_recon)
            save_images(args,
                        i,
                        yb_aug,
                        xa_fake,
                        domain="y",
                        reconstructed=yb_recon)

    # save pretrained parameters
    nn.save_parameters(os.path.join(args.model_save_path,
                                    'params_%06d.h5' % i))
Exemple #2
0
def train(args):

    input_photo = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    input_superpixel = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    input_cartoon = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    # output=>fake picture
    output = network.unet_generator(input_photo)
    #
    output = guided_filter(input_photo, output, r=1)

    blur_fake = guided_filter(output, output, r=5, eps=2e-1)
    blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1)

    gray_fake, gray_cartoon = utils.color_shift(output, input_cartoon)

    d_loss_gray, g_loss_gray = loss.lsgan_loss(network.disc_sn,
                                               gray_cartoon,
                                               gray_fake,
                                               scale=1,
                                               patch=True,
                                               name='disc_gray')
    d_loss_blur, g_loss_blur = loss.lsgan_loss(network.disc_sn,
                                               blur_cartoon,
                                               blur_fake,
                                               scale=1,
                                               patch=True,
                                               name='disc_blur')

    vgg_model = loss.Vgg19('vgg19_no_fc.npy')
    vgg_photo = vgg_model.build_conv4_4(input_photo)
    vgg_output = vgg_model.build_conv4_4(output)
    vgg_superpixel = vgg_model.build_conv4_4(input_superpixel)
    h, w, c = vgg_photo.get_shape().as_list()[1:]

    photo_loss = tf.reduce_mean(
        tf.losses.absolute_difference(vgg_photo, vgg_output)) / (h * w * c)
    superpixel_loss = tf.reduce_mean(tf.losses.absolute_difference\
                                     (vgg_superpixel, vgg_output))/(h*w*c)
    recon_loss = photo_loss + superpixel_loss
    tv_loss = loss.total_variation_loss(output)

    g_loss_total = 1e4 * tv_loss + 1e-1 * g_loss_blur + g_loss_gray + 2e2 * recon_loss
    d_loss_total = d_loss_blur + d_loss_gray

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'gene' in var.name]
    disc_vars = [var for var in all_vars if 'disc' in var.name]

    tf.summary.scalar('tv_loss', tv_loss)
    tf.summary.scalar('photo_loss', photo_loss)
    tf.summary.scalar('superpixel_loss', superpixel_loss)
    tf.summary.scalar('recon_loss', recon_loss)
    tf.summary.scalar('d_loss_gray', d_loss_gray)
    tf.summary.scalar('g_loss_gray', g_loss_gray)
    tf.summary.scalar('d_loss_blur', d_loss_blur)
    tf.summary.scalar('g_loss_blur', g_loss_blur)
    tf.summary.scalar('d_loss_total', d_loss_total)
    tf.summary.scalar('g_loss_total', g_loss_total)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):

        g_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
                                        .minimize(g_loss_total, var_list=gene_vars)

        d_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
                                        .minimize(d_loss_total, var_list=disc_vars)
    '''
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    '''
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    train_writer = tf.summary.FileWriter(args.save_dir + '/train_log')
    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20)

    with tf.device('/device:GPU:0'):

        sess.run(tf.global_variables_initializer())
        saver.restore(sess,
                      tf.train.latest_checkpoint('pretrain/saved_models'))

        face_photo_dir = 'dataset/photo_face'
        face_photo_list = utils.load_image_list(face_photo_dir)
        scenery_photo_dir = 'dataset/photo_scenery'
        scenery_photo_list = utils.load_image_list(scenery_photo_dir)

        face_cartoon_dir = 'dataset/cartoon_face'
        face_cartoon_list = utils.load_image_list(face_cartoon_dir)
        scenery_cartoon_dir = 'dataset/cartoon_scenery'
        scenery_cartoon_list = utils.load_image_list(scenery_cartoon_dir)

        for total_iter in tqdm(range(args.total_iter)):

            if np.mod(total_iter, 5) == 0:
                photo_batch = utils.next_batch(face_photo_list,
                                               args.batch_size)
                cartoon_batch = utils.next_batch(face_cartoon_list,
                                                 args.batch_size)
            else:
                photo_batch = utils.next_batch(scenery_photo_list,
                                               args.batch_size)
                cartoon_batch = utils.next_batch(scenery_cartoon_list,
                                                 args.batch_size)

            inter_out = sess.run(output,
                                 feed_dict={
                                     input_photo: photo_batch,
                                     input_superpixel: photo_batch,
                                     input_cartoon: cartoon_batch
                                 })
            '''
            adaptive coloring has to be applied with the clip_by_value 
            in the last layer of generator network, which is not very stable.
            to stabiliy reproduce our results, please use power=1.0
            and comment the clip_by_value function in the network.py first
            If this works, then try to use adaptive color with clip_by_value.
            '''
            if args.use_enhance:
                superpixel_batch = utils.selective_adacolor(inter_out,
                                                            power=1.2)
            else:
                superpixel_batch = utils.simple_superpixel(inter_out,
                                                           seg_num=200)

            _, g_loss, r_loss = sess.run(
                [g_optim, g_loss_total, recon_loss],
                feed_dict={
                    input_photo: photo_batch,
                    input_superpixel: superpixel_batch,
                    input_cartoon: cartoon_batch
                })

            _, d_loss, train_info = sess.run(
                [d_optim, d_loss_total, summary_op],
                feed_dict={
                    input_photo: photo_batch,
                    input_superpixel: superpixel_batch,
                    input_cartoon: cartoon_batch
                })

            train_writer.add_summary(train_info, total_iter)

            if np.mod(total_iter + 1, 50) == 0:

                print('Iter: {}, d_loss: {}, g_loss: {}, recon_loss: {}'.\
                        format(total_iter, d_loss, g_loss, r_loss))
                if np.mod(total_iter + 1, 500) == 0:
                    saver.save(sess,
                               args.save_dir + '/saved_models/model',
                               write_meta_graph=False,
                               global_step=total_iter)

                    photo_face = utils.next_batch(face_photo_list,
                                                  args.batch_size)
                    cartoon_face = utils.next_batch(face_cartoon_list,
                                                    args.batch_size)
                    photo_scenery = utils.next_batch(scenery_photo_list,
                                                     args.batch_size)
                    cartoon_scenery = utils.next_batch(scenery_cartoon_list,
                                                       args.batch_size)

                    result_face = sess.run(output,
                                           feed_dict={
                                               input_photo: photo_face,
                                               input_superpixel: photo_face,
                                               input_cartoon: cartoon_face
                                           })

                    result_scenery = sess.run(output,
                                              feed_dict={
                                                  input_photo: photo_scenery,
                                                  input_superpixel:
                                                  photo_scenery,
                                                  input_cartoon:
                                                  cartoon_scenery
                                              })

                    utils.write_batch_image(
                        result_face, args.save_dir + '/images',
                        str(total_iter) + '_face_result.jpg', 4)
                    utils.write_batch_image(
                        photo_face, args.save_dir + '/images',
                        str(total_iter) + '_face_photo.jpg', 4)

                    utils.write_batch_image(
                        result_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_result.jpg', 4)
                    utils.write_batch_image(
                        photo_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_photo.jpg', 4)
Exemple #3
0
def train(args):

    # get context

    ctx = get_extension_context(args.context)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    config = read_yaml(args.config)

    if args.info:
        config.monitor_params.info = args.info

    if comm.size == 1:
        comm = None
    else:
        # disable outputs from logger except its rank = 0
        if comm.rank > 0:
            import logging
            logger.setLevel(logging.ERROR)

    test = False
    train_params = config.train_params
    dataset_params = config.dataset_params
    model_params = config.model_params

    loss_flags = get_loss_flags(train_params)

    start_epoch = 0

    rng = np.random.RandomState(device_id)
    data_iterator = frame_data_iterator(
        root_dir=dataset_params.root_dir,
        frame_shape=dataset_params.frame_shape,
        id_sampling=dataset_params.id_sampling,
        is_train=True,
        random_seed=rng,
        augmentation_params=dataset_params.augmentation_params,
        batch_size=train_params['batch_size'],
        shuffle=True,
        with_memory_cache=False,
        with_file_cache=False)

    if n_devices > 1:
        data_iterator = data_iterator.slice(rng=rng,
                                            num_of_slices=comm.size,
                                            slice_pos=comm.rank)
        # workaround not to use memory cache
        data_iterator._data_source._on_memory = False
        logger.info("Disabled on memory data cache.")

    bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    with nn.parameter_scope("kp_detector"):
        # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))}

        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=test,
                                    comm=comm)
        persistent_all(kp_source)

        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=test,
                                     comm=comm)
        persistent_all(kp_driving)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=kp_source,
                                              kp_driving=kp_driving,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=test,
                                              comm=comm)
        # generated is a dictionary containing;
        # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
        # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4))
        # 'occlusion_map': Variable((bs, 1, h/4, w/4))
        # 'deformed': Variable((bs, c, h, w))
        # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator.

    generated["prediction"].persistent = True

    pyramide_real = get_image_pyramid(driving, train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_real)

    pyramide_fake = get_image_pyramid(generated['prediction'],
                                      train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_fake)

    total_loss_G = None  # dammy. defined temporarily
    loss_var_dict = {}

    # perceptual loss using VGG19 (always applied)
    if loss_flags.use_perceptual_loss:
        logger.info("Use Perceptual Loss.")
        scales = train_params.scales
        weights = train_params.loss_weights.perceptual
        vgg_param_path = train_params.vgg_param_path
        percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales,
                                      weights, vgg_param_path)
        percep_loss.persistent = True
        loss_var_dict['perceptual_loss'] = percep_loss
        total_loss_G = percep_loss

    # (LS)GAN loss and feature matching loss
    if loss_flags.use_gan_loss:
        logger.info("Use GAN Loss.")
        with nn.parameter_scope("discriminator"):
            discriminator_maps_generated = multiscale_discriminator(
                pyramide_fake,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

            discriminator_maps_real = multiscale_discriminator(
                pyramide_real,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

        for v in discriminator_maps_generated["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_generated["prediction_map_1"].persistent = True

        for v in discriminator_maps_real["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_real["prediction_map_1"].persistent = True

        for i, scale in enumerate(model_params.discriminator_params.scales):
            key = f'prediction_map_{scale}'.replace('.', '-')
            lsgan_loss_weight = train_params.loss_weights.generator_gan
            # LSGAN loss for Generator
            if i == 0:
                gan_loss_gen = lsgan_loss(discriminator_maps_generated[key],
                                          lsgan_loss_weight)
            else:
                gan_loss_gen += lsgan_loss(discriminator_maps_generated[key],
                                           lsgan_loss_weight)
            # LSGAN loss for Discriminator
            if i == 0:
                gan_loss_dis = lsgan_loss(discriminator_maps_real[key],
                                          lsgan_loss_weight,
                                          discriminator_maps_generated[key])
            else:
                gan_loss_dis += lsgan_loss(discriminator_maps_real[key],
                                           lsgan_loss_weight,
                                           discriminator_maps_generated[key])
        gan_loss_dis.persistent = True
        loss_var_dict['gan_loss_dis'] = gan_loss_dis
        total_loss_D = gan_loss_dis
        total_loss_D.persistent = True

        gan_loss_gen.persistent = True
        loss_var_dict['gan_loss_gen'] = gan_loss_gen
        total_loss_G += gan_loss_gen

        if loss_flags.use_feature_matching_loss:
            logger.info("Use Feature Matching Loss.")
            fm_weights = train_params.loss_weights.feature_matching
            fm_loss = feature_matching_loss(discriminator_maps_real,
                                            discriminator_maps_generated,
                                            model_params, fm_weights)
            fm_loss.persistent = True
            loss_var_dict['feature_matching_loss'] = fm_loss
            total_loss_G += fm_loss

    # transform loss
    if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss:
        transform = Transform(bs, **config.train_params.transform_params)
        transformed_frame = transform.transform_frame(driving)

        with nn.parameter_scope("kp_detector"):
            transformed_kp = detect_keypoint(transformed_frame,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=test,
                                             comm=comm)
        persistent_all(transformed_kp)

        # Value loss part
        if loss_flags.use_equivariance_value_loss:
            logger.info("Use Equivariance Value Loss.")
            warped_kp_value = transform.warp_coordinates(
                transformed_kp['value'])
            eq_value_weight = train_params.loss_weights.equivariance_value

            eq_value_loss = equivariance_value_loss(kp_driving['value'],
                                                    warped_kp_value,
                                                    eq_value_weight)
            eq_value_loss.persistent = True
            loss_var_dict['equivariance_value_loss'] = eq_value_loss
            total_loss_G += eq_value_loss

        # jacobian loss part
        if loss_flags.use_equivariance_jacobian_loss:
            logger.info("Use Equivariance Jacobian Loss.")
            arithmetic_jacobian = transform.jacobian(transformed_kp['value'])
            eq_jac_weight = train_params.loss_weights.equivariance_jacobian
            eq_jac_loss = equivariance_jacobian_loss(
                kp_driving['jacobian'], arithmetic_jacobian,
                transformed_kp['jacobian'], eq_jac_weight)
            eq_jac_loss.persistent = True
            loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss
            total_loss_G += eq_jac_loss

    assert total_loss_G is not None
    total_loss_G.persistent = True
    loss_var_dict['total_loss_gen'] = total_loss_G

    # -------------------- Create Monitors --------------------
    monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors(
        config, loss_flags, loss_var_dict)

    if device_id == 0:
        # Dump training info .yaml
        _ = shutil.copy(args.config, log_dir)  # copy the config yaml
        training_info_yaml = os.path.join(log_dir, "training_info.yaml")
        os.rename(os.path.join(log_dir, os.path.basename(args.config)),
                  training_info_yaml)
        # then add additional information
        with open(training_info_yaml, "a", encoding="utf-8") as f:
            f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None")

    # -------------------- Solver Setup --------------------
    solvers = setup_solvers(train_params)
    solver_generator = solvers["generator"]
    solver_discriminator = solvers["discriminator"]
    solver_kp_detector = solvers["kp_detector"]

    # max epochs
    num_epochs = train_params['num_epochs']

    # iteration per epoch
    num_iter_per_epoch = data_iterator.size // bs
    # will be increased by num_repeat
    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        num_iter_per_epoch *= config.train_params.num_repeats

    # modify learning rate if current epoch exceeds the number defined in
    lr_decay_at_epochs = train_params['epoch_milestones']  # ex. [60, 90]
    gamma = 0.1  # decay rate

    # -------------------- For finetuning ---------------------
    if args.ft_params:
        assert os.path.isfile(args.ft_params)
        logger.info(f"load {args.ft_params} for finetuning.")
        nn.load_parameters(args.ft_params)
        start_epoch = int(
            os.path.splitext(os.path.basename(
                args.ft_params))[0].split("epoch_")[1])

        # set solver's state
        for name, solver in solvers.items():
            saved_states = os.path.join(
                os.path.dirname(args.ft_params),
                f"state_{name}_at_epoch_{start_epoch}.h5")
            solver.load_states(saved_states)

        start_epoch += 1
        logger.info(f"Resuming from epoch {start_epoch}.")

    logger.info(
        f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch."
    )

    for e in range(start_epoch, num_epochs):
        logger.info(f"Epoch: {e} / {num_epochs}.")
        data_iterator._reset()  # rewind the iterator at the beginning

        # learning rate scheduler
        if e in lr_decay_at_epochs:
            logger.info("Learning rate decayed.")
            learning_rate_decay(solvers, gamma=gamma)

        for i in range(num_iter_per_epoch):
            _driving, _source = data_iterator.next()
            source.d = _source
            driving.d = _driving

            # update generator and keypoint detector
            total_loss_G.forward()

            if device_id == 0:
                monitors_gen.add((e * num_iter_per_epoch + i) * n_devices)

            solver_generator.zero_grad()
            solver_kp_detector.zero_grad()

            callback = None
            if n_devices > 1:
                params = [x.grad for x in solver_generator.get_parameters().values()] + \
                         [x.grad for x in solver_kp_detector.get_parameters().values()]
                callback = comm.all_reduce_callback(params, 2 << 20)
            total_loss_G.backward(clear_buffer=True,
                                  communicator_callbacks=callback)

            solver_generator.update()
            solver_kp_detector.update()

            if loss_flags.use_gan_loss:
                # update discriminator

                total_loss_D.forward(clear_no_need_grad=True)
                if device_id == 0:
                    monitors_dis.add((e * num_iter_per_epoch + i) * n_devices)

                solver_discriminator.zero_grad()

                callback = None
                if n_devices > 1:
                    params = [
                        x.grad for x in
                        solver_discriminator.get_parameters().values()
                    ]
                    callback = comm.all_reduce_callback(params, 2 << 20)
                total_loss_D.backward(clear_buffer=True,
                                      communicator_callbacks=callback)

                solver_discriminator.update()

            if device_id == 0:
                monitor_time.add((e * num_iter_per_epoch + i) * n_devices)

            if device_id == 0 and (
                (e * num_iter_per_epoch + i) *
                    n_devices) % config.monitor_params.visualize_freq == 0:
                images_to_visualize = [
                    source.d, driving.d, generated["prediction"].d
                ]
                visuals = combine_images(images_to_visualize)
                monitor_vis.add((e * num_iter_per_epoch + i) * n_devices,
                                visuals)

        if device_id == 0:
            if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1:
                save_parameters(e, log_dir, solvers)

    return