def test_resnet_res_layer(): # Test ResLayer of 3 Bottleneck w\o downsample layer = ResLayer(Bottleneck, 64, 16, 3) assert len(layer) == 3 assert layer[0].conv1.in_channels == 64 assert layer[0].conv1.out_channels == 16 for i in range(1, len(layer)): assert layer[i].conv1.in_channels == 64 assert layer[i].conv1.out_channels == 16 for i in range(len(layer)): assert layer[i].downsample is None x = torch.randn(1, 64, 56, 56) x_out = layer(x) assert x_out.shape == torch.Size([1, 64, 56, 56]) # Test ResLayer of 3 Bottleneck with downsample layer = ResLayer(Bottleneck, 64, 64, 3) assert layer[0].downsample[0].out_channels == 256 for i in range(1, len(layer)): assert layer[i].downsample is None x = torch.randn(1, 64, 56, 56) x_out = layer(x) assert x_out.shape == torch.Size([1, 256, 56, 56]) # Test ResLayer of 3 Bottleneck with stride=2 layer = ResLayer(Bottleneck, 64, 64, 3, stride=2) assert layer[0].downsample[0].out_channels == 256 assert layer[0].downsample[0].stride == (2, 2) for i in range(1, len(layer)): assert layer[i].downsample is None x = torch.randn(1, 64, 56, 56) x_out = layer(x) assert x_out.shape == torch.Size([1, 256, 28, 28]) # Test ResLayer of 3 Bottleneck with stride=2 and average downsample layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True) assert isinstance(layer[0].downsample[0], AvgPool2d) assert layer[0].downsample[1].out_channels == 256 assert layer[0].downsample[1].stride == (1, 1) for i in range(1, len(layer)): assert layer[i].downsample is None x = torch.randn(1, 64, 56, 56) x_out = layer(x) assert x_out.shape == torch.Size([1, 256, 28, 28]) # Test ResLayer of 3 BasicBlock with stride=2 and downsample_first=False layer = ResLayer(BasicBlock, 64, 64, 3, stride=2, downsample_first=False) assert layer[2].downsample[0].out_channels == 64 assert layer[2].downsample[0].stride == (2, 2) for i in range(len(layer) - 1): assert layer[i].downsample is None x = torch.randn(1, 64, 56, 56) x_out = layer(x) assert x_out.shape == torch.Size([1, 64, 28, 28])
def __init__(self, num_convs=4, in_channels=256, conv_out_channels=256, num_classes=80, loss_weight=1.0, conv_cfg=None, norm_cfg=None, conv_to_res=False, init_cfg=dict( type='Normal', std=0.01, override=dict(name='fc'))): super(GlobalContextHead, self).__init__(init_cfg) self.num_convs = num_convs self.in_channels = in_channels self.conv_out_channels = conv_out_channels self.num_classes = num_classes self.loss_weight = loss_weight self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False if self.conv_to_res: num_res_blocks = num_convs // 2 self.convs = ResLayer( SimplifiedBasicBlock, in_channels, self.conv_out_channels, num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_convs = num_res_blocks else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = self.in_channels if i == 0 else conv_out_channels self.convs.append( ConvModule( in_channels, conv_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(conv_out_channels, num_classes) self.criterion = nn.BCEWithLogitsLoss()
def __init__(self, conv_to_res=True, **kwargs): super(SCNetSemanticHead, self).__init__(**kwargs) self.conv_to_res = conv_to_res if self.conv_to_res: num_res_blocks = self.num_convs // 2 self.convs = ResLayer(SimplifiedBasicBlock, self.in_channels, self.conv_out_channels, num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_convs = num_res_blocks
def __init__(self, conv_to_res=True, **kwargs): super(SCNetMaskHead, self).__init__(**kwargs) self.conv_to_res = conv_to_res if conv_to_res: assert self.conv_kernel_size == 3 self.num_res_blocks = self.num_convs // 2 self.convs = ResLayer(SimplifiedBasicBlock, self.in_channels, self.conv_out_channels, self.num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)
class GlobalContextHead(nn.Module): """Global context head used in SCNet https://arxiv.org/abs/2012.10150. Args: num_convs (int, optional): number of convolutional layer in GlbCtxHead. Default: 4. in_channels (int, optional): number of input channels. Default: 256. conv_out_channels (int, optional): number of output channels before classification layer. Default: 256. loss_weight (float, optional): global context loss weight. Default: 1. conv_cfg (dict, optional): config to init conv layer. Default: None. norm_cfg (dict, optional): config to init norm layer. Default: None. conv_to_res (bool, optional): if True, 2 convs will be grouped into 1 `SimplifiedBasicBlock` using a skip connection. Default: False. """ def __init__(self, num_convs=4, in_channels=256, conv_out_channels=256, num_classes=81, loss_weight=1.0, conv_cfg=None, norm_cfg=None, conv_to_res=False): super(GlobalContextHead, self).__init__() self.num_convs = num_convs self.in_channels = in_channels self.conv_out_channels = conv_out_channels self.num_classes = num_classes self.loss_weight = loss_weight self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False if self.conv_to_res: num_res_blocks = num_convs // 2 self.convs = ResLayer( SimplifiedBasicBlock, in_channels, self.conv_out_channels, num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_convs = num_res_blocks else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = self.in_channels if i == 0 else conv_out_channels self.convs.append( ConvModule( in_channels, conv_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(conv_out_channels, num_classes) self.criterion = nn.BCEWithLogitsLoss() def init_weights(self): """Init weights for the head.""" nn.init.normal_(self.fc.weight, 0, 0.01) nn.init.constant_(self.fc.bias, 0) @auto_fp16() def forward(self, feats): """Forward function.""" x = feats[-1] for i in range(self.num_convs): x = self.convs[i](x) x = self.pool(x) # multi-class prediction mc_pred = x.reshape(x.size(0), -1) mc_pred = self.fc(mc_pred) return mc_pred, x @force_fp32(apply_to=('pred', )) def loss(self, pred, labels): """Loss function.""" labels = [lbl.unique() for lbl in labels] targets = pred.new_zeros(pred.size()) for i, label in enumerate(labels): targets[i, label] = 1.0 loss = self.loss_weight * self.criterion(pred, targets) return loss
def __init__(self, num_convs=4, roi_feat_size=14, in_channels=256, conv_kernel_size=3, conv_out_channels=256, num_classes=80, class_agnostic=False, upsample_cfg=dict(type='deconv', scale_factor=2), conv_cfg=None, norm_cfg=None, conv_to_res=False, loss_mask=dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)): super(FCNMaskHead, self).__init__() self.upsample_cfg = upsample_cfg.copy() if self.upsample_cfg['type'] not in [ None, 'deconv', 'nearest', 'bilinear', 'carafe' ]: raise ValueError( f'Invalid upsample method {self.upsample_cfg["type"]}, ' 'accepted methods are "deconv", "nearest", "bilinear", ' '"carafe"') self.num_convs = num_convs # WARN: roi_feat_size is reserved and not used self.roi_feat_size = _pair(roi_feat_size) self.in_channels = in_channels self.conv_kernel_size = conv_kernel_size self.conv_out_channels = conv_out_channels self.upsample_method = self.upsample_cfg.get('type') self.scale_factor = self.upsample_cfg.pop('scale_factor', None) self.num_classes = num_classes self.class_agnostic = class_agnostic self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False self.loss_mask = build_loss(loss_mask) if conv_to_res: assert conv_kernel_size == 3 self.num_res_blocks = num_convs // 2 self.convs = ResLayer(SimplifiedBasicBlock, in_channels, self.conv_out_channels, self.num_res_blocks, conv_cfg=conv_cfg, norm_cfg=norm_cfg) else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = (self.in_channels if i == 0 else self.conv_out_channels) padding = (self.conv_kernel_size - 1) // 2 self.convs.append( ConvModule(in_channels, self.conv_out_channels, self.conv_kernel_size, padding=padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg)) upsample_in_channels = (self.conv_out_channels if self.num_convs > 0 else in_channels) upsample_cfg_ = self.upsample_cfg.copy() if self.upsample_method is None: self.upsample = None elif self.upsample_method == 'deconv': upsample_cfg_.update(in_channels=upsample_in_channels, out_channels=self.conv_out_channels, kernel_size=self.scale_factor, stride=self.scale_factor) self.upsample = build_upsample_layer(upsample_cfg_) elif self.upsample_method == 'carafe': upsample_cfg_.update(channels=upsample_in_channels, scale_factor=self.scale_factor) self.upsample = build_upsample_layer(upsample_cfg_) else: # suppress warnings align_corners = (None if self.upsample_method == 'nearest' else False) upsample_cfg_.update(scale_factor=self.scale_factor, mode=self.upsample_method, align_corners=align_corners) self.upsample = build_upsample_layer(upsample_cfg_) out_channels = 1 if self.class_agnostic else self.num_classes logits_in_channel = (self.conv_out_channels if self.upsample_method == 'deconv' else upsample_in_channels) self.conv_logits = Conv2d(logits_in_channel, out_channels, 1) self.relu = nn.ReLU(inplace=True) self.debug_imgs = None
class FCNMaskHead(nn.Module): def __init__(self, num_convs=4, roi_feat_size=14, in_channels=256, conv_kernel_size=3, conv_out_channels=256, num_classes=80, class_agnostic=False, upsample_cfg=dict(type='deconv', scale_factor=2), conv_cfg=None, norm_cfg=None, conv_to_res=False, loss_mask=dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)): super(FCNMaskHead, self).__init__() self.upsample_cfg = upsample_cfg.copy() if self.upsample_cfg['type'] not in [ None, 'deconv', 'nearest', 'bilinear', 'carafe' ]: raise ValueError( f'Invalid upsample method {self.upsample_cfg["type"]}, ' 'accepted methods are "deconv", "nearest", "bilinear", ' '"carafe"') self.num_convs = num_convs # WARN: roi_feat_size is reserved and not used self.roi_feat_size = _pair(roi_feat_size) self.in_channels = in_channels self.conv_kernel_size = conv_kernel_size self.conv_out_channels = conv_out_channels self.upsample_method = self.upsample_cfg.get('type') self.scale_factor = self.upsample_cfg.pop('scale_factor', None) self.num_classes = num_classes self.class_agnostic = class_agnostic self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False self.loss_mask = build_loss(loss_mask) if conv_to_res: assert conv_kernel_size == 3 self.num_res_blocks = num_convs // 2 self.convs = ResLayer(SimplifiedBasicBlock, in_channels, self.conv_out_channels, self.num_res_blocks, conv_cfg=conv_cfg, norm_cfg=norm_cfg) else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = (self.in_channels if i == 0 else self.conv_out_channels) padding = (self.conv_kernel_size - 1) // 2 self.convs.append( ConvModule(in_channels, self.conv_out_channels, self.conv_kernel_size, padding=padding, conv_cfg=conv_cfg, norm_cfg=norm_cfg)) upsample_in_channels = (self.conv_out_channels if self.num_convs > 0 else in_channels) upsample_cfg_ = self.upsample_cfg.copy() if self.upsample_method is None: self.upsample = None elif self.upsample_method == 'deconv': upsample_cfg_.update(in_channels=upsample_in_channels, out_channels=self.conv_out_channels, kernel_size=self.scale_factor, stride=self.scale_factor) self.upsample = build_upsample_layer(upsample_cfg_) elif self.upsample_method == 'carafe': upsample_cfg_.update(channels=upsample_in_channels, scale_factor=self.scale_factor) self.upsample = build_upsample_layer(upsample_cfg_) else: # suppress warnings align_corners = (None if self.upsample_method == 'nearest' else False) upsample_cfg_.update(scale_factor=self.scale_factor, mode=self.upsample_method, align_corners=align_corners) self.upsample = build_upsample_layer(upsample_cfg_) out_channels = 1 if self.class_agnostic else self.num_classes logits_in_channel = (self.conv_out_channels if self.upsample_method == 'deconv' else upsample_in_channels) self.conv_logits = Conv2d(logits_in_channel, out_channels, 1) self.relu = nn.ReLU(inplace=True) self.debug_imgs = None def init_weights(self): for m in [self.upsample, self.conv_logits]: if m is None: continue elif isinstance(m, CARAFEPack): m.init_weights() else: nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0) @auto_fp16() def forward(self, x): for conv in self.convs: x = conv(x) if self.upsample is not None: x = self.upsample(x) if self.upsample_method == 'deconv': x = self.relu(x) mask_pred = self.conv_logits(x) return mask_pred def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg): pos_proposals = [res.pos_bboxes for res in sampling_results] pos_assigned_gt_inds = [ res.pos_assigned_gt_inds for res in sampling_results ] mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, gt_masks, rcnn_train_cfg) return mask_targets @force_fp32(apply_to=('mask_pred', )) def loss(self, mask_pred, mask_targets, labels): loss = dict() if mask_pred.size(0) == 0: loss_mask = mask_pred.sum() else: if self.class_agnostic: loss_mask = self.loss_mask(mask_pred, mask_targets, torch.zeros_like(labels)) else: loss_mask = self.loss_mask(mask_pred, mask_targets, labels) loss['loss_mask'] = loss_mask return loss def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape, scale_factor, rescale): """Get segmentation masks from mask_pred and bboxes. Args: mask_pred (Tensor or ndarray): shape (n, #class, h, w). For single-scale testing, mask_pred is the direct output of model, whose type is Tensor, while for multi-scale testing, it will be converted to numpy array outside of this method. det_bboxes (Tensor): shape (n, 4/5) det_labels (Tensor): shape (n, ) img_shape (Tensor): shape (3, ) rcnn_test_cfg (dict): rcnn testing config ori_shape: original image size Returns: list[list]: encoded masks """ if isinstance(mask_pred, torch.Tensor): mask_pred = mask_pred.sigmoid() else: mask_pred = det_bboxes.new_tensor(mask_pred) device = mask_pred.device cls_segms = [[] for _ in range(self.num_classes) ] # BG is not included in num_classes bboxes = det_bboxes[:, :4] labels = det_labels if rescale: img_h, img_w = ori_shape[:2] else: if isinstance(scale_factor, float): img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32) img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32) else: w_scale, h_scale = scale_factor[0], scale_factor[1] img_h = np.round(ori_shape[0] * h_scale.item()).astype( np.int32) img_w = np.round(ori_shape[1] * w_scale.item()).astype( np.int32) scale_factor = 1.0 if not isinstance(scale_factor, (float, torch.Tensor)): scale_factor = bboxes.new_tensor(scale_factor) bboxes = bboxes / scale_factor if torch.onnx.is_in_onnx_export(): # TODO: Remove after F.grid_sample is supported. from torchvision.models.detection.roi_heads \ import paste_masks_in_image masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2]) thr = rcnn_test_cfg.get('mask_thr_binary', 0) if thr > 0: masks = masks >= thr return masks N = len(mask_pred) # The actual implementation split the input into chunks, # and paste them chunk by chunk. if device.type == 'cpu': # CPU is most efficient when they are pasted one by one with # skip_empty=True, so that it performs minimal number of # operations. num_chunks = N else: # GPU benefits from parallelism for larger chunks, # but may have memory issue num_chunks = int( np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) assert (num_chunks <= N), 'Default GPU_MEM_LIMIT is too small; try increasing it' chunks = torch.chunk(torch.arange(N, device=device), num_chunks) threshold = rcnn_test_cfg.mask_thr_binary im_mask = torch.zeros( N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8) if not self.class_agnostic: mask_pred = mask_pred[range(N), labels][:, None] for inds in chunks: masks_chunk, spatial_inds = _do_paste_mask( mask_pred[inds], bboxes[inds], img_h, img_w, skip_empty=device.type == 'cpu') if threshold >= 0: masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) else: # for visualization and debugging masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) im_mask[(inds, ) + spatial_inds] = masks_chunk for i in range(N): cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy()) return cls_segms
def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``.""" return ResLayer(**kwargs)
def __init__(self, num_ins, fusion_level, num_convs=4, in_channels=256, conv_out_channels=256, num_classes=183, ignore_label=255, loss_weight=0.2, conv_cfg=None, norm_cfg=None, conv_to_res=False): super(FusedSemanticHead, self).__init__() self.num_ins = num_ins self.fusion_level = fusion_level self.num_convs = num_convs self.in_channels = in_channels self.conv_out_channels = conv_out_channels self.num_classes = num_classes self.ignore_label = ignore_label self.loss_weight = loss_weight self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False self.lateral_convs = nn.ModuleList() for i in range(self.num_ins): self.lateral_convs.append( ConvModule(self.in_channels, self.in_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, inplace=False)) if self.conv_to_res: num_res_blocks = num_convs // 2 self.convs = ResLayer(SimplifiedBasicBlock, in_channels, self.conv_out_channels, num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_convs = num_res_blocks else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = self.in_channels if i == 0 else conv_out_channels self.convs.append( ConvModule(in_channels, conv_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.conv_embedding = ConvModule(conv_out_channels, conv_out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1) self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label)
class FusedSemanticHead(nn.Module): r"""Multi-level fused semantic segmentation head. .. code-block:: none in_1 -> 1x1 conv --- | in_2 -> 1x1 conv -- | || in_3 -> 1x1 conv - || ||| /-> 1x1 conv (mask prediction) in_4 -> 1x1 conv -----> 3x3 convs (*4) | \-> 1x1 conv (feature) in_5 -> 1x1 conv --- """ # noqa: W605 def __init__(self, num_ins, fusion_level, num_convs=4, in_channels=256, conv_out_channels=256, num_classes=183, ignore_label=255, loss_weight=0.2, conv_cfg=None, norm_cfg=None, conv_to_res=False): super(FusedSemanticHead, self).__init__() self.num_ins = num_ins self.fusion_level = fusion_level self.num_convs = num_convs self.in_channels = in_channels self.conv_out_channels = conv_out_channels self.num_classes = num_classes self.ignore_label = ignore_label self.loss_weight = loss_weight self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False self.lateral_convs = nn.ModuleList() for i in range(self.num_ins): self.lateral_convs.append( ConvModule(self.in_channels, self.in_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, inplace=False)) if self.conv_to_res: num_res_blocks = num_convs // 2 self.convs = ResLayer(SimplifiedBasicBlock, in_channels, self.conv_out_channels, num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_convs = num_res_blocks else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = self.in_channels if i == 0 else conv_out_channels self.convs.append( ConvModule(in_channels, conv_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.conv_embedding = ConvModule(conv_out_channels, conv_out_channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1) self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label) def init_weights(self): kaiming_init(self.conv_logits) @auto_fp16() def forward(self, feats): x = self.lateral_convs[self.fusion_level](feats[self.fusion_level]) fused_size = tuple(x.shape[-2:]) for i, feat in enumerate(feats): if i != self.fusion_level: feat = F.interpolate(feat, size=fused_size, mode='bilinear', align_corners=True) x += self.lateral_convs[i](feat) for i in range(self.num_convs): x = self.convs[i](x) mask_pred = self.conv_logits(x) x = self.conv_embedding(x) return mask_pred, x @force_fp32(apply_to=('mask_pred', )) def loss(self, mask_pred, labels): labels = labels.squeeze(1).long() loss_semantic_seg = self.criterion(mask_pred, labels) loss_semantic_seg *= self.loss_weight return loss_semantic_seg
class GlobalContextHead(nn.Module): """Global context head.""" def __init__(self, num_ins, num_convs=4, in_channels=256, conv_out_channels=256, num_classes=81, loss_weight=1.0, conv_cfg=None, norm_cfg=None, conv_to_res=False): super(GlobalContextHead, self).__init__() self.num_ins = num_ins self.num_convs = num_convs self.in_channels = in_channels self.conv_out_channels = conv_out_channels self.num_classes = num_classes self.loss_weight = loss_weight self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.conv_to_res = conv_to_res self.fp16_enabled = False if self.conv_to_res: num_res_blocks = num_convs // 2 self.convs = ResLayer(SimplifiedBasicBlock, in_channels, self.conv_out_channels, num_res_blocks, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg) self.num_convs = num_res_blocks else: self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = self.in_channels if i == 0 else conv_out_channels self.convs.append( ConvModule(in_channels, conv_out_channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg)) self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(conv_out_channels, num_classes) self.criterion = nn.BCEWithLogitsLoss() def init_weights(self): nn.init.normal_(self.fc.weight, 0, 0.01) nn.init.constant_(self.fc.bias, 0) @auto_fp16() def forward(self, feats): x = feats[-1] for i in range(self.num_convs): x = self.convs[i](x) x = self.pool(x) # multi-class prediction mc_pred = x.reshape(x.size(0), -1) mc_pred = self.fc(mc_pred) return mc_pred, x @force_fp32(apply_to=('pred', )) def loss(self, pred, labels): labels = [lbl.unique() for lbl in labels] targets = pred.new_zeros(pred.size()) for i, label in enumerate(labels): targets[i, label] = 1.0 loss = self.loss_weight * self.criterion(pred, targets) return loss