示例#1
0
    def __init__(self, C_in, C_out, kernel_size=3, stride=1, dilation=1, groups=1, slimmable=True, width_mult_list=[1.]):
        super(DwsBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.C_in = C_in
        self.C_out = C_out
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.slimmable = slimmable
        self.width_mult_list = width_mult_list
        assert stride in [1, 2]
        if self.stride == 2: self.dilation = 1
        self.ratio = (1., 1.)

        self.relu = nn.ReLU(inplace=True)

        if self.slimmable:
            self.conv1 = USConv2d(C_in, C_in, 3, stride, padding=dilation, dilation=dilation, groups=C_in, bias=False, width_mult_list=width_mult_list)
            self.bn1 = USBatchNorm2d(C_out, width_mult_list)

            self.conv2 = USConv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list)
            self.bn2 = USBatchNorm2d(C_out, width_mult_list)
       
        else:
            self.conv1 = Conv2d(C_in, C_in, 3, stride, padding=dilation, dilation=dilation, groups=C_in, bias=False)
            self.bn1 = BatchNorm2d(C_out)

            self.conv2 = Conv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False)
            self.bn2 = BatchNorm2d(C_out)
示例#2
0
    def __init__(self, C_in, C_out, kernel_size=3, stride=1, dilation=1, groups=1, slimmable=True, width_mult_list=[1.]):
        super(BasicResidual_downup_1x, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        groups = 1
        self.C_in = C_in
        self.C_out = C_out
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.slimmable = slimmable
        self.width_mult_list = width_mult_list
        assert stride in [1, 2]
        if self.stride == 2: self.dilation = 1
        self.ratio = (1., 1.)

        self.relu = nn.ReLU(inplace=True)
        if slimmable:
            self.conv1 = USConv2d(C_in, C_out, 3, 1, padding=dilation, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list)
            self.bn1 = USBatchNorm2d(C_out, width_mult_list)
            if self.stride==1:
                self.downsample = nn.Sequential(
                    USConv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list),
                    USBatchNorm2d(C_out, width_mult_list)
                )
        else:
            self.conv1 = nn.Conv2d(C_in, C_out, 3, 1, padding=dilation, dilation=dilation, groups=groups, bias=False)
            # self.bn1 = nn.BatchNorm2d(C_out)
            self.bn1 = BatchNorm2d(C_out)
            if self.stride==1:
                self.downsample = nn.Sequential(
                    nn.Conv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False),
                    BatchNorm2d(C_out)
                )
示例#3
0
 def __init__(self,
              C_in,
              C_out,
              stride=1,
              slimmable=True,
              width_mult_list=[1.]):
     super(FactorizedReduce, self).__init__()
     assert stride in [1, 2]
     assert C_out % 2 == 0
     self.C_in = C_in
     self.C_out = C_out
     self.stride = stride
     self.slimmable = slimmable
     self.width_mult_list = width_mult_list
     self.ratio = (1., 1.)
     if stride == 1 and slimmable:
         self.conv1 = USConv2d(C_in,
                               C_out,
                               1,
                               stride=1,
                               padding=0,
                               bias=False,
                               width_mult_list=width_mult_list)
         self.bn = USBatchNorm2d(C_out, width_mult_list)
         self.relu = nn.ReLU(inplace=True)
     elif stride == 2:
         self.relu = nn.ReLU(inplace=True)
         if slimmable:
             self.conv1 = USConv2d(C_in,
                                   C_out // 2,
                                   1,
                                   stride=2,
                                   padding=0,
                                   bias=False,
                                   width_mult_list=width_mult_list)
             self.conv2 = USConv2d(C_in,
                                   C_out // 2,
                                   1,
                                   stride=2,
                                   padding=0,
                                   bias=False,
                                   width_mult_list=width_mult_list)
             self.bn = USBatchNorm2d(C_out, width_mult_list)
         else:
             self.conv1 = nn.Conv2d(C_in,
                                    C_out // 2,
                                    1,
                                    stride=2,
                                    padding=0,
                                    bias=False)
             self.conv2 = nn.Conv2d(C_in,
                                    C_out // 2,
                                    1,
                                    stride=2,
                                    padding=0,
                                    bias=False)
             self.bn = BatchNorm2d(C_out)
