def train(args): # Variable size. bs, ch, h, w = args.batch_size, 3, args.loadSizeH, args.loadSizeW # Determine normalization method. if args.norm == "instance": norm_layer = functools.partial(PF.instance_normalization, fix_parameters=True, no_bias=True, no_scale=True) else: norm_layer = PF.batch_normalization # Prepare Generator and Discriminator based on user config. generator = functools.partial(models.generator, input_nc=args.input_nc, output_nc=args.output_nc, ngf=args.ngf, norm_layer=norm_layer, use_dropout=False, n_blocks=9, padding_type='reflect') discriminator = functools.partial(models.discriminator, input_nc=args.output_nc, ndf=args.ndf, n_layers=args.n_layers_D, norm_layer=norm_layer, use_sigmoid=False) # --------------------- Computation Graphs -------------------- # Input images and masks of both source / target domain x = nn.Variable([bs, ch, h, w], need_grad=False) a = nn.Variable([bs, 1, h, w], need_grad=False) y = nn.Variable([bs, ch, h, w], need_grad=False) b = nn.Variable([bs, 1, h, w], need_grad=False) # Apply image augmentation and get an unlinked variable xa_aug = image_augmentation(args, x, a) xa_aug.persistent = True xa_aug_unlinked = xa_aug.get_unlinked_variable() yb_aug = image_augmentation(args, y, b) yb_aug.persistent = True yb_aug_unlinked = yb_aug.get_unlinked_variable() # variables used for Image Pool x_history = nn.Variable([bs, ch, h, w]) a_history = nn.Variable([bs, 1, h, w]) y_history = nn.Variable([bs, ch, h, w]) b_history = nn.Variable([bs, 1, h, w]) # Generate Images (x -> y') with nn.parameter_scope("gen_x2y"): yb_fake = generator(xa_aug_unlinked) yb_fake.persistent = True yb_fake_unlinked = yb_fake.get_unlinked_variable() # Generate Images (y -> x') with nn.parameter_scope("gen_y2x"): xa_fake = generator(yb_aug_unlinked) xa_fake.persistent = True xa_fake_unlinked = xa_fake.get_unlinked_variable() # Reconstruct Images (y' -> x) with nn.parameter_scope("gen_y2x"): xa_recon = generator(yb_fake_unlinked) xa_recon.persistent = True # Reconstruct Images (x' -> y) with nn.parameter_scope("gen_x2y"): yb_recon = generator(xa_fake_unlinked) yb_recon.persistent = True # Use Discriminator on y' and x' with nn.parameter_scope("dis_y"): d_y_fake = discriminator(yb_fake_unlinked) d_y_fake.persistent = True with nn.parameter_scope("dis_x"): d_x_fake = discriminator(xa_fake_unlinked) d_x_fake.persistent = True # Use Discriminator on y and x with nn.parameter_scope("dis_y"): d_y_real = discriminator(yb_aug_unlinked) with nn.parameter_scope("dis_x"): d_x_real = discriminator(xa_aug_unlinked) # Identity Mapping (x -> x) with nn.parameter_scope("gen_y2x"): xa_idt = generator(xa_aug_unlinked) # Identity Mapping (y -> y) with nn.parameter_scope("gen_x2y"): yb_idt = generator(yb_aug_unlinked) # -------------------- Loss -------------------- # (LS)GAN Loss (for Discriminator) loss_dis_x = (loss.lsgan_loss(d_y_fake, False) + loss.lsgan_loss(d_y_real, True)) * 0.5 loss_dis_y = (loss.lsgan_loss(d_x_fake, False) + loss.lsgan_loss(d_x_real, True)) * 0.5 loss_dis = loss_dis_x + loss_dis_y # Cycle Consistency Loss loss_cyc_x = args.lambda_cyc * loss.recon_loss(xa_recon, xa_aug_unlinked) loss_cyc_y = args.lambda_cyc * loss.recon_loss(yb_recon, yb_aug_unlinked) loss_cyc = loss_cyc_x + loss_cyc_y # Identity Mapping Loss loss_idt_x = args.lambda_idt * loss.recon_loss(xa_idt, xa_aug_unlinked) loss_idt_y = args.lambda_idt * loss.recon_loss(yb_idt, yb_aug_unlinked) loss_idt = loss_idt_x + loss_idt_y # Context Preserving Loss loss_ctx_x = args.lambda_ctx * \ loss.context_preserving_loss(xa_aug_unlinked, yb_fake_unlinked) loss_ctx_y = args.lambda_ctx * \ loss.context_preserving_loss(yb_aug_unlinked, xa_fake_unlinked) loss_ctx = loss_ctx_x + loss_ctx_y # (LS)GAN Loss (for Generator) d_loss_gen_x = loss.lsgan_loss(d_x_fake, True) d_loss_gen_y = loss.lsgan_loss(d_y_fake, True) d_loss_gen = d_loss_gen_x + d_loss_gen_y # Total Loss for Generator loss_gen = loss_cyc + loss_idt + loss_ctx + d_loss_gen # --------------------- Solvers -------------------- # Initial learning rates G_lr = args.learning_rate_G #D_lr = args.learning_rate_D # As opposed to the description in the paper, D_lr is set the same as G_lr. D_lr = args.learning_rate_G # Define solvers solver_gen_x2y = S.Adam(G_lr, args.beta1, args.beta2) solver_gen_y2x = S.Adam(G_lr, args.beta1, args.beta2) solver_dis_x = S.Adam(D_lr, args.beta1, args.beta2) solver_dis_y = S.Adam(D_lr, args.beta1, args.beta2) # Set Parameters to each solver with nn.parameter_scope("gen_x2y"): solver_gen_x2y.set_parameters(nn.get_parameters()) with nn.parameter_scope("gen_y2x"): solver_gen_y2x.set_parameters(nn.get_parameters()) with nn.parameter_scope("dis_x"): solver_dis_x.set_parameters(nn.get_parameters()) with nn.parameter_scope("dis_y"): solver_dis_y.set_parameters(nn.get_parameters()) # create convenient functions manipulating Solvers def solvers_zero_grad(): # Zeroing Gradients of all solvers solver_gen_x2y.zero_grad() solver_gen_y2x.zero_grad() solver_dis_x.zero_grad() solver_dis_y.zero_grad() def solvers_update_parameters(new_D_lr, new_G_lr): # Learning rate updater solver_gen_x2y.set_learning_rate(new_G_lr) solver_gen_y2x.set_learning_rate(new_G_lr) solver_dis_x.set_learning_rate(new_D_lr) solver_dis_y.set_learning_rate(new_D_lr) # -------------------- Data Iterators -------------------- ds_train_A = insta_gan_data_source(args, train=True, domain="A", shuffle=True) di_train_A = insta_gan_data_iterator(ds_train_A, args.batch_size) ds_train_B = insta_gan_data_source(args, train=True, domain="B", shuffle=True) di_train_B = insta_gan_data_iterator(ds_train_B, args.batch_size) # -------------------- Monitors -------------------- monitoring_targets_dis = { 'discriminator_loss_x': loss_dis_x, 'discriminator_loss_y': loss_dis_y } monitors_dis = Monitors(args, monitoring_targets_dis) monitoring_targets_gen = { 'generator_loss_x': d_loss_gen_x, 'generator_loss_y': d_loss_gen_y, 'reconstruction_loss_x': loss_cyc_x, 'reconstruction_loss_y': loss_cyc_y, 'identity_mapping_loss_x': loss_idt_x, 'identity_mapping_loss_y': loss_idt_y, 'content_preserving_loss_x': loss_ctx_x, 'content_preserving_loss_y': loss_ctx_y } monitors_gen = Monitors(args, monitoring_targets_gen) monitor_time = MonitorTimeElapsed("Training_time", Monitor(args.monitor_path), args.log_step) # Training loop epoch = 0 n_images = max([ds_train_B.size, ds_train_A.size]) print("{} images exist.".format(n_images)) max_iter = args.max_epoch * n_images // args.batch_size decay_iter = args.max_epoch - args.lr_decay_start_epoch for i in range(max_iter): if i % (n_images // args.batch_size) == 0 and i > 0: # Learning Rate Decay epoch += 1 print("epoch {}".format(epoch)) if epoch >= args.lr_decay_start_epoch: new_D_lr = D_lr * \ (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) / float(decay_iter - 1)) new_G_lr = G_lr * \ (1.0 - max(0, epoch - args.lr_decay_start_epoch - 1) / float(decay_iter - 1)) solvers_update_parameters(new_D_lr, new_G_lr) print("Current learning rate for Discriminator: {}".format( solver_dis_x.learning_rate())) print("Current learning rate for Generator: {}".format( solver_gen_x2y.learning_rate())) # Get data x_data, a_data = di_train_A.next() y_data, b_data = di_train_B.next() x.d, a.d = x_data, a_data y.d, b.d = y_data, b_data solvers_zero_grad() # Image Augmentation nn.forward_all([xa_aug, yb_aug], clear_buffer=True) # Generate fake images nn.forward_all([xa_fake, yb_fake], clear_no_need_grad=True) # -------- Train Discriminator -------- loss_dis.forward(clear_no_need_grad=True) monitors_dis.add(i) loss_dis.backward(clear_buffer=True) solver_dis_x.update() solver_dis_y.update() # -------- Train Generators -------- # since the gradients computed above remain, reset to zero. xa_fake_unlinked.grad.zero() yb_fake_unlinked.grad.zero() solvers_zero_grad() loss_gen.forward(clear_no_need_grad=True) monitors_gen.add(i) monitor_time.add(i) loss_gen.backward(clear_buffer=True) xa_fake.backward(grad=None, clear_buffer=True) yb_fake.backward(grad=None, clear_buffer=True) solver_gen_x2y.update() solver_gen_y2x.update() if i % (n_images // args.batch_size) == 0: # save translation results after every epoch. save_images(args, i, xa_aug, yb_fake, domain="x", reconstructed=xa_recon) save_images(args, i, yb_aug, xa_fake, domain="y", reconstructed=yb_recon) # save pretrained parameters nn.save_parameters(os.path.join(args.model_save_path, 'params_%06d.h5' % i))
def train(args): input_photo = tf.placeholder( tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3]) input_superpixel = tf.placeholder( tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3]) input_cartoon = tf.placeholder( tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3]) # output=>fake picture output = network.unet_generator(input_photo) # output = guided_filter(input_photo, output, r=1) blur_fake = guided_filter(output, output, r=5, eps=2e-1) blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1) gray_fake, gray_cartoon = utils.color_shift(output, input_cartoon) d_loss_gray, g_loss_gray = loss.lsgan_loss(network.disc_sn, gray_cartoon, gray_fake, scale=1, patch=True, name='disc_gray') d_loss_blur, g_loss_blur = loss.lsgan_loss(network.disc_sn, blur_cartoon, blur_fake, scale=1, patch=True, name='disc_blur') vgg_model = loss.Vgg19('vgg19_no_fc.npy') vgg_photo = vgg_model.build_conv4_4(input_photo) vgg_output = vgg_model.build_conv4_4(output) vgg_superpixel = vgg_model.build_conv4_4(input_superpixel) h, w, c = vgg_photo.get_shape().as_list()[1:] photo_loss = tf.reduce_mean( tf.losses.absolute_difference(vgg_photo, vgg_output)) / (h * w * c) superpixel_loss = tf.reduce_mean(tf.losses.absolute_difference\ (vgg_superpixel, vgg_output))/(h*w*c) recon_loss = photo_loss + superpixel_loss tv_loss = loss.total_variation_loss(output) g_loss_total = 1e4 * tv_loss + 1e-1 * g_loss_blur + g_loss_gray + 2e2 * recon_loss d_loss_total = d_loss_blur + d_loss_gray all_vars = tf.trainable_variables() gene_vars = [var for var in all_vars if 'gene' in var.name] disc_vars = [var for var in all_vars if 'disc' in var.name] tf.summary.scalar('tv_loss', tv_loss) tf.summary.scalar('photo_loss', photo_loss) tf.summary.scalar('superpixel_loss', superpixel_loss) tf.summary.scalar('recon_loss', recon_loss) tf.summary.scalar('d_loss_gray', d_loss_gray) tf.summary.scalar('g_loss_gray', g_loss_gray) tf.summary.scalar('d_loss_blur', d_loss_blur) tf.summary.scalar('g_loss_blur', g_loss_blur) tf.summary.scalar('d_loss_total', d_loss_total) tf.summary.scalar('g_loss_total', g_loss_total) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): g_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ .minimize(g_loss_total, var_list=gene_vars) d_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ .minimize(d_loss_total, var_list=disc_vars) ''' config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) ''' gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) train_writer = tf.summary.FileWriter(args.save_dir + '/train_log') summary_op = tf.summary.merge_all() saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20) with tf.device('/device:GPU:0'): sess.run(tf.global_variables_initializer()) saver.restore(sess, tf.train.latest_checkpoint('pretrain/saved_models')) face_photo_dir = 'dataset/photo_face' face_photo_list = utils.load_image_list(face_photo_dir) scenery_photo_dir = 'dataset/photo_scenery' scenery_photo_list = utils.load_image_list(scenery_photo_dir) face_cartoon_dir = 'dataset/cartoon_face' face_cartoon_list = utils.load_image_list(face_cartoon_dir) scenery_cartoon_dir = 'dataset/cartoon_scenery' scenery_cartoon_list = utils.load_image_list(scenery_cartoon_dir) for total_iter in tqdm(range(args.total_iter)): if np.mod(total_iter, 5) == 0: photo_batch = utils.next_batch(face_photo_list, args.batch_size) cartoon_batch = utils.next_batch(face_cartoon_list, args.batch_size) else: photo_batch = utils.next_batch(scenery_photo_list, args.batch_size) cartoon_batch = utils.next_batch(scenery_cartoon_list, args.batch_size) inter_out = sess.run(output, feed_dict={ input_photo: photo_batch, input_superpixel: photo_batch, input_cartoon: cartoon_batch }) ''' adaptive coloring has to be applied with the clip_by_value in the last layer of generator network, which is not very stable. to stabiliy reproduce our results, please use power=1.0 and comment the clip_by_value function in the network.py first If this works, then try to use adaptive color with clip_by_value. ''' if args.use_enhance: superpixel_batch = utils.selective_adacolor(inter_out, power=1.2) else: superpixel_batch = utils.simple_superpixel(inter_out, seg_num=200) _, g_loss, r_loss = sess.run( [g_optim, g_loss_total, recon_loss], feed_dict={ input_photo: photo_batch, input_superpixel: superpixel_batch, input_cartoon: cartoon_batch }) _, d_loss, train_info = sess.run( [d_optim, d_loss_total, summary_op], feed_dict={ input_photo: photo_batch, input_superpixel: superpixel_batch, input_cartoon: cartoon_batch }) train_writer.add_summary(train_info, total_iter) if np.mod(total_iter + 1, 50) == 0: print('Iter: {}, d_loss: {}, g_loss: {}, recon_loss: {}'.\ format(total_iter, d_loss, g_loss, r_loss)) if np.mod(total_iter + 1, 500) == 0: saver.save(sess, args.save_dir + '/saved_models/model', write_meta_graph=False, global_step=total_iter) photo_face = utils.next_batch(face_photo_list, args.batch_size) cartoon_face = utils.next_batch(face_cartoon_list, args.batch_size) photo_scenery = utils.next_batch(scenery_photo_list, args.batch_size) cartoon_scenery = utils.next_batch(scenery_cartoon_list, args.batch_size) result_face = sess.run(output, feed_dict={ input_photo: photo_face, input_superpixel: photo_face, input_cartoon: cartoon_face }) result_scenery = sess.run(output, feed_dict={ input_photo: photo_scenery, input_superpixel: photo_scenery, input_cartoon: cartoon_scenery }) utils.write_batch_image( result_face, args.save_dir + '/images', str(total_iter) + '_face_result.jpg', 4) utils.write_batch_image( photo_face, args.save_dir + '/images', str(total_iter) + '_face_photo.jpg', 4) utils.write_batch_image( result_scenery, args.save_dir + '/images', str(total_iter) + '_scenery_result.jpg', 4) utils.write_batch_image( photo_scenery, args.save_dir + '/images', str(total_iter) + '_scenery_photo.jpg', 4)
def train(args): # get context ctx = get_extension_context(args.context) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) config = read_yaml(args.config) if args.info: config.monitor_params.info = args.info if comm.size == 1: comm = None else: # disable outputs from logger except its rank = 0 if comm.rank > 0: import logging logger.setLevel(logging.ERROR) test = False train_params = config.train_params dataset_params = config.dataset_params model_params = config.model_params loss_flags = get_loss_flags(train_params) start_epoch = 0 rng = np.random.RandomState(device_id) data_iterator = frame_data_iterator( root_dir=dataset_params.root_dir, frame_shape=dataset_params.frame_shape, id_sampling=dataset_params.id_sampling, is_train=True, random_seed=rng, augmentation_params=dataset_params.augmentation_params, batch_size=train_params['batch_size'], shuffle=True, with_memory_cache=False, with_file_cache=False) if n_devices > 1: data_iterator = data_iterator.slice(rng=rng, num_of_slices=comm.size, slice_pos=comm.rank) # workaround not to use memory cache data_iterator._data_source._on_memory = False logger.info("Disabled on memory data cache.") bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) with nn.parameter_scope("kp_detector"): # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))} kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_source) kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_driving) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=kp_source, kp_driving=kp_driving, **model_params.generator_params, **model_params.common_params, test=test, comm=comm) # generated is a dictionary containing; # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4)) # 'occlusion_map': Variable((bs, 1, h/4, w/4)) # 'deformed': Variable((bs, c, h, w)) # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator. generated["prediction"].persistent = True pyramide_real = get_image_pyramid(driving, train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_real) pyramide_fake = get_image_pyramid(generated['prediction'], train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_fake) total_loss_G = None # dammy. defined temporarily loss_var_dict = {} # perceptual loss using VGG19 (always applied) if loss_flags.use_perceptual_loss: logger.info("Use Perceptual Loss.") scales = train_params.scales weights = train_params.loss_weights.perceptual vgg_param_path = train_params.vgg_param_path percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales, weights, vgg_param_path) percep_loss.persistent = True loss_var_dict['perceptual_loss'] = percep_loss total_loss_G = percep_loss # (LS)GAN loss and feature matching loss if loss_flags.use_gan_loss: logger.info("Use GAN Loss.") with nn.parameter_scope("discriminator"): discriminator_maps_generated = multiscale_discriminator( pyramide_fake, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) discriminator_maps_real = multiscale_discriminator( pyramide_real, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) for v in discriminator_maps_generated["feature_maps_1"]: v.persistent = True discriminator_maps_generated["prediction_map_1"].persistent = True for v in discriminator_maps_real["feature_maps_1"]: v.persistent = True discriminator_maps_real["prediction_map_1"].persistent = True for i, scale in enumerate(model_params.discriminator_params.scales): key = f'prediction_map_{scale}'.replace('.', '-') lsgan_loss_weight = train_params.loss_weights.generator_gan # LSGAN loss for Generator if i == 0: gan_loss_gen = lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) else: gan_loss_gen += lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) # LSGAN loss for Discriminator if i == 0: gan_loss_dis = lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) else: gan_loss_dis += lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) gan_loss_dis.persistent = True loss_var_dict['gan_loss_dis'] = gan_loss_dis total_loss_D = gan_loss_dis total_loss_D.persistent = True gan_loss_gen.persistent = True loss_var_dict['gan_loss_gen'] = gan_loss_gen total_loss_G += gan_loss_gen if loss_flags.use_feature_matching_loss: logger.info("Use Feature Matching Loss.") fm_weights = train_params.loss_weights.feature_matching fm_loss = feature_matching_loss(discriminator_maps_real, discriminator_maps_generated, model_params, fm_weights) fm_loss.persistent = True loss_var_dict['feature_matching_loss'] = fm_loss total_loss_G += fm_loss # transform loss if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss: transform = Transform(bs, **config.train_params.transform_params) transformed_frame = transform.transform_frame(driving) with nn.parameter_scope("kp_detector"): transformed_kp = detect_keypoint(transformed_frame, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(transformed_kp) # Value loss part if loss_flags.use_equivariance_value_loss: logger.info("Use Equivariance Value Loss.") warped_kp_value = transform.warp_coordinates( transformed_kp['value']) eq_value_weight = train_params.loss_weights.equivariance_value eq_value_loss = equivariance_value_loss(kp_driving['value'], warped_kp_value, eq_value_weight) eq_value_loss.persistent = True loss_var_dict['equivariance_value_loss'] = eq_value_loss total_loss_G += eq_value_loss # jacobian loss part if loss_flags.use_equivariance_jacobian_loss: logger.info("Use Equivariance Jacobian Loss.") arithmetic_jacobian = transform.jacobian(transformed_kp['value']) eq_jac_weight = train_params.loss_weights.equivariance_jacobian eq_jac_loss = equivariance_jacobian_loss( kp_driving['jacobian'], arithmetic_jacobian, transformed_kp['jacobian'], eq_jac_weight) eq_jac_loss.persistent = True loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss total_loss_G += eq_jac_loss assert total_loss_G is not None total_loss_G.persistent = True loss_var_dict['total_loss_gen'] = total_loss_G # -------------------- Create Monitors -------------------- monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors( config, loss_flags, loss_var_dict) if device_id == 0: # Dump training info .yaml _ = shutil.copy(args.config, log_dir) # copy the config yaml training_info_yaml = os.path.join(log_dir, "training_info.yaml") os.rename(os.path.join(log_dir, os.path.basename(args.config)), training_info_yaml) # then add additional information with open(training_info_yaml, "a", encoding="utf-8") as f: f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None") # -------------------- Solver Setup -------------------- solvers = setup_solvers(train_params) solver_generator = solvers["generator"] solver_discriminator = solvers["discriminator"] solver_kp_detector = solvers["kp_detector"] # max epochs num_epochs = train_params['num_epochs'] # iteration per epoch num_iter_per_epoch = data_iterator.size // bs # will be increased by num_repeat if 'num_repeats' in train_params or train_params['num_repeats'] != 1: num_iter_per_epoch *= config.train_params.num_repeats # modify learning rate if current epoch exceeds the number defined in lr_decay_at_epochs = train_params['epoch_milestones'] # ex. [60, 90] gamma = 0.1 # decay rate # -------------------- For finetuning --------------------- if args.ft_params: assert os.path.isfile(args.ft_params) logger.info(f"load {args.ft_params} for finetuning.") nn.load_parameters(args.ft_params) start_epoch = int( os.path.splitext(os.path.basename( args.ft_params))[0].split("epoch_")[1]) # set solver's state for name, solver in solvers.items(): saved_states = os.path.join( os.path.dirname(args.ft_params), f"state_{name}_at_epoch_{start_epoch}.h5") solver.load_states(saved_states) start_epoch += 1 logger.info(f"Resuming from epoch {start_epoch}.") logger.info( f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch." ) for e in range(start_epoch, num_epochs): logger.info(f"Epoch: {e} / {num_epochs}.") data_iterator._reset() # rewind the iterator at the beginning # learning rate scheduler if e in lr_decay_at_epochs: logger.info("Learning rate decayed.") learning_rate_decay(solvers, gamma=gamma) for i in range(num_iter_per_epoch): _driving, _source = data_iterator.next() source.d = _source driving.d = _driving # update generator and keypoint detector total_loss_G.forward() if device_id == 0: monitors_gen.add((e * num_iter_per_epoch + i) * n_devices) solver_generator.zero_grad() solver_kp_detector.zero_grad() callback = None if n_devices > 1: params = [x.grad for x in solver_generator.get_parameters().values()] + \ [x.grad for x in solver_kp_detector.get_parameters().values()] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_G.backward(clear_buffer=True, communicator_callbacks=callback) solver_generator.update() solver_kp_detector.update() if loss_flags.use_gan_loss: # update discriminator total_loss_D.forward(clear_no_need_grad=True) if device_id == 0: monitors_dis.add((e * num_iter_per_epoch + i) * n_devices) solver_discriminator.zero_grad() callback = None if n_devices > 1: params = [ x.grad for x in solver_discriminator.get_parameters().values() ] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_D.backward(clear_buffer=True, communicator_callbacks=callback) solver_discriminator.update() if device_id == 0: monitor_time.add((e * num_iter_per_epoch + i) * n_devices) if device_id == 0 and ( (e * num_iter_per_epoch + i) * n_devices) % config.monitor_params.visualize_freq == 0: images_to_visualize = [ source.d, driving.d, generated["prediction"].d ] visuals = combine_images(images_to_visualize) monitor_vis.add((e * num_iter_per_epoch + i) * n_devices, visuals) if device_id == 0: if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1: save_parameters(e, log_dir, solvers) return