def init_nnabla(conf=None, ext_name=None, device_id=None, type_config=None): import nnabla as nn from nnabla.ext_utils import get_extension_context from .comm import CommunicatorWrapper if conf is None: conf = AttrDict() if ext_name is not None: conf.ext_name = ext_name if device_id is not None: conf.device_id = device_id if type_config is not None: conf.type_config = type_config # set context ctx = get_extension_context(ext_name=conf.ext_name, device_id=conf.device_id, type_config=conf.type_config) # init communicator comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) # disable outputs from logger except rank==0 if comm.rank > 0: from nnabla import logger import logging logger.setLevel(logging.ERROR) return comm
def init_nnabla(ctx_config): import nnabla as nn from nnabla.ext_utils import get_extension_context from comm import CommunicatorWrapper # set context ctx = get_extension_context(**ctx_config) # init communicator comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) # disable outputs from logger except rank==0 if comm.rank > 0: from nnabla import logger import logging logger.setLevel(logging.ERROR) return comm
def animate(args): # get context ctx = get_extension_context(args.context) nn.set_default_context(ctx) logger.setLevel(logging.ERROR) # to supress minor messages if not args.config: assert not args.params, "pretrained weights file is given, but corresponding config file is not. Please give both." download_provided_file( "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/voxceleb_trained_info.yaml" ) args.config = 'voxceleb_trained_info.yaml' download_provided_file( "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/pretrained_fomm_params.h5" ) config = read_yaml(args.config) dataset_params = config.dataset_params model_params = config.model_params if args.detailed: vis_params = config.visualizer_params visualizer = Visualizer(**vis_params) if not args.params: assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters." param_file = os.path.join(config.log_dir, config.saved_parameters) else: param_file = args.params print(f"Loading {param_file} for image animation...") nn.load_parameters(param_file) bs, h, w, c = [1] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving_initial = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) filename = args.driving # process repeated until all the test data is used driving_video = read_video( filename, dataset_params.frame_shape) # (#frames, h, w, 3) driving_video = np.transpose(driving_video, (0, 3, 1, 2)) # (#frames, 3, h, w) source_img = imread(args.source, channel_first=True, size=(256, 256)) / 255. source_img = source_img[:3] source.d = np.expand_dims(source_img, 0) driving_initial.d = driving_video[0][:3, ] with nn.parameter_scope("kp_detector"): kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_source) with nn.parameter_scope("kp_detector"): kp_driving_initial = detect_keypoint(driving_initial, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_driving_initial) with nn.parameter_scope("kp_detector"): kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_driving) if args.adapt_movement_scale: nn.forward_all([ kp_source["value"], kp_source["jacobian"], kp_driving_initial["value"], kp_driving_initial["jacobian"] ]) source_area = ConvexHull(kp_source['value'].d[0]).volume driving_area = ConvexHull(kp_driving_initial['value'].d[0]).volume adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) else: adapt_movement_scale = 1 kp_norm = adjust_kp(kp_source=unlink_all(kp_source), kp_driving=kp_driving, kp_driving_initial=unlink_all(kp_driving_initial), adapt_movement_scale=adapt_movement_scale, use_relative_movement=args.unuse_relative_movement, use_relative_jacobian=args.unuse_relative_jacobian) persistent_all(kp_norm) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=unlink_all(kp_source), kp_driving=kp_norm, **model_params.generator_params, **model_params.common_params, test=True, comm=False) if not args.full and 'sparse_deformed' in generated: del generated['sparse_deformed'] # remove needless info persistent_all(generated) generated['kp_driving'] = kp_driving generated['kp_source'] = kp_source generated['kp_norm'] = kp_norm # generated contains these values; # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4)) # (bs, num_kp + 1, c, h, w) # 'occlusion_map': <Variable((bs, 1, h/4, w/4)) # 'deformed': <Variable((bs, c, h, w)) # 'prediction': <Variable((bs, c, h, w)) mode = "arbitrary" if "log_dir" in config: result_dir = os.path.join(args.out_dir, os.path.basename(config.log_dir), f"{mode}") else: result_dir = os.path.join(args.out_dir, "test_result", f"{mode}") # create an empty directory to save generated results _ = nm.Monitor(result_dir) # load the header images. header = imread("imgs/header_combined.png", channel_first=True) generated_images = list() # compute these in advance and reuse nn.forward_all([kp_source["value"], kp_source["jacobian"]], clear_buffer=True) nn.forward_all( [kp_driving_initial["value"], kp_driving_initial["jacobian"]], clear_buffer=True) num_of_driving_frames = driving_video.shape[0] for frame_idx in tqdm(range(num_of_driving_frames)): driving.d = driving_video[frame_idx][:3, ] nn.forward_all([generated["prediction"], generated["deformed"]], clear_buffer=True) if args.detailed: # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map visualization = visualizer.visualize(source=source.d, driving=driving.d, out=generated) if args.full: visualization = reshape_result(visualization) # (H, W, C) combined_image = visualization.transpose(2, 0, 1) # (C, H, W) elif args.only_generated: combined_image = np.clip(generated["prediction"].d[0], 0.0, 1.0) combined_image = (255 * combined_image).astype( np.uint8) # (C, H, W) else: # visualize source, driving, and generated image driving_fake = np.concatenate([ np.clip(driving.d[0], 0.0, 1.0), np.clip(generated["prediction"].d[0], 0.0, 1.0) ], axis=2) header_source = np.concatenate([ np.clip(header / 255., 0.0, 1.0), np.clip(source.d[0], 0.0, 1.0) ], axis=2) combined_image = np.concatenate([header_source, driving_fake], axis=1) combined_image = (255 * combined_image).astype(np.uint8) generated_images.append(combined_image) # once each video is generated, save it. output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4" output_filename = f"{os.path.basename(args.source)}_by_{output_filename}" output_filename = output_filename.replace("#", "_") if args.output_png: monitor_vis = nm.MonitorImage(output_filename, nm.Monitor(result_dir), interval=1, num_images=1, normalize_method=lambda x: x) for frame_idx, img in enumerate(generated_images): monitor_vis.add(frame_idx, img) else: generated_images = [_.transpose(1, 2, 0) for _ in generated_images] # you might need to change ffmpeg_params according to your environment. mimsave(f'{os.path.join(result_dir, output_filename)}', generated_images, fps=args.fps, ffmpeg_params=[ "-pix_fmt", "yuv420p", "-vcodec", "libx264", "-f", "mp4", "-q", "0" ]) return
def reconstruct(args): # get context ctx = get_extension_context(args.context) nn.set_default_context(ctx) logger.setLevel(logging.ERROR) # to supress minor messages config = read_yaml(args.config) dataset_params = config.dataset_params model_params = config.model_params if args.detailed: vis_params = config.visualizer_params visualizer = Visualizer(**vis_params) if not args.params: assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters." param_file = os.path.join( config.log_dir, config.saved_parameters) else: param_file = args.params nn.load_parameters(param_file) bs, h, w, c = [1] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving_initial = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) with nn.parameter_scope("kp_detector"): kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_source) with nn.parameter_scope("kp_detector"): kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=True, comm=False) persistent_all(kp_driving) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=unlink_all(kp_source), kp_driving=kp_driving, **model_params.generator_params, **model_params.common_params, test=True, comm=False) if not args.full and 'sparse_deformed' in generated: del generated['sparse_deformed'] # remove needless info persistent_all(generated) generated['kp_driving'] = kp_driving generated['kp_source'] = kp_source # generated contains these values; # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4)) # (bs, num_kp + 1, c, h, w) # 'occlusion_map': <Variable((bs, 1, h/4, w/4)) # 'deformed': <Variable((bs, c, h, w)) # 'prediction': <Variable((bs, c, h, w)) mode = "reconstruction" if "log_dir" in config: result_dir = os.path.join(args.out_dir, os.path.basename(config.log_dir), f"{mode}") else: result_dir = os.path.join(args.out_dir, "test_result", f"{mode}") # create an empty directory to save generated results _ = nm.Monitor(result_dir) if args.eval: os.makedirs(os.path.join(result_dir, "png"), exist_ok=True) # load the header images. header = imread("imgs/header_combined.png", channel_first=True) filenames = sorted(glob.glob(os.path.join( dataset_params.root_dir, "test", "*"))) recon_loss_list = list() for filename in tqdm(filenames): # process repeated until all the test data is used driving_video = read_video( filename, dataset_params.frame_shape) # (#frames, h, w, 3) driving_video = np.transpose( driving_video, (0, 3, 1, 2)) # (#frames, 3, h, w) generated_images = list() source_img = driving_video[0] source.d = np.expand_dims(source_img, 0) driving_initial.d = driving_video[0] # compute these in advance and reuse nn.forward_all( [kp_source["value"], kp_source["jacobian"]], clear_buffer=True) num_of_driving_frames = driving_video.shape[0] for frame_idx in tqdm(range(num_of_driving_frames)): driving.d = driving_video[frame_idx] nn.forward_all([generated["prediction"], generated["deformed"]], clear_buffer=True) if args.detailed: # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map visualization = visualizer.visualize( source=source.d, driving=driving.d, out=generated) if args.full: visualization = reshape_result(visualization) # (H, W, C) combined_image = visualization.transpose(2, 0, 1) # (C, H, W) elif args.only_generated: combined_image = np.clip( generated["prediction"].d[0], 0.0, 1.0) combined_image = ( 255*combined_image).astype(np.uint8) # (C, H, W) else: # visualize source, driving, and generated image driving_fake = np.concatenate([np.clip(driving.d[0], 0.0, 1.0), np.clip(generated["prediction"].d[0], 0.0, 1.0)], axis=2) header_source = np.concatenate([np.clip(header / 255., 0.0, 1.0), np.clip(source.d[0], 0.0, 1.0)], axis=2) combined_image = np.concatenate( [header_source, driving_fake], axis=1) combined_image = (255*combined_image).astype(np.uint8) generated_images.append(combined_image) # compute L1 distance per frame. recon_loss_list.append( np.mean(np.abs(generated["prediction"].d[0] - driving.d[0]))) # post process only for reconstruction evaluation. if args.eval: # crop generated images region only. if args.only_generated: eval_images = generated_images elif args.full: eval_images = [_[:, :h, 4*w:5*w] for _ in generated_images] elif args.detailed: assert generated_images[0].shape == (c, h, 5*w) eval_images = [_[:, :, 3*w:4*w] for _ in generated_images] else: eval_images = [_[:, h:, w:] for _ in generated_images] # place them horizontally and save for evaluation. image_for_eval = np.concatenate( eval_images, axis=2).transpose(1, 2, 0) imsave(os.path.join(result_dir, "png", f"{os.path.basename(filename)}.png"), image_for_eval) # once each video is generated, save it. output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4" if args.output_png: monitor_vis = nm.MonitorImage(output_filename, nm.Monitor(result_dir), interval=1, num_images=1, normalize_method=lambda x: x) for frame_idx, img in enumerate(generated_images): monitor_vis.add(frame_idx, img) else: generated_images = [_.transpose(1, 2, 0) for _ in generated_images] # you might need to change ffmpeg_params according to your environment. mimsave(f'{os.path.join(result_dir, output_filename)}', generated_images, fps=args.fps, ffmpeg_params=["-pix_fmt", "yuv420p", "-vcodec", "libx264", "-f", "mp4", "-q", "0"]) print(f"Reconstruction loss: {np.mean(recon_loss_list)}") return
def train(args): # get context ctx = get_extension_context(args.context) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) config = read_yaml(args.config) if args.info: config.monitor_params.info = args.info if comm.size == 1: comm = None else: # disable outputs from logger except its rank = 0 if comm.rank > 0: import logging logger.setLevel(logging.ERROR) test = False train_params = config.train_params dataset_params = config.dataset_params model_params = config.model_params loss_flags = get_loss_flags(train_params) start_epoch = 0 rng = np.random.RandomState(device_id) data_iterator = frame_data_iterator( root_dir=dataset_params.root_dir, frame_shape=dataset_params.frame_shape, id_sampling=dataset_params.id_sampling, is_train=True, random_seed=rng, augmentation_params=dataset_params.augmentation_params, batch_size=train_params['batch_size'], shuffle=True, with_memory_cache=False, with_file_cache=False) if n_devices > 1: data_iterator = data_iterator.slice(rng=rng, num_of_slices=comm.size, slice_pos=comm.rank) # workaround not to use memory cache data_iterator._data_source._on_memory = False logger.info("Disabled on memory data cache.") bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) with nn.parameter_scope("kp_detector"): # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))} kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_source) kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_driving) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=kp_source, kp_driving=kp_driving, **model_params.generator_params, **model_params.common_params, test=test, comm=comm) # generated is a dictionary containing; # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4)) # 'occlusion_map': Variable((bs, 1, h/4, w/4)) # 'deformed': Variable((bs, c, h, w)) # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator. generated["prediction"].persistent = True pyramide_real = get_image_pyramid(driving, train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_real) pyramide_fake = get_image_pyramid(generated['prediction'], train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_fake) total_loss_G = None # dammy. defined temporarily loss_var_dict = {} # perceptual loss using VGG19 (always applied) if loss_flags.use_perceptual_loss: logger.info("Use Perceptual Loss.") scales = train_params.scales weights = train_params.loss_weights.perceptual vgg_param_path = train_params.vgg_param_path percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales, weights, vgg_param_path) percep_loss.persistent = True loss_var_dict['perceptual_loss'] = percep_loss total_loss_G = percep_loss # (LS)GAN loss and feature matching loss if loss_flags.use_gan_loss: logger.info("Use GAN Loss.") with nn.parameter_scope("discriminator"): discriminator_maps_generated = multiscale_discriminator( pyramide_fake, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) discriminator_maps_real = multiscale_discriminator( pyramide_real, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) for v in discriminator_maps_generated["feature_maps_1"]: v.persistent = True discriminator_maps_generated["prediction_map_1"].persistent = True for v in discriminator_maps_real["feature_maps_1"]: v.persistent = True discriminator_maps_real["prediction_map_1"].persistent = True for i, scale in enumerate(model_params.discriminator_params.scales): key = f'prediction_map_{scale}'.replace('.', '-') lsgan_loss_weight = train_params.loss_weights.generator_gan # LSGAN loss for Generator if i == 0: gan_loss_gen = lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) else: gan_loss_gen += lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) # LSGAN loss for Discriminator if i == 0: gan_loss_dis = lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) else: gan_loss_dis += lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) gan_loss_dis.persistent = True loss_var_dict['gan_loss_dis'] = gan_loss_dis total_loss_D = gan_loss_dis total_loss_D.persistent = True gan_loss_gen.persistent = True loss_var_dict['gan_loss_gen'] = gan_loss_gen total_loss_G += gan_loss_gen if loss_flags.use_feature_matching_loss: logger.info("Use Feature Matching Loss.") fm_weights = train_params.loss_weights.feature_matching fm_loss = feature_matching_loss(discriminator_maps_real, discriminator_maps_generated, model_params, fm_weights) fm_loss.persistent = True loss_var_dict['feature_matching_loss'] = fm_loss total_loss_G += fm_loss # transform loss if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss: transform = Transform(bs, **config.train_params.transform_params) transformed_frame = transform.transform_frame(driving) with nn.parameter_scope("kp_detector"): transformed_kp = detect_keypoint(transformed_frame, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(transformed_kp) # Value loss part if loss_flags.use_equivariance_value_loss: logger.info("Use Equivariance Value Loss.") warped_kp_value = transform.warp_coordinates( transformed_kp['value']) eq_value_weight = train_params.loss_weights.equivariance_value eq_value_loss = equivariance_value_loss(kp_driving['value'], warped_kp_value, eq_value_weight) eq_value_loss.persistent = True loss_var_dict['equivariance_value_loss'] = eq_value_loss total_loss_G += eq_value_loss # jacobian loss part if loss_flags.use_equivariance_jacobian_loss: logger.info("Use Equivariance Jacobian Loss.") arithmetic_jacobian = transform.jacobian(transformed_kp['value']) eq_jac_weight = train_params.loss_weights.equivariance_jacobian eq_jac_loss = equivariance_jacobian_loss( kp_driving['jacobian'], arithmetic_jacobian, transformed_kp['jacobian'], eq_jac_weight) eq_jac_loss.persistent = True loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss total_loss_G += eq_jac_loss assert total_loss_G is not None total_loss_G.persistent = True loss_var_dict['total_loss_gen'] = total_loss_G # -------------------- Create Monitors -------------------- monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors( config, loss_flags, loss_var_dict) if device_id == 0: # Dump training info .yaml _ = shutil.copy(args.config, log_dir) # copy the config yaml training_info_yaml = os.path.join(log_dir, "training_info.yaml") os.rename(os.path.join(log_dir, os.path.basename(args.config)), training_info_yaml) # then add additional information with open(training_info_yaml, "a", encoding="utf-8") as f: f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None") # -------------------- Solver Setup -------------------- solvers = setup_solvers(train_params) solver_generator = solvers["generator"] solver_discriminator = solvers["discriminator"] solver_kp_detector = solvers["kp_detector"] # max epochs num_epochs = train_params['num_epochs'] # iteration per epoch num_iter_per_epoch = data_iterator.size // bs # will be increased by num_repeat if 'num_repeats' in train_params or train_params['num_repeats'] != 1: num_iter_per_epoch *= config.train_params.num_repeats # modify learning rate if current epoch exceeds the number defined in lr_decay_at_epochs = train_params['epoch_milestones'] # ex. [60, 90] gamma = 0.1 # decay rate # -------------------- For finetuning --------------------- if args.ft_params: assert os.path.isfile(args.ft_params) logger.info(f"load {args.ft_params} for finetuning.") nn.load_parameters(args.ft_params) start_epoch = int( os.path.splitext(os.path.basename( args.ft_params))[0].split("epoch_")[1]) # set solver's state for name, solver in solvers.items(): saved_states = os.path.join( os.path.dirname(args.ft_params), f"state_{name}_at_epoch_{start_epoch}.h5") solver.load_states(saved_states) start_epoch += 1 logger.info(f"Resuming from epoch {start_epoch}.") logger.info( f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch." ) for e in range(start_epoch, num_epochs): logger.info(f"Epoch: {e} / {num_epochs}.") data_iterator._reset() # rewind the iterator at the beginning # learning rate scheduler if e in lr_decay_at_epochs: logger.info("Learning rate decayed.") learning_rate_decay(solvers, gamma=gamma) for i in range(num_iter_per_epoch): _driving, _source = data_iterator.next() source.d = _source driving.d = _driving # update generator and keypoint detector total_loss_G.forward() if device_id == 0: monitors_gen.add((e * num_iter_per_epoch + i) * n_devices) solver_generator.zero_grad() solver_kp_detector.zero_grad() callback = None if n_devices > 1: params = [x.grad for x in solver_generator.get_parameters().values()] + \ [x.grad for x in solver_kp_detector.get_parameters().values()] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_G.backward(clear_buffer=True, communicator_callbacks=callback) solver_generator.update() solver_kp_detector.update() if loss_flags.use_gan_loss: # update discriminator total_loss_D.forward(clear_no_need_grad=True) if device_id == 0: monitors_dis.add((e * num_iter_per_epoch + i) * n_devices) solver_discriminator.zero_grad() callback = None if n_devices > 1: params = [ x.grad for x in solver_discriminator.get_parameters().values() ] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_D.backward(clear_buffer=True, communicator_callbacks=callback) solver_discriminator.update() if device_id == 0: monitor_time.add((e * num_iter_per_epoch + i) * n_devices) if device_id == 0 and ( (e * num_iter_per_epoch + i) * n_devices) % config.monitor_params.visualize_freq == 0: images_to_visualize = [ source.d, driving.d, generated["prediction"].d ] visuals = combine_images(images_to_visualize) monitor_vis.add((e * num_iter_per_epoch + i) * n_devices, visuals) if device_id == 0: if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1: save_parameters(e, log_dir, solvers) return