def new_encode_decode_encoder_encoder(img, img_metas, extract_layer_ids=[], use_resize=True): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = obj.extract_feat(img, extract_layer_ids=extract_layer_ids) if len(extract_layer_ids) > 0: x, feats = x out = obj._decode_head_forward_test(x, img_metas) if use_resize: out = resize(input=out, size=img.shape[2:], mode='bilinear', align_corners=obj.align_corners) if hasattr(obj, 'auxiliary_head'): out2 = obj._auxiliary_head_forward_test(x, img_metas) out2 = resize(input=out2, size=img.shape[2:], mode='bilinear', align_corners=obj.align_corners) if len(extract_layer_ids) > 0: return out, out2, feats else: return out, out2, None if len(extract_layer_ids) > 0: return out, feats else: return out, None
def forward(self, x): output = [] # sub 1 output.append(self.conv_sub1(x)) # sub 2 x = resize( x, scale_factor=0.5, mode='bilinear', align_corners=self.align_corners) x = self.backbone.stem(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) output.append(self.conv_sub2(x)) # sub 4 x = resize( x, scale_factor=0.5, mode='bilinear', align_corners=self.align_corners) x = self.backbone.layer3(x) x = self.backbone.layer4(x) psp_outs = self.psp_modules(x) + [x] psp_outs = torch.cat(psp_outs, dim=1) x = self.psp_bottleneck(psp_outs) output.append(self.conv_sub4(x)) return output
def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) # print(x.size()) aspp_outs = [ resize( self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) output = self.bottleneck(aspp_outs) if self.c1_bottleneck is not None: c1_output = self.c1_bottleneck(inputs[0]) output = resize( input=output, size=c1_output.shape[2:], mode='bilinear', align_corners=self.align_corners) output = torch.cat([output, c1_output], dim=1) output = self.sep_bottleneck(output) output = self.cls_seg(output) # print(a) return output
def forward(self, x): outs = list(self.backbone(x)) avg = F.adaptive_avg_pool2d(outs[-1], 1) avg_feat = self.conv_avg(avg) feature_up = resize(avg_feat, size=outs[-1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) arms_out = [] for i in range(len(self.arms)): x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up feature_up = resize(x_arm, size=outs[len(outs) - 1 - i - 1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) feature_up = self.convs[i](feature_up) arms_out.append(feature_up) feat_fuse = self.ffm(outs[0], arms_out[1]) # The `outputs` has four feature maps. # `outs[0]` is outputted for `STDCHead` auxiliary head. # Two feature maps of `arms_out` are outputted for auxiliary head. # `feat_fuse` is outputted for decoder head. outputs = [outs[0]] + list(arms_out) + [feat_fuse] return tuple(outputs)
def forward(self, inputs): assert len(inputs) == len(self.in_channels) # build laterals laterals = [ lateral_conv(inputs[i + self.start_level]) for i, lateral_conv in enumerate(self.lateral_convs) ] # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): # In some cases, fixing `scale factor` (e.g. 2) is preferred, but # it cannot co-exist with `size` in `F.interpolate`. if 'scale_factor' in self.upsample_cfg: laterals[i - 1] = laterals[i - 1] + resize( laterals[i], **self.upsample_cfg) else: prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] = laterals[i - 1] + resize( laterals[i], size=prev_shape, **self.upsample_cfg) # build outputs # part 1: from original levels outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) ] # part 2: add extra levels if self.num_outs > len(outs): # use max pool to get more levels on top of outputs # (e.g., Faster R-CNN, Mask R-CNN) if not self.add_extra_convs: for i in range(self.num_outs - used_backbone_levels): outs.append(F.max_pool2d(outs[-1], 1, stride=2)) # add conv layers on top of original feature maps (RetinaNet) else: if self.add_extra_convs == 'on_input': extra_source = inputs[self.backbone_end_level - 1] elif self.add_extra_convs == 'on_lateral': extra_source = laterals[-1] elif self.add_extra_convs == 'on_output': extra_source = outs[-1] else: raise NotImplementedError outs.append(self.fpn_convs[used_backbone_levels](extra_source)) for i in range(used_backbone_levels + 1, self.num_outs): if self.relu_before_extra_convs: outs.append(self.fpn_convs[i](F.relu(outs[-1]))) else: outs.append(self.fpn_convs[i](outs[-1])) return tuple(outs)
def forward(self, x): x_4, x_8, x_16, x_32 = self.backbone(x) x_gap = self.gap_conv(x_32) x_32_arm = self.arm32(x_32) x_32_sum = x_32_arm + x_gap x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest') x_32_up = self.conv_head32(x_32_up) x_16_arm = self.arm16(x_16) x_16_sum = x_16_arm + x_32_up x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest') x_16_up = self.conv_head16(x_16_up) return x_16_up, x_32_up
def _forward_feature(self, inputs): """Forward function for feature maps before classifying each pixel with ``self.cls_seg`` fc. Args: inputs (list[Tensor]): List of multi-level img features. Returns: feats (Tensor): A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head. """ inputs = self._transform_inputs(inputs) # build laterals laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] laterals.append(self.psp_forward(inputs)) # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] = laterals[i - 1] + resize( laterals[i], size=prev_shape, mode='bilinear', align_corners=self.align_corners) # build outputs fpn_outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1) ] # append psp feature fpn_outs.append(laterals[-1]) for i in range(used_backbone_levels - 1, 0, -1): fpn_outs[i] = resize( fpn_outs[i], size=fpn_outs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) fpn_outs = torch.cat(fpn_outs, dim=1) feats = self.fpn_bottleneck(fpn_outs) return feats
def contrastive_losses(self, seg_logits, gt_semantic_seg, seg_logits1, seg_label, img_metas): """Compute pixel-wise contrastive loss.""" loss = dict() seg_logit = resize(input=seg_logits, size=gt_semantic_seg.shape[2:], mode='bilinear', align_corners=self.align_corners) for i, logit in enumerate(seg_logits1): seg_logits1[i] = resize(input=logit, size=img_metas[0]['img_shape'][:2], mode='bilinear', align_corners=self.align_corners) for i, label in enumerate(seg_label): seg_label[i] = resize(input=label, size=img_metas[0]['img_shape'][:2], mode='bilinear', align_corners=self.align_corners) if self.sampler is not None: seg_weight = self.sampler.sample(seg_logit, gt_semantic_seg) else: seg_weight = None gt_semantic_seg = gt_semantic_seg.squeeze(1) loss['loss_seg'] = self.loss_decode(seg_logits1, seg_label, seg_logit, gt_semantic_seg, img_metas, weight=seg_weight, ignore_index=self.ignore_index) # loss['loss_seg'] = self.loss_decode1( # seg_logits1, # seg_label, # img_metas, # weight=None, # ignore_index=self.ignore_index) # loss['loss_seg'] = self.loss_decode1( # seg_logits1, # seg_label, # img_metas, # weight=None, # ignore_index=self.ignore_index) loss['acc_seg'] = accuracy(seg_logit, gt_semantic_seg) return loss
def forward(self, inputs): """Forward function.""" assert len(inputs) == len(self.in_channels), 'Length of inputs must \ be the same with self.in_channels!' feats = [ self.conv_layers[i - self.start_level](inputs[i]) for i in range(self.start_level, self.backbone_end_level) ] h, w = feats[0].shape[2:] for i in range(1, len(feats)): feats[i] = resize(feats[i], size=(h, w), mode='bilinear', align_corners=self.align_corners) feat = torch.cat(feats, dim=1) concat_feat = torch.cat([ self.dilation_layers[i](feat) for i in range(len(self.dilations)) ], dim=1) outs = [] # Default: outs[2] is the output of JPU for decoder head, outs[1] is # the feature map from backbone for auxiliary head. Additionally, # outs[0] can also be used for auxiliary head. for i in range(self.start_level, self.backbone_end_level - 1): outs.append(inputs[i]) outs.append(concat_feat) return tuple(outs)
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): """Resize pos_embed weights. Resize pos_embed using bicubic interpolate method. Args: pos_embed (torch.Tensor): Position embedding weights. input_shpae (tuple): Tuple for (downsampled input image height, downsampled input image width). pos_shape (tuple): The resolution of downsampled origin training image. mode (str): Algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'nearest'`` Return: torch.Tensor: The resized pos_embed of shape [B, L_new, C] """ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' pos_h, pos_w = pos_shape cls_token_weight = pos_embed[:, 0] pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] pos_embed_weight = pos_embed_weight.reshape( 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) cls_token_weight = cls_token_weight.unsqueeze(1) pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) return pos_embed
def simple_test(self, img: torch.Tensor, img_meta: Iterable, **kwargs) -> list: if not self.is_cuda_available: img = img.detach().cpu() elif self.device_id >= 0: img = img.cuda(self.device_id) device_type = img.device.type self.io_binding.bind_input(name='input', device_type=device_type, device_id=self.device_id, element_type=np.float32, shape=img.shape, buffer_ptr=img.data_ptr()) self.sess.run_with_iobinding(self.io_binding) seg_pred = self.io_binding.copy_outputs_to_cpu()[0] # whole might support dynamic reshape ori_shape = img_meta[0]['ori_shape'] if not (ori_shape[0] == seg_pred.shape[-2] and ori_shape[1] == seg_pred.shape[-1]): seg_pred = torch.from_numpy(seg_pred).float() seg_pred = resize(seg_pred, size=tuple(ori_shape[:2]), mode='nearest') seg_pred = seg_pred.long().detach().cpu().numpy() seg_pred = seg_pred[0] seg_pred = list(seg_pred) return seg_pred
def losses(self, seg_logit, seg_label): """Compute segmentation loss.""" loss = dict() seg_logit = resize( input=seg_logit, size=seg_label.shape[2:], mode='bilinear', align_corners=self.align_corners) if self.sampler is not None: seg_weight = self.sampler.sample(seg_logit, seg_label) else: seg_weight = None seg_label = seg_label.squeeze(1) for loss_decode in self.loss_decode: if loss_decode.loss_name not in loss: loss[loss_decode.loss_name] = loss_decode( seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index) else: loss[loss_decode.loss_name] += loss_decode( seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index) loss['acc_seg'] = accuracy(seg_logit, seg_label) return loss
def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize(input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs
def forward(self, x): """Forward function.""" pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) # [batch_size, channels, h, w] x = self.input_redu_conv(x) # [batch_size, channels, pool_scale, pool_scale] pooled_x = self.pooled_redu_conv(pooled_x) batch_size = x.size(0) # [batch_size, pool_scale * pool_scale, channels] pooled_x = pooled_x.view(batch_size, self.channels, -1).permute(0, 2, 1).contiguous() # [batch_size, h * w, pool_scale * pool_scale] affinity_matrix = self.gla(x + resize( self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) ).permute(0, 2, 3, 1).reshape( batch_size, -1, self.pool_scale**2) affinity_matrix = F.sigmoid(affinity_matrix) # [batch_size, h * w, channels] z_out = torch.matmul(affinity_matrix, pooled_x) # [batch_size, channels, h * w] z_out = z_out.permute(0, 2, 1).contiguous() # [batch_size, channels, h, w] z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) z_out = self.residual_conv(z_out) z_out = F.relu(z_out + x) if self.fusion: z_out = self.fusion_conv(z_out) return z_out
def forward(self, x_d, x_s): detail_dwconv = self.detail_dwconv(x_d) detail_down = self.detail_down(x_d) semantic_conv = self.semantic_conv(x_s) semantic_dwconv = self.semantic_dwconv(x_s) semantic_conv = resize(input=semantic_conv, size=detail_dwconv.shape[2:], mode='bilinear', align_corners=self.align_corners) fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) fuse_2 = resize(input=fuse_2, size=fuse_1.shape[2:], mode='bilinear', align_corners=self.align_corners) output = self.conv(fuse_1 + fuse_2) return output
def forward(self, query_feats, key_feats): """Forward function.""" context = super(ObjectAttentionBlock, self).forward(query_feats, key_feats) output = self.bottleneck(torch.cat([context, query_feats], dim=1)) if self.query_downsample is not None: output = resize(query_feats) return output
def forward(self, *inputs): x = inputs[0] if len(inputs) == 2: if x.shape != inputs[1].shape: res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False) else: res = inputs[1] x = x + self.res_conv_unit1(res) x = self.res_conv_unit2(x) x = resize(x, scale_factor=2, mode='bilinear', align_corners=self.align_corners) x = self.project(x) return x
def encode_decode(self, img, img_metas): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = self.extract_feat(img) out = self._decode_head_forward_test(x, img_metas) out = resize(input=out, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners) return out
def forward(self, input): conv_out = self.conv(input) pool_out = self.pool(input) pool_out = resize(input=pool_out, size=conv_out.size()[2:], mode='bilinear', align_corners=False) output = torch.cat([conv_out, pool_out], 1) output = self.bn(output) output = self.act(output) return output
def forward(self, x): """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(x) upsampled_ppm_out = resize(ppm_out, size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ppm_outs.append(upsampled_ppm_out) return ppm_outs
def slide_inference(self, img, img_meta, rescale): """Inference by sliding-window with overlap. If h_crop > h_img or w_crop > w_img, the small patch will be used to decode without padding. """ h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = img.size() num_classes = self.num_classes h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 # preds = [img.new_zeros((batch_size, num_classes, h_img, w_img)), preds = [ img.new_zeros((batch_size, num_classes, h_img, w_img)), img.new_zeros((batch_size, num_classes, h_img, w_img)), img.new_zeros((batch_size, num_classes, h_img, w_img)) ] count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) for h_idx in range(h_grids): for w_idx in range(w_grids): y1 = h_idx * h_stride x1 = w_idx * w_stride y2 = min(y1 + h_crop, h_img) x2 = min(x1 + w_crop, w_img) y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] crop_seg_logits = self.encode_decode(crop_img, img_meta) for idx, crop_seg_logit in enumerate(crop_seg_logits): preds[idx] += F.pad( crop_seg_logit, (int(x1), int(preds[idx].shape[3] - x2), int(y1), int(preds[idx].shape[2] - y2))) count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 if torch.onnx.is_in_onnx_export(): # cast count_mat to constant while exporting to ONNX count_mat = torch.from_numpy( count_mat.cpu().detach().numpy()).to(device=img.device) preds = [pred / count_mat for pred in preds] if rescale: preds = [ resize(pred, size=img_meta[0]['ori_shape'][:2], mode='bilinear', align_corners=self.align_corners, warning=False) for pred in preds ] # A-->[A, ..] return preds
def head(self, x, inputs): aspp_outs = [ resize(self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) output = self.bottleneck(aspp_outs) if self.c1_bottleneck is not None: c1_output = self.c1_bottleneck(inputs[0]) output = resize(input=output, size=c1_output.shape[2:], mode='bilinear', align_corners=self.align_corners) output = torch.cat([output, c1_output], dim=1) output = self.sep_bottleneck(output) output = self.cls_seg(output) return output
def forward(self, higher_res_feature, lower_res_feature): lower_res_feature = resize(lower_res_feature, size=higher_res_feature.size()[2:], mode='bilinear', align_corners=self.align_corners) lower_res_feature = self.dwconv(lower_res_feature) lower_res_feature = self.conv_lower_res(lower_res_feature) higher_res_feature = self.conv_higher_res(higher_res_feature) out = higher_res_feature + lower_res_feature return self.relu(out)
def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) # build laterals laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] laterals.append(self.psp_forward(inputs)) # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] += resize( laterals[i], size=prev_shape, mode='bilinear', align_corners=self.align_corners) # build outputs fpn_outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1) ] # append psp feature fpn_outs.append(laterals[-1]) for i in range(used_backbone_levels - 1, 0, -1): fpn_outs[i] = resize( fpn_outs[i], size=fpn_outs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) fpn_outs = torch.cat(fpn_outs, dim=1) output = self.fpn_bottleneck(fpn_outs) output = self.cls_seg(output) return output
def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) x = inputs[-1] x = self.aspp_conv(x) * resize(self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) x = self.conv_up_input(x) for i in range(len(self.branch_channels) - 1, -1, -1): x = resize(x, size=inputs[i].size()[2:], mode='bilinear', align_corners=self.align_corners) x = torch.cat([x, self.convs[i](inputs[i])], 1) x = self.conv_ups[i](x) return self.cls_seg(x)
def whole_inference(self, img, img_meta, rescale): """Inference with full image.""" seg_logit = self.encode_decode(img, img_meta) if rescale: seg_logit = resize(seg_logit, size=img_meta[0]['ori_shape'][:2], mode='bilinear', align_corners=self.align_corners, warning=False) return seg_logit
def encode_decode(self, img, img_metas): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = self.extract_feat_l(img) out = self.decode_head_l[0].forward_test(x, img_metas, self.test_cfg) for i in range(1, self.num_stages): out = self.decode_head_l[i].forward_test(x, out, img_metas, self.test_cfg) out = resize(input=out, size=img.shape[2:], mode='bilinear', align_corners=self.align_corners) return out
def forward(self, x_low, x_high): x_low = resize(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=self.align_corners) # Note: Different from original paper, `x_low` is underwent # `self.conv_low` rather than another 1x1 conv classifier # before being used for auxiliary head. x_low = self.conv_low(x_low) x_high = self.conv_high(x_high) x = x_low + x_high x = F.relu(x, inplace=True) return x, x_low
def forward(self, inputs): """Forward function.""" x = self._transform_inputs(inputs) aspp_outs = [ resize(self.image_pool(x), size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ] aspp_outs.extend(self.aspp_modules(x)) aspp_outs = torch.cat(aspp_outs, dim=1) output = self.bottleneck(aspp_outs) output = self.cls_seg(output) return output
def forward(self, inputs, prev_output): feats = self.bottleneck(inputs[self.feature_key]) """Forward function.""" cur_prob = resize(input=prev_output, size=feats.shape[2:], mode='bilinear', align_corners=self.align_corners) context = self.spatial_gather_module(feats, cur_prob) output = self.object_context_block(feats, context) # build decoder for i in range(self.decoder_stage): low_level_key = self.low_level_key[i] low_level_feats = self.project[i](inputs[low_level_key]) output = resize(output, size=low_level_feats.shape[2:], mode='bilinear', align_corners=self.align_corners) output = torch.cat([output, low_level_feats], dim=1) output = self.fuse[i](output) output = self.cls_seg(output) return output