Example #1
0
    def compute_flow_and_conf(self, im1, im2):
        assert (im1.size()[1] == 3)
        assert (im1.size() == im2.size())
        old_h, old_w = im1.size()[2], im1.size()[3]
        new_h, new_w = old_h // 64 * 64, old_w // 64 * 64
        if old_h != new_h:
            im1 = F.interpolate(im1,
                                size=(new_h, new_w),
                                mode='bilinear',
                                align_corners=False)
            im2 = F.interpolate(im2,
                                size=(new_h, new_w),
                                mode='bilinear',
                                align_corners=False)
        data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2)
        with torch.no_grad():
            flow1 = self.flowNet(data1)
        # img_diff = torch.sum(abs(im1 - resample(im2, flow1)),
        #                      dim=1, keepdim=True)
        # conf = torch.clamp(1 - img_diff, 0, 1)

        conf = (self.norm(im1 - resample(im2, flow1)) < 0.02).float()

        # data2 = torch.cat([im2.unsqueeze(2), im1.unsqueeze(2)], dim=2)
        # with torch.no_grad():
        #     flow2 = self.flowNet(data2)
        # warped_flow2 = resample(flow2, flow1)
        # flow_sum = self.norm(flow1 + warped_flow2)
        # disocc = flow_sum > (0.05 * (self.norm(flow1) +
        # self.norm(warped_flow2)) + 0.5)
        # conf = 1 - disocc.float()

        if old_h != new_h:
            flow1 = F.interpolate(flow1,
                                  size=(old_h, old_w),
                                  mode='bilinear',
                                  align_corners=False) * old_h / new_h
            conf = F.interpolate(conf,
                                 size=(old_h, old_w),
                                 mode='bilinear',
                                 align_corners=False)
        return flow1, conf
