def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_fn=F.relu, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None, bn_args=_BN_ARGS_PT, drop_connect_rate=0.): super(InvertedResidual, self).__init__() mid_chs = int(in_chs * exp_ratio) self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate # Point-wise expansion self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) self.shuffle_type = shuffle_type if shuffle_type is not None and isinstance(exp_kernel_size, list): self.shuffle = ChannelShuffle(len(exp_kernel_size)) # Depth-wise convolution self.conv_dw = select_conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args) # Squeeze-and-excitation if self.has_se: se_base_chs = mid_chs if se_reduce_mid else in_chs self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn3 = nn.BatchNorm2d(out_chs, **bn_args)
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_fn=F.relu, noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid, bn_args=_BN_ARGS_PT, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() assert stride in [1, 2] self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate self.conv_dw = select_conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) self.bn1 = nn.BatchNorm2d(in_chs, **bn_args) # Squeeze-and-excitation if self.has_se: self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type='', act_fn=F.relu, drop_rate=0., drop_connect_rate=0., se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT, global_pool='avg', head_conv='default', weight_init='goog'): super(GenMUXNet, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.act_fn = act_fn self.num_features = num_features stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = nn.BatchNorm2d(stem_size, **bn_args) in_chs = stem_size builder = _BlockBuilder(channel_multiplier, channel_divisor, channel_min, pad_type, act_fn, se_gate_fn, se_reduce_mid, bn_args, drop_connect_rate, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs if not head_conv or head_conv == 'none': self.efficient_head = False self.conv_head = None assert in_chs == self.num_features else: self.efficient_head = head_conv == 'efficient' self.conv_head = select_conv2d(in_chs, self.num_features, 1, padding=pad_type) self.bn2 = None if self.efficient_head else nn.BatchNorm2d( self.num_features, **bn_args) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.classifier = nn.Linear( self.num_features * self.global_pool.feat_mult(), self.num_classes) for m in self.modules(): if weight_init == 'goog': _initialize_weight_goog(m) else: _initialize_weight_default(m)
def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_fn=F.relu, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None, bn_args=_BN_ARGS_PT, drop_connect_rate=0., split_ratio=0.75, shuffle_groups=2, dw_group_factor=1, scales=0): super(MuxInvertedResidual, self).__init__() assert in_chs == out_chs, "should only be used when input channels == output channels" assert stride < 2, "should NOT be used to down-sample" self.split = SplitBlock(split_ratio) in_chs = int(in_chs * split_ratio) out_chs = int(out_chs * split_ratio) mid_chs = int(in_chs * exp_ratio) self.has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.act_fn = act_fn self.drop_connect_rate = drop_connect_rate # Point-wise expansion self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) # Depth-wise/group-wise convolution self.conv_dw = select_conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, groups=mid_chs // dw_group_factor, scales=scales) self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args) # Squeeze-and-excitation if self.has_se: se_base_chs = mid_chs if se_reduce_mid else in_chs self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn3 = nn.BatchNorm2d(out_chs, **bn_args) self.shuffle = ChannelShuffle(groups=shuffle_groups)