def forward(self, batched_inputs): images = [x["image"].to(self.device) for x in batched_inputs] images = [self.normalizer(x) for x in images] images = ImageList.from_tensors(images, self.backbone.size_divisibility) features = self.backbone(images.tensor) if "instances" in batched_inputs[0] : gt_instances = [x["instances"].to(self.device) for x in batched_inputs] elif "targets" in batched_inputs[0]: log_first_n( logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10 ) gt_instances = [x["targets"].to(self.device) for x in batched_inputs] else: gt_instances = None proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) if not self.training: if 'instance' in self.gt_input: assert gt_instances is not None for im_i in range(len(gt_instances)): gt_instances_per_im = gt_instances[im_i] bboxes = gt_instances_per_im.gt_boxes.tensor instances_per_im = Instances(proposals[im_i]._image_size) instances_per_im.pred_boxes = Boxes(bboxes) instances_per_im.pred_classes = gt_instances_per_im.gt_classes instances_per_im.scores = torch.ones_like(gt_instances_per_im.gt_classes).to(bboxes.device) if gt_instances_per_im.has("gt_masks"): gt_masks = gt_instances_per_im.gt_masks ext_pts_off = self.refinement_head.refine_head.get_simple_extreme_points( gt_masks.polygons).to(bboxes.device) ex_t = torch.stack([ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2) ex_l = torch.stack([bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2) ex_b = torch.stack([ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2) ex_r = torch.stack([bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2) instances_per_im.ext_points = ExtremePoints( torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1)) else: quad = self.refinement_head.refine_head.get_quadrangle(bboxes).view(-1, 4, 2) instances_per_im.ext_points = ExtremePoints(quad) proposals[im_i] = instances_per_im head_losses, proposals = self.refinement_head(features, proposals, gt_instances) # In training, the proposals are not useful at all in RPN models; but not here # This makes RPN-only models about 5% slower. if self.training: proposal_losses.update(head_losses) return proposal_losses processed_results = [] for results_per_image, input_per_image, image_size in zip( proposals, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) instance_r = detector_postprocess(results_per_image, height, width) processed_results.append( {"instances": instance_r} ) return processed_results
def forward(self, features, pred_instances=None, targets=None): for i, f in enumerate(self.in_features): if i == 0: x = self.scale_heads[i](features[f]) else: x = x + self.scale_heads[i](features[f]) pred_logits = self.predictor(x) pred_edge = pred_logits.sigmoid() att_map = self.attender(1 - pred_edge) # regions that need evolution if self.training: edge_target = targets[0] snake_input = x pred_edge_full = F.interpolate( pred_edge, scale_factor=self.common_stride, mode="bilinear", align_corners=False, ) snake_input = torch.cat([att_map, x], dim=1) # Quick fix for batches that do not have poly after filtering try: _, poly_loss = self.refine_head(snake_input, None, targets[1]) except Exception: poly_loss = {} edge_loss = self.loss(pred_edge_full, edge_target) * self.loss_weight poly_loss.update({ "loss_edge_det": edge_loss, }) return [], poly_loss, [] else: snake_input = torch.cat([att_map, x], dim=1) if "instance" in self.gt_input: assert targets[1][0] is not None for im_i in range(len(targets[1][0])): gt_instances_per_im = targets[1][0][im_i] bboxes = gt_instances_per_im.gt_boxes.tensor instances_per_im = Instances( pred_instances[im_i]._image_size) instances_per_im.pred_boxes = Boxes(bboxes) instances_per_im.pred_classes = gt_instances_per_im.gt_classes instances_per_im.scores = torch.ones_like( gt_instances_per_im.gt_classes, device=bboxes.device) if gt_instances_per_im.has("gt_masks"): gt_masks = gt_instances_per_im.gt_masks ext_pts_off = self.refine_head.get_simple_extreme_points( gt_masks.polygons).to(bboxes.device) ex_t = torch.stack( [ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2) ex_l = torch.stack( [bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2) ex_b = torch.stack( [ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2) ex_r = torch.stack( [bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2) instances_per_im.ext_points = ExtremePoints( torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1)) pred_instances[im_i] = instances_per_im new_instances, _ = self.refine_head(snake_input, pred_instances, None) pred_edge = att_map return pred_edge, {}, new_instances
def compute_targets_for_polys(self, targets): init_sample_locations = [] init_sample_targets = [] poly_sample_locations = [] poly_sample_targets = [] image_index = [] scales = [] # per image for im_i in range(len(targets)): targets_per_im = targets[im_i] bboxes = targets_per_im.gt_boxes.tensor # no gt if bboxes.numel() == 0: continue gt_masks = targets_per_im.gt_masks # use this as a scaling ws = bboxes[:, 2] - bboxes[:, 0] hs = bboxes[:, 3] - bboxes[:, 1] quadrangle = (self.get_quadrangle(bboxes).cpu().numpy().reshape( -1, 4, 2)) # (k, 4, 2) if self.initial == "octagon": # [t_H_off, l_V_off, b_H_off, r_V_off] ext_pts_off = self.get_simple_extreme_points( gt_masks.polygons).to(bboxes.device) ex_t = torch.stack( [ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2) ex_l = torch.stack( [bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2) ex_b = torch.stack( [ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2) ex_r = torch.stack( [bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2) # k x 4 x 2 ext_points = torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1) # N x 16 (ccw) octagons = (ExtremePoints( torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1)).get_octagons().cpu().numpy().reshape( -1, 8, 2)) else: raise ValueError("Invalid initial input!") # List[nd.array], element shape: (P, 2) OR None contours = self.get_simple_contour(gt_masks) # per instance for (quad, oct, cnt, ext, w, h) in zip(quadrangle, octagons, contours, ext_points, ws, hs): if cnt is None: continue # used for normalization scale = torch.min(w, h) # make it clock-wise cnt = cnt[::-1] if Polygon(cnt).exterior.is_ccw else cnt assert not Polygon( cnt).exterior.is_ccw, "1) contour must be clock-wise!" # sampling from quadrangle # print(quad.shape) # print(oct.shape) quad_sampled_pts = self.uniform_sample(quad, 40) # sampling from octagon oct_sampled_pts = self.uniform_sample(oct, self.num_sampling) oct_sampled_pts = (oct_sampled_pts[::-1] if Polygon(oct_sampled_pts).exterior.is_ccw else oct_sampled_pts) assert not Polygon( oct_sampled_pts ).exterior.is_ccw, "1) contour must be clock-wise!" # sampling from ground truth oct_sampled_targets = self.uniform_sample( cnt, len(cnt) * self.num_sampling) # (big, 2) # i) find a single nearest, so that becomes ordered point sets tt_idx = np.argmin( np.power(oct_sampled_targets - oct_sampled_pts[0], 2).sum(axis=1)) oct_sampled_targets = np.roll(oct_sampled_targets, -tt_idx, axis=0)[::len(cnt)] # assert not Polygon(oct_sampled_targets).exterior.is_ccw, '2) contour must be clock-wise!' quad_sampled_pts = torch.tensor(quad_sampled_pts, device=bboxes.device) oct_sampled_pts = torch.tensor(oct_sampled_pts, device=bboxes.device) oct_sampled_targets = torch.tensor(oct_sampled_targets, device=bboxes.device) # oct_sampled_targets = gt_sampled_pts - oct_sampled_pts # offset field init_sample_locations.append(quad_sampled_pts) init_sample_targets.append(ext) poly_sample_locations.append(oct_sampled_pts) poly_sample_targets.append(oct_sampled_targets) image_index.append(im_i) scales.append(scale) init_sample_locations = torch.stack(init_sample_locations, dim=0) init_sample_targets = torch.stack(init_sample_targets, dim=0) poly_sample_locations = torch.stack(poly_sample_locations, dim=0) poly_sample_targets = torch.stack(poly_sample_targets, dim=0) image_index = torch.tensor(image_index, device=bboxes.device) scales = torch.stack(scales, dim=0) return { "quadrangle_locs": init_sample_locations, "quadrangle_targets": init_sample_targets, "octagon_locs": poly_sample_locations, "octagon_targets": poly_sample_targets, "scales": scales, "image_idx": image_index, }
def forward(self, features, pred_instances=None, targets=None): if self.training: training_targets = self.compute_targets_for_polys(targets) locations, reg_targets, scales, image_idx = ( training_targets["octagon_locs"], training_targets["octagon_targets"], training_targets["scales"], training_targets["image_idx"], ) init_locations, init_targets = ( training_targets["quadrangle_locs"], training_targets["quadrangle_targets"], ) else: assert pred_instances is not None init_locations, image_idx = self.sample_quadrangles_fast( pred_instances) if len(init_locations) == 0: return pred_instances, {} # enhance bottom features TODO: maybe reduce later for i in range(self.num_convs): features = self.bottom_out[i](features) pred_exts = self.init(self.init_snake, features, init_locations, image_idx) if not self.training: h = features.shape[2] * 4 w = features.shape[3] * 4 poly_sample_locations = [] for i, instance_per_im in enumerate(pred_instances): pred_exts_per_im = pred_exts[image_idx == i] # N x 4 x 2 pred_exts_per_im[..., 0] = torch.clamp(pred_exts_per_im[..., 0], min=0, max=w - 1) pred_exts_per_im[..., 1] = torch.clamp(pred_exts_per_im[..., 1], min=0, max=h - 1) if not instance_per_im.has("ext_points"): instance_per_im.ext_points = ExtremePoints( pred_exts_per_im) poly_sample_locations.append( self.get_octagon(pred_exts_per_im, self.num_sampling)) else: # NOTE: For GT Input testing # print('Using GT EX') poly_sample_locations.append( self.get_octagon(instance_per_im.ext_points.tensor, self.num_sampling)) locations = cat(poly_sample_locations, dim=0) location_preds = [] for i in range(len(self.num_iter)): deformer = self.__getattr__("deformer" + str(i)) if i == 0: pred_location = self.evolve(deformer, features, locations, image_idx) else: pred_location = self.evolve(deformer, features, pred_location, image_idx) location_preds.append(pred_location) if self.training: evolve_loss = 0 for pred in location_preds: evolve_loss += (self.loss_reg( pred / scales[:, None, None], reg_targets / scales[:, None, None], ) / 3) init_loss = self.loss_reg(pred_exts / scales[:, None, None], init_targets / scales[:, None, None]) losses = { "loss_evolve": evolve_loss * self.refine_loss_weight, "loss_init": init_loss * self.refine_loss_weight, } return [], losses else: new_instances = self.predict_postprocess(pred_instances, locations, location_preds, image_idx) return new_instances, {}
def forward(self, features, pred_instances=None, targets=None): if self.edge_on: with timer.env("pfpn_back"): for i, f in enumerate(self.in_features): if i == 0: x = self.scale_heads[i](features[f]) else: x = x + self.scale_heads[i](features[f]) if self.edge_on: with timer.env("edge"): pred_logits = self.predictor(x) pred_edge = pred_logits.sigmoid() if self.attention: # print('pred edge', pred_edge) att_map = self.attender( 1 - pred_edge ) # regions that need evolution if self.training: edge_target = targets[0] if self.edge_in: edge_prior = targets[0].unsqueeze(1).float().clone() # (B, 1, H, W) edge_prior[edge_prior == self.ignore_value] = 0 # remove ignore value edge_prior = self.mean_filter(edge_prior) edge_prior = F.interpolate( edge_prior, scale_factor=1 / self.common_stride, mode="bilinear", align_corners=False, ) edge_prior[edge_prior > 0] = 1 if self.strong_feat: snake_input = torch.cat([edge_prior, x], dim=1) else: snake_input = torch.cat([edge_prior, features["p2"]], dim=1) else: if self.strong_feat: snake_input = x else: snake_input = features["p2"] if self.edge_on: pred_edge_full = F.interpolate( pred_edge, scale_factor=self.common_stride, mode="bilinear", align_corners=False, ) if self.selective_refine: edge_prior = targets[0].unsqueeze(1).float().clone() # (B, 1, H, W) edge_prior[edge_prior == self.ignore_value] = 0 # remove ignore value edge_prior = self.dilate_filter(edge_prior) # edge_prior = self.dilate_filter(edge_prior) # edge_target = edge_prior.clone() edge_prior[edge_prior > 0] = 1 edge_prior = F.interpolate( edge_prior, scale_factor=1 / self.common_stride, mode="bilinear", align_corners=False, ) if self.strong_feat: snake_input = torch.cat([edge_prior, x], dim=1) else: if self.pred_edge: snake_input = torch.cat( [edge_prior, pred_logits, features["p2"]], dim=1 ) else: snake_input = torch.cat([edge_prior, features["p2"]], dim=1) if self.attention: if self.strong_feat: snake_input = torch.cat([att_map, x], dim=1) else: # dont cater pred_edge option now snake_input = torch.cat([att_map, features["p2"]], dim=1) ### Quick fix for batches that do not have poly after filtering _, poly_loss = self.refine_head(snake_input, None, targets[1]) if self.edge_on: edge_loss = self.loss(pred_edge_full, edge_target) * self.loss_weight poly_loss.update( { "loss_edge_det": edge_loss, } ) return [], poly_loss, [] else: if self.edge_in or self.selective_refine: if self.edge_map_thre > 0: pred_edge = (pred_edge > self.edge_map_thre).float() if "edge" in self.gt_input: assert targets[0] is not None pred_edge = targets[0].unsqueeze(1).float().clone() pred_edge[pred_edge == self.ignore_value] = 0 # remove ignore value if self.selective_refine: pred_edge = self.dilate_filter(pred_edge) # pred_edge = self.dilate_filter(pred_edge) pred_edge = F.interpolate( pred_edge, scale_factor=1 / self.common_stride, mode="bilinear", align_corners=False, ) pred_edge[pred_edge > 0] = 1 if self.strong_feat: snake_input = torch.cat([pred_edge, x], dim=1) else: snake_input = torch.cat([pred_edge, features["p2"]], dim=1) else: if self.strong_feat: snake_input = x else: snake_input = features["p2"] if self.attention: if self.strong_feat: snake_input = torch.cat([att_map, x], dim=1) else: # dont cater pred_edge option now snake_input = torch.cat([att_map, features["p2"]], dim=1) if "instance" in self.gt_input: assert targets[1][0] is not None for im_i in range(len(targets[1][0])): gt_instances_per_im = targets[1][0][im_i] bboxes = gt_instances_per_im.gt_boxes.tensor instances_per_im = Instances(pred_instances[im_i]._image_size) instances_per_im.pred_boxes = Boxes(bboxes) instances_per_im.pred_classes = gt_instances_per_im.gt_classes instances_per_im.scores = torch.ones_like( gt_instances_per_im.gt_classes, device=bboxes.device ) if gt_instances_per_im.has("gt_masks"): gt_masks = gt_instances_per_im.gt_masks ext_pts_off = self.refine_head.get_simple_extreme_points( gt_masks.polygons ).to(bboxes.device) ex_t = torch.stack( [ext_pts_off[:, None, 0], bboxes[:, None, 1]], dim=2 ) ex_l = torch.stack( [bboxes[:, None, 0], ext_pts_off[:, None, 1]], dim=2 ) ex_b = torch.stack( [ext_pts_off[:, None, 2], bboxes[:, None, 3]], dim=2 ) ex_r = torch.stack( [bboxes[:, None, 2], ext_pts_off[:, None, 3]], dim=2 ) instances_per_im.ext_points = ExtremePoints( torch.cat([ex_t, ex_l, ex_b, ex_r], dim=1) ) # TODO: NOTE: Test for theoretic limit. ##### # contours = self.refine_head.get_simple_contour(gt_masks) # poly_sample_targets = [] # for i, cnt in enumerate(contours): # if cnt is None: # xmin, ymin = bboxes[:, 0], bboxes[:, 1] # (n,) # xmax, ymax = bboxes[:, 2], bboxes[:, 3] # (n,) # box = [ # xmax, ymin, xmin, ymin, xmin, ymax, xmax, ymax # ] # box = torch.stack(box, dim=1).view(-1, 4, 2) # sampled_box = self.refine_head.uniform_upsample(box[None], # self.refine_head.num_sampling) # poly_sample_targets.append(sampled_box[i]) # # print(sampled_box.shape) # continue # # # 1) uniform-sample # oct_sampled_targets = self.refine_head.uniform_sample(cnt, # len(cnt) * self.refine_head.num_sampling) # (big, 2) # tt_idx = np.random.randint(len(oct_sampled_targets)) # oct_sampled_targets = np.roll(oct_sampled_targets, -tt_idx, axis=0)[::len(cnt)] # oct_sampled_targets = torch.tensor(oct_sampled_targets, device=bboxes.device) # poly_sample_targets.append(oct_sampled_targets) # # print(oct_sampled_targets.shape) # # # 2) polar-sample # # ... # poly_sample_targets = torch.stack(poly_sample_targets, dim=0) # instances_per_im.pred_polys = PolygonPoints(poly_sample_targets) # TODO: NOTE: Test for theoretic limit. ##### pred_instances[im_i] = instances_per_im new_instances, _ = self.refine_head(snake_input, pred_instances, None) # new_instances = pred_instances if not self.edge_on: pred_edge = torch.rand(1, 1, 5, 5, device=snake_input.device) if self.attention: pred_edge = att_map return pred_edge, {}, new_instances