def inference(self, batched_inputs): images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) if self.proposal_generator: proposals, _ = self.proposal_generator(images, features) else: raise NotImplementedError detector_results, pan_detector_results = self.roi_heads( images, features, proposals) sem_seg_results, _ = self.sem_seg_head(features) pan_seg_results, _ = self.panoptic_head(None, sem_seg_results, pan_detector_results) processed_results = [] for sem_seg_result, detector_result, pan_seg_result, input_per_image, image_size in zip( sem_seg_results, detector_results, pan_seg_results, batched_inputs, images.image_sizes): processed_result = {} height = input_per_image.get("height") width = input_per_image.get("width") sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) detector_r = detector_postprocess(detector_result, height, width) processed_result.update({ "sem_seg": sem_seg_r, "instances": detector_r }) if self.combine_on: panoptic_r = combine_semantic_and_instance_outputs( detector_r, sem_seg_r.argmax(dim=0), self.combine_overlap_threshold, self.combine_stuff_area_limit, self.combine_instances_confidence_threshold, ) else: pan_pred = sem_seg_postprocess(pan_seg_result["pan_logit"], image_size, height, width) del pan_seg_result["pan_logit"] pan_seg_result["pan_pred"] = pan_pred.argmax(dim=0) panoptic_r = pan_seg_postprocess(pan_seg_result, sem_seg_r.argmax(dim=0), self.stuff_num_classes, self.stuff_area_limit) processed_result.update({"panoptic_seg": panoptic_r}) processed_results.append(processed_result) return processed_results
def f(batched_inputs, c2_inputs, c2_results): image_sizes = [[int(im[0]), int(im[1])] for im in c2_inputs["im_info"]] detector_results = assemble_rcnn_outputs_by_name( image_sizes, c2_results, force_mask_on=True ) sem_seg_results = c2_results["sem_seg"] # copied from meta_arch/panoptic_fpn.py ... processed_results = [] for sem_seg_result, detector_result, input_per_image, image_size in zip( sem_seg_results, detector_results, batched_inputs, image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) detector_r = detector_postprocess(detector_result, height, width) processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r}) if combine_on: panoptic_r = combine_semantic_and_instance_outputs( detector_r, sem_seg_r.argmax(dim=0), combine_overlap_threshold, combine_stuff_area_limit, combine_instances_confidence_threshold, ) processed_results[-1]["panoptic_seg"] = panoptic_r return processed_results
def __call__( self, batched_inputs: List[Dict[str, Any]], tensor_inputs: torch.Tensor, tensor_outputs: torch.Tensor, ) -> List[Dict[str, Any]]: """ Rescales sem_seg logits to original image input resolution, and packages the logits into D2Go's expected output format. Args: inputs (List[Dict[str, Tensor]]): batched inputs from the dataloader. tensor_inputs (Tensor): tensorized inputs, e.g. from `PreprocessFunc`. tensor_outputs (Tensor): sem seg logits tensor from the model to process. Returns: processed_results (List[Dict]): List of D2Go output dicts ready to be used downstream in an Evaluator, for export, etc. """ results = tensor_outputs # nchw processed_results = [] for result, input_per_image in zip(results, batched_inputs): height = input_per_image.get("height") width = input_per_image.get("width") image_tensor_shape = input_per_image["image"].shape image_size = (image_tensor_shape[1], image_tensor_shape[2]) # D2's sem_seg_postprocess rescales sem seg masks to the # provided original input resolution. r = sem_seg_postprocess(result, image_size, height, width) processed_results.append({"sem_seg": r}) return processed_results
def __call__(self, inputs, tensor_inputs, tensor_outputs): results = tensor_outputs # nchw processed_results = [] for result, input_per_image in zip(results, inputs): height = input_per_image.get("height") width = input_per_image.get("width") image_tensor_shape = input_per_image["image"].shape image_size = (image_tensor_shape[1], image_tensor_shape[2]) r = sem_seg_postprocess(result, image_size, height, width) processed_results.append({"sem_seg": r}) return processed_results
def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper` . Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: image: Tensor, image in (C, H, W) format. sem_seg: semantic segmentation ground truth Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: Each dict is the output for one input image. The dict contains one key "sem_seg" whose value is a Tensor of the output resolution that represents the per-pixel segmentation prediction. """ images = [x["image"].to(self.device) for x in batched_inputs] images = [self.normalizer(x) for x in images] images = ImageList.from_tensors(images, self.backbone.size_divisibility) size = images.tensor.size()[-2:] features = self.backbone(images.tensor) if "sem_seg" in batched_inputs[0]: targets = [x["sem_seg"].to(self.device) for x in batched_inputs] targets = ImageList.from_tensors( targets, self.backbone.size_divisibility, self.sem_seg_head.ignore_value).tensor else: targets = None results, losses = self.sem_seg_head(features, size, targets) if self.training: return losses processed_results = [] for result, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): height = input_per_image.get("height") width = input_per_image.get("width") r = sem_seg_postprocess(result, image_size, height, width) processed_results.append({"sem_seg": r}) return processed_results
def forward(self, batched_inputs): # complete image images = [x["image"].to(self.device) for x in batched_inputs] images = [self.normalizer(x) for x in images] images = ImageList.from_tensors(images, self.inpaint_net.size_divisibility) # triplet input maps: # erased regions masks = [x["mask"].to(self.device) for x in batched_inputs] masks = ImageList.from_tensors(masks, self.inpaint_net.size_divisibility) # mask the input image with masks erased_ims = images.tensor * (1. - masks.tensor) # ones map ones_ims = [ torch.ones_like(x["mask"].to(self.device)) for x in batched_inputs ] ones_ims = ImageList.from_tensors(ones_ims, self.inpaint_net.size_divisibility) # the conv layer use zero padding, this is used to indicate the image boundary # generation process input_tensor = torch.cat([erased_ims, ones_ims.tensor, masks.tensor], dim=1) coarse_inp, fine_inp, offset_flow = self.inpaint_net( input_tensor, masks.tensor) # offset_flow is used to visualize if self.training: raise NotImplementedError else: processed_results = [] inpainted_im = erased_ims * ( 1. - masks.tensor) + fine_inp * masks.tensor for result, input_per_image, image_size in zip( inpainted_im, batched_inputs, images.image_sizes): height = input_per_image.get("height") width = input_per_image.get("width") r = sem_seg_postprocess(result, image_size, height, width) # abuse semantic segmentation postprocess. it basically does some resize processed_results.append({"inpainted": r}) return processed_results
def forward(self, batched_inputs: list): images = [x["image"].to(self.device) for x in batched_inputs] images = [self.normalizer(x) for x in images] images = ImageList.from_tensors(images, size_divisibility=16) preds = self.run_model(images.tensor) if self.training: targets = [x["sem_seg"].to(self.device) for x in batched_inputs] targets = ImageList.from_tensors(targets, size_divisibility=16, pad_value=255).tensor return dict(loss_sem_seg=F.cross_entropy( preds, targets, reduction="mean", ignore_index=255)) processed_preds = [] for pred, input_per_image, image_size in zip(preds, batched_inputs, images.image_sizes): height = input_per_image.get("height") width = input_per_image.get("width") r = sem_seg_postprocess(pred, image_size, height, width) processed_preds.append({"sem_seg": r}) return processed_preds
def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "sem_seg": semantic segmentation ground truth * "center": center points heatmap ground truth * "offset": pixel offsets to center points ground truth * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model (may be different from input resolution), used in inference. Returns: list[dict]: each dict is the results for one image. The dict contains the following keys: * "panoptic_seg", "sem_seg": see documentation :doc:`/tutorials/models` for the standard output format * "instances": available if ``predict_instances is True``. see documentation :doc:`/tutorials/models` for the standard output format """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] # To avoid error in ASPP layer when input has different size. size_divisibility = ( self.size_divisibility if self.size_divisibility > 0 else self.backbone.size_divisibility ) images = ImageList.from_tensors(images, size_divisibility) features = self.backbone(images.tensor) losses = {} if "sem_seg" in batched_inputs[0]: targets = [x["sem_seg"].to(self.device) for x in batched_inputs] targets = ImageList.from_tensors( targets, size_divisibility, self.sem_seg_head.ignore_value ).tensor if "sem_seg_weights" in batched_inputs[0]: # The default D2 DatasetMapper may not contain "sem_seg_weights" # Avoid error in testing when default DatasetMapper is used. weights = [x["sem_seg_weights"].to(self.device) for x in batched_inputs] weights = ImageList.from_tensors(weights, size_divisibility).tensor else: weights = None else: targets = None weights = None sem_seg_results, sem_seg_losses = self.sem_seg_head(features, targets, weights) losses.update(sem_seg_losses) if "center" in batched_inputs[0] and "offset" in batched_inputs[0]: center_targets = [x["center"].to(self.device) for x in batched_inputs] center_targets = ImageList.from_tensors( center_targets, size_divisibility ).tensor.unsqueeze(1) center_weights = [x["center_weights"].to(self.device) for x in batched_inputs] center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor offset_targets = [x["offset"].to(self.device) for x in batched_inputs] offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor offset_weights = [x["offset_weights"].to(self.device) for x in batched_inputs] offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor else: center_targets = None center_weights = None offset_targets = None offset_weights = None center_results, offset_results, center_losses, offset_losses = self.ins_embed_head( features, center_targets, center_weights, offset_targets, offset_weights ) losses.update(center_losses) losses.update(offset_losses) if self.training: return losses if self.benchmark_network_speed: return [] processed_results = [] for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip( sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height") width = input_per_image.get("width") r = sem_seg_postprocess(sem_seg_result, image_size, height, width) c = sem_seg_postprocess(center_result, image_size, height, width) o = sem_seg_postprocess(offset_result, image_size, height, width) # Post-processing to get panoptic segmentation. panoptic_image, _ = get_panoptic_segmentation( r.argmax(dim=0, keepdim=True), c, o, thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(), label_divisor=self.meta.label_divisor, stuff_area=self.stuff_area, void_label=-1, threshold=self.threshold, nms_kernel=self.nms_kernel, top_k=self.top_k, ) # For semantic segmentation evaluation. processed_results.append({"sem_seg": r}) panoptic_image = panoptic_image.squeeze(0) semantic_prob = F.softmax(r, dim=0) # Write results to disk: img = input_per_image["image"] from detectron2.utils.visualizer import Visualizer from detectron2.data.detection_utils import convert_image_to_rgb from PIL import Image import os img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format).astype("uint8") img = np.array(Image.fromarray(img).resize((width, height))) v_panoptic = Visualizer(img, self.meta) v_panoptic = v_panoptic.draw_panoptic_seg_predictions(panoptic_image.cpu(), None) pan_img = v_panoptic.get_image() image_path = input_per_image['file_name'].split(os.sep) image_name = os.path.splitext(image_path[-1])[0] Image.fromarray(pan_img).save(os.path.join('/home/ahabbas/projects/conseg/affinityNet/output_pdl/coco/eval_vis', image_name + '_panoptic.png')) # For panoptic segmentation evaluation. processed_results[-1]["panoptic_seg"] = (panoptic_image, None) # For instance segmentation evaluation. if self.predict_instances: instances = [] panoptic_image_cpu = panoptic_image.cpu().numpy() for panoptic_label in np.unique(panoptic_image_cpu): if panoptic_label == -1: continue pred_class = panoptic_label // self.meta.label_divisor isthing = pred_class in list( self.meta.thing_dataset_id_to_contiguous_id.values() ) # Get instance segmentation results. if isthing: instance = Instances((height, width)) # Evaluation code takes continuous id starting from 0 instance.pred_classes = torch.tensor( [pred_class], device=panoptic_image.device ) mask = panoptic_image == panoptic_label instance.pred_masks = mask.unsqueeze(0) # Average semantic probability sem_scores = semantic_prob[pred_class, ...] sem_scores = torch.mean(sem_scores[mask]) # Center point probability mask_indices = torch.nonzero(mask).float() center_y, center_x = ( torch.mean(mask_indices[:, 0]), torch.mean(mask_indices[:, 1]), ) center_scores = c[0, int(center_y.item()), int(center_x.item())] # Confidence score is semantic prob * center prob. instance.scores = torch.tensor( [sem_scores * center_scores], device=panoptic_image.device ) # Get bounding boxes instance.pred_boxes = BitMasks(instance.pred_masks).get_bounding_boxes() instances.append(instance) if len(instances) > 0: processed_results[-1]["instances"] = Instances.cat(instances) return processed_results
def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "instances": Instances * "sem_seg": semantic segmentation ground truth. * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: each dict is the results for one image. The dict contains the following keys: * "instances": see :meth:`GeneralizedRCNN.forward` for its format. * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. * "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`. See the return value of :func:`combine_semantic_and_instance_outputs` for its format. """ image_path = [x['file_name'] for x in batched_inputs] if self.training: flips = [x['flip'] for x in batched_inputs] else: flips = None if self.training: exemplar_input = self.get_exemplar_input(image_path, sample_size=1) if exemplar_input is not None: l = len(batched_inputs) batched_inputs = batched_inputs + exemplar_input images, features, proposals, gt_instances, gt_integral_sem_seg, _, losses = self._forward( batched_inputs) exemplar_features = {} for k, v in features.items(): exemplar_features[k] = v[l:] exemplar_gt_instances = gt_instances[l:] image_path = [x['file_name'] for x in batched_inputs] if self.training: exemplar_flips = [x['flip'] for x in batched_inputs] else: exemplar_flips = None with torch.no_grad(): exemplar_info = self.roi_heads.get_box_features( exemplar_features, exemplar_gt_instances) detector_results, detector_losses = self.roi_heads( images, features, proposals, gt_instances, gt_integral_sem_seg, image_path=image_path, flips=exemplar_flips, exemplar_info=exemplar_info) del exemplar_info, exemplar_input else: exemplar_info = None images, features, proposals, gt_instances, gt_integral_sem_seg, _, losses = self._forward( batched_inputs) detector_results, detector_losses = self.roi_heads( images, features, proposals, gt_instances, gt_integral_sem_seg, image_path=image_path, flips=flips, exemplar_info=exemplar_info) else: exemplar_info = None images, features, proposals, gt_instances, gt_integral_sem_seg, sem_seg_results, losses = self._forward( batched_inputs) detector_results, detector_losses = self.roi_heads( images, features, proposals, gt_instances, gt_integral_sem_seg, image_path=image_path, flips=flips, exemplar_info=exemplar_info) if self.training: losses.update({ k: v * self.instance_loss_weight for k, v in detector_losses.items() }) return losses processed_results = [] for sem_seg_result, detector_result, input_per_image, image_size in zip( sem_seg_results, detector_results, batched_inputs, images.image_sizes): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) detector_r = detector_postprocess(detector_result, height, width) processed_results.append({ "sem_seg": sem_seg_r, "instances": detector_r }) if self.combine_on: panoptic_r = combine_semantic_and_instance_outputs( detector_r, sem_seg_r.argmax(dim=0), self.combine_overlap_threshold, self.combine_stuff_area_limit, self.combine_instances_confidence_threshold, ) processed_results[-1]["panoptic_seg"] = panoptic_r return processed_results
def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "sem_seg": semantic segmentation ground truth * "center": center points heatmap ground truth * "offset": pixel offsets to center points ground truth * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model (may be different from input resolution), used in inference. Returns: list[dict]: each dict is the results for one image. The dict contains the following keys: * "instances": see :meth:`GeneralizedRCNN.forward` for its format. * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. * "panoptic_seg": see :func:`combine_semantic_and_instance_outputs` for its format. """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] size_divisibility = self.backbone.size_divisibility images = ImageList.from_tensors(images, size_divisibility) features = self.backbone(images.tensor) losses = {} if "sem_seg" in batched_inputs[0]: targets = [x["sem_seg"].to(self.device) for x in batched_inputs] targets = ImageList.from_tensors( targets, size_divisibility, self.sem_seg_head.ignore_value).tensor if "sem_seg_weights" in batched_inputs[0]: # The default D2 DatasetMapper may not contain "sem_seg_weights" # Avoid error in testing when default DatasetMapper is used. weights = [ x["sem_seg_weights"].to(self.device) for x in batched_inputs ] weights = ImageList.from_tensors(weights, size_divisibility).tensor else: weights = None else: targets = None weights = None sem_seg_results, sem_seg_losses = self.sem_seg_head( features, targets, weights) losses.update(sem_seg_losses) if "center" in batched_inputs[0] and "offset" in batched_inputs[0]: center_targets = [ x["center"].to(self.device) for x in batched_inputs ] center_targets = ImageList.from_tensors( center_targets, size_divisibility).tensor.unsqueeze(1) center_weights = [ x["center_weights"].to(self.device) for x in batched_inputs ] center_weights = ImageList.from_tensors(center_weights, size_divisibility).tensor offset_targets = [ x["offset"].to(self.device) for x in batched_inputs ] offset_targets = ImageList.from_tensors(offset_targets, size_divisibility).tensor offset_weights = [ x["offset_weights"].to(self.device) for x in batched_inputs ] offset_weights = ImageList.from_tensors(offset_weights, size_divisibility).tensor else: center_targets = None center_weights = None offset_targets = None offset_weights = None center_results, offset_results, center_losses, offset_losses = self.ins_embed_head( features, center_targets, center_weights, offset_targets, offset_weights) losses.update(center_losses) losses.update(offset_losses) if self.training: return losses processed_results = [] for sem_seg_result, center_result, offset_result, input_per_image, image_size in zip( sem_seg_results, center_results, offset_results, batched_inputs, images.image_sizes): height = input_per_image.get("height") width = input_per_image.get("width") r = sem_seg_postprocess(sem_seg_result, image_size, height, width) c = sem_seg_postprocess(center_result, image_size, height, width) o = sem_seg_postprocess(offset_result, image_size, height, width) # Post-processing to get panoptic segmentation. panoptic_image, _ = get_panoptic_segmentation( r.argmax(dim=0, keepdim=True), c, o, thing_ids=self.meta.thing_dataset_id_to_contiguous_id.values(), label_divisor=self.meta.label_divisor, stuff_area=self.stuff_area, void_label=-1, threshold=self.threshold, nms_kernel=self.nms_kernel, top_k=self.top_k, ) # For semantic segmentation evaluation. processed_results.append({"sem_seg": r}) panoptic_image = panoptic_image.squeeze(0) semantic_prob = F.softmax(r, dim=0) # For panoptic segmentation evaluation. processed_results[-1]["panoptic_seg"] = (panoptic_image, None) # For instance segmentation evaluation. if self.predict_instances: instances = [] panoptic_image_cpu = panoptic_image.cpu().numpy() for panoptic_label in np.unique(panoptic_image_cpu): if panoptic_label == -1: continue pred_class = panoptic_label // self.meta.label_divisor isthing = pred_class in list( self.meta.thing_dataset_id_to_contiguous_id.values()) # Get instance segmentation results. if isthing: instance = Instances((height, width)) # Evaluation code takes continuous id starting from 0 instance.pred_classes = torch.tensor( [pred_class], device=panoptic_image.device) mask = panoptic_image == panoptic_label instance.pred_masks = mask.unsqueeze(0) # Average semantic probability sem_scores = semantic_prob[pred_class, ...] sem_scores = torch.mean(sem_scores[mask]) # Center point probability mask_indices = torch.nonzero(mask).float() center_y, center_x = ( torch.mean(mask_indices[:, 0]), torch.mean(mask_indices[:, 1]), ) center_scores = c[0, int(center_y.item()), int(center_x.item())] # Confidence score is semantic prob * center prob. instance.scores = torch.tensor( [sem_scores * center_scores], device=panoptic_image.device) # Get bounding boxes instance.pred_boxes = BitMasks( instance.pred_masks).get_bounding_boxes() instances.append(instance) if len(instances) > 0: processed_results[-1]["instances"] = Instances.cat( instances) return processed_results
def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "instances": Instances * "sem_seg": semantic segmentation ground truth. * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: each dict is the results for one image. The dict contains the following keys: * "instances": see :meth:`GeneralizedRCNN.forward` for its format. * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. * "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`. See the return value of :func:`combine_semantic_and_instance_outputs` for its format. """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.backbone.size_divisibility) features = self.backbone(images.tensor) if "proposals" in batched_inputs[0]: proposals = [ x["proposals"].to(self.device) for x in batched_inputs ] proposal_losses = {} if "sem_seg" in batched_inputs[0]: gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs] gt_sem_seg = ImageList.from_tensors( gt_sem_seg, self.backbone.size_divisibility, self.sem_seg_head.ignore_value).tensor else: gt_sem_seg = None sem_seg_results, sem_seg_losses = self.sem_seg_head( features, gt_sem_seg) if "instances" in batched_inputs[0]: gt_instances = [ x["instances"].to(self.device) for x in batched_inputs ] else: gt_instances = None if self.proposal_generator: proposals, proposal_losses = self.proposal_generator( images, features, gt_instances) detector_results, detector_losses = self.roi_heads( images, features, proposals, gt_instances) if self.training: losses = {} losses.update(sem_seg_losses) losses.update({ k: v * self.instance_loss_weight for k, v in detector_losses.items() }) losses.update(proposal_losses) return losses processed_results = [] for sem_seg_result, detector_result, input_per_image, image_size in zip( sem_seg_results, detector_results, batched_inputs, images.image_sizes): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) detector_r = detector_postprocess(detector_result, height, width) processed_results.append({ "sem_seg": sem_seg_r, "instances": detector_r }) if self.combine_on: panoptic_r = combine_semantic_and_instance_outputs( detector_r, sem_seg_r.argmax(dim=0), self.combine_overlap_threshold, self.combine_stuff_area_limit, self.combine_instances_confidence_threshold, ) processed_results[-1]["panoptic_seg"] = panoptic_r return processed_results
def forward(self, batched_inputs): # complete image images = [x["image"].to(self.device) for x in batched_inputs] images = [self.normalizer(x) for x in images] images = ImageList.from_tensors(images, self.generator.size_divisibility) # triplet input maps: # erased regions masks = [x["mask"].to(self.device) for x in batched_inputs] masks = ImageList.from_tensors(masks, self.generator.size_divisibility) # mask the input image with masks erased_ims = images.tensor * (1. - masks.tensor) # ones map ones_ims = [ torch.ones_like(x["mask"].to(self.device)) for x in batched_inputs ] ones_ims = ImageList.from_tensors(ones_ims, self.generator.size_divisibility) # the conv layer use zero padding, this is used to indicate the image boundary # generation process input_tensor = torch.cat([erased_ims, ones_ims.tensor, masks.tensor], dim=1) coarse_inp, fine_inp, offset_flow = self.generator( input_tensor, masks.tensor) # offset_flow is used to visualize if self.training: # reconstruction loss losses = {} losses["loss_coarse_rec"] = self.loss_rec_weight * torch.abs( images.tensor - coarse_inp).mean() losses["loss_fine_rec"] = self.loss_rec_weight * torch.abs( images.tensor - fine_inp).mean() # discriminator real_and_fake_ims = torch.cat([images.tensor, fine_inp], dim=0) real_and_fake_masks = torch.cat([masks.tensor, masks.tensor], dim=0) # append masks disc_pred = self.discriminator( torch.cat([real_and_fake_ims, real_and_fake_masks], dim=1)) pred_for_real, pred_for_fake = torch.split(disc_pred, disc_pred.size(0) // 2, dim=0) # TODO: perhaps configure the loss function g_loss, d_loss = self.get_discriminator_hinge_loss( pred_for_real, pred_for_fake) losses['loss_gen'] = self.loss_gan_weight * g_loss losses['loss_disc'] = d_loss losses["generator_loss"] = sum([ losses[k] for k in ["loss_coarse_rec", "loss_fine_rec", "loss_gen"] ]) losses["discriminator_loss"] = losses["loss_disc"] return losses else: processed_results = [] inpainted_im = erased_ims * ( 1. - masks.tensor) + fine_inp * masks.tensor for result, input_per_image, image_size in zip( inpainted_im, batched_inputs, images.image_sizes): height = input_per_image.get("height") width = input_per_image.get("width") r = sem_seg_postprocess(result, image_size, height, width) # abuse semantic segmentation postprocess. it basically does some resize processed_results.append({"inpainted": r}) return processed_results
def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: image: Tensor, image in (C, H, W) format. instances: Instances sem_seg: semantic segmentation ground truth. Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. Returns: list[dict]: each dict is the results for one image. The dict contains the following keys: "instances": see :meth:`GeneralizedRCNN.forward` for its format. "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. "panoptic_seg": available when `PANOPTIC_FPN.COMBINE.ENABLED`. See the return value of :func:`combine_semantic_and_instance_outputs` for its format. """ images = [x["image"].to(self.device) for x in batched_inputs] images = [self.normalizer(x) for x in images] images = ImageList.from_tensors(images, self.backbone.size_divisibility) features = self.backbone(images.tensor) if self.combine_on: if "sem_seg" in batched_inputs[0]: gt_sem = [x["sem_seg"].to(self.device) for x in batched_inputs] gt_sem = ImageList.from_tensors( gt_sem, self.backbone.size_divisibility, self.panoptic_module.ignore_value).tensor else: gt_sem = None sem_seg_results, sem_seg_losses = self.panoptic_module( features, gt_sem) if "basis_sem" in batched_inputs[0]: basis_sem = [ x["basis_sem"].to(self.device) for x in batched_inputs ] basis_sem = ImageList.from_tensors(basis_sem, self.backbone.size_divisibility, 0).tensor else: basis_sem = None basis_out, basis_losses = self.basis_module(features, basis_sem) if "instances" in batched_inputs[0]: gt_instances = [ x["instances"].to(self.device) for x in batched_inputs ] else: gt_instances = None proposals, proposal_losses = self.proposal_generator( images, features, gt_instances, self.top_layer) detector_results, detector_losses = self.blender( basis_out["bases"], proposals, gt_instances) if self.training: losses = {} losses.update(basis_losses) losses.update({ k: v * self.instance_loss_weight for k, v in detector_losses.items() }) losses.update(proposal_losses) if self.combine_on: losses.update(sem_seg_losses) return losses processed_results = [] for i, (detector_result, input_per_image, image_size) in enumerate( zip(detector_results, batched_inputs, images.image_sizes)): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) detector_r = detector_postprocess(detector_result, height, width) processed_result = {"instances": detector_r} if self.combine_on: sem_seg_r = sem_seg_postprocess(sem_seg_results[i], image_size, height, width) processed_result["sem_seg"] = sem_seg_r if "seg_thing_out" in basis_out: seg_thing_r = sem_seg_postprocess(basis_out["seg_thing_out"], image_size, height, width) processed_result["sem_thing_seg"] = seg_thing_r if self.basis_module.visualize: processed_result["bases"] = basis_out["bases"] processed_results.append(processed_result) if self.combine_on: panoptic_r = combine_semantic_and_instance_outputs( detector_r, sem_seg_r.argmax(dim=0), self.combine_overlap_threshold, self.combine_stuff_area_limit, self.combine_instances_confidence_threshold) processed_results[-1]["panoptic_seg"] = panoptic_r return processed_results