def forward(self, pyd): pyd = create_pyramid(pyd, self.n_local_enhancers) if len(pyd) == 1: return self.base(pyd[-1]) x = pyd[-1] x = self.base.in_conv(x) x = self.base.inner(x) for n in range(1, len(pyd)): enhancer = getattr(self, 'enhancer%d' % n) x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x) if n == self.n_local_enhancers: x = enhancer.out_conv(x) return x
def forward(self, pyd): pyd = create_pyramid(pyd, self.n_local_enhancers) # Call global at the coarsest level if len(pyd) == 1: return self.base(pyd[-1]) x = pyd[-1] x = self.base.in_conv(x) x = self.base.inner(x) # Apply enhancer for each level for n in range(1, len(pyd)): enhancer = getattr(self, 'enhancer%d' % n) # x = enhancer(pyd[self.n_local_enhancers - n], x) x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x) if n == self.n_local_enhancers: x = enhancer.out_conv(x) return x
def forward(self, pyd): pyd = create_pyramid(pyd, self.n_local_enhancers) if len(pyd) == 1: return self.base(pyd[-1]) x = pyd[-1] x = self.base.in_conv(x) x = self.base.inner(x) for n in range(1, len(pyd)): enhancer = getattr(self, 'enhancer%d' % n) x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x) if n == self.n_local_enhancers: output = [] for i in range(len(self.out_nc)): out_conv = getattr(enhancer, 'out_conv%d' % (i + 1)) output.append(out_conv(x)) return tuple(output)
def forward(self, pyd): pyd[1] = torch.nn.functional.interpolate(pyd[0], scale_factor=0.5, mode='area') pyd = create_pyramid(pyd, self.n_local_enhancers) # Call global at the coarsest level if len(pyd) == 1: return self.base(pyd[-1]) x = pyd[-1] x = self.base.in_conv(x) x = self.base.inner(x) # Apply enhancer for each level for n in range(1, len(pyd)): enhancer = getattr(self, 'enhancer%d' % n) # x = enhancer(pyd[self.n_local_enhancers - n], x) x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x) if n == self.n_local_enhancers: x = enhancer.out_conv(x) return x
def __call__(self, source_path, target_path, output_path=None, select_source='longest', select_target='longest', finetune=None): is_vid = os.path.splitext(source_path)[1] == '.mp4' finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid # Validation assert os.path.isfile( source_path), 'Source path "%s" does not exist' % source_path assert os.path.isfile( target_path), 'Target path "%s" does not exist' % target_path # Cache input source_cache_dir, source_seq_file_path, _ = self.cache(source_path) target_cache_dir, target_seq_file_path, _ = self.cache(target_path) # Load sequences from file with open(source_seq_file_path, "rb") as fp: # Unpickling source_seq_list = pickle.load(fp) with open(target_seq_file_path, "rb") as fp: # Unpickling target_seq_list = pickle.load(fp) # Select source and target sequence source_seq = select_seq(source_seq_list, select_source) target_seq = select_seq(target_seq_list, select_target) # Set source and target sequence videos paths src_path_no_ext, src_ext = os.path.splitext(source_path) src_vid_seq_name = os.path.basename( src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext) src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name) tgt_path_no_ext, tgt_ext = os.path.splitext(target_path) tgt_vid_seq_name = os.path.basename( tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext) tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name) # Set output path if output_path is not None: if os.path.isdir(output_path): output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4' output_path = os.path.join(output_path, output_filename) # Initialize appearance map src_transform = img_lms_pose_transforms.Compose( [Rotate(), Pyramids(2), ToTensor(), Normalize()]) tgt_transform = img_lms_pose_transforms.Compose( [ToTensor(), Normalize()]) appearance_map = AppearanceMapDataset( src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform, self.landmarks_postfix, self.pose_postfix, self.segmentation_postfix, self.min_radius) appearance_map_loader = DataLoader(appearance_map, batch_size=self.batch_size, num_workers=1, pin_memory=True, drop_last=False, shuffle=False) # Initialize video writer self.video_renderer.init(target_path, target_seq, output_path, _appearance_map=appearance_map) # Finetune reenactment model on source sequences if finetune: self.finetune(src_vid_seq_path, self.finetune_save) print( f'=> Face swapping: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...' ) # For each batch of frames in the target video for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \ in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)): # Prepare input for p in range(len(src_frame)): src_frame[p] = src_frame[p].to(self.device) tgt_frame = tgt_frame.to(self.device) tgt_landmarks = tgt_landmarks.to(self.device) # tgt_mask = tgt_mask.unsqueeze(1).to(self.device) tgt_mask = tgt_mask.unsqueeze(1).int().to(self.device).bool( ) # TODO: check if the boolean tensor bug is fixed bw = bw.to(self.device) bw_indices = torch.nonzero(torch.any(bw > 0, dim=0), as_tuple=True)[0] bw = bw[:, bw_indices] # For each source frame perform reenactment reenactment_triplet = [] for j in bw_indices: input = [] for p in range(len(src_frame)): context = self.landmarks_decoders[p](tgt_landmarks) input.append( torch.cat((src_frame[p][:, j], context), dim=1)) # Reenactment reenactment_triplet.append(self.Gr(input).unsqueeze(1)) reenactment_tensor = torch.cat(reenactment_triplet, dim=1) # Barycentric interpolation of reenacted frames reenactment_tensor = (reenactment_tensor * bw.view(*bw.shape, 1, 1, 1)).sum(dim=1) # Compute reenactment segmentation reenactment_seg = self.S(reenactment_tensor) reenactment_background_mask_tensor = (reenactment_seg.argmax(1) != 1).unsqueeze(1) # Remove the background of the aligned face reenactment_tensor.masked_fill_(reenactment_background_mask_tensor, -1.0) # Soften target mask soft_tgt_mask, eroded_tgt_mask = self.smooth_mask(tgt_mask) # Complete face inpainting_input_tensor = torch.cat( (reenactment_tensor, eroded_tgt_mask.float()), dim=1) inpainting_input_tensor_pyd = create_pyramid( inpainting_input_tensor, 2) completion_tensor = self.Gc(inpainting_input_tensor_pyd) # Blend faces transfer_tensor = transfer_mask(completion_tensor, tgt_frame, eroded_tgt_mask) blend_input_tensor = torch.cat( (transfer_tensor, tgt_frame, eroded_tgt_mask.float()), dim=1) blend_input_tensor_pyd = create_pyramid(blend_input_tensor, 2) blend_tensor = self.Gb(blend_input_tensor_pyd) result_tensor = blend_tensor * soft_tgt_mask + tgt_frame * ( 1 - soft_tgt_mask) # Write output if self.verbose == 0: self.video_renderer.write(result_tensor) elif self.verbose == 1: curr_src_frames = [ src_frame[0][:, i] for i in range(src_frame[0].shape[1]) ] self.video_renderer.write(*curr_src_frames, result_tensor, tgt_frame) else: curr_src_frames = [ src_frame[0][:, i] for i in range(src_frame[0].shape[1]) ] tgt_seg_blend = blend_seg_label(tgt_frame, tgt_mask.squeeze(1), alpha=0.2) soft_tgt_mask = soft_tgt_mask.mul(2.).sub(1.).repeat( 1, 3, 1, 1) self.video_renderer.write(*curr_src_frames, result_tensor, tgt_frame, reenactment_tensor, completion_tensor, transfer_tensor, soft_tgt_mask, tgt_seg_blend, tgt_pose) # Load original reenactment weights if finetune: if self.gpus and len(self.gpus) > 1: self.Gr.module.load_state_dict(self.reenactment_state_dict) else: self.Gr.load_state_dict(self.reenactment_state_dict) # Finalize video and wait for the video writer to finish writing self.video_renderer.finalize() self.video_renderer.wait_until_finished()
def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode Gc.train(train) D.train(train) Gr.train(False) S.train(False) L.train(False) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, scheduler_G.get_lr()[0])) # For each batch in the training data for i, (img, target) in enumerate(pbar): # Prepare input with torch.no_grad(): # For each view images for j in range(len(img)): # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(device) # Compute context context = L(img[1][0].sub(context_mean).div(context_std)) context = landmarks_utils.filter_landmarks(context) # Normalize each of the pyramid images for j in range(len(img)): for p in range(len(img[j])): img[j][p].sub_(img_mean).div_(img_std) # # Compute segmentation # seg = [] # for j in range(len(img)): # curr_seg = S(img[j][0]) # if curr_seg.shape[2:] != (res, res): # curr_seg = F.interpolate(curr_seg, (res, res), mode='bicubic', align_corners=False) # seg.append(curr_seg) # Compute segmentation target_seg = S(img[1][0]) if target_seg.shape[2:] != (res, res): target_seg = F.interpolate(target_seg, (res, res), mode='bicubic', align_corners=False) # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0]) - 1, -1, -1): context = F.interpolate(context, size=img[0][p].shape[2:], mode='bicubic', align_corners=False) input.insert(0, torch.cat((img[0][p], context), dim=1)) # Reenactment reenactment_img = Gr(input) reenactment_seg = S(reenactment_img) if reenactment_img.shape[2:] != (res, res): reenactment_img = F.interpolate(reenactment_img, (res, res), mode='bilinear', align_corners=False) reenactment_seg = F.interpolate(reenactment_seg, (res, res), mode='bilinear', align_corners=False) # Remove unnecessary pyramids for j in range(len(img)): img[j] = img[j][-ri - 1:] # Source face reenactment_face_mask = reenactment_seg.argmax(1) == 1 inpainting_mask = seg_utils.random_hair_inpainting_mask_tensor( reenactment_face_mask).to(device) reenactment_face_mask = reenactment_face_mask * ( inpainting_mask == 0) reenactment_img_with_hole = reenactment_img.masked_fill( ~reenactment_face_mask.unsqueeze(1), background_value) # Target face target_face_mask = (target_seg.argmax(1) == 1).unsqueeze(1) inpainting_target = img[1][0] inpainting_target.masked_fill_(~target_face_mask, background_value) # Inpainting input inpainting_input = torch.cat( (reenactment_img_with_hole, target_face_mask.float()), dim=1) inpainting_input_pyd = img_utils.create_pyramid( inpainting_input, len(img[0])) # Face inpainting inpainting_pred = Gc(inpainting_input_pyd) # Fake Detection and Loss inpainting_pred_pyd = img_utils.create_pyramid( inpainting_pred, len(img[0])) pred_fake_pool = D([x.detach() for x in inpainting_pred_pyd]) loss_D_fake = criterion_gan(pred_fake_pool, False) # Real Detection and Loss inpainting_target_pyd = img_utils.create_pyramid( inpainting_target, len(img[0])) pred_real = D(inpainting_target_pyd) loss_D_real = criterion_gan(pred_real, True) loss_D_total = (loss_D_fake + loss_D_real) * 0.5 # GAN loss (Fake Passability Loss) pred_fake = D(inpainting_pred_pyd) loss_G_GAN = criterion_gan(pred_fake, True) # Reconstruction loss_pixelwise = criterion_pixelwise(inpainting_pred, inpainting_target) loss_id = criterion_id(inpainting_pred, inpainting_target) loss_attr = criterion_attr(inpainting_pred, inpainting_target) loss_rec = pix_weight * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN if train: # Update generator weights optimizer_G.zero_grad() loss_G_total.backward() optimizer_G.step() # Update discriminator weights optimizer_D.zero_grad() loss_D_total.backward() optimizer_D.step() logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, g_gan=loss_G_GAN, d_gan=loss_D_total) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg( '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images grid = img_utils.make_grid(img[0][0], reenactment_img, reenactment_img_with_hole, inpainting_pred, inpainting_target) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['rec'].avg
def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode G.train(train) D.train(train) S.train(False) L.train(False) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, scheduler_G.get_lr()[0])) # For each batch in the training data for i, (img, target) in enumerate(pbar): # Prepare input with torch.no_grad(): # For each view images for j in range(len(img)): # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(device) # Compute context context = L(img[1][0].sub(context_mean).div(context_std)) context = landmarks_utils.filter_context(context) # Normalize each of the pyramid images for j in range(len(img)): for p in range(len(img[j])): img[j][p].sub_(img_mean).div_(img_std) # Compute segmentation seg = S(img[1][0]) if seg.shape[2:] != (res, res): seg = F.interpolate(seg, (res, res), mode='bicubic', align_corners=False) # seg = img_utils.create_pyramid(seg, len(img[0]))[-ri - 1:] # Remove unnecessary pyramids for j in range(len(img)): img[j] = img[j][-ri - 1:] # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0]) - 1, -1, -1): context = F.interpolate(context, size=img[0][p].shape[2:], mode='bicubic', align_corners=False) input.append(torch.cat((img[0][p], context), dim=1)) input = input[::-1] # Reenactment img_pred, seg_pred = G(input) # Fake Detection and Loss img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0])) pred_fake_pool = D([x.detach() for x in img_pred_pyd]) loss_D_fake = criterion_gan(pred_fake_pool, False) # Real Detection and Loss pred_real = D(img[1]) loss_D_real = criterion_gan(pred_real, True) loss_D_total = (loss_D_fake + loss_D_real) * 0.5 # GAN loss (Fake Passability Loss) pred_fake = D(img_pred_pyd) loss_G_GAN = criterion_gan(pred_fake, True) # Reconstruction and segmentation loss loss_pixelwise = criterion_pixelwise(img_pred, img[1][0]) loss_id = criterion_id(img_pred, img[1][0]) loss_attr = criterion_attr(img_pred, img[1][0]) loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr loss_seg = criterion_pixelwise(seg_pred, seg) loss_G_total = rec_weight * loss_rec + seg_weight * loss_seg + gan_weight * loss_G_GAN if train: # Update generator weights optimizer_G.zero_grad() loss_G_total.backward() optimizer_G.step() # Update discriminator weights optimizer_D.zero_grad() loss_D_total.backward() optimizer_D.step() logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, seg=loss_seg, g_gan=loss_G_GAN, d_gan=loss_D_total) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images blend_seg_pred = seg_utils.blend_seg_pred(img[1][0], seg_pred) blend_seg = seg_utils.blend_seg_pred(img[1][0], seg) grid = img_utils.make_grid(img[0][0], img_pred, img[1][0], blend_seg_pred, blend_seg) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['rec'].avg
def main( source_path, target_path, arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)', model_path='../weights/ijbc_msrunet_256_1_2_reenactment_stepwise_v1.pth', pose_model_path='../weights/hopenet_robust_alpha1.pth', pil_transforms1=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), pil_transforms2=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)', 'landmark_transforms.Resize(256)', 'landmark_transforms.Pyramids(2)'), tensor_transforms1=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), tensor_transforms2=( 'landmark_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), output_path=None, crop_size=256, display=False): torch.set_grad_enabled(False) # Initialize models fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=True) device, gpus = utils.set_device() G = obj_factory(arch).to(device) checkpoint = torch.load(model_path) G.load_state_dict(checkpoint['state_dict']) G.train(False) # Initialize pose Gp = Hopenet().to(device) checkpoint = torch.load(pose_model_path) Gp.load_state_dict(checkpoint['state_dict']) Gp.train(False) # Initialize landmarks to heatmaps landmarks2heatmaps = [ LandmarkHeatmap(kernel_size=13, size=(256, 256)).to(device), LandmarkHeatmap(kernel_size=7, size=(128, 128)).to(device) ] # Initialize transformations pil_transforms1 = obj_factory( pil_transforms1) if pil_transforms1 is not None else [] pil_transforms2 = obj_factory( pil_transforms2) if pil_transforms2 is not None else [] tensor_transforms1 = obj_factory( tensor_transforms1) if tensor_transforms1 is not None else [] tensor_transforms2 = obj_factory( tensor_transforms2) if tensor_transforms2 is not None else [] img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 + tensor_transforms1) img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 + tensor_transforms2) # Process source image source_bgr = cv2.imread(source_path) source_rgb = source_bgr[:, :, ::-1] source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size) if source_bbox is None: raise RuntimeError("Couldn't detect a face in source image: " + source_path) source_tensor, source_landmarks, source_bbox = img_transforms1( source_rgb, source_landmarks, source_bbox) source_cropped_bgr = tensor2bgr( source_tensor[0] if isinstance(source_tensor, list) else source_tensor) for i in range(len(source_tensor)): source_tensor[i] = source_tensor[i].unsqueeze(0).to(device) # Extract landmarks, bounding boxes, euler angles, and 3D landmarks from target video frame_indices, landmarks, bboxes, eulers, landmarks_3d = \ extract_landmarks_bboxes_euler_3d_from_video(target_path, Gp, fa, device=device) if frame_indices.size == 0: raise RuntimeError('No faces were detected in the target video: ' + target_path) # Open target video file cap = cv2.VideoCapture(target_path) if not cap.isOpened(): raise RuntimeError('Failed to read target video: ' + target_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) # Initialize output video file if output_path is not None: if os.path.isdir(output_path): output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \ os.path.splitext(os.path.basename(target_path))[0] + '.mp4' output_path = os.path.join(output_path, output_filename) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out_vid = cv2.VideoWriter( output_path, fourcc, fps, (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0])) else: out_vid = None # For each frame in the target video valid_frame_ind = 0 for i in tqdm(range(total_frames)): ret, target_bgr = cap.read() if target_bgr is None: continue if i not in frame_indices: continue target_rgb = target_bgr[:, :, ::-1] target_tensor, target_landmarks, target_bbox = img_transforms2( target_rgb, landmarks_3d[valid_frame_ind], bboxes[valid_frame_ind]) target_euler = eulers[valid_frame_ind] valid_frame_ind += 1 # TODO: Calculate the number of required reenactment iterations reenactment_iterations = 2 # Generate landmarks sequence target_landmarks_sequence = [] for ri in range(1, reenactment_iterations): interp_landmarks = [] for j in range(len(source_tensor)): alpha = float(ri) / reenactment_iterations curr_interp_landmarks_np = interpolate_points( source_landmarks[j].cpu().numpy(), target_landmarks[j].cpu().numpy(), alpha=alpha) interp_landmarks.append( torch.from_numpy(curr_interp_landmarks_np)) target_landmarks_sequence.append(interp_landmarks) target_landmarks_sequence.append(target_landmarks) # Iterative reenactment out_img_tensor = source_tensor for curr_target_landmarks in target_landmarks_sequence: out_img_tensor = create_pyramid(out_img_tensor, 2) input_tensor = [] for j in range(len(out_img_tensor)): curr_target_landmarks[j] = curr_target_landmarks[j].unsqueeze( 0).to(device) curr_target_landmarks[j] = landmarks2heatmaps[j]( curr_target_landmarks[j]) input_tensor.append( torch.cat((out_img_tensor[j], curr_target_landmarks[j]), dim=1)) out_img_tensor, out_seg_tensor = G(input_tensor) # Convert back to numpy images out_img_bgr = tensor2bgr(out_img_tensor) frame_cropped_bgr = tensor2bgr(target_tensor[0]) # Render # for point in np.round(frame_landmarks).astype(int): # cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1) render_img = np.concatenate( (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1) if out_vid is not None: out_vid.write(render_img) if out_vid is None or display: cv2.imshow('render_img', render_img) if cv2.waitKey(1) & 0xFF == ord('q'): break
def proces_epoch(dataset_loader, train=True): stage = 'TRAINING' if train else 'VALIDATION' total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch pbar = tqdm(dataset_loader, unit='batches') # Set networks training mode G.train(train) D.train(train) # Reset logger logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format( stage, res, res, epoch + 1, res_epochs, optimizer_G.param_groups[0]['lr'])) # For each batch in the training data for i, (img, landmarks, target) in enumerate(pbar): # Prepare input with torch.no_grad(): # For each view images and landmarks landmarks[1] = landmarks[1].to(device) for j in range(len(img)): # landmarks[j] = landmarks[j].to(device) # For each pyramid image: push to device for p in range(len(img[j])): img[j][p] = img[j][p].to(device) # Remove unnecessary pyramids for j in range(len(img)): img[j] = img[j][-ri - 1:] # Concatenate pyramid images with context to derive the final input input = [] for p in range(len(img[0])): context = res_landmarks_decoders[p](landmarks[1]) input.append(torch.cat((img[0][p], context), dim=1)) # Reenactment img_pred = G(input) # Fake Detection and Loss img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0])) pred_fake_pool = D([x.detach() for x in img_pred_pyd]) loss_D_fake = criterion_gan(pred_fake_pool, False) # Real Detection and Loss pred_real = D(img[1]) loss_D_real = criterion_gan(pred_real, True) loss_D_total = (loss_D_fake + loss_D_real) * 0.5 # GAN loss (Fake Passability Loss) pred_fake = D(img_pred_pyd) loss_G_GAN = criterion_gan(pred_fake, True) # Reconstruction and segmentation loss loss_pixelwise = criterion_pixelwise(img_pred, img[1][0]) loss_id = criterion_id(img_pred, img[1][0]) loss_attr = criterion_attr(img_pred, img[1][0]) loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN if train: # Update generator weights optimizer_G.zero_grad() loss_G_total.backward() optimizer_G.step() # Update discriminator weights optimizer_D.zero_grad() loss_D_total.backward() optimizer_D.step() logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, g_gan=loss_G_GAN, d_gan=loss_D_total) total_iter += dataset_loader.batch_size # Batch logs pbar.set_description(str(logger)) if train and i % log_freq == 0: logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter) # Epoch logs logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch) if not train: # Log images grid = img_utils.make_grid(img[0][0], img_pred, img[1][0]) logger.log_image('%dx%d/vis' % (res, res), grid, epoch) return logger.log_dict['losses']['rec'].avg