def test_transformer(config, netG, train_iterators, monitor, param_file): netG_A2B = netG['netG_A2B'] train_iterator_src, train_iterator_trg = train_iterators # Load boundary image to get Variable shapes bod_map_A = train_iterator_src.next()[0] bod_map_B = train_iterator_trg.next()[0] real_bod_map_A = nn.Variable(bod_map_A.shape) real_bod_map_B = nn.Variable(bod_map_B.shape) real_bod_map_A.persistent, real_bod_map_B.persistent = True, True ################### Graph Construction #################### # Generator with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_A2B'): fake_bod_map_B = netG_A2B( real_bod_map_A, test=True, norm_type=config["norm_type"]) # (1, 15, 64, 64) fake_bod_map_B.persistent = True # load parameters of networks with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_A2B'): nn.load_parameters(param_file) monitor_vis = nm.MonitorImage('result', monitor, interval=config["test"]["vis_interval"], num_images=1, normalize_method=lambda x: x) # Test i = 0 iter_per_epoch = train_iterator_src.size // config["test"]["batch_size"] + 1 if config["num_test"]: num_test = config["num_test"] else: num_test = train_iterator_src.size for _ in range(iter_per_epoch): bod_map_A = train_iterator_src.next()[0] bod_map_B = train_iterator_trg.next()[0] real_bod_map_A.d, real_bod_map_B.d = bod_map_A, bod_map_B # Generate fake images fake_bod_map_B.forward(clear_buffer=True) i += 1 images_to_visualize = [ real_bod_map_A.d, fake_bod_map_B.d, real_bod_map_B.d ] visuals = combine_images(images_to_visualize) monitor_vis.add(i, visuals) if i > num_test: break
def get_monitors(config, loss_flags, loss_var_dict, test=False): log_root_dir = config.monitor_params.monitor_path log_dir = os.path.join(log_root_dir, get_current_time()) # if additional information is given, add it if "info" in config.monitor_params: info = config.monitor_params.info log_dir = f'{log_dir}_{info}' master_monitor_misc = nm.Monitor(log_dir) monitor_vis = nm.MonitorImage('images', master_monitor_misc, interval=1, num_images=4, normalize_method=lambda x: x) if test: # when inference, returns the visualization monitor only return monitor_vis interval = config.monitor_params.monitor_freq monitoring_var_dict_gen = dict() monitoring_var_dict_dis = dict() if loss_flags.use_perceptual_loss: monitoring_var_dict_gen.update( {'perceptual_loss': loss_var_dict['perceptual_loss']}) if loss_flags.use_gan_loss: monitoring_var_dict_gen.update( {'gan_loss_gen': loss_var_dict['gan_loss_gen']}) if loss_flags.use_gan_loss: monitoring_var_dict_dis.update( {'gan_loss_dis': loss_var_dict['gan_loss_dis']}) if loss_flags.use_feature_matching_loss: monitoring_var_dict_gen.update( {'feature_matching_loss': loss_var_dict['feature_matching_loss']}) if loss_flags.use_equivariance_value_loss: monitoring_var_dict_gen.update( {'equivariance_value_loss': loss_var_dict['equivariance_value_loss']}) if loss_flags.use_equivariance_jacobian_loss: monitoring_var_dict_gen.update( {'equivariance_jacobian_loss': loss_var_dict['equivariance_jacobian_loss']}) monitoring_var_dict_gen.update( {'total_loss_gen': loss_var_dict['total_loss_gen']}) master_monitor_gen = nm.Monitor(log_dir) master_monitor_dis = nm.Monitor(log_dir) monitors_gen = MonitorManager(monitoring_var_dict_gen, master_monitor_gen, interval=interval) monitors_dis = MonitorManager(monitoring_var_dict_dis, master_monitor_dis, interval=interval) monitor_time = nm.MonitorTimeElapsed('time_training', master_monitor_misc, interval=interval) return monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir
def test(config, netG, train_iterator, monitor, param_file): # Load image and boundary image to get Variable shapes img, bod_map, bod_map_resize = train_iterator.next() real_img = nn.Variable(img.shape) real_bod_map = nn.Variable(bod_map.shape) real_bod_map_resize = nn.Variable(bod_map_resize.shape) ################### Graph Construction #################### # Generator with nn.parameter_scope('netG_decoder'): fake_img = netG(real_bod_map, test=False) fake_img.persistent = True # load parameters of networks with nn.parameter_scope('netG_decoder'): nn.load_parameters(param_file) monitor_vis = nm.MonitorImage('result', monitor, interval=config["test"]["vis_interval"], num_images=4, normalize_method=lambda x: x) # Test i = 0 iter_per_epoch = train_iterator.size // config["test"]["batch_size"] + 1 if config["num_test"]: num_test = config["num_test"] else: num_test = train_iterator.size for _ in range(iter_per_epoch): img, bod_map, bod_map_resize = train_iterator.next() real_img.d = img real_bod_map.d = bod_map real_bod_map_resize.d = bod_map_resize # Generate fake image fake_img.forward(clear_buffer=True) i += 1 images_to_visualize = [real_bod_map_resize.d, fake_img.d, img] visuals = combine_images(images_to_visualize) monitor_vis.add(i, visuals) if i > num_test: break
def train(config, netG, netD, solver_netG, solver_netD, train_iterator, monitor): if config["train"][ "feature_loss"] and config["train"]["feature_loss"]["lambda"] > 0: print( f'Applying VGG feature Loss, weight: {config["train"]["feature_loss"]["lambda"]}.' ) with_feature_loss = True else: with_feature_loss = False # Load image and boundary image to get Variable shapes img, bod_map, bod_map_resize = train_iterator.next() real_img = nn.Variable(img.shape) real_bod_map = nn.Variable(bod_map.shape) real_bod_map_resize = nn.Variable(bod_map_resize.shape) ################### Graph Construction #################### # Generator with nn.parameter_scope('netG_decoder'): fake_img = netG(real_bod_map, test=False) fake_img.persistent = True fake_img_unlinked = fake_img.get_unlinked_variable() # Discriminator with nn.parameter_scope('netD_decoder'): pred_fake = netD(F.concatenate(real_bod_map_resize, fake_img_unlinked, axis=1), test=False) pred_real = netD(F.concatenate(real_bod_map_resize, real_img, axis=1), test=False) real_target = F.constant(1, pred_fake.shape) fake_target = F.constant(0, pred_real.shape) ################### Loss Definition #################### # for Generator gan_loss_G = gan_loss(pred_fake, real_target) gan_loss_G.persistent = True weight_L1 = config["train"]["weight_L1"] L1_loss = recon_loss(fake_img_unlinked, real_img) L1_loss.persistent = True loss_netG = gan_loss_G + weight_L1 * L1_loss if with_feature_loss: feature_loss = vgg16_perceptual_loss(127.5 * (fake_img_unlinked + 1.), 127.5 * (real_img + 1.)) feature_loss.persistent = True loss_netG += feature_loss * config["train"]["feature_loss"]["lambda"] # for Discriminator loss_netD = (gan_loss(pred_real, real_target) + gan_loss(pred_fake, fake_target)) * 0.5 ################### Setting Solvers #################### # for Generator with nn.parameter_scope('netG_decoder'): solver_netG.set_parameters(nn.get_parameters()) # for Discrimintar with nn.parameter_scope('netD_decoder'): solver_netD.set_parameters(nn.get_parameters()) ################### Create Monitors #################### interval = config["monitor"]["interval"] monitors_G_dict = { 'loss_netG': loss_netG, 'loss_gan': gan_loss_G, 'L1_loss': L1_loss } if with_feature_loss: monitors_G_dict.update({'vgg_feature_loss': feature_loss}) monitors_G = MonitorManager(monitors_G_dict, monitor, interval=interval) monitors_D_dict = {'loss_netD': loss_netD} monitors_D = MonitorManager(monitors_D_dict, monitor, interval=interval) monitor_time = nm.MonitorTimeElapsed('time_training', monitor, interval=interval) monitor_vis = nm.MonitorImage('result', monitor, interval=1, num_images=4, normalize_method=lambda x: x) # Dump training information with open(os.path.join(monitor._save_path, "training_info.yaml"), "w", encoding="utf-8") as f: f.write(yaml.dump(config)) # Training epoch = config["train"]["epochs"] i = 0 lr_decay_start_at = config["train"]["lr_decay_start_at"] iter_per_epoch = train_iterator.size // config["train"]["batch_size"] + 1 for e in range(epoch): logger.info(f'Epoch = {e} / {epoch}') train_iterator._reset() # rewind the iterator if e > lr_decay_start_at: decay_coeff = 1.0 - max(0, e - lr_decay_start_at) / 50. lr_decayed = config["train"]["lr"] * decay_coeff print(f"learning rate decayed to {lr_decayed}") solver_netG.set_learning_rate(lr_decayed) solver_netD.set_learning_rate(lr_decayed) for _ in range(iter_per_epoch): img, bod_map, bod_map_resize = train_iterator.next() # bod_map_noize = np.random.random_sample(bod_map.shape) * 0.01 # bod_map_resize_noize = np.random.random_sample(bod_map_resize.shape) * 0.01 real_img.d = img real_bod_map.d = bod_map # + bod_map_noize real_bod_map_resize.d = bod_map_resize # + bod_map_resize_noize # Generate fake image fake_img.forward(clear_no_need_grad=True) # Update Discriminator solver_netD.zero_grad() solver_netG.zero_grad() loss_netD.forward(clear_no_need_grad=True) loss_netD.backward(clear_buffer=True) solver_netD.update() # Update Generator solver_netD.zero_grad() solver_netG.zero_grad() fake_img_unlinked.grad.zero() loss_netG.forward(clear_no_need_grad=True) loss_netG.backward(clear_buffer=True) fake_img.backward(grad=None) solver_netG.update() # Monitors monitor_time.add(i) monitors_G.add(i) monitors_D.add(i) i += 1 images_to_visualize = [real_bod_map_resize.d, fake_img.d, img] visuals = combine_images(images_to_visualize) monitor_vis.add(i, visuals) if e % config["monitor"]["save_interval"] == 0 or e == epoch - 1: # Save parameters of networks netG_save_path = os.path.join(monitor._save_path, f'netG_decoder_{e}.h5') with nn.parameter_scope('netG_decoder'): nn.save_parameters(netG_save_path) netD_save_path = os.path.join(monitor._save_path, f'netD_decoder_{e}.h5') with nn.parameter_scope('netD_decoder'): nn.save_parameters(netD_save_path)
def animate(args): # get context ctx = get_extension_context(args.context) nn.set_default_context(ctx) logger.setLevel(logging.ERROR) # to supress minor messages if not args.config: assert not args.params, "pretrained weights file is given, but corresponding config file is not. Please give both." download_provided_file( "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/voxceleb_trained_info.yaml" ) args.config = 'voxceleb_trained_info.yaml' download_provided_file( "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/pretrained_fomm_params.h5" ) config = read_yaml(args.config) dataset_params = config.dataset_params model_params = config.model_params if args.detailed: vis_params = config.visualizer_params visualizer = Visualizer(**vis_params) if not args.params: assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters." param_file = os.path.join(config.log_dir, config.saved_parameters) else: param_file = args.params print(f"Loading {param_file} for image animation...") nn.load_parameters(param_file) bs, h, w, c = [1] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving_initial = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) filename = args.driving # process repeated until all the test data is used driving_video = read_video( filename, dataset_params.frame_shape) # (#frames, h, w, 3) driving_video = np.transpose(driving_video, (0, 3, 1, 2)) # (#frames, 3, h, w) source_img = imread(args.source, channel_first=True, size=(256, 256)) / 255. source_img = source_img[:3] source.d = np.expand_dims(source_img, 0) driving_initial.d = driving_video[0][:3, ] with nn.parameter_scope("kp_detector"): kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_source) with nn.parameter_scope("kp_detector"): kp_driving_initial = detect_keypoint(driving_initial, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_driving_initial) with nn.parameter_scope("kp_detector"): kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_driving) if args.adapt_movement_scale: nn.forward_all([ kp_source["value"], kp_source["jacobian"], kp_driving_initial["value"], kp_driving_initial["jacobian"] ]) source_area = ConvexHull(kp_source['value'].d[0]).volume driving_area = ConvexHull(kp_driving_initial['value'].d[0]).volume adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) else: adapt_movement_scale = 1 kp_norm = adjust_kp(kp_source=unlink_all(kp_source), kp_driving=kp_driving, kp_driving_initial=unlink_all(kp_driving_initial), adapt_movement_scale=adapt_movement_scale, use_relative_movement=args.unuse_relative_movement, use_relative_jacobian=args.unuse_relative_jacobian) persistent_all(kp_norm) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=unlink_all(kp_source), kp_driving=kp_norm, **model_params.generator_params, **model_params.common_params, test=True, comm=False) if not args.full and 'sparse_deformed' in generated: del generated['sparse_deformed'] # remove needless info persistent_all(generated) generated['kp_driving'] = kp_driving generated['kp_source'] = kp_source generated['kp_norm'] = kp_norm # generated contains these values; # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4)) # (bs, num_kp + 1, c, h, w) # 'occlusion_map': <Variable((bs, 1, h/4, w/4)) # 'deformed': <Variable((bs, c, h, w)) # 'prediction': <Variable((bs, c, h, w)) mode = "arbitrary" if "log_dir" in config: result_dir = os.path.join(args.out_dir, os.path.basename(config.log_dir), f"{mode}") else: result_dir = os.path.join(args.out_dir, "test_result", f"{mode}") # create an empty directory to save generated results _ = nm.Monitor(result_dir) # load the header images. header = imread("imgs/header_combined.png", channel_first=True) generated_images = list() # compute these in advance and reuse nn.forward_all([kp_source["value"], kp_source["jacobian"]], clear_buffer=True) nn.forward_all( [kp_driving_initial["value"], kp_driving_initial["jacobian"]], clear_buffer=True) num_of_driving_frames = driving_video.shape[0] for frame_idx in tqdm(range(num_of_driving_frames)): driving.d = driving_video[frame_idx][:3, ] nn.forward_all([generated["prediction"], generated["deformed"]], clear_buffer=True) if args.detailed: # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map visualization = visualizer.visualize(source=source.d, driving=driving.d, out=generated) if args.full: visualization = reshape_result(visualization) # (H, W, C) combined_image = visualization.transpose(2, 0, 1) # (C, H, W) elif args.only_generated: combined_image = np.clip(generated["prediction"].d[0], 0.0, 1.0) combined_image = (255 * combined_image).astype( np.uint8) # (C, H, W) else: # visualize source, driving, and generated image driving_fake = np.concatenate([ np.clip(driving.d[0], 0.0, 1.0), np.clip(generated["prediction"].d[0], 0.0, 1.0) ], axis=2) header_source = np.concatenate([ np.clip(header / 255., 0.0, 1.0), np.clip(source.d[0], 0.0, 1.0) ], axis=2) combined_image = np.concatenate([header_source, driving_fake], axis=1) combined_image = (255 * combined_image).astype(np.uint8) generated_images.append(combined_image) # once each video is generated, save it. output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4" output_filename = f"{os.path.basename(args.source)}_by_{output_filename}" output_filename = output_filename.replace("#", "_") if args.output_png: monitor_vis = nm.MonitorImage(output_filename, nm.Monitor(result_dir), interval=1, num_images=1, normalize_method=lambda x: x) for frame_idx, img in enumerate(generated_images): monitor_vis.add(frame_idx, img) else: generated_images = [_.transpose(1, 2, 0) for _ in generated_images] # you might need to change ffmpeg_params according to your environment. mimsave(f'{os.path.join(result_dir, output_filename)}', generated_images, fps=args.fps, ffmpeg_params=[ "-pix_fmt", "yuv420p", "-vcodec", "libx264", "-f", "mp4", "-q", "0" ]) return
def test(encoder_config, transformer_config, decoder_config, encoder_netG, transformer_netG, decoder_netG, src_celeb_name, trg_celeb_name, test_iterator, monitor, encoder_param_file, transformer_param_file, decoder_param_file): # prepare nn.Variable real_img = nn.Variable((1, 3, 256, 256)) real_bod_map = nn.Variable((1, 15, 64, 64)) real_bod_map_resize = nn.Variable((1, 15, 256, 256)) # encoder with nn.parameter_scope(encoder_config["model_name"]): _, preds = encoder_netG( real_img, batch_stat=False, planes=encoder_config["model"]["planes"], output_nc=encoder_config["model"]["output_nc"], num_stacks=encoder_config["model"]["num_stacks"], activation=encoder_config["model"]["activation"], ) preds.persistent = True preds_unlinked = preds.get_unlinked_variable() # load parameters of networks with nn.parameter_scope(encoder_config["model_name"]): nn.load_parameters(encoder_param_file) # transformer # Generator with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_A2B'): fake_bod_map = transformer_netG( preds, test=True, norm_type=transformer_config["norm_type"]) with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_A2B'): nn.load_parameters(transformer_param_file) fake_bod_map.persistent = True fake_bod_map_unlinked = fake_bod_map.get_unlinked_variable() # decoder with nn.parameter_scope('netG_decoder'): fake_img = decoder_netG(fake_bod_map_unlinked, test=True) fake_img.persistent = True # load parameters of networks with nn.parameter_scope('netG_decoder'): nn.load_parameters(decoder_param_file) monitor_vis = nm.MonitorImage('result', monitor, interval=1, num_images=1, normalize_method=lambda x: x) # test num_test_batches = test_iterator.size for i in range(num_test_batches): _real_img, _, _real_bod_map_resize = test_iterator.next() real_img.d = _real_img real_bod_map_resize.d = _real_bod_map_resize # Generator preds.forward(clear_no_need_grad=True) fake_bod_map.forward(clear_no_need_grad=True) fake_img.forward(clear_no_need_grad=True) images_to_visualize = [ real_img.d, preds.d, fake_bod_map.d, fake_img.d, real_bod_map_resize.d ] visuals = combine_images(images_to_visualize) monitor_vis.add(i, visuals)
def reconstruct(args): # get context ctx = get_extension_context(args.context) nn.set_default_context(ctx) logger.setLevel(logging.ERROR) # to supress minor messages config = read_yaml(args.config) dataset_params = config.dataset_params model_params = config.model_params if args.detailed: vis_params = config.visualizer_params visualizer = Visualizer(**vis_params) if not args.params: assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters." param_file = os.path.join( config.log_dir, config.saved_parameters) else: param_file = args.params nn.load_parameters(param_file) bs, h, w, c = [1] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving_initial = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) with nn.parameter_scope("kp_detector"): kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_source) with nn.parameter_scope("kp_detector"): kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_driving) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=unlink_all(kp_source), kp_driving=kp_driving, **model_params.generator_params, **model_params.common_params, test=True, comm=False) if not args.full and 'sparse_deformed' in generated: del generated['sparse_deformed'] # remove needless info persistent_all(generated) generated['kp_driving'] = kp_driving generated['kp_source'] = kp_source # generated contains these values; # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4)) # (bs, num_kp + 1, c, h, w) # 'occlusion_map': <Variable((bs, 1, h/4, w/4)) # 'deformed': <Variable((bs, c, h, w)) # 'prediction': <Variable((bs, c, h, w)) mode = "reconstruction" if "log_dir" in config: result_dir = os.path.join(args.out_dir, os.path.basename(config.log_dir), f"{mode}") else: result_dir = os.path.join(args.out_dir, "test_result", f"{mode}") # create an empty directory to save generated results _ = nm.Monitor(result_dir) if args.eval: os.makedirs(os.path.join(result_dir, "png"), exist_ok=True) # load the header images. header = imread("imgs/header_combined.png", channel_first=True) filenames = sorted(glob.glob(os.path.join( dataset_params.root_dir, "test", "*"))) recon_loss_list = list() for filename in tqdm(filenames): # process repeated until all the test data is used driving_video = read_video( filename, dataset_params.frame_shape) # (#frames, h, w, 3) driving_video = np.transpose( driving_video, (0, 3, 1, 2)) # (#frames, 3, h, w) generated_images = list() source_img = driving_video[0] source.d = np.expand_dims(source_img, 0) driving_initial.d = driving_video[0] # compute these in advance and reuse nn.forward_all( [kp_source["value"], kp_source["jacobian"]], clear_buffer=True) num_of_driving_frames = driving_video.shape[0] for frame_idx in tqdm(range(num_of_driving_frames)): driving.d = driving_video[frame_idx] nn.forward_all([generated["prediction"], generated["deformed"]], clear_buffer=True) if args.detailed: # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map visualization = visualizer.visualize( source=source.d, driving=driving.d, out=generated) if args.full: visualization = reshape_result(visualization) # (H, W, C) combined_image = visualization.transpose(2, 0, 1) # (C, H, W) elif args.only_generated: combined_image = np.clip( generated["prediction"].d[0], 0.0, 1.0) combined_image = ( 255*combined_image).astype(np.uint8) # (C, H, W) else: # visualize source, driving, and generated image driving_fake = np.concatenate([np.clip(driving.d[0], 0.0, 1.0), np.clip(generated["prediction"].d[0], 0.0, 1.0)], axis=2) header_source = np.concatenate([np.clip(header / 255., 0.0, 1.0), np.clip(source.d[0], 0.0, 1.0)], axis=2) combined_image = np.concatenate( [header_source, driving_fake], axis=1) combined_image = (255*combined_image).astype(np.uint8) generated_images.append(combined_image) # compute L1 distance per frame. recon_loss_list.append( np.mean(np.abs(generated["prediction"].d[0] - driving.d[0]))) # post process only for reconstruction evaluation. if args.eval: # crop generated images region only. if args.only_generated: eval_images = generated_images elif args.full: eval_images = [_[:, :h, 4*w:5*w] for _ in generated_images] elif args.detailed: assert generated_images[0].shape == (c, h, 5*w) eval_images = [_[:, :, 3*w:4*w] for _ in generated_images] else: eval_images = [_[:, h:, w:] for _ in generated_images] # place them horizontally and save for evaluation. image_for_eval = np.concatenate( eval_images, axis=2).transpose(1, 2, 0) imsave(os.path.join(result_dir, "png", f"{os.path.basename(filename)}.png"), image_for_eval) # once each video is generated, save it. output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4" if args.output_png: monitor_vis = nm.MonitorImage(output_filename, nm.Monitor(result_dir), interval=1, num_images=1, normalize_method=lambda x: x) for frame_idx, img in enumerate(generated_images): monitor_vis.add(frame_idx, img) else: generated_images = [_.transpose(1, 2, 0) for _ in generated_images] # you might need to change ffmpeg_params according to your environment. mimsave(f'{os.path.join(result_dir, output_filename)}', generated_images, fps=args.fps, ffmpeg_params=["-pix_fmt", "yuv420p", "-vcodec", "libx264", "-f", "mp4", "-q", "0"]) print(f"Reconstruction loss: {np.mean(recon_loss_list)}") return
def train_transformer(config, netG, netD, solver_netG, solver_netD, train_iterators, monitor): netG_A2B, netG_B2A = netG['netG_A2B'], netG['netG_B2A'] netD_A, netD_B = netD['netD_A'], netD['netD_B'] solver_netG_AB, solver_netG_BA = solver_netG['netG_A2B'], solver_netG[ 'netG_B2A'] solver_netD_A, solver_netD_B = solver_netD['netD_A'], solver_netD['netD_B'] train_iterator_src, train_iterator_trg = train_iterators if config["train"][ "cycle_loss"] and config["train"]["cycle_loss"]["lambda"] > 0: print( f'Applying Cycle Loss, weight: {config["train"]["cycle_loss"]["lambda"]}.' ) with_cycle_loss = True else: with_cycle_loss = False if config["train"][ "shape_loss"] and config["train"]["shape_loss"]["lambda"] > 0: print( f'Applying Shape Loss using PCA, weight: {config["train"]["shape_loss"]["lambda"]}.' ) with_shape_loss = True else: with_shape_loss = False # Load boundary image to get Variable shapes bod_map_A = train_iterator_src.next()[0] bod_map_B = train_iterator_trg.next()[0] real_bod_map_A = nn.Variable(bod_map_A.shape) real_bod_map_B = nn.Variable(bod_map_B.shape) real_bod_map_A.persistent, real_bod_map_B.persistent = True, True ################### Graph Construction #################### # Generator with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_A2B'): fake_bod_map_B = netG_A2B( real_bod_map_A, test=False, norm_type=config["norm_type"]) # (1, 15, 64, 64) with nn.parameter_scope('netG_B2A'): fake_bod_map_A = netG_B2A( real_bod_map_B, test=False, norm_type=config["norm_type"]) # (1, 15, 64, 64) fake_bod_map_B.persistent, fake_bod_map_A.persistent = True, True fake_bod_map_B_unlinked = fake_bod_map_B.get_unlinked_variable() fake_bod_map_A_unlinked = fake_bod_map_A.get_unlinked_variable() # Reconstruct images if cycle loss is applied. if with_cycle_loss: with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_B2A'): recon_bod_map_A = netG_B2A( fake_bod_map_B_unlinked, test=False, norm_type=config["norm_type"]) # (1, 15, 64, 64) with nn.parameter_scope('netG_A2B'): recon_bod_map_B = netG_A2B( fake_bod_map_A_unlinked, test=False, norm_type=config["norm_type"]) # (1, 15, 64, 64) recon_bod_map_A.persistent, recon_bod_map_B.persistent = True, True # Discriminator with nn.parameter_scope('netD_transformer'): with nn.parameter_scope('netD_A'): pred_fake_A = netD_A(fake_bod_map_A_unlinked, test=False) pred_real_A = netD_A(real_bod_map_A, test=False) with nn.parameter_scope('netD_B'): pred_fake_B = netD_B(fake_bod_map_B_unlinked, test=False) pred_real_B = netD_B(real_bod_map_B, test=False) real_target = F.constant(1, pred_fake_A.shape) fake_target = F.constant(0, pred_real_A.shape) ################### Loss Definition #################### # Generator loss # LSGAN loss loss_gan_A = lsgan_loss(pred_fake_A, real_target) loss_gan_B = lsgan_loss(pred_fake_B, real_target) loss_gan_A.persistent, loss_gan_B.persistent = True, True loss_gan = loss_gan_A + loss_gan_B # Cycle loss if with_cycle_loss: loss_cycle_A = recon_loss(recon_bod_map_A, real_bod_map_A) loss_cycle_B = recon_loss(recon_bod_map_B, real_bod_map_B) loss_cycle_A.persistent, loss_cycle_B.persistent = True, True loss_cycle = loss_cycle_A + loss_cycle_B # Shape loss if with_shape_loss: with nn.parameter_scope("Align"): nn.load_parameters( config["train"]["shape_loss"]["align_param_path"]) shape_bod_map_real_A = models.align_resnet(real_bod_map_A, fix_parameters=True) shape_bod_map_fake_B = models.align_resnet(fake_bod_map_B_unlinked, fix_parameters=True) shape_bod_map_real_B = models.align_resnet(real_bod_map_B, fix_parameters=True) shape_bod_map_fake_A = models.align_resnet(fake_bod_map_A_unlinked, fix_parameters=True) with nn.parameter_scope("PCA"): nn.load_parameters(config["train"]["shape_loss"]["PCA_param_path"]) shape_bod_map_real_A = PF.affine(shape_bod_map_real_A, 212, fix_parameters=True) shape_bod_map_real_A = shape_bod_map_real_A[:, :3] shape_bod_map_fake_B = PF.affine(shape_bod_map_fake_B, 212, fix_parameters=True) shape_bod_map_fake_B = shape_bod_map_fake_B[:, :3] shape_bod_map_real_B = PF.affine(shape_bod_map_real_B, 212, fix_parameters=True) shape_bod_map_real_B = shape_bod_map_real_B[:, :3] shape_bod_map_fake_A = PF.affine(shape_bod_map_fake_A, 212, fix_parameters=True) shape_bod_map_fake_A = shape_bod_map_fake_A[:, :3] shape_bod_map_real_A.persistent, shape_bod_map_fake_A.persistent = True, True shape_bod_map_real_B.persistent, shape_bod_map_fake_B.persistent = True, True loss_shape_A = recon_loss(shape_bod_map_real_A, shape_bod_map_fake_B) loss_shape_B = recon_loss(shape_bod_map_real_B, shape_bod_map_fake_A) loss_shape_A.persistent, loss_shape_B.persistent = True, True loss_shape = loss_shape_A + loss_shape_B # Total Generator Loss loss_netG = loss_gan if with_cycle_loss: loss_netG += loss_cycle * config["train"]["cycle_loss"]["lambda"] if with_shape_loss: loss_netG += loss_shape * config["train"]["shape_loss"]["lambda"] # Discriminator loss loss_netD_A = lsgan_loss(pred_real_A, real_target) + \ lsgan_loss(pred_fake_A, fake_target) loss_netD_B = lsgan_loss(pred_real_B, real_target) + \ lsgan_loss(pred_fake_B, fake_target) loss_netD_A.persistent, loss_netD_B.persistent = True, True loss_netD = loss_netD_A + loss_netD_B ################### Setting Solvers #################### # Generator solver with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_A2B'): solver_netG_AB.set_parameters(nn.get_parameters()) with nn.parameter_scope('netG_B2A'): solver_netG_BA.set_parameters(nn.get_parameters()) # Discrimintar solver with nn.parameter_scope('netD_transformer'): with nn.parameter_scope('netD_A'): solver_netD_A.set_parameters(nn.get_parameters()) with nn.parameter_scope('netD_B'): solver_netD_B.set_parameters(nn.get_parameters()) ################### Create Monitors #################### interval = config["monitor"]["interval"] monitors_G_dict = { 'loss_netG': loss_netG, 'loss_gan_A': loss_gan_A, 'loss_gan_B': loss_gan_B } if with_cycle_loss: monitors_G_dict.update({ 'loss_cycle_A': loss_cycle_A, 'loss_cycle_B': loss_cycle_B }) if with_shape_loss: monitors_G_dict.update({ 'loss_shape_A': loss_shape_A, 'loss_shape_B': loss_shape_B }) monitors_G = MonitorManager(monitors_G_dict, monitor, interval=interval) monitors_D_dict = { 'loss_netD': loss_netD, 'loss_netD_A': loss_netD_A, 'loss_netD_B': loss_netD_B } monitors_D = MonitorManager(monitors_D_dict, monitor, interval=interval) monitor_time = nm.MonitorTimeElapsed('time_training', monitor, interval=interval) monitor_vis = nm.MonitorImage('result', monitor, interval=1, num_images=4, normalize_method=lambda x: x) # Dump training information with open(os.path.join(monitor._save_path, "training_info.yaml"), "w", encoding="utf-8") as f: f.write(yaml.dump(config)) # Training epoch = config["train"]["epochs"] i = 0 iter_per_epoch = train_iterator_src.size // config["train"][ "batch_size"] + 1 for e in range(epoch): logger.info(f'Epoch = {e} / {epoch}') train_iterator_src._reset() # rewind the iterator train_iterator_trg._reset() # rewind the iterator for _ in range(iter_per_epoch): bod_map_A = train_iterator_src.next()[0] bod_map_B = train_iterator_trg.next()[0] real_bod_map_A.d, real_bod_map_B.d = bod_map_A, bod_map_B # Generate fake image fake_bod_map_B.forward(clear_no_need_grad=True) fake_bod_map_A.forward(clear_no_need_grad=True) # Update Discriminator solver_netD_A.zero_grad() solver_netD_B.zero_grad() loss_netD.forward(clear_no_need_grad=True) loss_netD.backward(clear_buffer=True) if config["train"]["weight_decay"]: solver_netD_A.weight_decay(config["train"]["weight_decay"]) solver_netD_B.weight_decay(config["train"]["weight_decay"]) solver_netD_A.update() solver_netD_B.update() # Update Generator solver_netG_BA.zero_grad() solver_netG_AB.zero_grad() solver_netD_A.zero_grad() solver_netD_B.zero_grad() fake_bod_map_B_unlinked.grad.zero() fake_bod_map_A_unlinked.grad.zero() loss_netG.forward(clear_no_need_grad=True) loss_netG.backward(clear_buffer=True) fake_bod_map_B.backward(grad=None) fake_bod_map_A.backward(grad=None) solver_netG_AB.update() solver_netG_BA.update() # Monitors monitor_time.add(i) monitors_G.add(i) monitors_D.add(i) i += 1 images_to_visualize = [ real_bod_map_A.d, fake_bod_map_B.d, real_bod_map_B.d ] if with_cycle_loss: images_to_visualize.extend( [recon_bod_map_A.d, fake_bod_map_A.d, recon_bod_map_B.d]) else: images_to_visualize.extend([fake_bod_map_A.d]) visuals = combine_images(images_to_visualize) monitor_vis.add(i, visuals) if e % config["monitor"]["save_interval"] == 0 or e == epoch - 1: # Save parameters of networks netG_B2A_save_path = os.path.join(monitor._save_path, f'netG_transformer_B2A_{e}.h5') netG_A2B_save_path = os.path.join(monitor._save_path, f'netG_transformer_A2B_{e}.h5') with nn.parameter_scope('netG_transformer'): with nn.parameter_scope('netG_A2B'): nn.save_parameters(netG_A2B_save_path) with nn.parameter_scope('netG_B2A'): nn.save_parameters(netG_B2A_save_path) netD_A_save_path = os.path.join(monitor._save_path, f'netD_transformer_A_{e}.h5') netD_B_save_path = os.path.join(monitor._save_path, f'netD_transformer_B_{e}.h5') with nn.parameter_scope('netD_transformer'): with nn.parameter_scope('netD_A'): nn.save_parameters(netD_A_save_path) with nn.parameter_scope('netD_B'): nn.save_parameters(netD_B_save_path)
def train(config, train_iterator, valid_iterator, monitor): ################### Graph Construction #################### # Training graph img, htm = train_iterator.next() image = nn.Variable(img.shape) heatmap = nn.Variable(htm.shape) with nn.parameter_scope(config["model_name"]): preds = stacked_hourglass_net( image, batch_stat=True, planes=config["model"]["planes"], output_nc=config["model"]["output_nc"], num_stacks=config["model"]["num_stacks"], activation=config["model"]["activation"], ) if config["finetune"]: os.path.isfile(config["finetune"]["param_path"] ), "params file not found." with nn.parameter_scope(config["model_name"]): nn.load_parameters(config["finetune"]["param_path"]) # Loss Definition if config["loss_name"] == 'mse': def loss_func(pred, target): return F.mean( F.squared_error(pred, target)) elif config["loss_name"] == 'bce': def loss_func(pred, target): return F.mean( F.binary_cross_entropy(pred, target)) else: raise NotImplementedError losses = [] for pred in preds: loss_local = loss_func(pred, heatmap) loss_local.persistent = True losses.append(loss_local) loss = nn.Variable() loss.d = 0 for loss_local in losses: loss += loss_local ################### Setting Solvers #################### solver = S.Adam(config["train"]["lr"]) with nn.parameter_scope(config["model_name"]): solver.set_parameters(nn.get_parameters()) # Validation graph img, htm = valid_iterator.next() val_image = nn.Variable(img.shape) val_heatmap = nn.Variable(htm.shape) with nn.parameter_scope(config["model_name"]): val_preds = stacked_hourglass_net( val_image, batch_stat=False, planes=config["model"]["planes"], output_nc=config["model"]["output_nc"], num_stacks=config["model"]["num_stacks"], activation=config["model"]["activation"], ) for i in range(len(val_preds)): val_preds[i].persistent = True # Loss Definition val_losses = [] for pred in val_preds: loss_local = loss_func(pred, val_heatmap) loss_local.persistent = True val_losses.append(loss_local) val_loss = nn.Variable() val_loss.d = 0 for loss_local in val_losses: val_loss += loss_local num_train_batches = train_iterator.size // train_iterator.batch_size + 1 num_valid_batches = valid_iterator.size // valid_iterator.batch_size + 1 ################### Create Monitors #################### monitors_train_dict = {'loss_total': loss} for i in range(len(losses)): monitors_train_dict.update({f'loss_{i}': losses[i]}) monitors_val_dict = {'val_loss_total': val_loss} for i in range(len(val_losses)): monitors_val_dict.update({f'val_loss_{i}': val_losses[i]}) monitors_train = MonitorManager( monitors_train_dict, monitor, interval=config["monitor"]["interval"]*num_train_batches) monitors_val = MonitorManager( monitors_val_dict, monitor, interval=config["monitor"]["interval"]*num_valid_batches) monitor_time = nm.MonitorTimeElapsed( 'time', monitor, interval=config["monitor"]["interval"]*num_train_batches) monitor_vis = nm.MonitorImage( 'result', monitor, interval=1, num_images=4, normalize_method=lambda x: x) monitor_vis_val = nm.MonitorImage( 'result_val', monitor, interval=1, num_images=4, normalize_method=lambda x: x) os.mkdir(os.path.join(monitor._save_path, 'model')) # Dump training information with open(os.path.join(monitor._save_path, "training_info.yaml"), "w", encoding="utf-8") as f: f.write(yaml.dump(config)) # Training best_epoch = 0 best_val_loss = np.inf for e in range(config["train"]["epochs"]): watch_val_loss = 0 # training loop for i in range(num_train_batches): image.d, heatmap.d = train_iterator.next() solver.zero_grad() loss.forward() loss.backward(clear_buffer=True) solver.weight_decay(config["train"]["weight_decay"]) solver.update() monitors_train.add(e*num_train_batches + i) monitor_time.add(e*num_train_batches + i) # validation loop for i in range(num_valid_batches): val_image.d, val_heatmap.d = valid_iterator.next() val_loss.forward(clear_buffer=True) monitors_val.add(e*num_valid_batches + i) watch_val_loss += val_loss.d.copy() watch_val_loss /= num_valid_batches # visualization visuals = combine_images([image.d, preds[0].d, preds[1].d, heatmap.d]) monitor_vis.add(e, visuals) visuals_val = combine_images( [val_image.d, val_preds[0].d, val_preds[1].d, val_heatmap.d]) monitor_vis_val.add(e, visuals_val) # update best result and save weights if updated if best_val_loss > watch_val_loss or e % config["monitor"]["save_interval"] == 0: best_val_loss = watch_val_loss best_epoch = e save_path = os.path.join( monitor._save_path, 'model/model_epoch-{}.h5'.format(e)) with nn.parameter_scope(config["model_name"]): nn.save_parameters(save_path) # save the last parameters as well save_path = os.path.join( monitor._save_path, 'model/model_epoch-{}.h5'.format(e)) with nn.parameter_scope(config["model_name"]): nn.save_parameters(save_path) logger.info(f'Best Epoch: {best_epoch}.')