class ResNet3d(nn.Module): """ResNet 3d backbone. Args: depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. pretrained (str | None): Name of pretrained model. pretrained2d (bool): Whether to load pretrained 2D model. Default: True. in_channels (int): Channel num of input features. Default: 3. base_channels (int): Channel num of stem output features. Default: 64. out_indices (Sequence[int]): Indices of output feature. Default: (3, ). num_stages (int): Resnet stages. Default: 4. spatial_strides (Sequence[int]): Spatial strides of residual blocks of each stage. Default: ``(1, 2, 2, 2)``. temporal_strides (Sequence[int]): Temporal strides of residual blocks of each stage. Default: ``(1, 1, 1, 1)``. dilations (Sequence[int]): Dilation of each stage. Default: ``(1, 1, 1, 1)``. conv1_kernel (Sequence[int]): Kernel size of the first conv layer. Default: ``(5, 7, 7)``. conv1_stride_t (int): Temporal stride of the first conv layer. Default: 2. pool1_stride_t (int): Temporal stride of the first pooling layer. Default: 2. with_pool2 (bool): Whether to use pool2. Default: True. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. Default: 'pytorch'. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. Default: -1. inflate (Sequence[int]): Inflate Dims of each block. Default: (1, 1, 1, 1). inflate_style (str): ``3x1x1`` or ``1x1x1``. which determines the kernel sizes and padding strides for conv1 and conv2 in each block. Default: '3x1x1'. conv_cfg (dict): Config for conv layers. required keys are ``type`` Default: ``dict(type='Conv3d')``. norm_cfg (dict): Config for norm layers. required keys are ``type`` and ``requires_grad``. Default: ``dict(type='BN3d', requires_grad=True)``. act_cfg (dict): Config dict for activation layer. Default: ``dict(type='ReLU', inplace=True)``. norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze running stats (mean and var). Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. non_local (Sequence[int]): Determine whether to apply non-local module in the corresponding block of each stages. Default: (0, 0, 0, 0). non_local_cfg (dict): Config for non-local module. Default: ``dict()``. zero_init_residual (bool): Whether to use zero initialization for residual block, Default: True. kwargs (dict, optional): Key arguments for "make_res_layer". """ arch_settings = { 18: (BasicBlock3d, (2, 2, 2, 2)), 34: (BasicBlock3d, (3, 4, 6, 3)), 50: (Bottleneck3d, (3, 4, 6, 3)), 101: (Bottleneck3d, (3, 4, 23, 3)), 152: (Bottleneck3d, (3, 8, 36, 3)) } def __init__(self, depth, pretrained, pretrained2d=True, in_channels=3, num_stages=4, base_channels=64, out_indices=(3, ), spatial_strides=(1, 2, 2, 2), temporal_strides=(1, 1, 1, 1), dilations=(1, 1, 1, 1), conv1_kernel=(5, 7, 7), conv1_stride_t=2, pool1_stride_t=2, with_pool2=True, style='pytorch', frozen_stages=-1, inflate=(1, 1, 1, 1), inflate_style='3x1x1', conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), norm_eval=False, with_cp=False, non_local=(0, 0, 0, 0), non_local_cfg=dict(), zero_init_residual=True, **kwargs): super().__init__() if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') self.depth = depth self.pretrained = pretrained self.pretrained2d = pretrained2d self.in_channels = in_channels self.base_channels = base_channels self.num_stages = num_stages assert num_stages >= 1 and num_stages <= 4 self.out_indices = out_indices assert max(out_indices) < num_stages self.spatial_strides = spatial_strides self.temporal_strides = temporal_strides self.dilations = dilations assert len(spatial_strides) == len(temporal_strides) == len( dilations) == num_stages self.conv1_kernel = conv1_kernel self.conv1_stride_t = conv1_stride_t self.pool1_stride_t = pool1_stride_t self.with_pool2 = with_pool2 self.style = style self.frozen_stages = frozen_stages self.stage_inflations = _ntuple(num_stages)(inflate) self.non_local_stages = _ntuple(num_stages)(non_local) self.inflate_style = inflate_style self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.block, stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] self.inplanes = self.base_channels self.non_local_cfg = non_local_cfg self._make_stem_layer() self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): spatial_stride = spatial_strides[i] temporal_stride = temporal_strides[i] dilation = dilations[i] planes = self.base_channels * 2**i res_layer = self.make_res_layer(self.block, self.inplanes, planes, num_blocks, spatial_stride=spatial_stride, temporal_stride=temporal_stride, dilation=dilation, style=self.style, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, act_cfg=self.act_cfg, non_local=self.non_local_stages[i], non_local_cfg=self.non_local_cfg, inflate=self.stage_inflations[i], inflate_style=self.inflate_style, with_cp=with_cp, **kwargs) self.inplanes = planes * self.block.expansion layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self.feat_dim = self.block.expansion * self.base_channels * 2**( len(self.stage_blocks) - 1) def make_res_layer(self, block, inplanes, planes, blocks, spatial_stride=1, temporal_stride=1, dilation=1, style='pytorch', inflate=1, inflate_style='3x1x1', non_local=0, non_local_cfg=dict(), norm_cfg=None, act_cfg=None, conv_cfg=None, with_cp=False, **kwargs): """Build residual layer for ResNet3D. Args: block (nn.Module): Residual module to be built. inplanes (int): Number of channels for the input feature in each block. planes (int): Number of channels for the output feature in each block. blocks (int): Number of residual blocks. spatial_stride (int | Sequence[int]): Spatial strides in residual and conv layers. Default: 1. temporal_stride (int | Sequence[int]): Temporal strides in residual and conv layers. Default: 1. dilation (int): Spacing between kernel elements. Default: 1. style (str): ``pytorch`` or ``caffe``. If set to ``pytorch``, the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. Default: ``pytorch``. inflate (int | Sequence[int]): Determine whether to inflate for each block. Default: 1. inflate_style (str): ``3x1x1`` or ``1x1x1``. which determines the kernel sizes and padding strides for conv1 and conv2 in each block. Default: '3x1x1'. non_local (int | Sequence[int]): Determine whether to apply non-local module in the corresponding block of each stages. Default: 0. non_local_cfg (dict): Config for non-local module. Default: ``dict()``. conv_cfg (dict | None): Config for norm layers. Default: None. norm_cfg (dict | None): Config for norm layers. Default: None. act_cfg (dict | None): Config for activate layers. Default: None. with_cp (bool | None): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. Returns: nn.Module: A residual layer for the given config. """ inflate = inflate if not isinstance(inflate, int) else (inflate, ) * blocks non_local = non_local if not isinstance( non_local, int) else (non_local, ) * blocks assert len(inflate) == blocks and len(non_local) == blocks downsample = None if spatial_stride != 1 or inplanes != planes * block.expansion: downsample = ConvModule(inplanes, planes * block.expansion, kernel_size=1, stride=(temporal_stride, spatial_stride, spatial_stride), bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) layers = [] layers.append( block(inplanes, planes, spatial_stride=spatial_stride, temporal_stride=temporal_stride, dilation=dilation, downsample=downsample, style=style, inflate=(inflate[0] == 1), inflate_style=inflate_style, non_local=(non_local[0] == 1), non_local_cfg=non_local_cfg, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append( block(inplanes, planes, spatial_stride=1, temporal_stride=1, dilation=dilation, style=style, inflate=(inflate[i] == 1), inflate_style=inflate_style, non_local=(non_local[i] == 1), non_local_cfg=non_local_cfg, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) return nn.Sequential(*layers) def _inflate_conv_params(self, conv3d, state_dict_2d, module_name_2d, inflated_param_names): """Inflate a conv module from 2d to 3d. Args: conv3d (nn.Module): The destination conv3d module. state_dict_2d (OrderedDict): The state dict of pretrained 2d model. module_name_2d (str): The name of corresponding conv module in the 2d model. inflated_param_names (list[str]): List of parameters that have been inflated. """ weight_2d_name = module_name_2d + '.weight' conv2d_weight = state_dict_2d[weight_2d_name] kernel_t = conv3d.weight.data.shape[2] new_weight = conv2d_weight.data.unsqueeze(2).expand_as( conv3d.weight) / kernel_t conv3d.weight.data.copy_(new_weight) inflated_param_names.append(weight_2d_name) if getattr(conv3d, 'bias') is not None: bias_2d_name = module_name_2d + '.bias' conv3d.bias.data.copy_(state_dict_2d[bias_2d_name]) inflated_param_names.append(bias_2d_name) def _inflate_bn_params(self, bn3d, state_dict_2d, module_name_2d, inflated_param_names): """Inflate a norm module from 2d to 3d. Args: bn3d (nn.Module): The destination bn3d module. state_dict_2d (OrderedDict): The state dict of pretrained 2d model. module_name_2d (str): The name of corresponding bn module in the 2d model. inflated_param_names (list[str]): List of parameters that have been inflated. """ for param_name, param in bn3d.named_parameters(): param_2d_name = f'{module_name_2d}.{param_name}' param_2d = state_dict_2d[param_2d_name] param.data.copy_(param_2d) inflated_param_names.append(param_2d_name) for param_name, param in bn3d.named_buffers(): param_2d_name = f'{module_name_2d}.{param_name}' # some buffers like num_batches_tracked may not exist in old # checkpoints if param_2d_name in state_dict_2d: param_2d = state_dict_2d[param_2d_name] param.data.copy_(param_2d) inflated_param_names.append(param_2d_name) def inflate_weights(self, logger): """Inflate the resnet2d parameters to resnet3d. The differences between resnet3d and resnet2d mainly lie in an extra axis of conv kernel. To utilize the pretrained parameters in 2d model, the weight of conv2d models should be inflated to fit in the shapes of the 3d counterpart. Args: logger (logging.Logger): The logger used to print debugging infomation. """ state_dict_r2d = _load_checkpoint(self.pretrained) if 'state_dict' in state_dict_r2d: state_dict_r2d = state_dict_r2d['state_dict'] inflated_param_names = [] for name, module in self.named_modules(): if isinstance(module, ConvModule): # we use a ConvModule to wrap conv+bn+relu layers, thus the # name mapping is needed if 'downsample' in name: # layer{X}.{Y}.downsample.conv->layer{X}.{Y}.downsample.0 original_conv_name = name + '.0' # layer{X}.{Y}.downsample.bn->layer{X}.{Y}.downsample.1 original_bn_name = name + '.1' else: # layer{X}.{Y}.conv{n}.conv->layer{X}.{Y}.conv{n} original_conv_name = name # layer{X}.{Y}.conv{n}.bn->layer{X}.{Y}.bn{n} original_bn_name = name.replace('conv', 'bn') if original_conv_name + '.weight' not in state_dict_r2d: logger.warning(f'Module not exist in the state_dict_r2d' f': {original_conv_name}') else: shape_2d = state_dict_r2d[original_conv_name + '.weight'].shape shape_3d = module.conv.weight.data.shape if shape_2d != shape_3d[:2] + shape_3d[3:]: logger.warning(f'Weight shape mismatch for ' f': {original_conv_name} : ' f'3d weight shape: {shape_3d}; ' f'2d weight shape: {shape_2d}. ') else: self._inflate_conv_params(module.conv, state_dict_r2d, original_conv_name, inflated_param_names) if original_bn_name + '.weight' not in state_dict_r2d: logger.warning(f'Module not exist in the state_dict_r2d' f': {original_bn_name}') else: self._inflate_bn_params(module.bn, state_dict_r2d, original_bn_name, inflated_param_names) # check if any parameters in the 2d checkpoint are not loaded remaining_names = set( state_dict_r2d.keys()) - set(inflated_param_names) if remaining_names: logger.info(f'These parameters in the 2d checkpoint are not loaded' f': {remaining_names}') def _make_stem_layer(self): """Construct the stem layers consists of a conv+norm+act module and a pooling layer.""" self.conv1 = ConvModule(self.in_channels, self.base_channels, kernel_size=self.conv1_kernel, stride=(self.conv1_stride_t, 2, 2), padding=tuple([ (k - 1) // 2 for k in _triple(self.conv1_kernel) ]), bias=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(self.pool1_stride_t, 2, 2), padding=(0, 1, 1)) self.pool2 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1)) def _freeze_stages(self): """Prevent all the parameters from being optimized before ``self.frozen_stages``.""" if self.frozen_stages >= 0: self.conv1.eval() for param in self.conv1.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self): """Initiate the parameters either from existing checkpoint or from scratch.""" if isinstance(self.pretrained, str): logger = get_root_logger() logger.info(f'load model from: {self.pretrained}') if self.pretrained2d: # Inflate 2D model into 3D model. self.inflate_weights(logger) else: # Directly load 3D model. load_checkpoint(self, self.pretrained, strict=False, logger=logger) elif self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv3d): kaiming_init(m) elif isinstance(m, _BatchNorm): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck3d): constant_init(m.conv3.bn, 0) elif isinstance(m, BasicBlock3d): constant_init(m.conv2.bn, 0) else: raise TypeError('pretrained must be a str or None') def forward(self, x): """Defines the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The feature of the input samples extracted by the backbone. """ x = self.conv1(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i == 0 and self.with_pool2: x = self.pool2(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] else: return tuple(outs) def train(self, mode=True): """Set the optimization status when training.""" super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()
class MobileNetV2(nn.Module): """MobileNetV2 backbone. Args: pretrained (str | None): Name of pretrained model. Default: None. widen_factor (float): Width multiplier, multiply number of channels in each layer by this amount. Default: 1.0. out_indices (None or Sequence[int]): Output from which stages. Default: (7, ). frozen_stages (int): Stages to be frozen (all param fixed). Note that the last stage in ``MobileNetV2`` is ``conv2``. Default: -1, which means not freezing any parameters. conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU6'). norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ # Parameters to build layers. 4 parameters are needed to construct a # layer, from left to right: expand_ratio, channel, num_blocks, stride. arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], [6, 320, 1, 1]] def __init__(self, pretrained=None, widen_factor=1., out_indices=(7, ), frozen_stages=-1, conv_cfg=dict(type='Conv'), norm_cfg=dict(type='BN2d', requires_grad=True), act_cfg=dict(type='ReLU6', inplace=True), norm_eval=False, with_cp=False): super().__init__() self.pretrained = pretrained self.widen_factor = widen_factor self.out_indices = out_indices for index in out_indices: if index not in range(0, 8): raise ValueError('the item in out_indices must in ' f'range(0, 8). But received {index}') if frozen_stages not in range(-1, 9): raise ValueError('frozen_stages must be in range(-1, 9). ' f'But received {frozen_stages}') self.out_indices = out_indices self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.in_channels = make_divisible(32 * widen_factor, 8) self.conv1 = ConvModule( in_channels=3, out_channels=self.in_channels, kernel_size=3, stride=2, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.layers = [] for i, layer_cfg in enumerate(self.arch_settings): expand_ratio, channel, num_blocks, stride = layer_cfg out_channels = make_divisible(channel * widen_factor, 8) inverted_res_layer = self.make_layer( out_channels=out_channels, num_blocks=num_blocks, stride=stride, expand_ratio=expand_ratio) layer_name = f'layer{i + 1}' self.add_module(layer_name, inverted_res_layer) self.layers.append(layer_name) if widen_factor > 1.0: self.out_channel = int(1280 * widen_factor) else: self.out_channel = 1280 layer = ConvModule( in_channels=self.in_channels, out_channels=self.out_channel, kernel_size=1, stride=1, padding=0, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.add_module('conv2', layer) self.layers.append('conv2') def make_layer(self, out_channels, num_blocks, stride, expand_ratio): """Stack InvertedResidual blocks to build a layer for MobileNetV2. Args: out_channels (int): out_channels of block. num_blocks (int): number of blocks. stride (int): stride of the first block. Default: 1 expand_ratio (int): Expand the number of channels of the hidden layer in InvertedResidual by this ratio. Default: 6. """ layers = [] for i in range(num_blocks): if i >= 1: stride = 1 layers.append( InvertedResidual( self.in_channels, out_channels, stride, expand_ratio=expand_ratio, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, with_cp=self.with_cp)) self.in_channels = out_channels return nn.Sequential(*layers) def init_weights(self): if isinstance(self.pretrained, str): logger = get_root_logger() load_checkpoint(self, self.pretrained, strict=False, logger=logger) elif self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, (_BatchNorm, nn.GroupNorm)): constant_init(m, 1) else: raise TypeError('pretrained must be a str or None') def forward(self, x): x = self.conv1(x) outs = [] for i, layer_name in enumerate(self.layers): layer = getattr(self, layer_name) x = layer(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] return tuple(outs) def _freeze_stages(self): if self.frozen_stages >= 0: self.conv1.eval() for param in self.conv1.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): layer_name = self.layers[i - 1] layer = getattr(self, layer_name) layer.eval() for param in layer.parameters(): param.requires_grad = False def train(self, mode=True): super(MobileNetV2, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()
class X3D(nn.Module): """X3D backbone. https://arxiv.org/pdf/2004.04730.pdf. Args: gamma_w (float): Global channel width expansion factor. Default: 1. gamma_b (float): Bottleneck channel width expansion factor. Default: 1. gamma_d (float): Network depth expansion factor. Default: 1. pretrained (str | None): Name of pretrained model. Default: None. in_channels (int): Channel num of input features. Default: 3. num_stages (int): Resnet stages. Default: 4. spatial_strides (Sequence[int]): Spatial strides of residual blocks of each stage. Default: ``(1, 2, 2, 2)``. frozen_stages (int): Stages to be frozen (all param fixed). If set to -1, it means not freezing any parameters. Default: -1. se_style (str): The style of inserting SE modules into BlockX3D, 'half' denotes insert into half of the blocks, while 'all' denotes insert into all blocks. Default: 'half'. se_ratio (float | None): The reduction ratio of squeeze and excitation unit. If set as None, it means not using SE unit. Default: 1 / 16. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. conv_cfg (dict): Config for conv layers. required keys are ``type`` Default: ``dict(type='Conv3d')``. norm_cfg (dict): Config for norm layers. required keys are ``type`` and ``requires_grad``. Default: ``dict(type='BN3d', requires_grad=True)``. act_cfg (dict): Config dict for activation layer. Default: ``dict(type='ReLU', inplace=True)``. norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze running stats (mean and var). Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. zero_init_residual (bool): Whether to use zero initialization for residual block, Default: True. kwargs (dict, optional): Key arguments for "make_res_layer". """ def __init__(self, gamma_w=1.0, gamma_b=1.0, gamma_d=1.0, pretrained=None, in_channels=3, num_stages=4, spatial_strides=(2, 2, 2, 2), frozen_stages=-1, se_style='half', se_ratio=1 / 16, use_swish=True, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), norm_eval=False, with_cp=False, zero_init_residual=True, **kwargs): super().__init__() self.gamma_w = gamma_w self.gamma_b = gamma_b self.gamma_d = gamma_d self.pretrained = pretrained self.in_channels = in_channels # Hard coded, can be changed by gamma_w self.base_channels = 24 self.stage_blocks = [1, 2, 5, 3] # apply parameters gamma_w and gamma_d self.base_channels = self._round_width(self.base_channels, self.gamma_w) self.stage_blocks = [ self._round_repeats(x, self.gamma_d) for x in self.stage_blocks ] self.num_stages = num_stages assert 1 <= num_stages <= 4 self.spatial_strides = spatial_strides assert len(spatial_strides) == num_stages self.frozen_stages = frozen_stages self.se_style = se_style assert self.se_style in ['all', 'half'] self.se_ratio = se_ratio assert (self.se_ratio is None) or (self.se_ratio > 0) self.use_swish = use_swish self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.block = BlockX3D self.stage_blocks = self.stage_blocks[:num_stages] self.layer_inplanes = self.base_channels self._make_stem_layer() self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): spatial_stride = spatial_strides[i] inplanes = self.base_channels * 2**i planes = int(inplanes * self.gamma_b) res_layer = self.make_res_layer(self.block, self.layer_inplanes, inplanes, planes, num_blocks, spatial_stride=spatial_stride, se_style=self.se_style, se_ratio=self.se_ratio, use_swish=self.use_swish, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, act_cfg=self.act_cfg, with_cp=with_cp, **kwargs) self.layer_inplanes = inplanes layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self.feat_dim = self.base_channels * 2**(len(self.stage_blocks) - 1) self.conv5 = ConvModule(self.feat_dim, int(self.feat_dim * self.gamma_b), kernel_size=1, stride=1, padding=0, bias=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.feat_dim = int(self.feat_dim * self.gamma_b) @staticmethod def _round_width(width, multiplier, min_depth=8, divisor=8): """Round width of filters based on width multiplier.""" if not multiplier: return width width *= multiplier min_depth = min_depth or divisor new_filters = max(min_depth, int(width + divisor / 2) // divisor * divisor) if new_filters < 0.9 * width: new_filters += divisor return int(new_filters) @staticmethod def _round_repeats(repeats, multiplier): """Round number of layers based on depth multiplier.""" if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) # the module is parameterized with gamma_b # no temporal_stride def make_res_layer(self, block, layer_inplanes, inplanes, planes, blocks, spatial_stride=1, se_style='half', se_ratio=None, use_swish=True, norm_cfg=None, act_cfg=None, conv_cfg=None, with_cp=False, **kwargs): """Build residual layer for ResNet3D. Args: block (nn.Module): Residual module to be built. layer_inplanes (int): Number of channels for the input feature of the res layer. inplanes (int): Number of channels for the input feature in each block, which equals to base_channels * gamma_w. planes (int): Number of channels for the output feature in each block, which equals to base_channel * gamma_w * gamma_b. blocks (int): Number of residual blocks. spatial_stride (int): Spatial strides in residual and conv layers. Default: 1. se_style (str): The style of inserting SE modules into BlockX3D, 'half' denotes insert into half of the blocks, while 'all' denotes insert into all blocks. Default: 'half'. se_ratio (float | None): The reduction ratio of squeeze and excitation unit. If set as None, it means not using SE unit. Default: None. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. conv_cfg (dict | None): Config for norm layers. Default: None. norm_cfg (dict | None): Config for norm layers. Default: None. act_cfg (dict | None): Config for activate layers. Default: None. with_cp (bool | None): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. Returns: nn.Module: A residual layer for the given config. """ downsample = None if spatial_stride != 1 or layer_inplanes != inplanes: downsample = ConvModule(layer_inplanes, inplanes, kernel_size=1, stride=(1, spatial_stride, spatial_stride), padding=0, bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) use_se = [False] * blocks if self.se_style == 'all': use_se = [True] * blocks elif self.se_style == 'half': use_se = [i % 2 == 0 for i in range(blocks)] else: raise NotImplementedError layers = [] layers.append( block(layer_inplanes, planes, inplanes, spatial_stride=spatial_stride, downsample=downsample, se_ratio=se_ratio if use_se[0] else None, use_swish=use_swish, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) for i in range(1, blocks): layers.append( block(inplanes, planes, inplanes, spatial_stride=1, se_ratio=se_ratio if use_se[i] else None, use_swish=use_swish, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) return nn.Sequential(*layers) def _make_stem_layer(self): """Construct the stem layers consists of a conv+norm+act module and a pooling layer.""" self.conv1_s = ConvModule(self.in_channels, self.base_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None) self.conv1_t = ConvModule(self.base_channels, self.base_channels, kernel_size=(5, 1, 1), stride=(1, 1, 1), padding=(2, 0, 0), groups=self.base_channels, bias=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def _freeze_stages(self): """Prevent all the parameters from being optimized before ``self.frozen_stages``.""" if self.frozen_stages >= 0: self.conv1_s.eval() self.conv1_t.eval() for param in self.conv1_s.parameters(): param.requires_grad = False for param in self.conv1_t.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self): """Initiate the parameters either from existing checkpoint or from scratch.""" if isinstance(self.pretrained, str): logger = get_root_logger() logger.info(f'load model from: {self.pretrained}') load_checkpoint(self, self.pretrained, strict=False, logger=logger) elif self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv3d): kaiming_init(m) elif isinstance(m, _BatchNorm): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, BlockX3D): constant_init(m.conv3.bn, 0) else: raise TypeError('pretrained must be a str or None') def forward(self, x): """Defines the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The feature of the input samples extracted by the backbone. """ x = self.conv1_s(x) x = self.conv1_t(x) for layer_name in self.res_layers: res_layer = getattr(self, layer_name) x = res_layer(x) x = self.conv5(x) return x def train(self, mode=True): """Set the optimization status when training.""" super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()
class EfficientNet(nn.Module): # maps scale to coefficients of (width, depth, dropout) param_dict = { 0: (1.0, 1.0, 0.2), 1: (1.0, 1.1, 0.2), 2: (1.1, 1.2, 0.3), 3: (1.2, 1.4, 0.3), 4: (1.4, 1.8, 0.4), 5: (1.6, 2.2, 0.4), 6: (1.8, 2.6, 0.5), 7: (2.0, 3.1, 0.5), } def __init__(self, in_channels: int = 3, n_classes: int = 13, scale: int = 1, se_rate: float = 0.25, drop_connect_rate: float = 0.2, frozen_stages: int = -1, norm_cfg=dict(type='SyncBN', momentum=0.01, eps=1e-3), act_cfg=dict(type='Swish'), pretrained=None): super(EfficientNet, self).__init__() assert scale in range(0, 8) self.frozen_stages = frozen_stages self.in_channels = in_channels width_coefficient, depth_coefficient, dropout_rate = self.param_dict[ scale] self.width_coefficient = width_coefficient self.depth_coefficient = depth_coefficient self.divisor = 8 self.n_classes = n_classes self.pretrained = pretrained list_channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280] list_channels = [self._setup_channels(c) for c in list_channels] self.list_channels = list_channels list_num_repeats = [1, 2, 2, 3, 3, 4, 1] list_num_repeats = [self._setup_repeats(r) for r in list_num_repeats] self.list_num_repeats = list_num_repeats expand_rates = [1, 6, 6, 6, 6, 6, 6] strides = [1, 2, 2, 2, 1, 2, 1] kernel_sizes = [3, 3, 5, 3, 5, 5, 3] # Define stem self.stem = ConvModule(in_channels=self.in_channels, out_channels=self.list_channels[0], kernel_size=3, stride=2, padding=1, bias=False, norm_cfg=norm_cfg, act_cfg=act_cfg) # Define blocks blocks = [] counter = 0 num_blocks = sum(self.list_num_repeats) for idx in range(7): num_channels = self.list_channels[idx] next_num_channels = self.list_channels[idx + 1] num_repeats = self.list_num_repeats[idx] expand_rate = expand_rates[idx] kernel_size = kernel_sizes[idx] stride = strides[idx] # drop rate increases as depth increases drop_rate = drop_connect_rate * counter / num_blocks name = 'MBConv{}_{}'.format(expand_rate, counter) blocks.append((name, MBConv(num_channels, next_num_channels, kernel_size=kernel_size, stride=stride, expand_rate=expand_rate, se_rate=se_rate, drop_connect_rate=drop_rate, norm_cfg=norm_cfg, act_cfg=act_cfg))) counter += 1 for i in range(1, num_repeats): name = 'MBConv{}_{}'.format(expand_rate, counter) drop_rate = drop_connect_rate * counter / num_blocks blocks.append((name, MBConv(next_num_channels, next_num_channels, kernel_size=kernel_size, stride=1, expand_rate=expand_rate, se_rate=se_rate, drop_connect_rate=drop_rate, norm_cfg=norm_cfg, act_cfg=act_cfg))) counter += 1 self.blocks = nn.Sequential(OrderedDict(blocks)) # Define head self.head = nn.Sequential( ConvModule(self.list_channels[-2], self.list_channels[-1], kernel_size=1, bias=False, norm_cfg=norm_cfg, act_cfg=act_cfg), nn.AdaptiveAvgPool2d(1), Flatten(), nn.Dropout(p=dropout_rate), nn.Linear(self.list_channels[-1], self.n_classes)) if isinstance(pretrained, str): load_checkpoint(self, pretrained) elif pretrained is None: self.apply(init_weights) else: raise TypeError('pretrained must be a str or None') self.freeze_stages() def _setup_repeats(self, num_repeats): return int(math.ceil(self.depth_coefficient * num_repeats)) def _setup_channels(self, num_channels): """To ensure the new number of channels can be divided by divisor, for example 8.""" num_channels *= self.width_coefficient new_num_channels = math.floor(num_channels / self.divisor + 0.5) * self.divisor # To ensure the new number of channels are greater or equal to divisor new_num_channels = max(self.divisor, new_num_channels) # To avoid number of channels shrink too much if new_num_channels < 0.9 * num_channels: new_num_channels += self.divisor return new_num_channels def freeze_stages(self): if self.frozen_stages >= 0: self.stem.eval() for param in self.stem.parameters(): param.requires_grad = False idx = 0 for i in range(1, self.frozen_stages + 1): for _ in range(self.list_num_repeats[i - 1]): self.blocks[idx].eval() for param in self.blocks[idx].parameters(): param.requires_grad = False idx += 1 def forward(self, x): f = self.stem(x) f = self.blocks(f) y = self.head(f) return y