示例#4
0
    def __init__(self, C_in, C_out, kernel_size=3, stride=1, padding=None, dilation=1, groups=1, bias=False, slimmable=False, width_mult_list=[1.]):
        super(Conv, self).__init__()
        self.C_in = C_in
        self.C_out = C_out
        self.kernel_size = kernel_size
        assert stride in [1, 2]
        self.stride = stride
        if padding is None:
            # assume h_out = h_in / s
            self.padding = int(np.ceil((dilation * (kernel_size - 1) + 1 - stride) / 2.))
        else:
            self.padding = padding
        self.dilation = dilation
        assert type(groups) == int
        if kernel_size == 1:
            self.groups = 1
        else:
            self.groups = groups
        self.bias = bias
        self.slimmable = slimmable
        self.width_mult_list = width_mult_list
        self.ratio = (1., 1.)

        if slimmable:
            self.conv = USConv2d(C_in, C_out, kernel_size, stride, padding=self.padding, dilation=dilation, groups=self.groups, bias=bias, width_mult_list=width_mult_list)

        else:
            self.conv = Conv2d(C_in, C_out, kernel_size, stride, padding=self.padding, dilation=dilation, groups=self.groups, bias=bias)
示例#5
0
    def __init__(self, in_dim, slimmable=True, width_mult_list=[1.]):
        super(ATT, self).__init__()
        self.chanel_in = in_dim

        self.slimmable = slimmable
        self.width_mult_list = width_mult_list
        self.ratio = (1., 1.)
        
        if self.slimmable:
            self.query_conv = USConv2d(in_dim , in_dim//8 , 1, padding=0, stride=1, bias=False, width_mult_list=width_mult_list)
            self.key_conv = USConv2d(in_dim , in_dim//8 , 1, padding=0, stride=1, bias=False, width_mult_list=width_mult_list)
            self.value_conv = USConv2d(in_dim , in_dim , 1, padding=0, stride=1, bias=False, width_mult_list=width_mult_list)
        else:
            self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
            self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
            self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
示例#6
0
 def _make_layers(self, cfg):
     layers = []
     in_channels = 3
     for order, x in enumerate(cfg):
         if x == 'M':
             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
         else:
             if order == 0:
                 # head
                 layers += [USConv2d(in_channels, x, kernel_size=3, padding=1, us=[False, True]),
                            USBatchNorm2d(x),
                            nn.ReLU(inplace=True)]
                 in_channels = x
             else:
                 # body
                 layers += [USConv2d(in_channels, x, kernel_size=3, padding=1),
                            USBatchNorm2d(x),
                            nn.ReLU(inplace=True)]
                 in_channels = x
     layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
     return nn.Sequential(*layers)
示例#7
0
    def __init__(self, C_in, C_out, kernel_size=7, stride=1, dilation=1, groups=1, slimmable=True, width_mult_list=[1.]):
        super(Conv7x7, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.C_in = C_in
        self.C_out = C_out
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.slimmable = slimmable
        self.width_mult_list = width_mult_list
        assert stride in [1, 2]
        if self.stride == 2: self.dilation = 1
        self.ratio = (1., 1.)

        if slimmable:
            self.conv1 = USConv2d(C_in, C_out, 7, stride, padding=3, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list)
        else:
            self.conv1 = Conv2d(C_in, C_out, 7, stride, padding=3, dilation=dilation, groups=groups, bias=False)
示例#8
0
    def __init__(self, C_in, C_out, stride=1, slimmable=True, width_mult_list=[1.]):
        super(SkipConnect, self).__init__()
        assert stride in [1, 2]
        assert C_out % 2 == 0, 'C_out=%d'%C_out
        self.C_in = C_in
        self.C_out = C_out
        self.stride = stride
        self.slimmable = slimmable
        self.width_mult_list = width_mult_list
        self.ratio = (1., 1.)

        self.kernel_size = 1
        self.padding = 0

        if slimmable:
            self.conv = USConv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False, width_mult_list=width_mult_list)
            self.bn = USBatchNorm2d(C_out, width_mult_list)
            self.relu = nn.ReLU(inplace=True)

        # elif stride == 2 or C_in != C_out:
        else:
            self.conv = Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
            self.bn = BatchNorm2d(C_out)
            self.relu = nn.ReLU(inplace=True)
示例#9
0
    def __init__(
        self,
        *,
        dim,
        fmap_size,
        dim_out,
        proj_factor,
        downsample,
        slimmable=True,
        width_mult_list=[1.],
        heads = 4,
        dim_head = 128,
        rel_pos_emb = False,
        activation = nn.ReLU(inplace=True)
    ):
        super().__init__()

        # shortcut

        # contraction and expansion
        self.slimmable = slimmable
        self.width_mult_list = width_mult_list


        if slimmable:
            kernel_size, stride, padding = (3, 2, 1) if downsample else (1, 1, 0)
            self.sk = False
            self.shortcut = nn.Sequential(
                USConv2d(dim, dim_out, kernel_size, padding=padding, stride=stride, dilation=1, groups=1, bias=False, width_mult_list=width_mult_list),
                USBatchNorm2d(dim_out, width_mult_list),
                activation
            )
        else:
            if dim != dim_out or downsample:
                self.sk = False
                kernel_size, stride, padding = (3, 2, 1) if downsample else (1, 1, 0)

                self.shortcut = nn.Sequential(
                    nn.Conv2d(dim, dim_out, kernel_size, stride = stride, padding = padding, bias = False),
                    BatchNorm2d(dim_out),
                    activation
                )
            else:
                self.sk = True
                self.shortcut = nn.Identity()

        self.mix_bn1 = nn.ModuleList([])
        self.mix_bn2 = nn.ModuleList([])
        self.mix_bn3 = nn.ModuleList([])

        # attn_dim_in = dim_out // proj_factor
        attn_dim_in = dim_out
        # attn_dim_out = heads * dim_head
        attn_dim_out = attn_dim_in

        if self.slimmable:
            self.mix_bn1.append(USBatchNorm2d(dim_out, width_mult_list))
            self.mix_bn2.append(USBatchNorm2d(dim_out, width_mult_list))
            self.mix_bn3.append(USBatchNorm2d(dim_out, width_mult_list))
            nn.init.zeros_(self.mix_bn3[0].weight)
        else:
            self.mix_bn1.append(BatchNorm2d(dim_out))
            self.mix_bn2.append(BatchNorm2d(dim_out))
            self.mix_bn3.append(BatchNorm2d(dim_out))
            nn.init.zeros_(self.mix_bn3[0].weight)

        if self.slimmable:
            self.net1 = USConv2d(dim, attn_dim_in, 1, padding=0, stride=1, dilation=1, groups=1, bias=False, width_mult_list=width_mult_list)

            self.net2 = nn.Sequential(
                activation,
                ATT(attn_dim_in, slimmable=True, width_mult_list=width_mult_list),
                nn.AvgPool2d((2, 2)) if downsample else nn.Identity()
            )

            self.net3 = nn.Sequential(
                activation,
                USConv2d(attn_dim_out, dim_out, 1, padding=0, stride=1, dilation=1, groups=1, bias=False, width_mult_list=width_mult_list),
            )

        else:
            self.net1 = nn.Conv2d(dim, attn_dim_in, 1, bias = False)

            self.net2 = nn.Sequential(
                activation,
                ATT(attn_dim_in, slimmable=False),
                nn.AvgPool2d((2, 2)) if downsample else nn.Identity()
            )

            self.net3 = nn.Sequential(
                activation,
                nn.Conv2d(attn_dim_out, dim_out, 1, bias = False),
            )

        # init last batch norm gamma to zero

        # nn.init.zeros_(self.net[-1].weight)

        # final activation

        self.activation = activation