Example #2
0
    def save_image(self, path, data):
        r"""Save the output images to path.
        Note when the generate_raw_output is FALSE. Then,
        first_net_G_output['fake_raw_images'] is None and will not be displayed.
        In model average mode, we will plot the flow visualization twice.

        Args:
            path (str): Save path.
            data (dict): Training data for current iteration.
        """
        self.net_G.eval()
        if self.cfg.trainer.model_average:
            self.net_G.module.averaged_model.eval()
        self.net_G_output = None
        with torch.no_grad():
            first_net_G_output, net_G_output, all_info = self.gen_frames(data)
            if self.cfg.trainer.model_average:
                first_net_G_output_avg, net_G_output_avg = self.gen_frames(
                    data, use_model_average=True)

        # Visualize labels.
        label_lengths = self.train_data_loader.dataset.get_label_lengths()
        labels = split_labels(data['label'], label_lengths)
        vis_labels_start, vis_labels_end = [], []
        for key, value in labels.items():
            if 'seg_maps' in key:
                vis_labels_start.append(self.visualize_label(value[:, -1]))
                vis_labels_end.append(self.visualize_label(value[:, 0]))
            else:
                normalize = self.train_data_loader.dataset.normalize[key]
                vis_labels_start.append(
                    tensor2im(value[:, -1], normalize=normalize))
                vis_labels_end.append(
                    tensor2im(value[:, 0], normalize=normalize))

        if is_master():
            vis_images = [
                *vis_labels_start,
                tensor2im(data['images'][:, -1]),
                tensor2im(net_G_output['fake_images']),
                tensor2im(net_G_output['fake_raw_images'])
            ]
            if self.cfg.trainer.model_average:
                vis_images += [
                    tensor2im(net_G_output_avg['fake_images']),
                    tensor2im(net_G_output_avg['fake_raw_images'])
                ]

            if self.sequence_length > 1:
                if net_G_output['guidance_images_and_masks'] is not None:
                    guidance_image = tensor2im(
                        net_G_output['guidance_images_and_masks'][:, :3])
                    guidance_mask = tensor2im(
                        net_G_output['guidance_images_and_masks'][:, 3:4],
                        normalize=False)
                else:
                    im = tensor2im(data['images'][:, -1])
                    guidance_image = [np.zeros_like(item) for item in im]
                    guidance_mask = [np.zeros_like(item) for item in im]
                vis_images += [guidance_image, guidance_mask]

                vis_images_first = [
                    *vis_labels_end,
                    tensor2im(data['images'][:, 0]),
                    tensor2im(first_net_G_output['fake_images']),
                    tensor2im(first_net_G_output['fake_raw_images']),
                    [np.zeros_like(item) for item in guidance_image],
                    [np.zeros_like(item) for item in guidance_mask]
                ]
                if self.cfg.trainer.model_average:
                    vis_images_first += [
                        tensor2im(first_net_G_output_avg['fake_images']),
                        tensor2im(first_net_G_output_avg['fake_raw_images'])
                    ]

                if self.use_flow:
                    flow_gt, conf_gt = self.criteria['Flow'].flowNet(
                        data['images'][:, -1], data['images'][:, -2])
                    warped_image_gt = resample(data['images'][:, -1], flow_gt)
                    vis_images_first += [
                        tensor2flow(flow_gt),
                        tensor2im(conf_gt, normalize=False),
                        tensor2im(warped_image_gt),
                    ]
                    vis_images += [
                        tensor2flow(net_G_output['fake_flow_maps']),
                        tensor2im(net_G_output['fake_occlusion_masks'],
                                  normalize=False),
                        tensor2im(net_G_output['warped_images']),
                    ]
                    if self.cfg.trainer.model_average:
                        vis_images_first += [
                            tensor2flow(flow_gt),
                            tensor2im(conf_gt, normalize=False),
                            tensor2im(warped_image_gt),
                        ]
                        vis_images += [
                            tensor2flow(net_G_output_avg['fake_flow_maps']),
                            tensor2im(net_G_output_avg['fake_occlusion_masks'],
                                      normalize=False),
                            tensor2im(net_G_output_avg['warped_images'])
                        ]

                vis_images = [[
                    np.vstack((im_first, im))
                    for im_first, im in zip(imgs_first, imgs)
                ] for imgs_first, imgs in zip(vis_images_first, vis_images)
                              if imgs is not None]

            image_grid = np.hstack(
                [np.vstack(im) for im in vis_images if im is not None])

            print('Save output images to {}'.format(path))
            os.makedirs(os.path.dirname(path), exist_ok=True)
            imageio.imwrite(path, image_grid)

            # Gather all inputs and outputs for dumping into video.
            if self.sequence_length > 1:
                input_images, output_images, output_guidance = [], [], []
                for item in all_info['inputs']:
                    input_images.append(tensor2im(item['image'])[0])
                for item in all_info['outputs']:
                    output_images.append(tensor2im(item['fake_images'])[0])
                    if item['guidance_images_and_masks'] is not None:
                        output_guidance.append(
                            tensor2im(
                                item['guidance_images_and_masks'][:, :3])[0])
                    else:
                        output_guidance.append(np.zeros_like(
                            output_images[-1]))

                imageio.mimwrite(os.path.splitext(path)[0] + '.mp4',
                                 output_images,
                                 fps=2,
                                 macro_block_size=None)
                imageio.mimwrite(os.path.splitext(path)[0] + '_guidance.mp4',
                                 output_guidance,
                                 fps=2,
                                 macro_block_size=None)

            # for idx, item in enumerate(output_guidance):
            #     imageio.imwrite(os.path.splitext(
            #         path)[0] + '_guidance_%d.jpg' % (idx), item)
            # for idx, item in enumerate(input_images):
            #     imageio.imwrite(os.path.splitext(
            #         path)[0] + '_input_%d.jpg' % (idx), item)

        self.net_G.float()
