def plugin_to_basicblock(module: nn.Module, ratio): classname = module.__class__.__name__ module_output = module if classname.find('BasicBlock') != -1: module_output = BasicBlock(module.conv1.in_channels, module.conv1.out_channels, ratio=ratio, stride=module.stride, downsample=module.downsample) # conv1 bn1 param_util.copy_conv_parameters(module.conv1, module_output.conv1) if isinstance(module.bn1, nn.modules.batchnorm._BatchNorm): param_util.copy_bn_parameters(module.bn1, module_output.bn1) elif isinstance(module.bn1, nn.GroupNorm): param_util.copy_weight_bias(module.bn1, module_output.bn1) # conv2 bn2 param_util.copy_conv_parameters(module.conv2, module_output.conv2) if isinstance(module.bn2, nn.modules.batchnorm._BatchNorm): param_util.copy_bn_parameters(module.bn2, module_output.bn2) elif isinstance(module.bn2, nn.GroupNorm): param_util.copy_weight_bias(module.bn2, module_output.bn2) del module return module_output for name, sub_module in module.named_children(): module_output.add_module(name, plugin_to_basicblock(sub_module, ratio)) del module return module_output
def convert_conv2d_with_ws(module): """ Args: module: (nn.Module): containing module Returns: The original Conv2D with the converted `Conv2D with WS` layer Example:: >>> # r16 ct c3-c5 >>> from simplecv.module import ResNetEncoder >>> m = ResNetEncoder({}) >>> m = convert_conv2d_with_ws(m) """ classname = module.__class__.__name__ module_output = module if classname.find('Conv') != -1: module_output = Conv2D(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode) param_util.copy_conv_parameters(module, module_output) for name, sub_module in module.named_children(): module_output.add_module(name, convert_conv2d_with_ws(sub_module)) del module return module_output
def plugin_to_resnet(module: nn.Module, ratio): """ Args: module: (nn.Module): containing module ratio: (float) reduction ratio Returns: The original module with the converted `context_block.Bottleneck` layer Example:: >>> # r16 ct c3-c5 >>> from simplecv.module import ResNetEncoder >>> m = ResNetEncoder({}) >>> m.resnet.layer2 = plugin_to_resnet(m.resnet.layer2, 1 / 16.) >>> m.resnet.layer3 = plugin_to_resnet(m.resnet.layer3, 1 / 16.) >>> m.resnet.layer4 = plugin_to_resnet(m.resnet.layer4, 1 / 16.) """ classname = module.__class__.__name__ module_output = module if classname.find('Bottleneck') != -1: module_output = Bottleneck(module.conv1.in_channels, module.conv1.out_channels, ratio=ratio, stride=module.stride, downsample=module.downsample) # conv1 bn1 param_util.copy_conv_parameters(module.conv1, module_output.conv1) if isinstance(module.bn1, nn.modules.batchnorm._BatchNorm): param_util.copy_bn_parameters(module.bn1, module_output.bn1) elif isinstance(module.bn1, nn.GroupNorm): param_util.copy_weight_bias(module.bn1, module_output.bn1) # conv2 bn2 param_util.copy_conv_parameters(module.conv2, module_output.conv2) if isinstance(module.bn2, nn.modules.batchnorm._BatchNorm): param_util.copy_bn_parameters(module.bn2, module_output.bn2) elif isinstance(module.bn2, nn.GroupNorm): param_util.copy_weight_bias(module.bn2, module_output.bn2) # conv3 bn3 param_util.copy_conv_parameters(module.conv3, module_output.conv3) if isinstance(module.bn3, nn.modules.batchnorm._BatchNorm): param_util.copy_bn_parameters(module.bn3, module_output.bn3) elif isinstance(module.bn3, nn.GroupNorm): param_util.copy_weight_bias(module.bn3, module_output.bn3) del module return module_output for name, sub_module in module.named_children(): module_output.add_module(name, plugin_to_resnet(sub_module, ratio)) del module return module_output