def __init__(self, in_features: int, out_features: int, activation: nn.Module = ReLUInPlace, expansion: int = 4, **kwargs): super().__init__(in_features, out_features, activation, **kwargs) self.expansion = expansion self.expanded_features = out_features * self.expansion self.block = nn.Sequential( OrderedDict({ "bn1": nn.BatchNorm2d(in_features), "act1": activation(), "conv1": Conv2dPad(in_features, self.expanded_features, kernel_size=1, bias=False, **kwargs), "bn2": nn.BatchNorm2d(self.expanded_features), "act2": activation(), "conv2": Conv2dPad(self.expanded_features, out_features, kernel_size=3, bias=False, **kwargs), }))
def test_Conv2dPad(): x = torch.rand((1, 1, 5, 5)) block = Conv2dPad(1, 5, kernel_size=3) res = block(x) assert x.shape[-1] == res.shape[-1] assert x.shape[-2] == res.shape[-2] # no padding block = Conv2dPad(1, 5, kernel_size=3, mode = None) res = block(x) assert x.shape[-1] != res.shape[-1] assert x.shape[-2] != res.shape[-2] assert res.shape[-1] == 3 assert res.shape[-2] == 3
def __init__(self, in_features: int, out_features: int, stride: int = 2): super().__init__(OrderedDict({ 'pool': nn.AvgPool2d((2, 2)) if stride == 2 else nn.Identity(), 'conv': Conv2dPad(in_features, out_features, kernel_size=1, bias=False), 'bn': nn.BatchNorm2d(out_features) }))
def __init__(self, in_features: int, out_features: int, activation: nn.Module = ReLUInPlace, stride: int = 1, shortcut: nn.Module = ResNetShorcut, **kwargs): super().__init__(in_features, out_features, activation, stride=stride, shortcut=shortcut, **kwargs) self.block = nn.Sequential( OrderedDict( { 'bn1': nn.BatchNorm2d(in_features), 'act1': activation(), 'conv1': Conv2dPad(in_features, out_features, kernel_size=3, bias=False, stride=stride, **kwargs), 'bn2': nn.BatchNorm2d(out_features), 'act2': activation(), 'conv2': Conv2dPad(out_features, out_features, kernel_size=3, bias=False), } )) self.act = nn.Identity()
def __init__(self, in_features: int, out_features: int, stride: int = 2): super().__init__() self.conv = Conv2dPad(in_features, out_features, kernel_size=1, stride=stride, bias=False) self.bn = nn.BatchNorm2d(out_features)
def __init__(self, in_features: int, out_features: int, stride: int = 2): super().__init__( OrderedDict({ "pool": nn.AvgPool2d( (2, 2), ceil_mode=True) if stride == 2 else nn.Identity(), "conv": Conv2dPad(in_features, out_features, kernel_size=1, bias=False), "bn": nn.BatchNorm2d(out_features), }))
def __init__( self, in_features: int, out_features: int, activation: nn.Module = ReLUInPlace, stride: int = 1, shortcut: nn.Module = ResNetShorcut, **kwargs, ): super().__init__() self.should_apply_shortcut = in_features != out_features or stride != 1 self.block = nn.Sequential( OrderedDict({ "conv1": Conv2dPad( in_features, out_features, kernel_size=3, stride=stride, bias=False, **kwargs, ), "bn1": nn.BatchNorm2d(out_features), "act1": activation(), "conv2": Conv2dPad(out_features, out_features, kernel_size=3, bias=False), "bn2": nn.BatchNorm2d(out_features), })) self.shortcut = (shortcut(in_features, out_features, stride=stride) if self.should_apply_shortcut else nn.Identity()) self.act = activation()
def __init__(self, in_features: int, factor: int = 2, activation: nn.Module = ReLUInPlace): super().__init__( OrderedDict({ "bn": nn.BatchNorm2d(in_features), "act": activation(), "conv": Conv2dPad(in_features, in_features // factor, kernel_size=1, bias=False), "pool": nn.AvgPool2d(kernel_size=2, stride=2), }))
def __init__(self, in_features: int, factor: int = 2, activation: nn.Module = ReLUInPlace): super().__init__() self.block = nn.Sequential( OrderedDict({ 'bn': nn.BatchNorm2d(in_features), 'act': activation(), 'conv': Conv2dPad(in_features, in_features // factor, kernel_size=1, bias=False), 'pool': nn.AvgPool2d(kernel_size=2, stride=2) }))
def __init__(self, in_features: int, out_features: int, activation: nn.Module = ReLUInPlace, *args, **kwargs): super().__init__() self.block = nn.Sequential( OrderedDict({ "bn": nn.BatchNorm2d(in_features), "act": activation(), "conv": Conv2dPad(in_features, out_features, kernel_size=3, *args, **kwargs), }))