Example #3
0
    def forward(self, data):
        r"""vid2vid generator forward.

        Args:
           data (dict) : Dictionary of input data.
        Returns:
           output (dict) : Dictionary of output data.
        """
        label = data['label']
        label_prev, img_prev = data['prev_labels'], data['prev_images']
        is_first_frame = img_prev is None
        z = getattr(data, 'z', None)
        bs, _, h, w = label.size()

        if self.is_pose_data:
            label, label_prev = extract_valid_pose_labels(
                [label, label_prev], self.pose_type, self.remove_face_labels)

        # Get SPADE conditional maps by embedding current label input.
        cond_maps_now = self.get_cond_maps(label, self.label_embedding)

        # Input to the generator will either be noise/segmentation map (for
        # first frame) or encoded previous frame (for subsequent frames).
        if is_first_frame:
            # First frame in the sequence, start from scratch.
            if self.use_segmap_as_input:
                x_img = F.interpolate(label, size=(self.sh, self.sw))
                x_img = self.fc(x_img)
            else:
                if z is None:
                    z = torch.randn(bs,
                                    self.z_dim,
                                    dtype=label.dtype,
                                    device=label.get_device()).fill_(0)
                x_img = self.fc(z).view(bs, -1, self.sh, self.sw)

            # Upsampling layers.
            for i in range(self.num_layers, self.num_downsamples_img, -1):
                j = min(self.num_downsamples_embed, i)
                x_img = getattr(self, 'up_' + str(i))(x_img, *cond_maps_now[j])
                x_img = self.upsample(x_img)
        else:
            # Not the first frame, will encode the previous frame and feed to
            # the generator.
            x_img = self.down_first(img_prev[:, -1])

            # Get label embedding for the previous frame.
            cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
                                                self.label_embedding)

            # Downsampling layers.
            for i in range(self.num_downsamples_img + 1):
                j = min(self.num_downsamples_embed, i)
                x_img = getattr(self, 'down_' + str(i))(x_img,
                                                        *cond_maps_prev[j])
                if i != self.num_downsamples_img:
                    x_img = self.downsample(x_img)

            # Resnet blocks.
            j = min(self.num_downsamples_embed, self.num_downsamples_img + 1)
            for i in range(self.num_res_blocks):
                cond_maps = cond_maps_prev[j] if i < self.num_res_blocks // 2 \
                    else cond_maps_now[j]
                x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)

        flow = mask = img_warp = None

        num_frames_G = self.num_frames_G
        # Whether to warp the previous frame or not.
        warp_prev = self.temporal_initialized and not is_first_frame and \
            label_prev.shape[1] == num_frames_G - 1
        if warp_prev:
            # Estimate flow & mask.
            label_concat = torch.cat([label_prev.view(bs, -1, h, w), label],
                                     dim=1)
            img_prev_concat = img_prev.view(bs, -1, h, w)
            flow, mask = self.flow_network_temp(label_concat, img_prev_concat)
            img_warp = resample(img_prev[:, -1], flow)
            if self.spade_combine:
                # if using SPADE combine, integrate the warped image (and
                # occlusion mask) into conditional inputs for SPADE.
                img_embed = torch.cat([img_warp, mask], dim=1)
                cond_maps_img = self.get_cond_maps(img_embed,
                                                   self.img_prev_embedding)
                x_raw_img = None

        # Main image generation branch.
        for i in range(self.num_downsamples_img, -1, -1):
            # Get SPADE conditional inputs.
            j = min(i, self.num_downsamples_embed)
            cond_maps = cond_maps_now[j]

            # For raw output generation.
            if self.generate_raw_output:
                if i >= self.num_multi_spade_layers - 1:
                    x_raw_img = x_img
                if i < self.num_multi_spade_layers:
                    x_raw_img = self.one_up_conv_layer(x_raw_img, cond_maps, i)

            # For final output.
            if warp_prev and i < self.num_multi_spade_layers:
                cond_maps += cond_maps_img[j]
            x_img = self.one_up_conv_layer(x_img, cond_maps, i)

        # Final conv layer.
        img_final = torch.tanh(self.conv_img(x_img))

        img_raw = None
        if self.spade_combine and self.generate_raw_output:
            img_raw = torch.tanh(self.conv_img(x_raw_img))
        if warp_prev and not self.spade_combine:
            img_raw = img_final
            img_final = img_final * mask + img_warp * (1 - mask)

        output = dict()
        output['fake_images'] = img_final
        output['fake_flow_maps'] = flow
        output['fake_occlusion_masks'] = mask
        output['fake_raw_images'] = img_raw
        output['warped_images'] = img_warp
        return output
