def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset', np_transforms=None, tensor_transforms=( 'img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), workers=4, batch_size=4): import time from fsgan.utils.obj_factory import obj_factory from fsgan.utils.img_utils import tensor2bgr np_transforms = obj_factory( np_transforms) if np_transforms is not None else [] tensor_transforms = obj_factory( tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_landmarks_transforms.Compose(np_transforms + tensor_transforms) dataset = obj_factory(dataset, transform=img_transforms) dataloader = data.DataLoader(dataset, batch_size=4, num_workers=workers, pin_memory=True, drop_last=True, shuffle=True) start = time.time() if isinstance(dataset, ImageListDataset): for img, target in dataloader: print(img.shape) print(target) # For each batch for b in range(img.shape[0]): render_img = tensor2bgr(img[b]).copy() cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break else: for img1, img2, target in dataloader: print(img1.shape) print(img2.shape) print(target) # For each batch for b in range(target.shape[0]): left_img = tensor2bgr(img1[b]).copy() right_img = tensor2bgr(img2[b]).copy() render_img = np.concatenate((left_img, right_img), axis=1) cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break end = time.time() print('elapsed time: %f[s]' % (end - start))
def main(input, np_transforms=None, tensor_transforms=None, batch_size=4): from torchvision.transforms import Compose from fsgan.utils.obj_factory import obj_factory np_transforms = obj_factory(np_transforms) if np_transforms is not None else [] tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else [] img_transforms = Compose(np_transforms + tensor_transforms) img = cv2.imread(input) pose = np.array([1., 2., 3.]) x = img_transforms((img, pose)) pass
def main(dataset='fsgan.datasets.image_seg_dataset.ImageSegDataset', np_transforms1=None, np_transforms2=None, tensor_transforms1=( 'img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=('img_landmarks_transforms.ToTensor()', ), workers=4, batch_size=4): import time from fsgan.utils.obj_factory import obj_factory from fsgan.utils.seg_utils import blend_seg_pred, blend_seg_label from fsgan.utils.img_utils import tensor2bgr np_transforms1 = obj_factory( np_transforms1) if np_transforms1 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] img_transforms1 = img_landmarks_transforms.Compose(np_transforms1 + tensor_transforms1) np_transforms2 = obj_factory( np_transforms2) if np_transforms2 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms2 = img_landmarks_transforms.Compose(np_transforms2 + tensor_transforms2) dataset = obj_factory(dataset, transform=img_transforms1, target_transform=img_transforms2) dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, drop_last=True, shuffle=True) start = time.time() for img, seg in dataloader: # For each batch for b in range(img.shape[0]): blend_tensor = blend_seg_pred(img, seg) render_img = tensor2bgr(blend_tensor[b]) # render_img = tensor2bgr(img[b]) cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break end = time.time() print('elapsed time: %f[s]' % (end - start))
def main(model='res_unet.ResUNet', res=(256, )): from fsgan.utils.obj_factory import obj_factory model = obj_factory(model) if len(res) == 1: img = torch.rand(1, model.in_nc, res, res) pred = model(img) print(pred.shape) else: img = [] for i in range(1, len(res) + 1): img.append(torch.rand(1, model.in_nc, res[-i], res[-i])) pred = model(img) print(pred.shape)
def load_model(model_path, name='', device=None, arch=None, return_checkpoint=False, train=False): """ Load a model from checkpoint. This is a utility function that combines the model weights and architecture (string representation) to easily load any model without explicit knowledge of its class. Args: model_path (str): Path to the model's checkpoint (.pth) name (str): The name of the model (for printing and error management) device (torch.device): The device to load the model to arch (str): The model's architecture (string representation) return_checkpoint (bool): If True, the checkpoint will be returned as well train (bool): If True, the model will be set to train mode, else it will be set to test mode Returns: (nn.Module, dict (optional)): A tuple that contains: - model (nn.Module): The loaded model - checkpoint (dict, optional): The model's checkpoint (only if return_checkpoint is True) """ assert model_path is not None, '%s model must be specified!' % name assert os.path.exists( model_path), 'Couldn\'t find %s model in path: %s' % (name, model_path) print('=> Loading %s model: "%s"...' % (name, os.path.basename(model_path))) checkpoint = torch.load(model_path, map_location=torch.device('cpu')) assert arch is not None or 'arch' in checkpoint, 'Couldn\'t determine %s model architecture!' % name arch = checkpoint['arch'] if arch is None else arch model = obj_factory(arch) if device is not None: model.to(device) model.load_state_dict(checkpoint['state_dict']) model.train(train) if return_checkpoint: return model, checkpoint else: return model
def __init__( self, resolution=d('resolution'), crop_scale=d('crop_scale'), gpus=d('gpus'), cpu_only=d('cpu_only'), display=d('display'), verbose=d('verbose'), encoder_codec=d('encoder_codec'), # Detection arguments: detection_model=d('detection_model'), det_batch_size=d('det_batch_size'), det_postfix=d('det_postfix'), # Sequence arguments: iou_thresh=d('iou_thresh'), min_length=d('min_length'), min_size=d('min_size'), center_kernel=d('center_kernel'), size_kernel=d('size_kernel'), smooth_det=d('smooth_det'), seq_postfix=d('seq_postfix'), write_empty=d('write_empty'), # Pose arguments: pose_model=d('pose_model'), pose_batch_size=d('pose_batch_size'), pose_postfix=d('pose_postfix'), cache_pose=d('cache_pose'), cache_frontal=d('cache_frontal'), smooth_poses=d('smooth_poses'), # Landmarks arguments: lms_model=d('lms_model'), lms_batch_size=d('lms_batch_size'), landmarks_postfix=d('landmarks_postfix'), cache_landmarks=d('cache_landmarks'), smooth_landmarks=d('smooth_landmarks'), # Segmentation arguments: seg_model=d('seg_model'), smooth_segmentation=d('smooth_segmentation'), segmentation_postfix=d('segmentation_postfix'), cache_segmentation=d('cache_segmentation'), seg_batch_size=d('seg_batch_size'), seg_remove_mouth=d('seg_remove_mouth'), # Finetune arguments: finetune=d('finetune'), finetune_iterations=d('finetune_iterations'), finetune_lr=d('finetune_lr'), finetune_batch_size=d('finetune_batch_size'), finetune_workers=d('finetune_workers'), finetune_save=d('finetune_save'), # Swapping arguments: batch_size=d('batch_size'), reenactment_model=d('reenactment_model'), completion_model=d('completion_model'), blending_model=d('blending_model'), criterion_id=d('criterion_id'), min_radius=d('min_radius'), output_crop=d('output_crop'), renderer_process=d('renderer_process')): super(FaceSwapping, self).__init__(resolution, crop_scale, gpus, cpu_only, display, verbose, encoder_codec, detection_model=detection_model, det_batch_size=det_batch_size, det_postfix=det_postfix, iou_thresh=iou_thresh, min_length=min_length, min_size=min_size, center_kernel=center_kernel, size_kernel=size_kernel, smooth_det=smooth_det, seq_postfix=seq_postfix, write_empty=write_empty, pose_model=pose_model, pose_batch_size=pose_batch_size, pose_postfix=pose_postfix, cache_pose=True, cache_frontal=cache_frontal, smooth_poses=smooth_poses, lms_model=lms_model, lms_batch_size=lms_batch_size, landmarks_postfix=landmarks_postfix, cache_landmarks=True, smooth_landmarks=smooth_landmarks, seg_model=seg_model, seg_batch_size=seg_batch_size, segmentation_postfix=segmentation_postfix, cache_segmentation=True, smooth_segmentation=smooth_segmentation, seg_remove_mouth=seg_remove_mouth) self.batch_size = batch_size self.min_radius = min_radius self.output_crop = output_crop self.finetune_enabled = finetune self.finetune_iterations = finetune_iterations self.finetune_lr = finetune_lr self.finetune_batch_size = finetune_batch_size self.finetune_workers = finetune_workers self.finetune_save = finetune_save # Load reenactment model self.Gr, checkpoint = load_model(reenactment_model, 'face reenactment', self.device, return_checkpoint=True) self.Gr.arch = checkpoint['arch'] self.reenactment_state_dict = checkpoint['state_dict'] # Load all other models self.Gc = load_model(completion_model, 'face completion', self.device) self.Gb = load_model(blending_model, 'face blending', self.device) # Initialize landmarks decoders self.landmarks_decoders = [] for res in (128, 256): self.landmarks_decoders.insert( 0, LandmarksHeatMapDecoder(res).to(self.device)) # Initialize losses self.criterion_pixelwise = nn.L1Loss().to(self.device) self.criterion_id = obj_factory(criterion_id).to(self.device) # Support multiple GPUs if self.gpus and len(self.gpus) > 1: self.Gr = nn.DataParallel(self.Gr, self.gpus) self.Gc = nn.DataParallel(self.Gc, self.gpus) self.Gb = nn.DataParallel(self.Gb, self.gpus) self.criterion_id.vgg = nn.DataParallel(self.criterion_id.vgg, self.gpus) # Initialize soft erosion self.smooth_mask = SoftErosion(kernel_size=21, threshold=0.6).to(self.device) # Initialize video writer self.video_renderer = FaceSwappingRenderer( self.display, self.verbose, self.output_crop, self.resolution, self.crop_scale, encoder_codec, renderer_process) self.video_renderer.start()
def main( # General arguments exp_dir, resume_dir=None, start_epoch=None, epochs=(90, ), iterations=None, resolutions=(128, 256), lr_gen=(1e-4, ), lr_dis=(1e-4, ), gpus=None, workers=4, batch_size=(64, ), seed=None, log_freq=20, # Data arguments train_dataset='opencv_video_seq_dataset.VideoSeqDataset', val_dataset=None, numpy_transforms=None, tensor_transforms=( 'img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), # Training arguments optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)', pretrained=False, criterion_pixelwise='nn.L1Loss', criterion_id='vgg_loss.VGGLoss', criterion_attr='vgg_loss.VGGLoss', criterion_gan='gan_loss.GANLoss(use_lsgan=True)', generator='res_unet.MultiScaleResUNet(in_nc=4,out_nc=3)', discriminator='discriminators_pix2pix.MultiscaleDiscriminator', reenactment_model=None, seg_model=None, lms_model=None, pix_weight=0.1, rec_weight=1.0, gan_weight=0.001, background_value=-1.0): def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode Gc.train(train) D.train(train) Gr.train(False) S.train(False) L.train(False) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, scheduler_G.get_lr()[0])) # For each batch in the training data for i, (img, target) in enumerate(pbar): # Prepare input with torch.no_grad(): # For each view images for j in range(len(img)): # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(device) # Compute context context = L(img[1][0].sub(context_mean).div(context_std)) context = landmarks_utils.filter_landmarks(context) # Normalize each of the pyramid images for j in range(len(img)): for p in range(len(img[j])): img[j][p].sub_(img_mean).div_(img_std) # # Compute segmentation # seg = [] # for j in range(len(img)): # curr_seg = S(img[j][0]) # if curr_seg.shape[2:] != (res, res): # curr_seg = F.interpolate(curr_seg, (res, res), mode='bicubic', align_corners=False) # seg.append(curr_seg) # Compute segmentation target_seg = S(img[1][0]) if target_seg.shape[2:] != (res, res): target_seg = F.interpolate(target_seg, (res, res), mode='bicubic', align_corners=False) # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0]) - 1, -1, -1): context = F.interpolate(context, size=img[0][p].shape[2:], mode='bicubic', align_corners=False) input.insert(0, torch.cat((img[0][p], context), dim=1)) # Reenactment reenactment_img = Gr(input) reenactment_seg = S(reenactment_img) if reenactment_img.shape[2:] != (res, res): reenactment_img = F.interpolate(reenactment_img, (res, res), mode='bilinear', align_corners=False) reenactment_seg = F.interpolate(reenactment_seg, (res, res), mode='bilinear', align_corners=False) # Remove unnecessary pyramids for j in range(len(img)): img[j] = img[j][-ri - 1:] # Source face reenactment_face_mask = reenactment_seg.argmax(1) == 1 inpainting_mask = seg_utils.random_hair_inpainting_mask_tensor( reenactment_face_mask).to(device) reenactment_face_mask = reenactment_face_mask * ( inpainting_mask == 0) reenactment_img_with_hole = reenactment_img.masked_fill( ~reenactment_face_mask.unsqueeze(1), background_value) # Target face target_face_mask = (target_seg.argmax(1) == 1).unsqueeze(1) inpainting_target = img[1][0] inpainting_target.masked_fill_(~target_face_mask, background_value) # Inpainting input inpainting_input = torch.cat( (reenactment_img_with_hole, target_face_mask.float()), dim=1) inpainting_input_pyd = img_utils.create_pyramid( inpainting_input, len(img[0])) # Face inpainting inpainting_pred = Gc(inpainting_input_pyd) # Fake Detection and Loss inpainting_pred_pyd = img_utils.create_pyramid( inpainting_pred, len(img[0])) pred_fake_pool = D([x.detach() for x in inpainting_pred_pyd]) loss_D_fake = criterion_gan(pred_fake_pool, False) # Real Detection and Loss inpainting_target_pyd = img_utils.create_pyramid( inpainting_target, len(img[0])) pred_real = D(inpainting_target_pyd) loss_D_real = criterion_gan(pred_real, True) loss_D_total = (loss_D_fake + loss_D_real) * 0.5 # GAN loss (Fake Passability Loss) pred_fake = D(inpainting_pred_pyd) loss_G_GAN = criterion_gan(pred_fake, True) # Reconstruction loss_pixelwise = criterion_pixelwise(inpainting_pred, inpainting_target) loss_id = criterion_id(inpainting_pred, inpainting_target) loss_attr = criterion_attr(inpainting_pred, inpainting_target) loss_rec = pix_weight * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN if train: # Update generator weights optimizer_G.zero_grad() loss_G_total.backward() optimizer_G.step() # Update discriminator weights optimizer_D.zero_grad() loss_D_total.backward() optimizer_D.step() logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, g_gan=loss_G_GAN, d_gan=loss_D_total) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg( '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images grid = img_utils.make_grid(img[0][0], reenactment_img, reenactment_img_with_hole, inpainting_pred, inpainting_target) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['rec'].avg ################# # Main pipeline # ################# # Validation resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions] lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen] lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis] epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs] batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size] iterations = iterations if iterations is None or isinstance( iterations, (list, tuple)) else [iterations] lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs batch_size = batch_size * len(resolutions) if len( batch_size) == 1 else batch_size if iterations is not None: iterations = iterations * len(resolutions) if len( iterations) == 1 else iterations iterations = utils.str2int(iterations) if not os.path.isdir(exp_dir): raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'') assert len(lr_gen) == len(resolutions) assert len(lr_dis) == len(resolutions) assert len(epochs) == len(resolutions) assert len(batch_size) == len(resolutions) assert iterations is None or len(iterations) == len(resolutions) # Seed utils.set_seed(seed) # Check CUDA device availability device, gpus = utils.set_device(gpus) # Initialize loggers logger = TensorBoardLogger(log_dir=exp_dir) # Initialize datasets numpy_transforms = obj_factory( numpy_transforms) if numpy_transforms is not None else [] tensor_transforms = obj_factory( tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_landmarks_transforms.Compose(numpy_transforms + tensor_transforms) train_dataset = obj_factory(train_dataset, transform=img_transforms) if val_dataset is not None: val_dataset = obj_factory(val_dataset, transform=img_transforms) # Create networks Gc = obj_factory(generator).to(device) D = obj_factory(discriminator).to(device) # Resume from a checkpoint or initialize the networks weights randomly checkpoint_dir = exp_dir if resume_dir is None else resume_dir Gc_path = os.path.join(checkpoint_dir, 'Gc_latest.pth') D_path = os.path.join(checkpoint_dir, 'D_latest.pth') best_loss = 1000000. curr_res = resolutions[0] optimizer_G_state, optimizer_D_state = None, None if os.path.isfile(Gc_path) and os.path.isfile(D_path): print("=> loading checkpoint from '{}'".format(checkpoint_dir)) # Gc checkpoint = torch.load(Gc_path) if 'resolution' in checkpoint: curr_res = checkpoint['resolution'] start_epoch = checkpoint[ 'epoch'] if start_epoch is None else start_epoch else: curr_res = resolutions[1] if len( resolutions) > 1 else resolutions[0] best_loss = checkpoint['best_loss'] Gc.apply(utils.init_weights) Gc.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_G_state = checkpoint['optimizer'] # D D.apply(utils.init_weights) if os.path.isfile(D_path): checkpoint = torch.load(D_path) D.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_D_state = checkpoint['optimizer'] else: print("=> no checkpoint found at '{}'".format(checkpoint_dir)) if not pretrained: print("=> randomly initializing networks...") Gc.apply(utils.init_weights) D.apply(utils.init_weights) # Load reenactment model print('=> Loading face reenactment model: "' + os.path.basename(reenactment_model) + '"...') if reenactment_model is None: raise RuntimeError('Reenactment model must be specified!') if not os.path.exists(reenactment_model): raise RuntimeError('Couldn\'t find reenactment model in path: ' + reenactment_model) checkpoint = torch.load(reenactment_model) Gr = obj_factory(checkpoint['arch']).to(device) Gr.load_state_dict(checkpoint['state_dict']) # Load segmentation model print('=> Loading face segmentation model: "' + os.path.basename(seg_model) + '"...') if seg_model is None: raise RuntimeError('Segmentation model must be specified!') if not os.path.exists(seg_model): raise RuntimeError('Couldn\'t find segmentation model in path: ' + seg_model) checkpoint = torch.load(seg_model) S = obj_factory(checkpoint['arch']).to(device) S.load_state_dict(checkpoint['state_dict']) # Load face landmarks model print('=> Loading face landmarks model: "' + os.path.basename(lms_model) + '"...') assert os.path.isfile( lms_model), 'The model path "%s" does not exist' % lms_model L = hrnet_wlfw().to(device) state_dict = torch.load(lms_model) L.load_state_dict(state_dict) # Initialize normalization tensors # Note: this is necessary because of the landmarks model img_mean = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1) img_std = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1) context_mean = torch.as_tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) context_std = torch.as_tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) # Lossess criterion_pixelwise = obj_factory(criterion_pixelwise).to(device) criterion_id = obj_factory(criterion_id).to(device) criterion_attr = obj_factory(criterion_attr).to(device) criterion_gan = obj_factory(criterion_gan).to(device) # Support multiple GPUs if gpus and len(gpus) > 1: Gc = nn.DataParallel(Gc, gpus) Gr = nn.DataParallel(Gr, gpus) D = nn.DataParallel(D, gpus) S = nn.DataParallel(S, gpus) L = nn.DataParallel(L, gpus) criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus) criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus) # For each resolution start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0])) start_epoch = 0 if start_epoch is None else start_epoch for ri in range(start_res_ind, len(resolutions)): res = resolutions[ri] res_lr_gen = lr_gen[ri] res_lr_dis = lr_dis[ri] res_epochs = epochs[ri] res_iterations = iterations[ri] if iterations is not None else None res_batch_size = batch_size[ri] # Optimizer and scheduler optimizer_G = obj_factory(optimizer, Gc.parameters(), lr=res_lr_gen) optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis) scheduler_G = obj_factory(scheduler, optimizer_G) scheduler_D = obj_factory(scheduler, optimizer_D) if optimizer_G_state is not None: optimizer_G.load_state_dict(optimizer_G_state) optimizer_G_state = None if optimizer_D_state is not None: optimizer_D.load_state_dict(optimizer_D_state) optimizer_D_state = None # Initialize data loaders if res_iterations is None: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, len(train_dataset)) else: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, res_iterations) train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) if val_dataset is not None: if res_iterations is None: val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, len(val_dataset)) else: val_iterations = (res_iterations * len( val_dataset.classes)) // len(train_dataset.classes) val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, val_iterations) val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) else: val_loader = None # For each epoch for epoch in range(start_epoch, res_epochs): total_loss = proces_epoch(train_loader, train=True) if val_loader is not None: with torch.no_grad(): total_loss = proces_epoch(val_loader, train=False) # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler_G.step(total_loss) scheduler_D.step(total_loss) else: scheduler_G.step() scheduler_D.step() # Save models checkpoints is_best = total_loss < best_loss best_loss = min(best_loss, total_loss) utils.save_checkpoint( exp_dir, 'Gc', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': Gc.module.state_dict() if gpus and len(gpus) > 1 else Gc.state_dict(), 'optimizer': optimizer_G.state_dict(), 'best_loss': best_loss, }, is_best) utils.save_checkpoint( exp_dir, 'D', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': D.module.state_dict() if gpus and len(gpus) > 1 else D.state_dict(), 'optimizer': optimizer_D.state_dict(), 'best_loss': best_loss, }, is_best) # Reset start epoch to 0 because it's should only effect the first training resolution start_epoch = 0
def main( source_path, target_path, arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)', reenactment_model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth', seg_model_path='../weights/lfw_figaro_unet_256_2_0_segmentation_v1.pth', inpainting_model_path='../weights/ijbc_msrunet_256_2_0_inpainting_v1.pth', blend_model_path='../weights/ijbc_msrunet_256_2_0_blending_v1.pth', pose_model_path='../weights/hopenet_robust_alpha1.pth', pil_transforms1=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), pil_transforms2=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)', 'landmark_transforms.LandmarksToHeatmaps'), tensor_transforms1=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), output_path=None, min_radius=2.0, crop_size=256, reverse_output=False, verbose=0, output_crop=False, display=False): torch.set_grad_enabled(False) fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) device, gpus = utils.set_device() Gr = obj_factory(arch).to(device) checkpoint = torch.load(reenactment_model_path) Gr.load_state_dict(checkpoint['state_dict']) Gr.train(False) if seg_model_path is not None: print('Loading face segmentation model: "' + os.path.basename(seg_model_path) + '"...') if seg_model_path.endswith('.pth'): checkpoint = torch.load(seg_model_path) Gs = obj_factory(checkpoint['arch']).to(device) Gs.load_state_dict(checkpoint['state_dict']) else: Gs = torch.jit.load(seg_model_path, map_location=device) if Gs is None: raise RuntimeError('Failed to load face segmentation model!') Gs.eval() else: Gs = None if seg_model_path is not None: print('Loading face inpainting model: "' + os.path.basename(inpainting_model_path) + '"...') if inpainting_model_path.endswith('.pth'): checkpoint = torch.load(inpainting_model_path) Gi = obj_factory(checkpoint['arch']).to(device) Gi.load_state_dict(checkpoint['state_dict']) else: Gi = torch.jit.load(inpainting_model_path, map_location=device) if Gi is None: raise RuntimeError('Failed to load face segmentation model!') Gi.eval() else: Gi = None checkpoint = torch.load(blend_model_path) Gb = obj_factory(checkpoint['arch']).to(device) Gb.load_state_dict(checkpoint['state_dict']) Gb.train(False) Gp = Hopenet().to(device) checkpoint = torch.load(pose_model_path) Gp.load_state_dict(checkpoint['state_dict']) Gp.train(False) pil_transforms1 = obj_factory( pil_transforms1) if pil_transforms1 is not None else [] pil_transforms2 = obj_factory( pil_transforms2) if pil_transforms2 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 + tensor_transforms1) img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 + tensor_transforms2) source_frame_indices, source_landmarks, source_bboxes, source_eulers = \ extract_landmarks_bboxes_euler_from_images(source_path, Gp, fa, device=device) if source_frame_indices.size == 0: raise RuntimeError( 'No faces were detected in the source image directory: ' + source_path) target_frame_indices, target_landmarks, target_bboxes, target_eulers = \ extract_landmarks_bboxes_euler_from_images(target_path, Gp, fa, device=device) if target_frame_indices.size == 0: raise RuntimeError( 'No faces were detected in the target image directory: ' + target_path) source_img_paths = glob(os.path.join(source_path, '*.jpg')) target_img_paths = glob(os.path.join(target_path, '*.jpg')) source_valid_frame_ind = 0 for k, source_img_path in tqdm(enumerate(source_img_paths), unit='images', total=len(source_img_paths)): if k not in source_frame_indices: continue source_img_bgr = cv2.imread(source_img_path) if source_img_bgr is None: continue source_img_rgb = source_img_bgr[:, :, ::-1] curr_source_tensor, curr_source_landmarks, curr_source_bbox = img_transforms1( source_img_rgb, source_landmarks[source_valid_frame_ind], source_bboxes[source_valid_frame_ind]) source_valid_frame_ind += 1 for j in range(len(curr_source_tensor)): curr_source_tensor[j] = curr_source_tensor[j].to(device) target_valid_frame_ind = 0 for i, target_img_path in enumerate(target_img_paths): curr_output_name = '_'.join([ os.path.splitext(os.path.basename(source_img_path))[0], os.path.splitext(os.path.basename(target_img_path))[0] ]) + '.jpg' curr_output_path = os.path.join(output_path, curr_output_name) if os.path.isfile(curr_output_path): target_valid_frame_ind += 1 continue target_img_bgr = cv2.imread(target_img_path) if target_img_bgr is None: continue if i not in target_frame_indices: continue target_img_rgb = target_img_bgr[:, :, ::-1] curr_target_tensor, curr_target_landmarks, curr_target_bbox = img_transforms2( target_img_rgb, target_landmarks[target_valid_frame_ind], target_bboxes[target_valid_frame_ind]) curr_target_euler = target_eulers[target_valid_frame_ind] target_valid_frame_ind += 1 reenactment_input_tensor = [] for j in range(len(curr_source_tensor)): curr_target_landmarks[j] = curr_target_landmarks[j].to(device) reenactment_input_tensor.append( torch.cat( (curr_source_tensor[j], curr_target_landmarks[j]), dim=0).unsqueeze(0)) reenactment_img_tensor, reenactment_seg_tensor = Gr( reenactment_input_tensor) target_img_tensor = curr_target_tensor[0].unsqueeze(0).to(device) target_seg_pred_tensor = Gs(target_img_tensor) target_mask_tensor = target_seg_pred_tensor.argmax(1) == 1 aligned_face_mask_tensor = reenactment_seg_tensor.argmax(1) == 1 aligned_background_mask_tensor = ~aligned_face_mask_tensor aligned_img_no_background_tensor = reenactment_img_tensor.clone() aligned_img_no_background_tensor.masked_fill_( aligned_background_mask_tensor.unsqueeze(1), -1.0) inpainting_input_tensor = torch.cat( (aligned_img_no_background_tensor, target_mask_tensor.unsqueeze(1).float()), dim=1) inpainting_input_tensor_pyd = create_pyramid( inpainting_input_tensor, len(curr_target_tensor)) completion_tensor = Gi(inpainting_input_tensor_pyd) transfer_tensor = transfer_mask(completion_tensor, target_img_tensor, target_mask_tensor) blend_input_tensor = torch.cat( (transfer_tensor, target_img_tensor, target_mask_tensor.unsqueeze(1).float()), dim=1) blend_input_tensor_pyd = create_pyramid(blend_input_tensor, len(curr_target_tensor)) blend_tensor = Gb(blend_input_tensor_pyd) blend_img = tensor2bgr(blend_tensor) if verbose == 0: render_img = blend_img if output_crop else crop2img( target_img_bgr, blend_img, curr_target_bbox[0].numpy()) elif verbose == 1: reenactment_only_tensor = transfer_mask( reenactment_img_tensor, target_img_tensor, aligned_face_mask_tensor & target_mask_tensor) reenactment_only_img = tensor2bgr(reenactment_only_tensor) completion_only_img = tensor2bgr(transfer_tensor) transfer_tensor = transfer_mask( aligned_img_no_background_tensor, target_img_tensor, target_mask_tensor) blend_input_tensor = torch.cat( (transfer_tensor, target_img_tensor, target_mask_tensor.unsqueeze(1).float()), dim=1) blend_input_tensor_pyd = create_pyramid( blend_input_tensor, len(curr_target_tensor)) blend_tensor = Gb(blend_input_tensor_pyd) blend_only_img = tensor2bgr(blend_tensor) render_img = np.concatenate( (reenactment_only_img, completion_only_img, blend_only_img, blend_img), axis=1) elif verbose == 2: reenactment_img_bgr = tensor2bgr(reenactment_img_tensor) reenactment_seg_bgr = tensor2bgr( blend_seg_pred(reenactment_img_tensor, reenactment_seg_tensor)) target_seg_bgr = tensor2bgr( blend_seg_pred(target_img_tensor, target_seg_pred_tensor)) aligned_img_no_background_bgr = tensor2bgr( aligned_img_no_background_tensor) completion_bgr = tensor2bgr(completion_tensor) transfer_bgr = tensor2bgr(transfer_tensor) target_cropped_bgr = tensor2bgr(target_img_tensor) pose_axis_bgr = draw_axis(np.zeros_like(target_cropped_bgr), curr_target_euler[0], curr_target_euler[1], curr_target_euler[2]) render_img1 = np.concatenate( (reenactment_img_bgr, reenactment_seg_bgr, target_seg_bgr), axis=1) render_img2 = np.concatenate((aligned_img_no_background_bgr, completion_bgr, transfer_bgr), axis=1) render_img3 = np.concatenate( (pose_axis_bgr, blend_img, target_cropped_bgr), axis=1) render_img = np.concatenate( (render_img1, render_img2, render_img3), axis=0) elif verbose == 3: source_cropped_bgr = tensor2bgr( curr_source_tensor[0].unsqueeze(0)) target_cropped_bgr = tensor2bgr(target_img_tensor) render_img = np.concatenate( (source_cropped_bgr, target_cropped_bgr, blend_img), axis=1) cv2.imwrite(curr_output_path, render_img) if display: cv2.imshow('render_img', render_img) if cv2.waitKey(1) & 0xFF == ord('q'): break
def main( # General arguments exp_dir, resume_dir=None, start_epoch=None, epochs=(90, ), iterations=None, resolutions=(128, 256), learning_rate=(1e-1, ), gpus=None, workers=4, batch_size=(64, ), seed=None, log_freq=20, # Data arguments train_dataset='fsgan.image_seg_dataset.ImageSegDataset', val_dataset=None, numpy_transforms=None, tensor_transforms=( 'img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), # Training arguments optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)', criterion='nn.CrossEntropyLoss', model='fsgan.models.simple_unet.UNet(n_classes=3,feature_scale=1)', pretrained=False, benchmark='fsgan.train_segmentation.IOUBenchmark(3)'): def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode model.train(train) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, scheduler.get_lr()[0])) # For each batch in the training data for i, (input, target) in enumerate(pbar): # Prepare input input = input.to(device) target = target.to(device) with torch.no_grad(): target = target.argmax(dim=1) # Execute model pred = model(input) # Calculate loss loss_total = criterion(pred, target) # Run benchmark benchmark_res = benchmark(pred, target) if benchmark is not None else {} if train: # Update generator weights optimizer.zero_grad() loss_total.backward() optimizer.step() logger.update('losses', total=loss_total) logger.update('bench', **benchmark_res) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg( '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images seg_pred = blend_seg_pred(input, pred) seg_gt = blend_seg_label(input, target) grid = img_utils.make_grid(input, seg_pred, seg_gt) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['total'].avg ################# # Main pipeline # ################# # Validation resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions] learning_rate = learning_rate if isinstance(learning_rate, (list, tuple)) else [learning_rate] epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs] batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size] iterations = iterations if iterations is None or isinstance( iterations, (list, tuple)) else [iterations] learning_rate = learning_rate * len(resolutions) if len( learning_rate) == 1 else learning_rate epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs batch_size = batch_size * len(resolutions) if len( batch_size) == 1 else batch_size if iterations is not None: iterations = iterations * len(resolutions) if len( iterations) == 1 else iterations iterations = utils.str2int(iterations) if not os.path.isdir(exp_dir): raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'') assert len(learning_rate) == len(resolutions) assert len(epochs) == len(resolutions) assert len(batch_size) == len(resolutions) assert iterations is None or len(iterations) == len(resolutions) # Seed utils.set_seed(seed) # Check CUDA device availability device, gpus = utils.set_device(gpus) # Initialize loggers logger = TensorBoardLogger(log_dir=exp_dir) # Initialize datasets numpy_transforms = obj_factory( numpy_transforms) if numpy_transforms is not None else [] tensor_transforms = obj_factory( tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_landmarks_transforms.Compose(numpy_transforms + tensor_transforms) train_dataset = obj_factory(train_dataset, transform=img_transforms) if val_dataset is not None: val_dataset = obj_factory(val_dataset, transform=img_transforms) # Create networks arch = utils.get_arch(model, num_classes=len(train_dataset.classes)) model = obj_factory(model, num_classes=len(train_dataset.classes)).to(device) # Resume from a checkpoint or initialize the networks weights randomly checkpoint_dir = exp_dir if resume_dir is None else resume_dir model_path = os.path.join(checkpoint_dir, 'model_latest.pth') best_loss = 1e6 curr_res = resolutions[0] optimizer_state = None if os.path.isfile(model_path): print("=> loading checkpoint from '{}'".format(checkpoint_dir)) # model checkpoint = torch.load(model_path) if 'resolution' in checkpoint: curr_res = checkpoint['resolution'] start_epoch = checkpoint[ 'epoch'] if start_epoch is None else start_epoch # else: # curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0] best_loss_key = 'best_loss_%d' % curr_res best_loss = checkpoint[ best_loss_key] if best_loss_key in checkpoint else best_loss model.apply(utils.init_weights) model.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_state = checkpoint['optimizer'] else: print("=> no checkpoint found at '{}'".format(checkpoint_dir)) if not pretrained: print("=> randomly initializing networks...") model.apply(utils.init_weights) # Lossess criterion = obj_factory(criterion).to(device) # Benchmark benchmark = obj_factory(benchmark).to(device) # Support multiple GPUs if gpus and len(gpus) > 1: model = nn.DataParallel(model, gpus) # For each resolution start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0])) start_epoch = 0 if start_epoch is None else start_epoch for ri in range(start_res_ind, len(resolutions)): res = resolutions[ri] res_lr = learning_rate[ri] res_epochs = epochs[ri] res_iterations = iterations[ri] if iterations is not None else None res_batch_size = batch_size[ri] # Optimizer and scheduler optimizer = obj_factory(optimizer, model.parameters(), lr=res_lr) scheduler = obj_factory(scheduler, optimizer) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) # Initialize data loaders if res_iterations is None: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, len(train_dataset)) else: train_sampler = tutils.data.sampler.WeightedRandomSampler( train_dataset.weights, res_iterations) train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) if val_dataset is not None: if res_iterations is None: val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, len(val_dataset)) else: val_iterations = (res_iterations * len(val_dataset)) // len(train_dataset) val_sampler = tutils.data.sampler.WeightedRandomSampler( val_dataset.weights, val_iterations) val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) else: val_loader = None # For each epoch for epoch in range(start_epoch, res_epochs): total_loss = proces_epoch(train_loader, train=True) if val_loader is not None: with torch.no_grad(): total_loss = proces_epoch(val_loader, train=False) if hasattr(benchmark, 'reset'): benchmark.reset() # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(total_loss) else: scheduler.step() # Save models checkpoints is_best = total_loss < best_loss best_loss = min(best_loss, total_loss) utils.save_checkpoint( exp_dir, 'model', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': model.module.state_dict() if gpus and len(gpus) > 1 else model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_loss_%d' % res: best_loss, 'arch': arch, }, is_best) # Reset start epoch to 0 because it's should only effect the first training resolution start_epoch = 0 best_loss = 1e6
def main( source_path, target_path, arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)', model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth', pose_model_path='../weights/hopenet_robust_alpha1.pth', pil_transforms1=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), pil_transforms2=('landmark_transforms.FaceAlignCrop', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)', 'landmark_transforms.LandmarksToHeatmaps'), tensor_transforms1=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), output_path=None, crop_size=256, display=False): torch.set_grad_enabled(False) # Initialize models fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) device, gpus = utils.set_device() G = obj_factory(arch).to(device) checkpoint = torch.load(model_path) G.load_state_dict(checkpoint['state_dict']) G.train(False) # Initialize pose Gp = Hopenet().to(device) checkpoint = torch.load(pose_model_path) Gp.load_state_dict(checkpoint['state_dict']) Gp.train(False) # Initialize transformations pil_transforms1 = obj_factory( pil_transforms1) if pil_transforms1 is not None else [] pil_transforms2 = obj_factory( pil_transforms2) if pil_transforms2 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 + tensor_transforms1) img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 + tensor_transforms2) # Process source image source_bgr = cv2.imread(source_path) source_rgb = source_bgr[:, :, ::-1] source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size) if source_bbox is None: raise RuntimeError("Couldn't detect a face in source image: " + source_path) source_tensor, source_landmarks, source_bbox = img_transforms1( source_rgb, source_landmarks, source_bbox) source_cropped_bgr = tensor2bgr( source_tensor[0] if isinstance(source_tensor, list) else source_tensor) for i in range(len(source_tensor)): source_tensor[i] = source_tensor[i].to(device) # Extract landmarks and bounding boxes from target video frame_indices, landmarks, bboxes, eulers = extract_landmarks_bboxes_euler_from_video( target_path, Gp, device=device) if frame_indices.size == 0: raise RuntimeError('No faces were detected in the target video: ' + target_path) # Open target video file cap = cv2.VideoCapture(target_path) if not cap.isOpened(): raise RuntimeError('Failed to read video: ' + target_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) # Initialize output video file if output_path is not None: if os.path.isdir(output_path): output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \ os.path.splitext(os.path.basename(target_path))[0] + '.mp4' output_path = os.path.join(output_path, output_filename) print(output_path) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_vid = cv2.VideoWriter( output_path, fourcc, fps, (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0])) else: out_vid = None # For each frame in the target video valid_frame_ind = 0 for i in tqdm(range(total_frames)): ret, frame = cap.read() if frame is None: continue if i not in frame_indices: continue frame_rgb = frame[:, :, ::-1] frame_tensor, frame_landmarks, frame_bbox = img_transforms2( frame_rgb, landmarks[valid_frame_ind], bboxes[valid_frame_ind]) valid_frame_ind += 1 # frame_cropped_rgb, frame_landmarks = process_cached_frame(frame_rgb, landmarks[valid_frame_ind], # bboxes[valid_frame_ind], size) # frame_cropped_bgr = frame_cropped_rgb[:, :, ::-1].copy() # valid_frame_ind += 1 # # frame_tensor, frame_landmarks_tensor = prepare_generator_input(frame_cropped_rgb, frame_landmarks) # frame_landmarks_tensor.to(device) input_tensor = [] for j in range(len(source_tensor)): frame_landmarks[j] = frame_landmarks[j].to(device) input_tensor.append( torch.cat((source_tensor[j], frame_landmarks[j]), dim=0).unsqueeze(0).to(device)) out_img_tensor, out_seg_tensor = G(input_tensor) # Transfer image1 mask to image2 # face_mask_tensor = out_seg_tensor.argmax(1) == 1 # face # face_mask_tensor = out_seg_tensor.argmax(1) == 2 # hair # face_mask_tensor = out_seg_tensor.argmax(1) >= 1 # head # target_img_tensor = frame_tensor[0].view(1, frame_tensor[0].shape[0], # frame_tensor[0].shape[1], frame_tensor[0].shape[2]).to(device) # Convert back to numpy images out_img_bgr = tensor2bgr(out_img_tensor) frame_cropped_bgr = tensor2bgr(frame_tensor[0]) # Render # for point in np.round(frame_landmarks).astype(int): # cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1) render_img = np.concatenate( (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1) if out_vid is not None: out_vid.write(render_img) if out_vid is None or display: cv2.imshow('render_img', render_img) if cv2.waitKey(1) & 0xFF == ord('q'): break
def main( source_path, target_path, arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)', model_path='../weights/ijbc_msrunet_256_1_2_reenactment_stepwise_v1.pth', pose_model_path='../weights/hopenet_robust_alpha1.pth', pil_transforms1=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), pil_transforms2=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), tensor_transforms1=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), output_path=None, crop_size=256, display=False): torch.set_grad_enabled(False) # Initialize models fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=True) device, gpus = utils.set_device() G = obj_factory(arch).to(device) checkpoint = torch.load(model_path) G.load_state_dict(checkpoint['state_dict']) G.train(False) # Initialize pose Gp = Hopenet().to(device) checkpoint = torch.load(pose_model_path) Gp.load_state_dict(checkpoint['state_dict']) Gp.train(False) # Initialize landmarks to heatmaps landmarks2heatmaps = [ LandmarkHeatmap(kernel_size=13, size=(256, 256)).to(device), LandmarkHeatmap(kernel_size=7, size=(128, 128)).to(device) ] # Initialize transformations pil_transforms1 = obj_factory( pil_transforms1) if pil_transforms1 is not None else [] pil_transforms2 = obj_factory( pil_transforms2) if pil_transforms2 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 + tensor_transforms1) img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 + tensor_transforms2) # Process source image source_bgr = cv2.imread(source_path) source_rgb = source_bgr[:, :, ::-1] source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size) if source_bbox is None: raise RuntimeError("Couldn't detect a face in source image: " + source_path) source_tensor, source_landmarks, source_bbox = img_transforms1( source_rgb, source_landmarks, source_bbox) source_cropped_bgr = tensor2bgr( source_tensor[0] if isinstance(source_tensor, list) else source_tensor) for i in range(len(source_tensor)): source_tensor[i] = source_tensor[i].unsqueeze(0).to(device) # Extract landmarks, bounding boxes, euler angles, and 3D landmarks from target video frame_indices, landmarks, bboxes, eulers, landmarks_3d = \ extract_landmarks_bboxes_euler_3d_from_video(target_path, Gp, fa, device=device) if frame_indices.size == 0: raise RuntimeError('No faces were detected in the target video: ' + target_path) # Open target video file cap = cv2.VideoCapture(target_path) if not cap.isOpened(): raise RuntimeError('Failed to read target video: ' + target_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) # Initialize output video file if output_path is not None: if os.path.isdir(output_path): output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \ os.path.splitext(os.path.basename(target_path))[0] + '.mp4' output_path = os.path.join(output_path, output_filename) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_vid = cv2.VideoWriter( output_path, fourcc, fps, (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0])) else: out_vid = None # For each frame in the target video valid_frame_ind = 0 for i in tqdm(range(total_frames)): ret, target_bgr = cap.read() if target_bgr is None: continue if i not in frame_indices: continue target_rgb = target_bgr[:, :, ::-1] target_tensor, target_landmarks, target_bbox = img_transforms2( target_rgb, landmarks_3d[valid_frame_ind], bboxes[valid_frame_ind]) target_euler = eulers[valid_frame_ind] valid_frame_ind += 1 # TODO: Calculate the number of required reenactment iterations reenactment_iterations = 2 # Generate landmarks sequence target_landmarks_sequence = [] for ri in range(1, reenactment_iterations): interp_landmarks = [] for j in range(len(source_tensor)): alpha = float(ri) / reenactment_iterations curr_interp_landmarks_np = interpolate_points( source_landmarks[j].cpu().numpy(), target_landmarks[j].cpu().numpy(), alpha=alpha) interp_landmarks.append( torch.from_numpy(curr_interp_landmarks_np)) target_landmarks_sequence.append(interp_landmarks) target_landmarks_sequence.append(target_landmarks) # Iterative reenactment out_img_tensor = source_tensor for curr_target_landmarks in target_landmarks_sequence: out_img_tensor = create_pyramid(out_img_tensor, 2) input_tensor = [] for j in range(len(out_img_tensor)): curr_target_landmarks[j] = curr_target_landmarks[j].unsqueeze( 0).to(device) curr_target_landmarks[j] = landmarks2heatmaps[j]( curr_target_landmarks[j]) input_tensor.append( torch.cat((out_img_tensor[j], curr_target_landmarks[j]), dim=1)) out_img_tensor, out_seg_tensor = G(input_tensor) # Convert back to numpy images out_img_bgr = tensor2bgr(out_img_tensor) frame_cropped_bgr = tensor2bgr(target_tensor[0]) # Render # for point in np.round(frame_landmarks).astype(int): # cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1) render_img = np.concatenate( (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1) if out_vid is not None: out_vid.write(render_img) if out_vid is None or display: cv2.imshow('render_img', render_img) if cv2.waitKey(1) & 0xFF == ord('q'): break
def main(dataset='opencv_video_seq_dataset.VideoSeqDataset', np_transforms=None, tensor_transforms=( 'img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), workers=4, batch_size=4): import time from fsgan.utils.obj_factory import obj_factory from fsgan.utils.img_utils import tensor2bgr np_transforms = obj_factory( np_transforms) if np_transforms is not None else [] tensor_transforms = obj_factory( tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_landmarks_transforms.Compose(np_transforms + tensor_transforms) dataset = obj_factory(dataset, transform=img_transforms) # dataset = VideoSeqDataset(root_path, img_list_path, transform=img_transforms, frame_window=frame_window) dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, drop_last=True, shuffle=True) start = time.time() for frame_window, landmarks_window in dataloader: # print(frame_window.shape) if isinstance(frame_window, (list, tuple)): # For each batch for b in range(frame_window[0].shape[0]): # For each frame window in the list for p in range(len(frame_window)): # For each frame in the window for f in range(frame_window[p].shape[2]): print(frame_window[p][b, :, f, :, :].shape) # Render render_img = tensor2bgr( frame_window[p][b, :, f, :, :]).copy() landmarks = landmarks_window[p][b, f, :, :].numpy() # for point in np.round(landmarks).astype(int): for point in landmarks: cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1) cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break else: # For each batch for b in range(frame_window.shape[0]): # For each frame in the window for f in range(frame_window.shape[2]): print(frame_window[b, :, f, :, :].shape) # Render render_img = tensor2bgr(frame_window[b, :, f, :, :]).copy() landmarks = landmarks_window[b, f, :, :].numpy() # for point in np.round(landmarks).astype(int): for point in landmarks: cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1) cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break end = time.time() print('elapsed time: %f[s]' % (end - start))
def main( # General arguments exp_dir, resume_dir=None, start_epoch=None, epochs=(90,), iterations=None, resolutions=(128, 256), lr_gen=(1e-4,), lr_dis=(1e-4,), gpus=None, workers=4, batch_size=(64,), seed=None, log_freq=20, # Data arguments train_dataset='opencv_video_seq_dataset.VideoSeqDataset', val_dataset=None, numpy_transforms=None, tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'), # Training arguments optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)', pretrained=False, criterion_pixelwise='nn.L1Loss', criterion_id='vgg_loss.VGGLoss', criterion_attr='vgg_loss.VGGLoss', criterion_gan='gan_loss.GANLoss(use_lsgan=True)', generator='res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)', discriminator='discriminators_pix2pix.MultiscaleDiscriminator', rec_weight=1.0, gan_weight=0.001 ): def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode G.train(train) D.train(train) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, optimizer_G.param_groups[0]['lr'])) # For each batch in the training data for i, (img, landmarks, target) in enumerate(pbar): # Prepare input with torch.no_grad(): # For each view images and landmarks landmarks[1] = landmarks[1].to(device) for j in range(len(img)): # landmarks[j] = landmarks[j].to(device) # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(device) # Remove unnecessary pyramids for j in range(len(img)): img[j] = img[j][-ri - 1:] # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0])): context = res_landmarks_decoders[p](landmarks[1]) input.append(torch.cat((img[0][p], context), dim=1)) # Reenactment img_pred = G(input) # Fake Detection and Loss img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0])) pred_fake_pool = D([x.detach() for x in img_pred_pyd]) loss_D_fake = criterion_gan(pred_fake_pool, False) # Real Detection and Loss pred_real = D(img[1]) loss_D_real = criterion_gan(pred_real, True) loss_D_total = (loss_D_fake + loss_D_real) * 0.5 # GAN loss (Fake Passability Loss) pred_fake = D(img_pred_pyd) loss_G_GAN = criterion_gan(pred_fake, True) # Reconstruction and segmentation loss loss_pixelwise = criterion_pixelwise(img_pred, img[1][0]) loss_id = criterion_id(img_pred, img[1][0]) loss_attr = criterion_attr(img_pred, img[1][0]) loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN if train: # Update generator weights optimizer_G.zero_grad() loss_G_total.backward() optimizer_G.step() # Update discriminator weights optimizer_D.zero_grad() loss_D_total.backward() optimizer_D.step() logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, g_gan=loss_G_GAN, d_gan=loss_D_total) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images grid = img_utils.make_grid(img[0][0], img_pred, img[1][0]) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['rec'].avg ################# # Main pipeline # ################# # Validation resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions] lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen] lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis] epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs] batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size] iterations = iterations if iterations is None or isinstance(iterations, (list, tuple)) else [iterations] lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs batch_size = batch_size * len(resolutions) if len(batch_size) == 1 else batch_size if iterations is not None: iterations = iterations * len(resolutions) if len(iterations) == 1 else iterations iterations = utils.str2int(iterations) if not os.path.isdir(exp_dir): raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'') assert len(lr_gen) == len(resolutions) assert len(lr_dis) == len(resolutions) assert len(epochs) == len(resolutions) assert len(batch_size) == len(resolutions) assert iterations is None or len(iterations) == len(resolutions) # Seed utils.set_seed(seed) # Check CUDA device availability device, gpus = utils.set_device(gpus) # Initialize loggers logger = TensorBoardLogger(log_dir=exp_dir) # Initialize datasets numpy_transforms = obj_factory(numpy_transforms) if numpy_transforms is not None else [] tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_lms_pose_transforms.Compose(numpy_transforms + tensor_transforms) train_dataset = obj_factory(train_dataset, transform=img_transforms) if val_dataset is not None: val_dataset = obj_factory(val_dataset, transform=img_transforms) # Create networks G_arch = utils.get_arch(generator) D_arch = utils.get_arch(discriminator) G = obj_factory(generator).to(device) D = obj_factory(discriminator).to(device) # Resume from a checkpoint or initialize the networks weights randomly checkpoint_dir = exp_dir if resume_dir is None else resume_dir G_path = os.path.join(checkpoint_dir, 'G_latest.pth') D_path = os.path.join(checkpoint_dir, 'D_latest.pth') best_loss = 1e6 curr_res = resolutions[0] optimizer_G_state, optimizer_D_state = None, None if os.path.isfile(G_path) and os.path.isfile(D_path): print("=> loading checkpoint from '{}'".format(checkpoint_dir)) # G checkpoint = torch.load(G_path) if 'resolution' in checkpoint: curr_res = checkpoint['resolution'] start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch # else: # curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0] best_loss_key = 'best_loss_%d' % curr_res best_loss = checkpoint[best_loss_key] if best_loss_key in checkpoint else best_loss G.apply(utils.init_weights) G.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_G_state = checkpoint['optimizer'] # D D.apply(utils.init_weights) if os.path.isfile(D_path): checkpoint = torch.load(D_path) D.load_state_dict(checkpoint['state_dict'], strict=False) optimizer_D_state = checkpoint['optimizer'] else: print("=> no checkpoint found at '{}'".format(checkpoint_dir)) if not pretrained: print("=> randomly initializing networks...") G.apply(utils.init_weights) D.apply(utils.init_weights) # Initialize landmarks decoders landmarks_decoders = [] for res in resolutions: landmarks_decoders.insert(0, landmarks_utils.LandmarksHeatMapDecoder(res).to(device)) # Lossess criterion_pixelwise = obj_factory(criterion_pixelwise).to(device) criterion_id = obj_factory(criterion_id).to(device) criterion_attr = obj_factory(criterion_attr).to(device) criterion_gan = obj_factory(criterion_gan).to(device) # Support multiple GPUs if gpus and len(gpus) > 1: G = nn.DataParallel(G, gpus) D = nn.DataParallel(D, gpus) criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus) criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus) landmarks_decoders = [nn.DataParallel(ld, gpus) for ld in landmarks_decoders] # For each resolution start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0])) start_epoch = 0 if start_epoch is None else start_epoch for ri in range(start_res_ind, len(resolutions)): res = resolutions[ri] res_lr_gen = lr_gen[ri] res_lr_dis = lr_dis[ri] res_epochs = epochs[ri] res_iterations = iterations[ri] if iterations is not None else None res_batch_size = batch_size[ri] res_landmarks_decoders = landmarks_decoders[-ri - 1:] # Optimizer and scheduler optimizer_G = obj_factory(optimizer, G.parameters(), lr=res_lr_gen) optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis) scheduler_G = obj_factory(scheduler, optimizer_G) scheduler_D = obj_factory(scheduler, optimizer_D) if optimizer_G_state is not None: optimizer_G.load_state_dict(optimizer_G_state) optimizer_G_state = None if optimizer_D_state is not None: optimizer_D.load_state_dict(optimizer_D_state) optimizer_D_state = None # Initialize data loaders if res_iterations is None: train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, len(train_dataset)) else: train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, res_iterations) train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) if val_dataset is not None: if res_iterations is None: val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, len(val_dataset)) else: val_iterations = (res_iterations * len(val_dataset.classes)) // len(train_dataset.classes) val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, val_iterations) val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler, num_workers=workers, pin_memory=True, drop_last=True, shuffle=False) else: val_loader = None # For each epoch for epoch in range(start_epoch, res_epochs): total_loss = proces_epoch(train_loader, train=True) if val_loader is not None: with torch.no_grad(): total_loss = proces_epoch(val_loader, train=False) # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler_G.step(total_loss) scheduler_D.step(total_loss) else: scheduler_G.step() scheduler_D.step() # Save models checkpoints is_best = total_loss < best_loss best_loss = min(best_loss, total_loss) utils.save_checkpoint(exp_dir, 'G', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': G.module.state_dict() if gpus and len(gpus) > 1 else G.state_dict(), 'optimizer': optimizer_G.state_dict(), 'best_loss_%d' % res: best_loss, 'arch': G_arch, }, is_best) utils.save_checkpoint(exp_dir, 'D', { 'resolution': res, 'epoch': epoch + 1, 'state_dict': D.module.state_dict() if gpus and len(gpus) > 1 else D.state_dict(), 'optimizer': optimizer_D.state_dict(), 'best_loss_%d' % res: best_loss, 'arch': D_arch, }, is_best) # Reset start epoch to 0 because it's should only effect the first training resolution start_epoch = 0 best_loss = 1e6
def main(dataset='fsgan.datasets.seq_dataset.SeqDataset', np_transforms=None, tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'), workers=4, batch_size=4): import time import fsgan from fsgan.utils.obj_factory import obj_factory from fsgan.utils.img_utils import tensor2bgr np_transforms = obj_factory(np_transforms) if np_transforms is not None else [] tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else [] img_transforms = img_lms_pose_transforms.Compose(np_transforms + tensor_transforms) dataset = obj_factory(dataset, transform=img_transforms) # dataset = VideoSeqDataset(root_path, img_list_path, transform=img_transforms, frame_window=frame_window) dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, drop_last=True, shuffle=True) start = time.time() if isinstance(dataset, fsgan.datasets.seq_dataset.SeqPairDataset): for frame, landmarks, pose, target in dataloader: pass elif isinstance(dataset, fsgan.datasets.seq_dataset.SeqDataset): for frame, landmarks, pose in dataloader: # For each batch for b in range(frame.shape[0]): # Render render_img = tensor2bgr(frame[b]).copy() curr_landmarks = landmarks[b].numpy() * render_img.shape[0] curr_pose = pose[b].numpy() * 99. for point in curr_landmarks: cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1) msg = 'Pose: %.1f, %.1f, %.1f' % (curr_pose[0], curr_pose[1], curr_pose[2]) cv2.putText(render_img, msg, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA) cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break # print(frame_window.shape) # if isinstance(frame_window, (list, tuple)): # # For each batch # for b in range(frame_window[0].shape[0]): # # For each frame window in the list # for p in range(len(frame_window)): # # For each frame in the window # for f in range(frame_window[p].shape[2]): # print(frame_window[p][b, :, f, :, :].shape) # # Render # render_img = tensor2bgr(frame_window[p][b, :, f, :, :]).copy() # landmarks = landmarks_window[p][b, f, :, :].numpy() # # for point in np.round(landmarks).astype(int): # for point in landmarks: # cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1) # cv2.imshow('render_img', render_img) # if cv2.waitKey(0) & 0xFF == ord('q'): # break # else: # # For each batch # for b in range(frame_window.shape[0]): # # For each frame in the window # for f in range(frame_window.shape[2]): # print(frame_window[b, :, f, :, :].shape) # # Render # render_img = tensor2bgr(frame_window[b, :, f, :, :]).copy() # landmarks = landmarks_window[b, f, :, :].numpy() # # for point in np.round(landmarks).astype(int): # for point in landmarks: # cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1) # cv2.imshow('render_img', render_img) # if cv2.waitKey(0) & 0xFF == ord('q'): # break end = time.time() print('elapsed time: %f[s]' % (end - start))