def main( # General arguments exp_dir, resume_dir=None, start_epoch=None, epochs=(90, ), iterations=None, resolutions=(128, 256), lr_gen=(1e-4, ), lr_dis=(1e-4, ), gpus=None, workers=4, batch_size=(64, ), seed=None, log_freq=20, # Data arguments train_dataset='opencv_video_seq_dataset.VideoSeqDataset', val_dataset=None, numpy_transforms=None, tensor_transforms=( 'img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), # Training arguments optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)', pretrained=False, criterion_pixelwise='nn.L1Loss', criterion_id='vgg_loss.VGGLoss', criterion_attr='vgg_loss.VGGLoss', criterion_gan='gan_loss.GANLoss(use_lsgan=True)', generator='res_unet.MultiScaleResUNet(in_nc=4,out_nc=3)', discriminator='discriminators_pix2pix.MultiscaleDiscriminator', reenactment_model=None, seg_model=None, lms_model=None, pix_weight=0.1, rec_weight=1.0, gan_weight=0.001, background_value=-1.0): def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode Gc.train(train) D.train(train) Gr.train(False) S.train(False) L.train(False) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, scheduler_G.get_lr()[0])) # For each batch in the training data for i, (img, target) in enumerate(pbar): # Prepare input with torch.no_grad(): # For each view images for j in range(len(img)): # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(device) # Compute context context = L(img[1][0].sub(context_mean).div(context_std)) context = landmarks_utils.filter_landmarks(context) # Normalize each of the pyramid images for j in range(len(img)): for p in range(len(img[j])): img[j][p].sub_(img_mean).div_(img_std) # # Compute segmentation # seg = [] # for j in range(len(img)): # curr_seg = S(img[j][0]) # if curr_seg.shape[2:] != (res, res): # curr_seg = F.interpolate(curr_seg, (res, res), mode='bicubic', align_corners=False) # seg.append(curr_seg) # Compute segmentation target_seg = S(img[1][0]) if target_seg.shape[2:] != (res, res): target_seg = F.interpolate(target_seg, (res, res), mode='bicubic', align_corners=False) # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0]) - 1, -1, -1): context = F.interpolate(context, size=img[0][p].shape[2:], mode='bicubic', align_corners=False) input.insert(0, torch.cat((img[0][p], context), dim=1)) # Reenactment reenactment_img = Gr(input) reenactment_seg = S(reenactment_img) if reenactment_img.shape[2:] != (res, res): reenactment_img = F.interpolate(reenactment_img, (res, res), mode='bilinear', align_corners=False) reenactment_seg = F.interpolate(reenactment_seg, (res, res), mode='bilinear', align_corners=False) # Remove unnecessary pyramids for j in range(len(img)): img[j] = img[j][-ri - 1:] # Source face reenactment_face_mask = reenactment_seg.argmax(1) == 1 inpainting_mask = seg_utils.random_hair_inpainting_mask_tensor( reenactment_face_mask).to(device) reenactment_face_mask = reenactment_face_mask * ( inpainting_mask == 0) reenactment_img_with_hole = reenactment_img.masked_fill( ~reenactment_face_mask.unsqueeze(1), background_value) # Target face target_face_mask = (target_seg.argmax(1) == 1).unsqueeze(1) inpainting_target = img[1][0] inpainting_target.masked_fill_(~target_face_mask, background_value) # Inpainting input inpainting_input = torch.cat( (reenactment_img_with_hole, target_face_mask.float()), dim=1) inpainting_input_pyd = img_utils.create_pyramid( inpainting_input, len(img[0])) # Face inpainting inpainting_pred = Gc(inpainting_input_pyd) # Fake Detection and Loss inpainting_pred_pyd = img_utils.create_pyramid( inpainting_pred, len(img[0])) pred_fake_pool = D([x.detach() for x in inpainting_pred_pyd]) loss_D_fake = criterion_gan(pred_fake_pool, False) # Real Detection and Loss inpainting_target_pyd = img_utils.create_pyramid( inpainting_target, len(img[0])) pred_real = D(inpainting_target_pyd) loss_D_real = criterion_gan(pred_real, True) loss_D_total = (loss_D_fake + loss_D_real) * 0.5 # GAN loss (Fake Passability Loss) pred_fake = D(inpainting_pred_pyd) loss_G_GAN = criterion_gan(pred_fake, True) # Reconstruction loss_pixelwise = criterion_pixelwise(inpainting_pred, inpainting_target) loss_id = criterion_id(inpainting_pred, inpainting_target) loss_attr = criterion_attr(inpainting_pred, inpainting_target) loss_rec = pix_weight * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN if train: # Update generator weights optimizer_G.zero_grad() loss_G_total.backward() optimizer_G.step() # Update discriminator weights optimizer_D.zero_grad() loss_D_total.backward() optimizer_D.step() logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, g_gan=loss_G_GAN, d_gan=loss_D_total) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg( '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images grid = img_utils.make_grid(img[0][0], reenactment_img, reenactment_img_with_hole, inpainting_pred, inpainting_target) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['rec'].avg ################# # Main pipeline # ################# # Validation resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions] lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen] lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis] epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs] batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size] iterations = iterations if iterations is None or isinstance( iterations, (list, tuple)) else [iterations] lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs batch_size = batch_size * len(resolutions) if len( batch_size) == 1 else batch_size if iterations is not None: iterations = iterations * len(resolutions) if len( iterations) == 1 else iterations iterations = utils.str2int(iterations) if not os.path.isdir(exp_dir): raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'') assert len(lr_gen) == len(resolutions) assert len(lr_dis) == len(resolutions) assert len(epochs) == len(resolutions) assert len(batch_size) == len(resolutions) assert iterations is None or len(iterations) == len(resolutions) # Seed utils.set_seed(seed) # Check CUDA device availability device, gpus = utils.set_device(gpus) # Initialize loggers logger = TensorBoardLogger(log_dir=exp_dir) # Initialize datasets numpy_transforms = obj_factory( numpy_transforms) if numpy_transforms is not None else [] tensor_transforms = obj_factory( tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_landmarks_transforms.Compose(numpy_transforms + tensor_transforms) train_dataset = obj_factory(train_dataset, transform=img_transforms) if val_dataset is not None: val_dataset = obj_factory(val_dataset, transform=img_transforms) # Create networks Gc = obj_factory(generator).to(device) D = obj_factory(discriminator).to(device) # Resume from a checkpoint or initialize the networks weights randomly checkpoint_dir = exp_dir if resume_dir is None else resume_dir Gc_path = os.path.join(checkpoint_dir, 'Gc_latest.pth') D_path = os.path.join(checkpoint_dir, 'D_latest.pth') best_loss = 1000000. curr_res = resolutions[0] optimizer_G_state, optimizer_D_state = None, None if os.path.isfile(Gc_path) and os.path.isfile(D_path): print("=> loading checkpoint from '{}'".format(checkpoint_dir)) # Gc checkpoint = torch.load(Gc_path) if 'resolution' in checkpoint: curr_res = checkpoint['resolution'] start_epoch = checkpoint[ 'epoch'] if start_epoch is None else start_epoch else: curr_res = resolutions[1] if len( resolutions) > 1 else resolutions[0] best_loss = checkpoint['best_loss'] Gc.apply(utils.init_weights) Gc.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_G_state = checkpoint['optimizer'] # D D.apply(utils.init_weights) if os.path.isfile(D_path): checkpoint = torch.load(D_path) D.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_D_state = checkpoint['optimizer'] else: print("=> no checkpoint found at '{}'".format(checkpoint_dir)) if not pretrained: print("=> randomly initializing networks...") Gc.apply(utils.init_weights) D.apply(utils.init_weights) # Load reenactment model print('=> Loading face reenactment model: "' + os.path.basename(reenactment_model) + '"...') if reenactment_model is None: raise RuntimeError('Reenactment model must be specified!') if not os.path.exists(reenactment_model): raise RuntimeError('Couldn\'t find reenactment model in path: ' + reenactment_model) checkpoint = torch.load(reenactment_model) Gr = obj_factory(checkpoint['arch']).to(device) Gr.load_state_dict(checkpoint['state_dict']) # Load segmentation model print('=> Loading face segmentation model: "' + os.path.basename(seg_model) + '"...') if seg_model is None: raise RuntimeError('Segmentation model must be specified!') if not os.path.exists(seg_model): raise RuntimeError('Couldn\'t find segmentation model in path: ' + seg_model) checkpoint = torch.load(seg_model) S = obj_factory(checkpoint['arch']).to(device) S.load_state_dict(checkpoint['state_dict']) # Load face landmarks model print('=> Loading face landmarks model: "' + os.path.basename(lms_model) + '"...') assert os.path.isfile( lms_model), 'The model path "%s" does not exist' % lms_model L = hrnet_wlfw().to(device) state_dict = torch.load(lms_model) L.load_state_dict(state_dict) # Initialize normalization tensors # Note: this is necessary because of the landmarks model img_mean = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1) img_std = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1) context_mean = torch.as_tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) context_std = torch.as_tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) # Lossess criterion_pixelwise = obj_factory(criterion_pixelwise).to(device) criterion_id = obj_factory(criterion_id).to(device) criterion_attr = obj_factory(criterion_attr).to(device) criterion_gan = obj_factory(criterion_gan).to(device) # Support multiple GPUs if gpus and len(gpus) > 1: Gc = nn.DataParallel(Gc, gpus) Gr = nn.DataParallel(Gr, gpus) D = nn.DataParallel(D, gpus) S = nn.DataParallel(S, gpus) L = nn.DataParallel(L, gpus) criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus) criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus) # For each resolution start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0])) start_epoch = 0 if start_epoch is None else start_epoch for ri in range(start_res_ind, len(resolutions)): res = resolutions[ri] res_lr_gen = lr_gen[ri] res_lr_dis = lr_dis[ri] res_epochs = epochs[ri] res_iterations = iterations[ri] if iterations is not None else None res_batch_size = batch_size[ri] # Optimizer and scheduler optimizer_G = obj_factory(optimizer, Gc.parameters(), lr=res_lr_gen) optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis) scheduler_G = obj_factory(scheduler, optimizer_G) scheduler_D = obj_factory(scheduler, optimizer_D) if optimizer_G_state is not None: optimizer_G.load_state_dict(optimizer_G_state) optimizer_G_state = None if optimizer_D_state is not None: optimizer_D.load_state_dict(optimizer_D_state) optimizer_D_state = None # Initialize data loaders if res_iterations is None: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, len(train_dataset)) else: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, res_iterations) train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) if val_dataset is not None: if res_iterations is None: val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, len(val_dataset)) else: val_iterations = (res_iterations * len( val_dataset.classes)) // len(train_dataset.classes) val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, val_iterations) val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) else: val_loader = None # For each epoch for epoch in range(start_epoch, res_epochs): total_loss = proces_epoch(train_loader, train=True) if val_loader is not None: with torch.no_grad(): total_loss = proces_epoch(val_loader, train=False) # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler_G.step(total_loss) scheduler_D.step(total_loss) else: scheduler_G.step() scheduler_D.step() # Save models checkpoints is_best = total_loss < best_loss best_loss = min(best_loss, total_loss) utils.save_checkpoint( exp_dir, 'Gc', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': Gc.module.state_dict() if gpus and len(gpus) > 1 else Gc.state_dict(), 'optimizer': optimizer_G.state_dict(), 'best_loss': best_loss, }, is_best) utils.save_checkpoint( exp_dir, 'D', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': D.module.state_dict() if gpus and len(gpus) > 1 else D.state_dict(), 'optimizer': optimizer_D.state_dict(), 'best_loss': best_loss, }, is_best) # Reset start epoch to 0 because it's should only effect the first training resolution start_epoch = 0
def main( source_path, target_path, arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)', reenactment_model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth', seg_model_path='../weights/lfw_figaro_unet_256_2_0_segmentation_v1.pth', inpainting_model_path='../weights/ijbc_msrunet_256_2_0_inpainting_v1.pth', blend_model_path='../weights/ijbc_msrunet_256_2_0_blending_v1.pth', pose_model_path='../weights/hopenet_robust_alpha1.pth', pil_transforms1=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), pil_transforms2=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)', 'landmark_transforms.LandmarksToHeatmaps'), tensor_transforms1=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), output_path=None, min_radius=2.0, crop_size=256, reverse_output=False, verbose=0, output_crop=False, display=False): torch.set_grad_enabled(False) fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) device, gpus = utils.set_device() Gr = obj_factory(arch).to(device) checkpoint = torch.load(reenactment_model_path) Gr.load_state_dict(checkpoint['state_dict']) Gr.train(False) if seg_model_path is not None: print('Loading face segmentation model: "' + os.path.basename(seg_model_path) + '"...') if seg_model_path.endswith('.pth'): checkpoint = torch.load(seg_model_path) Gs = obj_factory(checkpoint['arch']).to(device) Gs.load_state_dict(checkpoint['state_dict']) else: Gs = torch.jit.load(seg_model_path, map_location=device) if Gs is None: raise RuntimeError('Failed to load face segmentation model!') Gs.eval() else: Gs = None if seg_model_path is not None: print('Loading face inpainting model: "' + os.path.basename(inpainting_model_path) + '"...') if inpainting_model_path.endswith('.pth'): checkpoint = torch.load(inpainting_model_path) Gi = obj_factory(checkpoint['arch']).to(device) Gi.load_state_dict(checkpoint['state_dict']) else: Gi = torch.jit.load(inpainting_model_path, map_location=device) if Gi is None: raise RuntimeError('Failed to load face segmentation model!') Gi.eval() else: Gi = None checkpoint = torch.load(blend_model_path) Gb = obj_factory(checkpoint['arch']).to(device) Gb.load_state_dict(checkpoint['state_dict']) Gb.train(False) Gp = Hopenet().to(device) checkpoint = torch.load(pose_model_path) Gp.load_state_dict(checkpoint['state_dict']) Gp.train(False) pil_transforms1 = obj_factory( pil_transforms1) if pil_transforms1 is not None else [] pil_transforms2 = obj_factory( pil_transforms2) if pil_transforms2 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 + tensor_transforms1) img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 + tensor_transforms2) source_frame_indices, source_landmarks, source_bboxes, source_eulers = \ extract_landmarks_bboxes_euler_from_images(source_path, Gp, fa, device=device) if source_frame_indices.size == 0: raise RuntimeError( 'No faces were detected in the source image directory: ' + source_path) target_frame_indices, target_landmarks, target_bboxes, target_eulers = \ extract_landmarks_bboxes_euler_from_images(target_path, Gp, fa, device=device) if target_frame_indices.size == 0: raise RuntimeError( 'No faces were detected in the target image directory: ' + target_path) source_img_paths = glob(os.path.join(source_path, '*.jpg')) target_img_paths = glob(os.path.join(target_path, '*.jpg')) source_valid_frame_ind = 0 for k, source_img_path in tqdm(enumerate(source_img_paths), unit='images', total=len(source_img_paths)): if k not in source_frame_indices: continue source_img_bgr = cv2.imread(source_img_path) if source_img_bgr is None: continue source_img_rgb = source_img_bgr[:, :, ::-1] curr_source_tensor, curr_source_landmarks, curr_source_bbox = img_transforms1( source_img_rgb, source_landmarks[source_valid_frame_ind], source_bboxes[source_valid_frame_ind]) source_valid_frame_ind += 1 for j in range(len(curr_source_tensor)): curr_source_tensor[j] = curr_source_tensor[j].to(device) target_valid_frame_ind = 0 for i, target_img_path in enumerate(target_img_paths): curr_output_name = '_'.join([ os.path.splitext(os.path.basename(source_img_path))[0], os.path.splitext(os.path.basename(target_img_path))[0] ]) + '.jpg' curr_output_path = os.path.join(output_path, curr_output_name) if os.path.isfile(curr_output_path): target_valid_frame_ind += 1 continue target_img_bgr = cv2.imread(target_img_path) if target_img_bgr is None: continue if i not in target_frame_indices: continue target_img_rgb = target_img_bgr[:, :, ::-1] curr_target_tensor, curr_target_landmarks, curr_target_bbox = img_transforms2( target_img_rgb, target_landmarks[target_valid_frame_ind], target_bboxes[target_valid_frame_ind]) curr_target_euler = target_eulers[target_valid_frame_ind] target_valid_frame_ind += 1 reenactment_input_tensor = [] for j in range(len(curr_source_tensor)): curr_target_landmarks[j] = curr_target_landmarks[j].to(device) reenactment_input_tensor.append( torch.cat( (curr_source_tensor[j], curr_target_landmarks[j]), dim=0).unsqueeze(0)) reenactment_img_tensor, reenactment_seg_tensor = Gr( reenactment_input_tensor) target_img_tensor = curr_target_tensor[0].unsqueeze(0).to(device) target_seg_pred_tensor = Gs(target_img_tensor) target_mask_tensor = target_seg_pred_tensor.argmax(1) == 1 aligned_face_mask_tensor = reenactment_seg_tensor.argmax(1) == 1 aligned_background_mask_tensor = ~aligned_face_mask_tensor aligned_img_no_background_tensor = reenactment_img_tensor.clone() aligned_img_no_background_tensor.masked_fill_( aligned_background_mask_tensor.unsqueeze(1), -1.0) inpainting_input_tensor = torch.cat( (aligned_img_no_background_tensor, target_mask_tensor.unsqueeze(1).float()), dim=1) inpainting_input_tensor_pyd = create_pyramid( inpainting_input_tensor, len(curr_target_tensor)) completion_tensor = Gi(inpainting_input_tensor_pyd) transfer_tensor = transfer_mask(completion_tensor, target_img_tensor, target_mask_tensor) blend_input_tensor = torch.cat( (transfer_tensor, target_img_tensor, target_mask_tensor.unsqueeze(1).float()), dim=1) blend_input_tensor_pyd = create_pyramid(blend_input_tensor, len(curr_target_tensor)) blend_tensor = Gb(blend_input_tensor_pyd) blend_img = tensor2bgr(blend_tensor) if verbose == 0: render_img = blend_img if output_crop else crop2img( target_img_bgr, blend_img, curr_target_bbox[0].numpy()) elif verbose == 1: reenactment_only_tensor = transfer_mask( reenactment_img_tensor, target_img_tensor, aligned_face_mask_tensor & target_mask_tensor) reenactment_only_img = tensor2bgr(reenactment_only_tensor) completion_only_img = tensor2bgr(transfer_tensor) transfer_tensor = transfer_mask( aligned_img_no_background_tensor, target_img_tensor, target_mask_tensor) blend_input_tensor = torch.cat( (transfer_tensor, target_img_tensor, target_mask_tensor.unsqueeze(1).float()), dim=1) blend_input_tensor_pyd = create_pyramid( blend_input_tensor, len(curr_target_tensor)) blend_tensor = Gb(blend_input_tensor_pyd) blend_only_img = tensor2bgr(blend_tensor) render_img = np.concatenate( (reenactment_only_img, completion_only_img, blend_only_img, blend_img), axis=1) elif verbose == 2: reenactment_img_bgr = tensor2bgr(reenactment_img_tensor) reenactment_seg_bgr = tensor2bgr( blend_seg_pred(reenactment_img_tensor, reenactment_seg_tensor)) target_seg_bgr = tensor2bgr( blend_seg_pred(target_img_tensor, target_seg_pred_tensor)) aligned_img_no_background_bgr = tensor2bgr( aligned_img_no_background_tensor) completion_bgr = tensor2bgr(completion_tensor) transfer_bgr = tensor2bgr(transfer_tensor) target_cropped_bgr = tensor2bgr(target_img_tensor) pose_axis_bgr = draw_axis(np.zeros_like(target_cropped_bgr), curr_target_euler[0], curr_target_euler[1], curr_target_euler[2]) render_img1 = np.concatenate( (reenactment_img_bgr, reenactment_seg_bgr, target_seg_bgr), axis=1) render_img2 = np.concatenate((aligned_img_no_background_bgr, completion_bgr, transfer_bgr), axis=1) render_img3 = np.concatenate( (pose_axis_bgr, blend_img, target_cropped_bgr), axis=1) render_img = np.concatenate( (render_img1, render_img2, render_img3), axis=0) elif verbose == 3: source_cropped_bgr = tensor2bgr( curr_source_tensor[0].unsqueeze(0)) target_cropped_bgr = tensor2bgr(target_img_tensor) render_img = np.concatenate( (source_cropped_bgr, target_cropped_bgr, blend_img), axis=1) cv2.imwrite(curr_output_path, render_img) if display: cv2.imshow('render_img', render_img) if cv2.waitKey(1) & 0xFF == ord('q'): break
def __init__( self, resolution=d('resolution'), crop_scale=d('crop_scale'), gpus=d('gpus'), cpu_only=d('cpu_only'), display=d('display'), verbose=d('verbose'), # Detection arguments: detection_model=d('detection_model'), det_batch_size=d('det_batch_size'), det_postfix=d('det_postfix'), # Sequence arguments: iou_thresh=d('iou_thresh'), min_length=d('min_length'), min_size=d('min_size'), center_kernel=d('center_kernel'), size_kernel=d('size_kernel'), smooth_det=d('smooth_det'), seq_postfix=d('seq_postfix'), write_empty=d('write_empty'), # Pose arguments: pose_model=d('pose_model'), pose_batch_size=d('pose_batch_size'), pose_postfix=d('pose_postfix'), cache_pose=d('cache_pose'), cache_frontal=d('cache_frontal'), smooth_poses=d('smooth_poses'), # Landmarks arguments: lms_model=d('lms_model'), lms_batch_size=d('lms_batch_size'), landmarks_postfix=d('landmarks_postfix'), cache_landmarks=d('cache_landmarks'), smooth_landmarks=d('smooth_landmarks'), # Segmentation arguments: seg_model=d('seg_model'), seg_batch_size=d('seg_batch_size'), segmentation_postfix=d('segmentation_postfix'), cache_segmentation=d('cache_segmentation'), smooth_segmentation=d('smooth_segmentation'), seg_remove_mouth=d('seg_remove_mouth')): # General self.resolution = resolution self.crop_scale = crop_scale self.display = display self.verbose = verbose # Detection self.face_detector = FaceDetector(det_postfix, detection_model, gpus, det_batch_size, display) self.det_postfix = det_postfix # Sequences self.iou_thresh = iou_thresh self.min_length = min_length self.min_size = min_size self.center_kernel = center_kernel self.size_kernel = size_kernel self.smooth_det = smooth_det self.seq_postfix = seq_postfix self.write_empty = write_empty # Pose self.pose_batch_size = pose_batch_size self.pose_postfix = pose_postfix self.cache_pose = cache_pose self.cache_frontal = cache_frontal self.smooth_poses = smooth_poses # Landmarks self.smooth_landmarks = smooth_landmarks self.landmarks_postfix = landmarks_postfix self.cache_landmarks = cache_landmarks self.lms_batch_size = lms_batch_size # Segmentation self.smooth_segmentation = smooth_segmentation self.segmentation_postfix = segmentation_postfix self.cache_segmentation = cache_segmentation self.seg_batch_size = seg_batch_size self.seg_remove_mouth = seg_remove_mouth and cache_landmarks # Initialize device torch.set_grad_enabled(False) self.device, self.gpus = set_device(gpus, not cpu_only) # Load models self.face_pose = load_model(pose_model, 'face pose', self.device) if cache_pose else None self.L = load_model(lms_model, 'face landmarks', self.device) if cache_landmarks else None self.S = load_model(seg_model, 'face segmentation', self.device) if cache_segmentation else None # Initialize heatmap encoder self.heatmap_encoder = LandmarksHeatMapEncoder().to(self.device) # Initialize normalization tensors # Note: this is necessary because of the landmarks model self.img_mean = torch.as_tensor([0.5, 0.5, 0.5], device=self.device).view(1, 3, 1, 1) self.img_std = torch.as_tensor([0.5, 0.5, 0.5], device=self.device).view(1, 3, 1, 1) self.context_mean = torch.as_tensor([0.485, 0.456, 0.406], device=self.device).view( 1, 3, 1, 1) self.context_std = torch.as_tensor([0.229, 0.224, 0.225], device=self.device).view( 1, 3, 1, 1) # Support multiple GPUs if self.gpus and len(self.gpus) > 1: self.face_pose = nn.DataParallel( self.face_pose, self.gpus) if self.face_pose is not None else None self.L = nn.DataParallel(self.L, self.gpus) if self.L is not None else None self.S = nn.DataParallel(self.S, self.gpus) if self.S is not None else None # Initialize temportal smoothing if smooth_segmentation > 0: self.smooth_seg = TemporalSmoothing(3, smooth_segmentation).to( self.device) else: self.smooth_seg = None # Initialize output videos format self.fourcc = cv2.VideoWriter_fourcc(*'avc1')
def main( # General arguments exp_dir, resume_dir=None, start_epoch=None, epochs=(90, ), iterations=None, resolutions=(128, 256), learning_rate=(1e-1, ), gpus=None, workers=4, batch_size=(64, ), seed=None, log_freq=20, # Data arguments train_dataset='fsgan.image_seg_dataset.ImageSegDataset', val_dataset=None, numpy_transforms=None, tensor_transforms=( 'img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), # Training arguments optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)', criterion='nn.CrossEntropyLoss', model='fsgan.models.simple_unet.UNet(n_classes=3,feature_scale=1)', pretrained=False, benchmark='fsgan.train_segmentation.IOUBenchmark(3)'): def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode model.train(train) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, scheduler.get_lr()[0])) # For each batch in the training data for i, (input, target) in enumerate(pbar): # Prepare input input = input.to(device) target = target.to(device) with torch.no_grad(): target = target.argmax(dim=1) # Execute model pred = model(input) # Calculate loss loss_total = criterion(pred, target) # Run benchmark benchmark_res = benchmark(pred, target) if benchmark is not None else {} if train: # Update generator weights optimizer.zero_grad() loss_total.backward() optimizer.step() logger.update('losses', total=loss_total) logger.update('bench', **benchmark_res) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg( '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images seg_pred = blend_seg_pred(input, pred) seg_gt = blend_seg_label(input, target) grid = img_utils.make_grid(input, seg_pred, seg_gt) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['total'].avg ################# # Main pipeline # ################# # Validation resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions] learning_rate = learning_rate if isinstance(learning_rate, (list, tuple)) else [learning_rate] epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs] batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size] iterations = iterations if iterations is None or isinstance( iterations, (list, tuple)) else [iterations] learning_rate = learning_rate * len(resolutions) if len( learning_rate) == 1 else learning_rate epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs batch_size = batch_size * len(resolutions) if len( batch_size) == 1 else batch_size if iterations is not None: iterations = iterations * len(resolutions) if len( iterations) == 1 else iterations iterations = utils.str2int(iterations) if not os.path.isdir(exp_dir): raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'') assert len(learning_rate) == len(resolutions) assert len(epochs) == len(resolutions) assert len(batch_size) == len(resolutions) assert iterations is None or len(iterations) == len(resolutions) # Seed utils.set_seed(seed) # Check CUDA device availability device, gpus = utils.set_device(gpus) # Initialize loggers logger = TensorBoardLogger(log_dir=exp_dir) # Initialize datasets numpy_transforms = obj_factory( numpy_transforms) if numpy_transforms is not None else [] tensor_transforms = obj_factory( tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_landmarks_transforms.Compose(numpy_transforms + tensor_transforms) train_dataset = obj_factory(train_dataset, transform=img_transforms) if val_dataset is not None: val_dataset = obj_factory(val_dataset, transform=img_transforms) # Create networks arch = utils.get_arch(model, num_classes=len(train_dataset.classes)) model = obj_factory(model, num_classes=len(train_dataset.classes)).to(device) # Resume from a checkpoint or initialize the networks weights randomly checkpoint_dir = exp_dir if resume_dir is None else resume_dir model_path = os.path.join(checkpoint_dir, 'model_latest.pth') best_loss = 1e6 curr_res = resolutions[0] optimizer_state = None if os.path.isfile(model_path): print("=> loading checkpoint from '{}'".format(checkpoint_dir)) # model checkpoint = torch.load(model_path) if 'resolution' in checkpoint: curr_res = checkpoint['resolution'] start_epoch = checkpoint[ 'epoch'] if start_epoch is None else start_epoch # else: # curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0] best_loss_key = 'best_loss_%d' % curr_res best_loss = checkpoint[ best_loss_key] if best_loss_key in checkpoint else best_loss model.apply(utils.init_weights) model.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_state = checkpoint['optimizer'] else: print("=> no checkpoint found at '{}'".format(checkpoint_dir)) if not pretrained: print("=> randomly initializing networks...") model.apply(utils.init_weights) # Lossess criterion = obj_factory(criterion).to(device) # Benchmark benchmark = obj_factory(benchmark).to(device) # Support multiple GPUs if gpus and len(gpus) > 1: model = nn.DataParallel(model, gpus) # For each resolution start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0])) start_epoch = 0 if start_epoch is None else start_epoch for ri in range(start_res_ind, len(resolutions)): res = resolutions[ri] res_lr = learning_rate[ri] res_epochs = epochs[ri] res_iterations = iterations[ri] if iterations is not None else None res_batch_size = batch_size[ri] # Optimizer and scheduler optimizer = obj_factory(optimizer, model.parameters(), lr=res_lr) scheduler = obj_factory(scheduler, optimizer) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) # Initialize data loaders if res_iterations is None: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, len(train_dataset)) else: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, res_iterations) train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) if val_dataset is not None: if res_iterations is None: val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, len(val_dataset)) else: val_iterations = (res_iterations * len(val_dataset)) // len(train_dataset) val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, val_iterations) val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) else: val_loader = None # For each epoch for epoch in range(start_epoch, res_epochs): total_loss = proces_epoch(train_loader, train=True) if val_loader is not None: with torch.no_grad(): total_loss = proces_epoch(val_loader, train=False) if hasattr(benchmark, 'reset'): benchmark.reset() # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(total_loss) else: scheduler.step() # Save models checkpoints is_best = total_loss < best_loss best_loss = min(best_loss, total_loss) utils.save_checkpoint( exp_dir, 'model', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': model.module.state_dict() if gpus and len(gpus) > 1 else model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss_%d' % res: best_loss, 'arch': arch, }, is_best) # Reset start epoch to 0 because it's should only effect the first training resolution start_epoch = 0 best_loss = 1e6
def main(input_path, output_path=None, seq_postfix='_dsfd_seq.pkl', output_postfix='_dsfd_seq_lms_euler.pkl', pose_model_path='weights/hopenet_robust_alpha1.pkl', smooth_det=False, smooth_euler=False, gpus=None, cpu_only=False, batch_size=16): cache_path = os.path.splitext(input_path)[0] + seq_postfix output_path = os.path.splitext( input_path)[0] + output_postfix if output_path is None else output_path # Initialize device torch.set_grad_enabled(False) device, gpus = set_device(gpus, not cpu_only) # Load sequences from file with open(cache_path, "rb") as fp: # Unpickling seq_list = pickle.load(fp) # Load pose model face_pose = Hopenet().to(device) checkpoint = torch.load(pose_model_path) face_pose.load_state_dict(checkpoint) face_pose.train(False) # Open input video file cap = cv2.VideoCapture(input_path) if not cap.isOpened(): raise RuntimeError('Failed to read video: ' + input_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) input_vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) input_vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Smooth sequence bounding boxes if smooth_det: for seq in seq_list: seq.smooth() # For each sequence total_detections = sum([len(s) for s in seq_list]) pbar = tqdm(range(total_detections), unit='detections') for seq in seq_list: euler = [] frame_cropped_tensor_list = [] cap.set(cv2.CAP_PROP_POS_FRAMES, seq.start_index) # For each detection bounding box in the current sequence for i, det in enumerate(seq.detections): ret, frame_bgr = cap.read() if frame_bgr is None: raise RuntimeError('Failed to read frame from video!') frame_rgb = frame_bgr[:, :, ::-1] # Crop frame bbox = np.concatenate((det[:2], det[2:] - det[:2])) bbox = scale_bbox(bbox, 1.2) frame_cropped_rgb = crop_img(frame_rgb, bbox) frame_cropped_rgb = cv2.resize(frame_cropped_rgb, (224, 224), interpolation=cv2.INTER_CUBIC) frame_cropped_tensor = rgb2tensor(frame_cropped_rgb).to(device) # Gather batches frame_cropped_tensor_list.append(frame_cropped_tensor) if len(frame_cropped_tensor_list) < batch_size and (i + 1) < len(seq): continue frame_cropped_tensor_batch = torch.cat(frame_cropped_tensor_list, dim=0) # Calculate euler angles curr_euler_batch = face_pose( frame_cropped_tensor_batch) # Yaw, Pitch, Roll curr_euler_batch = curr_euler_batch.cpu().numpy() # For each prediction in the batch for b, curr_euler in enumerate(curr_euler_batch): # Add euler to list euler.append(curr_euler) # Render # render_img = tensor2bgr(frame_cropped_tensor_batch[b]).copy() # cv2.putText(render_img, '(%.2f, %.2f, %.2f)' % (curr_euler[0], curr_euler[1], curr_euler[2]), (15, 15), # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) # cv2.imshow('render_img', render_img) # if cv2.waitKey(0) & 0xFF == ord('q'): # break # Clear lists frame_cropped_tensor_list.clear() pbar.update(len(frame_cropped_tensor_batch)) # Add landmarks to sequence and optionally smooth them euler = np.array(euler) if smooth_euler: euler = smooth(euler) seq.euler = euler # Write final sequence list to file with open(output_path, "wb") as fp: # Pickling pickle.dump(seq_list, fp)
def main( source_path, target_path, arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)', model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth', pose_model_path='../weights/hopenet_robust_alpha1.pth', pil_transforms1=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), pil_transforms2=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)', 'landmark_transforms.LandmarksToHeatmaps'), tensor_transforms1=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), output_path=None, crop_size=256, display=False): torch.set_grad_enabled(False) # Initialize models fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) device, gpus = utils.set_device() G = obj_factory(arch).to(device) checkpoint = torch.load(model_path) G.load_state_dict(checkpoint['state_dict']) G.train(False) # Initialize pose Gp = Hopenet().to(device) checkpoint = torch.load(pose_model_path) Gp.load_state_dict(checkpoint['state_dict']) Gp.train(False) # Initialize transformations pil_transforms1 = obj_factory( pil_transforms1) if pil_transforms1 is not None else [] pil_transforms2 = obj_factory( pil_transforms2) if pil_transforms2 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 + tensor_transforms1) img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 + tensor_transforms2) # Process source image source_bgr = cv2.imread(source_path) source_rgb = source_bgr[:, :, ::-1] source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size) if source_bbox is None: raise RuntimeError("Couldn't detect a face in source image: " + source_path) source_tensor, source_landmarks, source_bbox = img_transforms1( source_rgb, source_landmarks, source_bbox) source_cropped_bgr = tensor2bgr( source_tensor[0] if isinstance(source_tensor, list) else source_tensor) for i in range(len(source_tensor)): source_tensor[i] = source_tensor[i].to(device) # Extract landmarks and bounding boxes from target video frame_indices, landmarks, bboxes, eulers = extract_landmarks_bboxes_euler_from_video( target_path, Gp, device=device) if frame_indices.size == 0: raise RuntimeError('No faces were detected in the target video: ' + target_path) # Open target video file cap = cv2.VideoCapture(target_path) if not cap.isOpened(): raise RuntimeError('Failed to read video: ' + target_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) # Initialize output video file if output_path is not None: if os.path.isdir(output_path): output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \ os.path.splitext(os.path.basename(target_path))[0] + '.mp4' output_path = os.path.join(output_path, output_filename) print(output_path) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_vid = cv2.VideoWriter( output_path, fourcc, fps, (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0])) else: out_vid = None # For each frame in the target video valid_frame_ind = 0 for i in tqdm(range(total_frames)): ret, frame = cap.read() if frame is None: continue if i not in frame_indices: continue frame_rgb = frame[:, :, ::-1] frame_tensor, frame_landmarks, frame_bbox = img_transforms2( frame_rgb, landmarks[valid_frame_ind], bboxes[valid_frame_ind]) valid_frame_ind += 1 # frame_cropped_rgb, frame_landmarks = process_cached_frame(frame_rgb, landmarks[valid_frame_ind], # bboxes[valid_frame_ind], size) # frame_cropped_bgr = frame_cropped_rgb[:, :, ::-1].copy() # valid_frame_ind += 1 # # frame_tensor, frame_landmarks_tensor = prepare_generator_input(frame_cropped_rgb, frame_landmarks) # frame_landmarks_tensor.to(device) input_tensor = [] for j in range(len(source_tensor)): frame_landmarks[j] = frame_landmarks[j].to(device) input_tensor.append( torch.cat((source_tensor[j], frame_landmarks[j]), dim=0).unsqueeze(0).to(device)) out_img_tensor, out_seg_tensor = G(input_tensor) # Transfer image1 mask to image2 # face_mask_tensor = out_seg_tensor.argmax(1) == 1 # face # face_mask_tensor = out_seg_tensor.argmax(1) == 2 # hair # face_mask_tensor = out_seg_tensor.argmax(1) >= 1 # head # target_img_tensor = frame_tensor[0].view(1, frame_tensor[0].shape[0], # frame_tensor[0].shape[1], frame_tensor[0].shape[2]).to(device) # Convert back to numpy images out_img_bgr = tensor2bgr(out_img_tensor) frame_cropped_bgr = tensor2bgr(frame_tensor[0]) # Render # for point in np.round(frame_landmarks).astype(int): # cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1) render_img = np.concatenate( (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1) if out_vid is not None: out_vid.write(render_img) if out_vid is None or display: cv2.imshow('render_img', render_img) if cv2.waitKey(1) & 0xFF == ord('q'): break
def main( source_path, target_path, arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)', model_path='../weights/ijbc_msrunet_256_1_2_reenactment_stepwise_v1.pth', pose_model_path='../weights/hopenet_robust_alpha1.pth', pil_transforms1=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), pil_transforms2=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), tensor_transforms1=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), output_path=None, crop_size=256, display=False): torch.set_grad_enabled(False) # Initialize models fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=True) device, gpus = utils.set_device() G = obj_factory(arch).to(device) checkpoint = torch.load(model_path) G.load_state_dict(checkpoint['state_dict']) G.train(False) # Initialize pose Gp = Hopenet().to(device) checkpoint = torch.load(pose_model_path) Gp.load_state_dict(checkpoint['state_dict']) Gp.train(False) # Initialize landmarks to heatmaps landmarks2heatmaps = [ LandmarkHeatmap(kernel_size=13, size=(256, 256)).to(device), LandmarkHeatmap(kernel_size=7, size=(128, 128)).to(device) ] # Initialize transformations pil_transforms1 = obj_factory( pil_transforms1) if pil_transforms1 is not None else [] pil_transforms2 = obj_factory( pil_transforms2) if pil_transforms2 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 + tensor_transforms1) img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 + tensor_transforms2) # Process source image source_bgr = cv2.imread(source_path) source_rgb = source_bgr[:, :, ::-1] source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size) if source_bbox is None: raise RuntimeError("Couldn't detect a face in source image: " + source_path) source_tensor, source_landmarks, source_bbox = img_transforms1( source_rgb, source_landmarks, source_bbox) source_cropped_bgr = tensor2bgr( source_tensor[0] if isinstance(source_tensor, list) else source_tensor) for i in range(len(source_tensor)): source_tensor[i] = source_tensor[i].unsqueeze(0).to(device) # Extract landmarks, bounding boxes, euler angles, and 3D landmarks from target video frame_indices, landmarks, bboxes, eulers, landmarks_3d = \ extract_landmarks_bboxes_euler_3d_from_video(target_path, Gp, fa, device=device) if frame_indices.size == 0: raise RuntimeError('No faces were detected in the target video: ' + target_path) # Open target video file cap = cv2.VideoCapture(target_path) if not cap.isOpened(): raise RuntimeError('Failed to read target video: ' + target_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) # Initialize output video file if output_path is not None: if os.path.isdir(output_path): output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \ os.path.splitext(os.path.basename(target_path))[0] + '.mp4' output_path = os.path.join(output_path, output_filename) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_vid = cv2.VideoWriter( output_path, fourcc, fps, (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0])) else: out_vid = None # For each frame in the target video valid_frame_ind = 0 for i in tqdm(range(total_frames)): ret, target_bgr = cap.read() if target_bgr is None: continue if i not in frame_indices: continue target_rgb = target_bgr[:, :, ::-1] target_tensor, target_landmarks, target_bbox = img_transforms2( target_rgb, landmarks_3d[valid_frame_ind], bboxes[valid_frame_ind]) target_euler = eulers[valid_frame_ind] valid_frame_ind += 1 # TODO: Calculate the number of required reenactment iterations reenactment_iterations = 2 # Generate landmarks sequence target_landmarks_sequence = [] for ri in range(1, reenactment_iterations): interp_landmarks = [] for j in range(len(source_tensor)): alpha = float(ri) / reenactment_iterations curr_interp_landmarks_np = interpolate_points( source_landmarks[j].cpu().numpy(), target_landmarks[j].cpu().numpy(), alpha=alpha) interp_landmarks.append( torch.from_numpy(curr_interp_landmarks_np)) target_landmarks_sequence.append(interp_landmarks) target_landmarks_sequence.append(target_landmarks) # Iterative reenactment out_img_tensor = source_tensor for curr_target_landmarks in target_landmarks_sequence: out_img_tensor = create_pyramid(out_img_tensor, 2) input_tensor = [] for j in range(len(out_img_tensor)): curr_target_landmarks[j] = curr_target_landmarks[j].unsqueeze( 0).to(device) curr_target_landmarks[j] = landmarks2heatmaps[j]( curr_target_landmarks[j]) input_tensor.append( torch.cat((out_img_tensor[j], curr_target_landmarks[j]), dim=1)) out_img_tensor, out_seg_tensor = G(input_tensor) # Convert back to numpy images out_img_bgr = tensor2bgr(out_img_tensor) frame_cropped_bgr = tensor2bgr(target_tensor[0]) # Render # for point in np.round(frame_landmarks).astype(int): # cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1) render_img = np.concatenate( (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1) if out_vid is not None: out_vid.write(render_img) if out_vid is None or display: cv2.imshow('render_img', render_img) if cv2.waitKey(1) & 0xFF == ord('q'): break
def main( # General arguments exp_dir, resume_dir=None, start_epoch=None, epochs=(90,), iterations=None, resolutions=(128, 256), lr_gen=(1e-4,), lr_dis=(1e-4,), gpus=None, workers=4, batch_size=(64,), seed=None, log_freq=20, # Data arguments train_dataset='opencv_video_seq_dataset.VideoSeqDataset', val_dataset=None, numpy_transforms=None, tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'), # Training arguments optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)', pretrained=False, criterion_pixelwise='nn.L1Loss', criterion_id='vgg_loss.VGGLoss', criterion_attr='vgg_loss.VGGLoss', criterion_gan='gan_loss.GANLoss(use_lsgan=True)', generator='res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)', discriminator='discriminators_pix2pix.MultiscaleDiscriminator', rec_weight=1.0, gan_weight=0.001 ): def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode G.train(train) D.train(train) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, optimizer_G.param_groups[0]['lr'])) # For each batch in the training data for i, (img, landmarks, target) in enumerate(pbar): # Prepare input with torch.no_grad(): # For each view images and landmarks landmarks[1] = landmarks[1].to(device) for j in range(len(img)): # landmarks[j] = landmarks[j].to(device) # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(device) # Remove unnecessary pyramids for j in range(len(img)): img[j] = img[j][-ri - 1:] # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0])): context = res_landmarks_decoders[p](landmarks[1]) input.append(torch.cat((img[0][p], context), dim=1)) # Reenactment img_pred = G(input) # Fake Detection and Loss img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0])) pred_fake_pool = D([x.detach() for x in img_pred_pyd]) loss_D_fake = criterion_gan(pred_fake_pool, False) # Real Detection and Loss pred_real = D(img[1]) loss_D_real = criterion_gan(pred_real, True) loss_D_total = (loss_D_fake + loss_D_real) * 0.5 # GAN loss (Fake Passability Loss) pred_fake = D(img_pred_pyd) loss_G_GAN = criterion_gan(pred_fake, True) # Reconstruction and segmentation loss loss_pixelwise = criterion_pixelwise(img_pred, img[1][0]) loss_id = criterion_id(img_pred, img[1][0]) loss_attr = criterion_attr(img_pred, img[1][0]) loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN if train: # Update generator weights optimizer_G.zero_grad() loss_G_total.backward() optimizer_G.step() # Update discriminator weights optimizer_D.zero_grad() loss_D_total.backward() optimizer_D.step() logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, g_gan=loss_G_GAN, d_gan=loss_D_total) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images grid = img_utils.make_grid(img[0][0], img_pred, img[1][0]) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['rec'].avg ################# # Main pipeline # ################# # Validation resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions] lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen] lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis] epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs] batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size] iterations = iterations if iterations is None or isinstance(iterations, (list, tuple)) else [iterations] lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs batch_size = batch_size * len(resolutions) if len(batch_size) == 1 else batch_size if iterations is not None: iterations = iterations * len(resolutions) if len(iterations) == 1 else iterations iterations = utils.str2int(iterations) if not os.path.isdir(exp_dir): raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'') assert len(lr_gen) == len(resolutions) assert len(lr_dis) == len(resolutions) assert len(epochs) == len(resolutions) assert len(batch_size) == len(resolutions) assert iterations is None or len(iterations) == len(resolutions) # Seed utils.set_seed(seed) # Check CUDA device availability device, gpus = utils.set_device(gpus) # Initialize loggers logger = TensorBoardLogger(log_dir=exp_dir) # Initialize datasets numpy_transforms = obj_factory(numpy_transforms) if numpy_transforms is not None else [] tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_lms_pose_transforms.Compose(numpy_transforms + tensor_transforms) train_dataset = obj_factory(train_dataset, transform=img_transforms) if val_dataset is not None: val_dataset = obj_factory(val_dataset, transform=img_transforms) # Create networks G_arch = utils.get_arch(generator) D_arch = utils.get_arch(discriminator) G = obj_factory(generator).to(device) D = obj_factory(discriminator).to(device) # Resume from a checkpoint or initialize the networks weights randomly checkpoint_dir = exp_dir if resume_dir is None else resume_dir G_path = os.path.join(checkpoint_dir, 'G_latest.pth') D_path = os.path.join(checkpoint_dir, 'D_latest.pth') best_loss = 1e6 curr_res = resolutions[0] optimizer_G_state, optimizer_D_state = None, None if os.path.isfile(G_path) and os.path.isfile(D_path): print("=> loading checkpoint from '{}'".format(checkpoint_dir)) # G checkpoint = torch.load(G_path) if 'resolution' in checkpoint: curr_res = checkpoint['resolution'] start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch # else: # curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0] best_loss_key = 'best_loss_%d' % curr_res best_loss = checkpoint[best_loss_key] if best_loss_key in checkpoint else best_loss G.apply(utils.init_weights) G.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_G_state = checkpoint['optimizer'] # D D.apply(utils.init_weights) if os.path.isfile(D_path): checkpoint = torch.load(D_path) D.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_D_state = checkpoint['optimizer'] else: print("=> no checkpoint found at '{}'".format(checkpoint_dir)) if not pretrained: print("=> randomly initializing networks...") G.apply(utils.init_weights) D.apply(utils.init_weights) # Initialize landmarks decoders landmarks_decoders = [] for res in resolutions: landmarks_decoders.insert(0, landmarks_utils.LandmarksHeatMapDecoder(res).to(device)) # Lossess criterion_pixelwise = obj_factory(criterion_pixelwise).to(device) criterion_id = obj_factory(criterion_id).to(device) criterion_attr = obj_factory(criterion_attr).to(device) criterion_gan = obj_factory(criterion_gan).to(device) # Support multiple GPUs if gpus and len(gpus) > 1: G = nn.DataParallel(G, gpus) D = nn.DataParallel(D, gpus) criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus) criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus) landmarks_decoders = [nn.DataParallel(ld, gpus) for ld in landmarks_decoders] # For each resolution start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0])) start_epoch = 0 if start_epoch is None else start_epoch for ri in range(start_res_ind, len(resolutions)): res = resolutions[ri] res_lr_gen = lr_gen[ri] res_lr_dis = lr_dis[ri] res_epochs = epochs[ri] res_iterations = iterations[ri] if iterations is not None else None res_batch_size = batch_size[ri] res_landmarks_decoders = landmarks_decoders[-ri - 1:] # Optimizer and scheduler optimizer_G = obj_factory(optimizer, G.parameters(), lr=res_lr_gen) optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis) scheduler_G = obj_factory(scheduler, optimizer_G) scheduler_D = obj_factory(scheduler, optimizer_D) if optimizer_G_state is not None: optimizer_G.load_state_dict(optimizer_G_state) optimizer_G_state = None if optimizer_D_state is not None: optimizer_D.load_state_dict(optimizer_D_state) optimizer_D_state = None # Initialize data loaders if res_iterations is None: train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, len(train_dataset)) else: train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, res_iterations) train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) if val_dataset is not None: if res_iterations is None: val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, len(val_dataset)) else: val_iterations = (res_iterations * len(val_dataset.classes)) // len(train_dataset.classes) val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, val_iterations) val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) else: val_loader = None # For each epoch for epoch in range(start_epoch, res_epochs): total_loss = proces_epoch(train_loader, train=True) if val_loader is not None: with torch.no_grad(): total_loss = proces_epoch(val_loader, train=False) # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler_G.step(total_loss) scheduler_D.step(total_loss) else: scheduler_G.step() scheduler_D.step() # Save models checkpoints is_best = total_loss < best_loss best_loss = min(best_loss, total_loss) utils.save_checkpoint(exp_dir, 'G', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': G.module.state_dict() if gpus and len(gpus) > 1 else G.state_dict(), 'optimizer': optimizer_G.state_dict(), 'best_loss_%d' % res: best_loss, 'arch': G_arch, }, is_best) utils.save_checkpoint(exp_dir, 'D', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': D.module.state_dict() if gpus and len(gpus) > 1 else D.state_dict(), 'optimizer': optimizer_D.state_dict(), 'best_loss_%d' % res: best_loss, 'arch': D_arch, }, is_best) # Reset start epoch to 0 because it's should only effect the first training resolution start_epoch = 0 best_loss = 1e6