class ResNet(nn.Module): """ResNet backbone. Args: depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. pretrained (str | None): Name of pretrained model. Default: None. in_channels (int): Channel num of input features. Default: 3. num_stages (int): Resnet stages. Default: 4. strides (Sequence[int]): Strides of the first block of each stage. out_indices (Sequence[int]): Indices of output feature. Default: (3, ). dilations (Sequence[int]): Dilation of each stage. style (str): ``pytorch`` or ``caffe``. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. Default: ``pytorch``. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. Default: -1. conv_cfg (dict): Config for norm layers. Default: dict(type='Conv'). norm_cfg (dict): Config for norm layers. required keys are `type` and `requires_grad`. Default: dict(type='BN2d', requires_grad=True). act_cfg (dict): Config for activate layers. Default: dict(type='ReLU', inplace=True). norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze running stats (mean and var). Default: False. partial_bn (bool): Whether to use partial bn. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ arch_settings = { 18: (BasicBlock, (2, 2, 2, 2)), 34: (BasicBlock, (3, 4, 6, 3)), 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)) } def __init__(self, depth, pretrained=None, torchvision_pretrain=True, in_channels=3, num_stages=4, out_indices=(3, ), strides=(1, 2, 2, 2), dilations=(1, 1, 1, 1), style='pytorch', frozen_stages=-1, conv_cfg=dict(type='Conv'), norm_cfg=dict(type='BN2d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), norm_eval=False, partial_bn=False, with_cp=False): super().__init__() if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') self.depth = depth self.in_channels = in_channels self.pretrained = pretrained self.torchvision_pretrain = torchvision_pretrain self.num_stages = num_stages assert 1 <= num_stages <= 4 self.out_indices = out_indices assert max(out_indices) < num_stages self.strides = strides self.dilations = dilations assert len(strides) == len(dilations) == num_stages self.style = style self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval self.partial_bn = partial_bn self.with_cp = with_cp self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] self.inplanes = 64 self._make_stem_layer() self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): stride = strides[i] dilation = dilations[i] planes = 64 * 2**i res_layer = make_res_layer(self.block, self.inplanes, planes, num_blocks, stride=stride, dilation=dilation, style=self.style, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_cp=with_cp) self.inplanes = planes * self.block.expansion layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self.feat_dim = self.block.expansion * 64 * 2**( len(self.stage_blocks) - 1) def _make_stem_layer(self): """Construct the stem layers consists of a conv+norm+act module and a pooling layer.""" self.conv1 = ConvModule(self.in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @staticmethod def _load_conv_params(conv, state_dict_tv, module_name_tv, loaded_param_names): """Load the conv parameters of resnet from torchvision. Args: conv (nn.Module): The destination conv module. state_dict_tv (OrderedDict): The state dict of pretrained torchvision model. module_name_tv (str): The name of corresponding conv module in the torchvision model. loaded_param_names (list[str]): List of parameters that have been loaded. """ weight_tv_name = module_name_tv + '.weight' if conv.weight.data.shape == state_dict_tv[weight_tv_name].shape: conv.weight.data.copy_(state_dict_tv[weight_tv_name]) loaded_param_names.append(weight_tv_name) if getattr(conv, 'bias') is not None: bias_tv_name = module_name_tv + '.bias' if conv.bias.data.shape == state_dict_tv[bias_tv_name].shape: conv.bias.data.copy_(state_dict_tv[bias_tv_name]) loaded_param_names.append(bias_tv_name) @staticmethod def _load_bn_params(bn, state_dict_tv, module_name_tv, loaded_param_names): """Load the bn parameters of resnet from torchvision. Args: bn (nn.Module): The destination bn module. state_dict_tv (OrderedDict): The state dict of pretrained torchvision model. module_name_tv (str): The name of corresponding bn module in the torchvision model. loaded_param_names (list[str]): List of parameters that have been loaded. """ for param_name, param in bn.named_parameters(): param_tv_name = f'{module_name_tv}.{param_name}' param_tv = state_dict_tv[param_tv_name] if param.data.shape == param_tv.shape: param.data.copy_(param_tv) loaded_param_names.append(param_tv_name) for param_name, param in bn.named_buffers(): param_tv_name = f'{module_name_tv}.{param_name}' # some buffers like num_batches_tracked may not exist if param_tv_name in state_dict_tv: param_tv = state_dict_tv[param_tv_name] if param.data.shape == param_tv.shape: param.data.copy_(param_tv) loaded_param_names.append(param_tv_name) def _load_torchvision_checkpoint(self, logger=None): """Initiate the parameters from torchvision pretrained checkpoint.""" state_dict_torchvision = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict_torchvision: state_dict_torchvision = state_dict_torchvision['state_dict'] loaded_param_names = [] for name, module in self.named_modules(): if isinstance(module, ConvModule): # we use a ConvModule to wrap conv+bn+relu layers, thus the # name mapping is needed if 'downsample' in name: # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0 original_conv_name = name + '.0' # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1 original_bn_name = name + '.1' else: # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n} original_conv_name = name # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n} original_bn_name = name.replace('conv', 'bn') self._load_conv_params(module.conv, state_dict_torchvision, original_conv_name, loaded_param_names) self._load_bn_params(module.bn, state_dict_torchvision, original_bn_name, loaded_param_names) # check if any parameters in the 2d checkpoint are not loaded remaining_names = set( state_dict_torchvision.keys()) - set(loaded_param_names) if remaining_names: logger.info( f'These parameters in pretrained checkpoint are not loaded' f': {remaining_names}') def init_weights(self): """Initiate the parameters either from existing checkpoint or from scratch.""" if isinstance(self.pretrained, str): logger = get_root_logger() if self.torchvision_pretrain: # torchvision's self._load_torchvision_checkpoint(logger) else: # ours load_checkpoint(self, self.pretrained, strict=False, logger=logger) elif self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) else: raise TypeError('pretrained must be a str or None') def forward(self, x): """Defines the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The feature of the input samples extracted by the backbone. """ x = self.conv1(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] return tuple(outs) def _freeze_stages(self): """Prevent all the parameters from being optimized before ``self.frozen_stages``.""" if self.frozen_stages >= 0: self.conv1.bn.eval() for m in self.conv1.modules(): for param in m.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False def _partial_bn(self): logger = get_root_logger() logger.info('Freezing BatchNorm2D except the first one.') count_bn = 0 for m in self.modules(): if isinstance(m, nn.BatchNorm2d): count_bn += 1 if count_bn >= 2: m.eval() # shutdown update in frozen mode m.weight.requires_grad = False m.bias.requires_grad = False def train(self, mode=True): """Set the optimization status when training.""" super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval() if mode and self.partial_bn: self._partial_bn()
class KpHead(nn.Module): def __init__(self, lateral_connect, in_channels, out_channels, start_level, fusion_level, end_level, csp_center_loss=None, regr_h_loss=None, regr_offset_loss=None, conv_cfg=None, norm_cfg=None, act_cfg=None): super(KpHead, self).__init__() self.lateral_connect = lateral_connect self.num_ins = len(in_channels) self.start_level = start_level self.fusion_level = fusion_level self.end_level = end_level self.concat_level = self.end_level - self.fusion_level assert self.fusion_level >= self.start_level if self.lateral_connect: self.la_convs = nn.ModuleList() self.deconvs = nn.ModuleList() self.l2_norms = nn.ModuleList() for i in range(self.concat_level): if self.lateral_connect: l_conv = ConvModule( in_channels[i], out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.la_convs.append(l_conv) l2_norm = L2Normalization(out_channels, scale=10.0) self.l2_norms.append(l2_norm) for i in range(self.fusion_level + 1, self.end_level + 1): if i == self.fusion_level + 1: stride, padding = 2, 1 else: stride, padding = 4, 0 d_conv = DeconvModule(in_channels[i], out_channels, kernel_size=4, stride=stride, padding=padding) self.deconvs.append(d_conv) self.feat = ConvModule(out_channels*self.concat_level, out_channels, 3, stride=1, padding=1, norm_cfg=dict(type='BN', requires_grad=True), inplace=False) self.cls_conv = nn.Conv2d(out_channels, 1, 1) self.reg_conv = nn.Conv2d(out_channels, 1, 1) self.offset_conv = nn.Conv2d(out_channels, 2, 1) self.init_weights() self.csp_center_loss = CSPCenterLoss(**csp_center_loss) self.regr_loss = RegrHLoss(**regr_h_loss) self.offset_loss = RegrOffsetLoss(**regr_offset_loss) def init_weights(self): for m in self.feat.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight) nn.init.xavier_normal_(self.cls_conv.weight) nn.init.xavier_normal_(self.reg_conv.weight) nn.init.xavier_normal_(self.offset_conv.weight) nn.init.constant_(self.cls_conv.bias, -math.log(0.99/0.01)) nn.init.constant_(self.reg_conv.bias, 0) nn.init.constant_(self.offset_conv.bias, 0) def forward(self, inputs): assert len(inputs) == self.num_ins deconv_outs = [] if self.lateral_connect: la_features = [ la_conv(inputs[i]) for i, la_conv in enumerate(self.la_convs) ] for i in range(self.fusion_level + 1, self.end_level + 1): deconv_outs.append(self.deconvs[i](la_features[i])) else: for i in range(self.fusion_level + 1, self.end_level + 1): deconv_outs.append(self.deconvs[i-1](inputs[i])) norm_outs = [] for i in range(len(deconv_outs)): norm_outs.append(self.l2_norms[i](deconv_outs[i])) cat_out = torch.cat(norm_outs, dim=1) cat_out = self.feat(cat_out) x_class = self.cls_conv(cat_out) x_regr = self.reg_conv(cat_out) x_offset = self.offset_conv(cat_out) return [x_class, x_regr, x_offset] def loss(self, preds, seman_map, scale_map, offset_map): cls_pred, regr_pred, offset_pred = preds[0], preds[1], preds[2] loss_cls = self.csp_center_loss(cls_pred, seman_map) loss_regr = self.regr_loss(regr_pred, scale_map) loss_offset = self.offset_loss(offset_pred, offset_map) return dict(loss_cls=loss_cls, loss_regr=loss_regr, loss_offset=loss_offset)