Example #4
0
    def forward(self, data):
        r"""vid2vid generator forward.
        Args:
           data (dict) : Dictionary of input data.
        Returns:
           output (dict) : Dictionary of output data.
        """
        self._init_single_image_model()

        label = data['label']
        unprojection = data['unprojection']
        label_prev, img_prev = data['prev_labels'], data['prev_images']
        is_first_frame = img_prev is None
        z = getattr(data, 'z', None)
        bs, _, h, w = label.size()

        # Whether to warp the previous frame or not.
        flow = mask = img_warp = None
        warp_prev = self.temporal_initialized and not is_first_frame and \
            label_prev.shape[1] == self.num_frames_G - 1

        # Get guidance images and masks.
        guidance_images_and_masks, point_info = None, None
        if unprojection is not None:
            guidance_images_and_masks, point_info = \
                self.get_guidance_images_and_masks(unprojection)

        # Get SPADE conditional maps by embedding current label input.
        cond_maps_now = self.get_cond_maps(label, self.label_embedding)

        # Use single image model, if flow features are not available.
        # Guidance features are used whenever flow features are available.
        if self.single_image_model is not None and not warp_prev:
            # Get z vector for single image model.
            if self.single_image_model_z is None:
                bs = data['label'].size(0)
                z = torch.randn(bs,
                                self.single_image_model.style_dims,
                                dtype=torch.float32).cuda()
                if data['label'].dtype == torch.float16:
                    z = z.half()
                self.single_image_model_z = z

            # Get output image.
            data['z'] = self.single_image_model_z
            self.single_image_model.eval()
            with torch.no_grad():
                output = self.single_image_model.spade_generator(data)
            img_final = output['fake_images'].detach()
            fake_images_source = 'pretrained'
        else:
            # Input to the generator will either be noise/segmentation map (for
            # first frame) or encoded previous frame (for subsequent frames).
            if is_first_frame:
                # First frame in the sequence, start from scratch.
                if self.use_segmap_as_input:
                    x_img = F.interpolate(label, size=(self.sh, self.sw))
                    x_img = self.fc(x_img)
                else:
                    if z is None:
                        z = torch.randn(bs,
                                        self.z_dim,
                                        dtype=label.dtype,
                                        device=label.get_device()).fill_(0)
                    x_img = self.fc(z).view(bs, -1, self.sh, self.sw)

                # Upsampling layers.
                for i in range(self.num_layers, self.num_downsamples_img, -1):
                    j = min(self.num_downsamples_embed, i)
                    x_img = getattr(self, 'up_' + str(i))(x_img,
                                                          *cond_maps_now[j])
                    x_img = self.upsample(x_img)
            else:
                # Not the first frame, will encode the previous frame and feed
                # to the generator.
                x_img = self.down_first(img_prev[:, -1])

                # Get label embedding for the previous frame.
                cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
                                                    self.label_embedding)

                # Downsampling layers.
                for i in range(self.num_downsamples_img + 1):
                    j = min(self.num_downsamples_embed, i)
                    x_img = getattr(self, 'down_' + str(i))(x_img,
                                                            *cond_maps_prev[j])
                    if i != self.num_downsamples_img:
                        x_img = self.downsample(x_img)

                # Resnet blocks.
                j = min(self.num_downsamples_embed,
                        self.num_downsamples_img + 1)
                for i in range(self.num_res_blocks):
                    cond_maps = cond_maps_prev[j] if \
                        i < self.num_res_blocks // 2 else cond_maps_now[j]
                    x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)

            # Optical flow warped image features.
            if warp_prev:
                # Estimate flow & mask.
                label_concat = torch.cat(
                    [label_prev.view(bs, -1, h, w), label], dim=1)
                img_prev_concat = img_prev.view(bs, -1, h, w)
                flow, mask = self.flow_network_temp(label_concat,
                                                    img_prev_concat)
                img_warp = resample(img_prev[:, -1], flow)
                if self.spade_combine:
                    # if using SPADE combine, integrate the warped image (and
                    # occlusion mask) into conditional inputs for SPADE.
                    img_embed = torch.cat([img_warp, mask], dim=1)
                    cond_maps_img = self.get_cond_maps(img_embed,
                                                       self.img_prev_embedding)
                    x_raw_img = None

            # Main image generation branch.
            for i in range(self.num_downsamples_img, -1, -1):
                # Get SPADE conditional inputs.
                j = min(i, self.num_downsamples_embed)
                cond_maps = cond_maps_now[j]

                # For raw output generation.
                if self.generate_raw_output:
                    if i >= self.num_multi_spade_layers - 1:
                        x_raw_img = x_img
                    if i < self.num_multi_spade_layers:
                        x_raw_img = self.one_up_conv_layer(
                            x_raw_img, cond_maps, i)

                # Add flow and guidance features.
                if warp_prev:
                    if i < self.num_multi_spade_layers:
                        # Add flow.
                        cond_maps += cond_maps_img[j]
                        # Add guidance.
                        if guidance_images_and_masks is not None:
                            cond_maps += [guidance_images_and_masks]
                    elif not self.guidance_only_with_flow:
                        # Add guidance if it is to be applied to every layer.
                        if guidance_images_and_masks is not None:
                            cond_maps += [guidance_images_and_masks]

                x_img = self.one_up_conv_layer(x_img, cond_maps, i)

            # Final conv layer.
            img_final = torch.tanh(self.conv_img(x_img))
            fake_images_source = 'in_training'

        # Update the point cloud color dict of renderer.
        self.renderer_update_point_cloud(img_final, point_info)

        output = dict()
        output['fake_images'] = img_final
        output['fake_flow_maps'] = flow
        output['fake_occlusion_masks'] = mask
        output['fake_raw_images'] = None
        output['warped_images'] = img_warp
        output['guidance_images_and_masks'] = guidance_images_and_masks
        output['fake_images_source'] = fake_images_source
        return output
