def __init__(self, C_in, C_out, kernel_size=3, stride=1, dilation=1, groups=1, slimmable=True, width_mult_list=[1.]): super(DwsBlock, self).__init__() # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.C_in = C_in self.C_out = C_out self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.groups = groups self.slimmable = slimmable self.width_mult_list = width_mult_list assert stride in [1, 2] if self.stride == 2: self.dilation = 1 self.ratio = (1., 1.) self.relu = nn.ReLU(inplace=True) if self.slimmable: self.conv1 = USConv2d(C_in, C_in, 3, stride, padding=dilation, dilation=dilation, groups=C_in, bias=False, width_mult_list=width_mult_list) self.bn1 = USBatchNorm2d(C_out, width_mult_list) self.conv2 = USConv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list) self.bn2 = USBatchNorm2d(C_out, width_mult_list) else: self.conv1 = Conv2d(C_in, C_in, 3, stride, padding=dilation, dilation=dilation, groups=C_in, bias=False) self.bn1 = BatchNorm2d(C_out) self.conv2 = Conv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False) self.bn2 = BatchNorm2d(C_out)
def __init__(self, C_in, C_out, kernel_size=3, stride=1, dilation=1, groups=1, slimmable=True, width_mult_list=[1.]): super(BasicResidual_downup_1x, self).__init__() # Both self.conv1 and self.downsample layers downsample the input when stride != 1 groups = 1 self.C_in = C_in self.C_out = C_out self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.groups = groups self.slimmable = slimmable self.width_mult_list = width_mult_list assert stride in [1, 2] if self.stride == 2: self.dilation = 1 self.ratio = (1., 1.) self.relu = nn.ReLU(inplace=True) if slimmable: self.conv1 = USConv2d(C_in, C_out, 3, 1, padding=dilation, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list) self.bn1 = USBatchNorm2d(C_out, width_mult_list) if self.stride==1: self.downsample = nn.Sequential( USConv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list), USBatchNorm2d(C_out, width_mult_list) ) else: self.conv1 = nn.Conv2d(C_in, C_out, 3, 1, padding=dilation, dilation=dilation, groups=groups, bias=False) # self.bn1 = nn.BatchNorm2d(C_out) self.bn1 = BatchNorm2d(C_out) if self.stride==1: self.downsample = nn.Sequential( nn.Conv2d(C_in, C_out, 1, 1, padding=0, dilation=dilation, groups=groups, bias=False), BatchNorm2d(C_out) )
def __init__(self, C_in, C_out, stride=1, slimmable=True, width_mult_list=[1.]): super(FactorizedReduce, self).__init__() assert stride in [1, 2] assert C_out % 2 == 0 self.C_in = C_in self.C_out = C_out self.stride = stride self.slimmable = slimmable self.width_mult_list = width_mult_list self.ratio = (1., 1.) if stride == 1 and slimmable: self.conv1 = USConv2d(C_in, C_out, 1, stride=1, padding=0, bias=False, width_mult_list=width_mult_list) self.bn = USBatchNorm2d(C_out, width_mult_list) self.relu = nn.ReLU(inplace=True) elif stride == 2: self.relu = nn.ReLU(inplace=True) if slimmable: self.conv1 = USConv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False, width_mult_list=width_mult_list) self.conv2 = USConv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False, width_mult_list=width_mult_list) self.bn = USBatchNorm2d(C_out, width_mult_list) else: self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.bn = BatchNorm2d(C_out)
def __init__(self, C_in, C_out, kernel_size=3, stride=1, padding=None, dilation=1, groups=1, bias=False, slimmable=False, width_mult_list=[1.]): super(Conv, self).__init__() self.C_in = C_in self.C_out = C_out self.kernel_size = kernel_size assert stride in [1, 2] self.stride = stride if padding is None: # assume h_out = h_in / s self.padding = int(np.ceil((dilation * (kernel_size - 1) + 1 - stride) / 2.)) else: self.padding = padding self.dilation = dilation assert type(groups) == int if kernel_size == 1: self.groups = 1 else: self.groups = groups self.bias = bias self.slimmable = slimmable self.width_mult_list = width_mult_list self.ratio = (1., 1.) if slimmable: self.conv = USConv2d(C_in, C_out, kernel_size, stride, padding=self.padding, dilation=dilation, groups=self.groups, bias=bias, width_mult_list=width_mult_list) else: self.conv = Conv2d(C_in, C_out, kernel_size, stride, padding=self.padding, dilation=dilation, groups=self.groups, bias=bias)
def __init__(self, in_dim, slimmable=True, width_mult_list=[1.]): super(ATT, self).__init__() self.chanel_in = in_dim self.slimmable = slimmable self.width_mult_list = width_mult_list self.ratio = (1., 1.) if self.slimmable: self.query_conv = USConv2d(in_dim , in_dim//8 , 1, padding=0, stride=1, bias=False, width_mult_list=width_mult_list) self.key_conv = USConv2d(in_dim , in_dim//8 , 1, padding=0, stride=1, bias=False, width_mult_list=width_mult_list) self.value_conv = USConv2d(in_dim , in_dim , 1, padding=0, stride=1, bias=False, width_mult_list=width_mult_list) else: self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1) self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) #
def _make_layers(self, cfg): layers = [] in_channels = 3 for order, x in enumerate(cfg): if x == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: if order == 0: # head layers += [USConv2d(in_channels, x, kernel_size=3, padding=1, us=[False, True]), USBatchNorm2d(x), nn.ReLU(inplace=True)] in_channels = x else: # body layers += [USConv2d(in_channels, x, kernel_size=3, padding=1), USBatchNorm2d(x), nn.ReLU(inplace=True)] in_channels = x layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers)
def __init__(self, C_in, C_out, kernel_size=7, stride=1, dilation=1, groups=1, slimmable=True, width_mult_list=[1.]): super(Conv7x7, self).__init__() # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.C_in = C_in self.C_out = C_out self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.groups = groups self.slimmable = slimmable self.width_mult_list = width_mult_list assert stride in [1, 2] if self.stride == 2: self.dilation = 1 self.ratio = (1., 1.) if slimmable: self.conv1 = USConv2d(C_in, C_out, 7, stride, padding=3, dilation=dilation, groups=groups, bias=False, width_mult_list=width_mult_list) else: self.conv1 = Conv2d(C_in, C_out, 7, stride, padding=3, dilation=dilation, groups=groups, bias=False)
def __init__(self, C_in, C_out, stride=1, slimmable=True, width_mult_list=[1.]): super(SkipConnect, self).__init__() assert stride in [1, 2] assert C_out % 2 == 0, 'C_out=%d'%C_out self.C_in = C_in self.C_out = C_out self.stride = stride self.slimmable = slimmable self.width_mult_list = width_mult_list self.ratio = (1., 1.) self.kernel_size = 1 self.padding = 0 if slimmable: self.conv = USConv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False, width_mult_list=width_mult_list) self.bn = USBatchNorm2d(C_out, width_mult_list) self.relu = nn.ReLU(inplace=True) # elif stride == 2 or C_in != C_out: else: self.conv = Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False) self.bn = BatchNorm2d(C_out) self.relu = nn.ReLU(inplace=True)
def __init__( self, *, dim, fmap_size, dim_out, proj_factor, downsample, slimmable=True, width_mult_list=[1.], heads = 4, dim_head = 128, rel_pos_emb = False, activation = nn.ReLU(inplace=True) ): super().__init__() # shortcut # contraction and expansion self.slimmable = slimmable self.width_mult_list = width_mult_list if slimmable: kernel_size, stride, padding = (3, 2, 1) if downsample else (1, 1, 0) self.sk = False self.shortcut = nn.Sequential( USConv2d(dim, dim_out, kernel_size, padding=padding, stride=stride, dilation=1, groups=1, bias=False, width_mult_list=width_mult_list), USBatchNorm2d(dim_out, width_mult_list), activation ) else: if dim != dim_out or downsample: self.sk = False kernel_size, stride, padding = (3, 2, 1) if downsample else (1, 1, 0) self.shortcut = nn.Sequential( nn.Conv2d(dim, dim_out, kernel_size, stride = stride, padding = padding, bias = False), BatchNorm2d(dim_out), activation ) else: self.sk = True self.shortcut = nn.Identity() self.mix_bn1 = nn.ModuleList([]) self.mix_bn2 = nn.ModuleList([]) self.mix_bn3 = nn.ModuleList([]) # attn_dim_in = dim_out // proj_factor attn_dim_in = dim_out # attn_dim_out = heads * dim_head attn_dim_out = attn_dim_in if self.slimmable: self.mix_bn1.append(USBatchNorm2d(dim_out, width_mult_list)) self.mix_bn2.append(USBatchNorm2d(dim_out, width_mult_list)) self.mix_bn3.append(USBatchNorm2d(dim_out, width_mult_list)) nn.init.zeros_(self.mix_bn3[0].weight) else: self.mix_bn1.append(BatchNorm2d(dim_out)) self.mix_bn2.append(BatchNorm2d(dim_out)) self.mix_bn3.append(BatchNorm2d(dim_out)) nn.init.zeros_(self.mix_bn3[0].weight) if self.slimmable: self.net1 = USConv2d(dim, attn_dim_in, 1, padding=0, stride=1, dilation=1, groups=1, bias=False, width_mult_list=width_mult_list) self.net2 = nn.Sequential( activation, ATT(attn_dim_in, slimmable=True, width_mult_list=width_mult_list), nn.AvgPool2d((2, 2)) if downsample else nn.Identity() ) self.net3 = nn.Sequential( activation, USConv2d(attn_dim_out, dim_out, 1, padding=0, stride=1, dilation=1, groups=1, bias=False, width_mult_list=width_mult_list), ) else: self.net1 = nn.Conv2d(dim, attn_dim_in, 1, bias = False) self.net2 = nn.Sequential( activation, ATT(attn_dim_in, slimmable=False), nn.AvgPool2d((2, 2)) if downsample else nn.Identity() ) self.net3 = nn.Sequential( activation, nn.Conv2d(attn_dim_out, dim_out, 1, bias = False), ) # init last batch norm gamma to zero # nn.init.zeros_(self.net[-1].weight) # final activation self.activation = activation