def compute_flow_and_conf(self, im1, im2): assert (im1.size()[1] == 3) assert (im1.size() == im2.size()) old_h, old_w = im1.size()[2], im1.size()[3] new_h, new_w = old_h // 64 * 64, old_w // 64 * 64 if old_h != new_h: im1 = F.interpolate(im1, size=(new_h, new_w), mode='bilinear', align_corners=False) im2 = F.interpolate(im2, size=(new_h, new_w), mode='bilinear', align_corners=False) data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2) with torch.no_grad(): flow1 = self.flowNet(data1) # img_diff = torch.sum(abs(im1 - resample(im2, flow1)), # dim=1, keepdim=True) # conf = torch.clamp(1 - img_diff, 0, 1) conf = (self.norm(im1 - resample(im2, flow1)) < 0.02).float() # data2 = torch.cat([im2.unsqueeze(2), im1.unsqueeze(2)], dim=2) # with torch.no_grad(): # flow2 = self.flowNet(data2) # warped_flow2 = resample(flow2, flow1) # flow_sum = self.norm(flow1 + warped_flow2) # disocc = flow_sum > (0.05 * (self.norm(flow1) + # self.norm(warped_flow2)) + 0.5) # conf = 1 - disocc.float() if old_h != new_h: flow1 = F.interpolate(flow1, size=(old_h, old_w), mode='bilinear', align_corners=False) * old_h / new_h conf = F.interpolate(conf, size=(old_h, old_w), mode='bilinear', align_corners=False) return flow1, conf
def save_image(self, path, data): r"""Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output['fake_raw_images'] is None and will not be displayed. In model average mode, we will plot the flow visualization twice. Args: path (str): Save path. data (dict): Training data for current iteration. """ self.net_G.eval() if self.cfg.trainer.model_average: self.net_G.module.averaged_model.eval() self.net_G_output = None with torch.no_grad(): first_net_G_output, net_G_output, all_info = self.gen_frames(data) if self.cfg.trainer.model_average: first_net_G_output_avg, net_G_output_avg = self.gen_frames( data, use_model_average=True) # Visualize labels. label_lengths = self.train_data_loader.dataset.get_label_lengths() labels = split_labels(data['label'], label_lengths) vis_labels_start, vis_labels_end = [], [] for key, value in labels.items(): if 'seg_maps' in key: vis_labels_start.append(self.visualize_label(value[:, -1])) vis_labels_end.append(self.visualize_label(value[:, 0])) else: normalize = self.train_data_loader.dataset.normalize[key] vis_labels_start.append( tensor2im(value[:, -1], normalize=normalize)) vis_labels_end.append( tensor2im(value[:, 0], normalize=normalize)) if is_master(): vis_images = [ *vis_labels_start, tensor2im(data['images'][:, -1]), tensor2im(net_G_output['fake_images']), tensor2im(net_G_output['fake_raw_images']) ] if self.cfg.trainer.model_average: vis_images += [ tensor2im(net_G_output_avg['fake_images']), tensor2im(net_G_output_avg['fake_raw_images']) ] if self.sequence_length > 1: if net_G_output['guidance_images_and_masks'] is not None: guidance_image = tensor2im( net_G_output['guidance_images_and_masks'][:, :3]) guidance_mask = tensor2im( net_G_output['guidance_images_and_masks'][:, 3:4], normalize=False) else: im = tensor2im(data['images'][:, -1]) guidance_image = [np.zeros_like(item) for item in im] guidance_mask = [np.zeros_like(item) for item in im] vis_images += [guidance_image, guidance_mask] vis_images_first = [ *vis_labels_end, tensor2im(data['images'][:, 0]), tensor2im(first_net_G_output['fake_images']), tensor2im(first_net_G_output['fake_raw_images']), [np.zeros_like(item) for item in guidance_image], [np.zeros_like(item) for item in guidance_mask] ] if self.cfg.trainer.model_average: vis_images_first += [ tensor2im(first_net_G_output_avg['fake_images']), tensor2im(first_net_G_output_avg['fake_raw_images']) ] if self.use_flow: flow_gt, conf_gt = self.criteria['Flow'].flowNet( data['images'][:, -1], data['images'][:, -2]) warped_image_gt = resample(data['images'][:, -1], flow_gt) vis_images_first += [ tensor2flow(flow_gt), tensor2im(conf_gt, normalize=False), tensor2im(warped_image_gt), ] vis_images += [ tensor2flow(net_G_output['fake_flow_maps']), tensor2im(net_G_output['fake_occlusion_masks'], normalize=False), tensor2im(net_G_output['warped_images']), ] if self.cfg.trainer.model_average: vis_images_first += [ tensor2flow(flow_gt), tensor2im(conf_gt, normalize=False), tensor2im(warped_image_gt), ] vis_images += [ tensor2flow(net_G_output_avg['fake_flow_maps']), tensor2im(net_G_output_avg['fake_occlusion_masks'], normalize=False), tensor2im(net_G_output_avg['warped_images']) ] vis_images = [[ np.vstack((im_first, im)) for im_first, im in zip(imgs_first, imgs) ] for imgs_first, imgs in zip(vis_images_first, vis_images) if imgs is not None] image_grid = np.hstack( [np.vstack(im) for im in vis_images if im is not None]) print('Save output images to {}'.format(path)) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.imwrite(path, image_grid) # Gather all inputs and outputs for dumping into video. if self.sequence_length > 1: input_images, output_images, output_guidance = [], [], [] for item in all_info['inputs']: input_images.append(tensor2im(item['image'])[0]) for item in all_info['outputs']: output_images.append(tensor2im(item['fake_images'])[0]) if item['guidance_images_and_masks'] is not None: output_guidance.append( tensor2im( item['guidance_images_and_masks'][:, :3])[0]) else: output_guidance.append(np.zeros_like( output_images[-1])) imageio.mimwrite(os.path.splitext(path)[0] + '.mp4', output_images, fps=2, macro_block_size=None) imageio.mimwrite(os.path.splitext(path)[0] + '_guidance.mp4', output_guidance, fps=2, macro_block_size=None) # for idx, item in enumerate(output_guidance): # imageio.imwrite(os.path.splitext( # path)[0] + '_guidance_%d.jpg' % (idx), item) # for idx, item in enumerate(input_images): # imageio.imwrite(os.path.splitext( # path)[0] + '_input_%d.jpg' % (idx), item) self.net_G.float()
def forward(self, data): r"""vid2vid generator forward. Args: data (dict) : Dictionary of input data. Returns: output (dict) : Dictionary of output data. """ label = data['label'] label_prev, img_prev = data['prev_labels'], data['prev_images'] is_first_frame = img_prev is None z = getattr(data, 'z', None) bs, _, h, w = label.size() if self.is_pose_data: label, label_prev = extract_valid_pose_labels( [label, label_prev], self.pose_type, self.remove_face_labels) # Get SPADE conditional maps by embedding current label input. cond_maps_now = self.get_cond_maps(label, self.label_embedding) # Input to the generator will either be noise/segmentation map (for # first frame) or encoded previous frame (for subsequent frames). if is_first_frame: # First frame in the sequence, start from scratch. if self.use_segmap_as_input: x_img = F.interpolate(label, size=(self.sh, self.sw)) x_img = self.fc(x_img) else: if z is None: z = torch.randn(bs, self.z_dim, dtype=label.dtype, device=label.get_device()).fill_(0) x_img = self.fc(z).view(bs, -1, self.sh, self.sw) # Upsampling layers. for i in range(self.num_layers, self.num_downsamples_img, -1): j = min(self.num_downsamples_embed, i) x_img = getattr(self, 'up_' + str(i))(x_img, *cond_maps_now[j]) x_img = self.upsample(x_img) else: # Not the first frame, will encode the previous frame and feed to # the generator. x_img = self.down_first(img_prev[:, -1]) # Get label embedding for the previous frame. cond_maps_prev = self.get_cond_maps(label_prev[:, -1], self.label_embedding) # Downsampling layers. for i in range(self.num_downsamples_img + 1): j = min(self.num_downsamples_embed, i) x_img = getattr(self, 'down_' + str(i))(x_img, *cond_maps_prev[j]) if i != self.num_downsamples_img: x_img = self.downsample(x_img) # Resnet blocks. j = min(self.num_downsamples_embed, self.num_downsamples_img + 1) for i in range(self.num_res_blocks): cond_maps = cond_maps_prev[j] if i < self.num_res_blocks // 2 \ else cond_maps_now[j] x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps) flow = mask = img_warp = None num_frames_G = self.num_frames_G # Whether to warp the previous frame or not. warp_prev = self.temporal_initialized and not is_first_frame and \ label_prev.shape[1] == num_frames_G - 1 if warp_prev: # Estimate flow & mask. label_concat = torch.cat([label_prev.view(bs, -1, h, w), label], dim=1) img_prev_concat = img_prev.view(bs, -1, h, w) flow, mask = self.flow_network_temp(label_concat, img_prev_concat) img_warp = resample(img_prev[:, -1], flow) if self.spade_combine: # if using SPADE combine, integrate the warped image (and # occlusion mask) into conditional inputs for SPADE. img_embed = torch.cat([img_warp, mask], dim=1) cond_maps_img = self.get_cond_maps(img_embed, self.img_prev_embedding) x_raw_img = None # Main image generation branch. for i in range(self.num_downsamples_img, -1, -1): # Get SPADE conditional inputs. j = min(i, self.num_downsamples_embed) cond_maps = cond_maps_now[j] # For raw output generation. if self.generate_raw_output: if i >= self.num_multi_spade_layers - 1: x_raw_img = x_img if i < self.num_multi_spade_layers: x_raw_img = self.one_up_conv_layer(x_raw_img, cond_maps, i) # For final output. if warp_prev and i < self.num_multi_spade_layers: cond_maps += cond_maps_img[j] x_img = self.one_up_conv_layer(x_img, cond_maps, i) # Final conv layer. img_final = torch.tanh(self.conv_img(x_img)) img_raw = None if self.spade_combine and self.generate_raw_output: img_raw = torch.tanh(self.conv_img(x_raw_img)) if warp_prev and not self.spade_combine: img_raw = img_final img_final = img_final * mask + img_warp * (1 - mask) output = dict() output['fake_images'] = img_final output['fake_flow_maps'] = flow output['fake_occlusion_masks'] = mask output['fake_raw_images'] = img_raw output['warped_images'] = img_warp return output
def forward(self, data): r"""vid2vid generator forward. Args: data (dict) : Dictionary of input data. Returns: output (dict) : Dictionary of output data. """ self._init_single_image_model() label = data['label'] unprojection = data['unprojection'] label_prev, img_prev = data['prev_labels'], data['prev_images'] is_first_frame = img_prev is None z = getattr(data, 'z', None) bs, _, h, w = label.size() # Whether to warp the previous frame or not. flow = mask = img_warp = None warp_prev = self.temporal_initialized and not is_first_frame and \ label_prev.shape[1] == self.num_frames_G - 1 # Get guidance images and masks. guidance_images_and_masks, point_info = None, None if unprojection is not None: guidance_images_and_masks, point_info = \ self.get_guidance_images_and_masks(unprojection) # Get SPADE conditional maps by embedding current label input. cond_maps_now = self.get_cond_maps(label, self.label_embedding) # Use single image model, if flow features are not available. # Guidance features are used whenever flow features are available. if self.single_image_model is not None and not warp_prev: # Get z vector for single image model. if self.single_image_model_z is None: bs = data['label'].size(0) z = torch.randn(bs, self.single_image_model.style_dims, dtype=torch.float32).cuda() if data['label'].dtype == torch.float16: z = z.half() self.single_image_model_z = z # Get output image. data['z'] = self.single_image_model_z self.single_image_model.eval() with torch.no_grad(): output = self.single_image_model.spade_generator(data) img_final = output['fake_images'].detach() fake_images_source = 'pretrained' else: # Input to the generator will either be noise/segmentation map (for # first frame) or encoded previous frame (for subsequent frames). if is_first_frame: # First frame in the sequence, start from scratch. if self.use_segmap_as_input: x_img = F.interpolate(label, size=(self.sh, self.sw)) x_img = self.fc(x_img) else: if z is None: z = torch.randn(bs, self.z_dim, dtype=label.dtype, device=label.get_device()).fill_(0) x_img = self.fc(z).view(bs, -1, self.sh, self.sw) # Upsampling layers. for i in range(self.num_layers, self.num_downsamples_img, -1): j = min(self.num_downsamples_embed, i) x_img = getattr(self, 'up_' + str(i))(x_img, *cond_maps_now[j]) x_img = self.upsample(x_img) else: # Not the first frame, will encode the previous frame and feed # to the generator. x_img = self.down_first(img_prev[:, -1]) # Get label embedding for the previous frame. cond_maps_prev = self.get_cond_maps(label_prev[:, -1], self.label_embedding) # Downsampling layers. for i in range(self.num_downsamples_img + 1): j = min(self.num_downsamples_embed, i) x_img = getattr(self, 'down_' + str(i))(x_img, *cond_maps_prev[j]) if i != self.num_downsamples_img: x_img = self.downsample(x_img) # Resnet blocks. j = min(self.num_downsamples_embed, self.num_downsamples_img + 1) for i in range(self.num_res_blocks): cond_maps = cond_maps_prev[j] if \ i < self.num_res_blocks // 2 else cond_maps_now[j] x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps) # Optical flow warped image features. if warp_prev: # Estimate flow & mask. label_concat = torch.cat( [label_prev.view(bs, -1, h, w), label], dim=1) img_prev_concat = img_prev.view(bs, -1, h, w) flow, mask = self.flow_network_temp(label_concat, img_prev_concat) img_warp = resample(img_prev[:, -1], flow) if self.spade_combine: # if using SPADE combine, integrate the warped image (and # occlusion mask) into conditional inputs for SPADE. img_embed = torch.cat([img_warp, mask], dim=1) cond_maps_img = self.get_cond_maps(img_embed, self.img_prev_embedding) x_raw_img = None # Main image generation branch. for i in range(self.num_downsamples_img, -1, -1): # Get SPADE conditional inputs. j = min(i, self.num_downsamples_embed) cond_maps = cond_maps_now[j] # For raw output generation. if self.generate_raw_output: if i >= self.num_multi_spade_layers - 1: x_raw_img = x_img if i < self.num_multi_spade_layers: x_raw_img = self.one_up_conv_layer( x_raw_img, cond_maps, i) # Add flow and guidance features. if warp_prev: if i < self.num_multi_spade_layers: # Add flow. cond_maps += cond_maps_img[j] # Add guidance. if guidance_images_and_masks is not None: cond_maps += [guidance_images_and_masks] elif not self.guidance_only_with_flow: # Add guidance if it is to be applied to every layer. if guidance_images_and_masks is not None: cond_maps += [guidance_images_and_masks] x_img = self.one_up_conv_layer(x_img, cond_maps, i) # Final conv layer. img_final = torch.tanh(self.conv_img(x_img)) fake_images_source = 'in_training' # Update the point cloud color dict of renderer. self.renderer_update_point_cloud(img_final, point_info) output = dict() output['fake_images'] = img_final output['fake_flow_maps'] = flow output['fake_occlusion_masks'] = mask output['fake_raw_images'] = None output['warped_images'] = img_warp output['guidance_images_and_masks'] = guidance_images_and_masks output['fake_images_source'] = fake_images_source return output
def compute_flow_losses(self, flow, warped_images, tgt_image, flow_gt, flow_conf_gt, fg_mask, tgt_label, ref_label): r"""Compute losses on the generated flow maps. Args: flow (tensor or list of tensors): Generated flow maps. warped_images (tensor or list of tensors): Warped images using the flow maps. tgt_image (tensor): Target image for the warped image. flow_gt (tensor or list of tensors): Ground truth flow maps. flow_conf_gt (tensor or list of tensors): Confidence for the ground truth flow maps. fg_mask (tensor): Foreground mask for the target image. tgt_label (tensor): Target label map. ref_label (tensor): Reference label map. Returns: (dict): - loss_flow_L1 (tensor): L1 loss compared to ground truth flow. - loss_flow_warp (tensor): L1 loss between the warped image and the target image when using the flow to warp. - body_mask_diff (tensor): Difference between warped body part map and target body part map. Used for pose dataset only. """ loss_flow_L1 = torch.tensor(0., device=torch.device('cuda')) loss_flow_warp = torch.tensor(0., device=torch.device('cuda')) if isinstance(flow, list): # Compute flow losses for both warping reference -> target and # previous -> target. for i in range(len(flow)): loss_flow_L1_i, loss_flow_warp_i = \ self.compute_flow_loss(flow[i], warped_images[i], tgt_image, flow_gt[i], flow_conf_gt[i], fg_mask) loss_flow_L1 += loss_flow_L1_i loss_flow_warp += loss_flow_warp_i else: # Compute loss for warping either reference or previous images. loss_flow_L1, loss_flow_warp = \ self.compute_flow_loss(flow, warped_images, tgt_image, flow_gt[-1], flow_conf_gt[-1], fg_mask) # For pose dataset only. body_mask_diff = None if self.warp_ref: if self.for_pose_dataset: # Warped reference body part map should be similar to target # body part map. body_mask = get_part_mask(tgt_label[:, 2]) ref_body_mask = get_part_mask(ref_label[:, 2]) warped_ref_body_mask = resample(ref_body_mask, flow[0]) loss_flow_warp += self.criterion(warped_ref_body_mask, body_mask) body_mask_diff = torch.sum(abs(warped_ref_body_mask - body_mask), dim=1, keepdim=True) if self.has_fg: # Warped reference foreground map should be similar to target # foreground map. fg_mask, ref_fg_mask = \ get_fg_mask([tgt_label, ref_label], True) warped_ref_fg_mask = resample(ref_fg_mask, flow[0]) loss_flow_warp += self.criterion(warped_ref_fg_mask, fg_mask) return loss_flow_L1, loss_flow_warp, body_mask_diff