Example #5
0
    def compute_flow_losses(self, flow, warped_images, tgt_image, flow_gt,
                            flow_conf_gt, fg_mask, tgt_label, ref_label):
        r"""Compute losses on the generated flow maps.

        Args:
            flow (tensor or list of tensors): Generated flow maps.
                warped_images (tensor or list of tensors): Warped images using the
                flow maps.
            tgt_image (tensor): Target image for the warped image.
                flow_gt (tensor or list of tensors): Ground truth flow maps.
            flow_conf_gt (tensor or list of tensors): Confidence for the ground
                truth flow maps.
            fg_mask (tensor): Foreground mask for the target image.
            tgt_label (tensor): Target label map.
            ref_label (tensor): Reference label map.
        Returns:
            (dict):
              - loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
              - loss_flow_warp (tensor): L1 loss between the warped image and the
                target image when using the flow to warp.
              - body_mask_diff (tensor): Difference between warped body part map
                and target body part map. Used for pose dataset only.
        """
        loss_flow_L1 = torch.tensor(0., device=torch.device('cuda'))
        loss_flow_warp = torch.tensor(0., device=torch.device('cuda'))
        if isinstance(flow, list):
            # Compute flow losses for both warping reference -> target and
            # previous -> target.
            for i in range(len(flow)):
                loss_flow_L1_i, loss_flow_warp_i = \
                    self.compute_flow_loss(flow[i], warped_images[i], tgt_image,
                                           flow_gt[i], flow_conf_gt[i], fg_mask)
                loss_flow_L1 += loss_flow_L1_i
                loss_flow_warp += loss_flow_warp_i
        else:
            # Compute loss for warping either reference or previous images.
            loss_flow_L1, loss_flow_warp = \
                self.compute_flow_loss(flow, warped_images, tgt_image,
                                       flow_gt[-1], flow_conf_gt[-1], fg_mask)

        # For pose dataset only.
        body_mask_diff = None
        if self.warp_ref:
            if self.for_pose_dataset:
                # Warped reference body part map should be similar to target
                # body part map.
                body_mask = get_part_mask(tgt_label[:, 2])
                ref_body_mask = get_part_mask(ref_label[:, 2])
                warped_ref_body_mask = resample(ref_body_mask, flow[0])
                loss_flow_warp += self.criterion(warped_ref_body_mask,
                                                 body_mask)
                body_mask_diff = torch.sum(abs(warped_ref_body_mask -
                                               body_mask),
                                           dim=1,
                                           keepdim=True)

            if self.has_fg:
                # Warped reference foreground map should be similar to target
                # foreground map.
                fg_mask, ref_fg_mask = \
                    get_fg_mask([tgt_label, ref_label], True)
                warped_ref_fg_mask = resample(ref_fg_mask, flow[0])
                loss_flow_warp += self.criterion(warped_ref_fg_mask, fg_mask)

        return loss_flow_L1, loss_flow_warp, body_mask_diff