def remove_batchnorm(self) -> 'ResBlockBottleneck': remove_batchnorm(self.branch) remove_batchnorm(self.shortcut) assert isinstance(self.branch.conv3, nn.Conv2d) constant_init(self.branch.conv3, 0) return self
def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None: super(PreactResBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.stride = stride self.pre = CondSeq() self.preact = CondSeq() self.branch = CondSeq( collections.OrderedDict([ ('bn1', constant_init(nn.BatchNorm2d(in_channels), 1)), ('relu', nn.ReLU(True)), ('conv1', kaiming( Conv3x3(in_channels, out_channels, stride=stride, bias=False))), ('bn2', constant_init(nn.BatchNorm2d(out_channels), 1)), ('relu2', nn.ReLU(True)), ('conv2', constant_init(Conv3x3(out_channels, out_channels), 0)), ])) self.shortcut = make_preact_resnet_shortcut(in_channels, out_channels, stride) if in_channels != out_channels: self.preact_skip() self.post = CondSeq()
def wide(self, divisor: int = 2) -> 'PreactResBlockBottleneck': in_ch = self.in_channels out_ch = self.out_channels stride = self.stride mid = out_ch // divisor self.branch = CondSeq( collections.OrderedDict([ ('bn1', constant_init(nn.BatchNorm2d(in_ch), 1)), ('relu', nn.ReLU(True)), ('conv1', kaiming(Conv1x1(in_ch, mid, bias=False))), ('bn2', constant_init(nn.BatchNorm2d(mid), 1)), ('relu2', nn.ReLU(True)), ('conv2', kaiming(Conv3x3(mid, mid, stride=stride, bias=False))), ('bn3', constant_init(nn.BatchNorm2d(mid), 1)), ('relu3', nn.ReLU(True)), ('conv3', constant_init(Conv1x1(mid, out_ch), 0)) ])) return self