Ejemplo n.º 1
0
 def __init__(self, args, conv3x3=common.default_conv):
     super(CFQKBN, self).__init__()
     n_classes = 10 if args.data_train == 'CIFAR10' else 100
     self.conv1 = common.BasicBlock(3,
                                    32,
                                    5,
                                    stride=1,
                                    bias=False,
                                    conv3x3=conv3x3,
                                    args=args)
     self.conv2 = common.BasicBlock(32,
                                    32,
                                    5,
                                    stride=1,
                                    bias=False,
                                    conv3x3=conv3x3,
                                    args=args)
     self.conv3 = common.BasicBlock(32,
                                    64,
                                    5,
                                    stride=1,
                                    bias=False,
                                    conv3x3=conv3x3,
                                    args=args)
     self.fc1 = nn.Linear(in_features=3 * 3 * 64, out_features=64)
     self.relu = nn.ReLU(inplace=True)
     self.fc2 = nn.Linear(in_features=64, out_features=n_classes)
 def __init__(self,
              channel=128,
              reduction=2,
              ksize=3,
              scale=3,
              stride=1,
              softmax_scale=10,
              average=True,
              conv=common.default_conv):
     super(NonLocalAttention, self).__init__()
     self.conv_match1 = common.BasicBlock(conv,
                                          channel,
                                          channel // reduction,
                                          1,
                                          bn=False,
                                          act=nn.PReLU())
     self.conv_match2 = common.BasicBlock(conv,
                                          channel,
                                          channel // reduction,
                                          1,
                                          bn=False,
                                          act=nn.PReLU())
     self.conv_assembly = common.BasicBlock(conv,
                                            channel,
                                            channel,
                                            1,
                                            bn=False,
                                            act=nn.PReLU())
Ejemplo n.º 3
0
    def __init__(self, args, gan_type='GAN'):
        super(Discriminator, self).__init__()

        in_channels = 3
        out_channels = 64
        depth = 3
        #bn = not gan_type == 'WGAN_GP'
        bn = True
        act = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        m_features = [
            common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
        ]
        for i in range(depth):
            in_channels = out_channels
            if i % 2 == 1:
                stride = 1
                out_channels *= 2
            else:
                stride = 2
            m_features.append(common.BasicBlock(
                in_channels, out_channels, 3, stride=stride, bn=bn, act=act
            ))

        self.features = nn.Sequential(*m_features)

        patch_size = args.patch_size // (2**((depth + 1) // 2))
        m_classifier = [
            nn.Linear(out_channels * patch_size**2, 1024),
            act,
            nn.Linear(1024, 1)
        ]
        self.classifier = nn.Sequential(*m_classifier)
Ejemplo n.º 4
0
    def __init__(self, conv=common.default_conv, **kwargs,):
        super(MSMNetModel, self).__init__()
        # -------------- Define multi-scale model architecture here ------------
        # scale = 3
        input_channles = kwargs['input_channels']
        num_resblocks = kwargs['num_ms_resblocks']
        intermediate_channels = kwargs['intermediate_channels']
        kernel_size = kwargs['default_kernel_size']
        activation = nn.ReLU(True)
        rgb_range = kwargs['rgb_range']
        self.sub_mean = common.MeanShift(rgb_range)
        self.add_mean = common.MeanShift(rgb_range, sign=1)

        # head to read scaled image
        _head = common.BasicBlock(conv, input_channles, intermediate_channels, kernel_size)

        # pre-process 2*Resblock each
        self.pre_process = nn.ModuleList([
            nn.Sequential(
                common.PreResBlock(conv, 2*intermediate_channels, intermediate_channels, 5, bn=True, act=activation),
                common.ResBlock(conv, intermediate_channels, 5, bn=True, act=activation)
            ),
            nn.Sequential(
                common.PreResBlock(conv, 2*intermediate_channels, intermediate_channels, 5, bn=True, act=activation),
                common.ResBlock(conv, intermediate_channels, 5, bn=True, act=activation)
            ),
            nn.Sequential(
                common.ResBlock(conv, intermediate_channels, 5, bn=True, act=activation),
                common.ResBlock(conv, intermediate_channels, 5, bn=True, act=activation)
            )
        ])

        # body 16*Resblocks each
        _body = [
            common.ResBlock(
                conv, intermediate_channels, kernel_size, bn=True, act=activation
            ) for _ in range(num_resblocks)
        ]
        _body.append(conv(intermediate_channels, intermediate_channels, kernel_size))

        # upsample to enlarge the scale
        self.upsample = common.Upsampler(conv, intermediate_channels, bn=False, act=False)

        _output = nn.ModuleList([
            # conv(3, 1, 3),
            common.BasicBlock(conv, intermediate_channels, intermediate_channels, kernel_size, act=nn.Sigmoid())
        ])


        self.head = nn.Sequential(*_head)
        self.body = nn.Sequential(*_body)
        self.output = nn.Sequential(*_output)
Ejemplo n.º 5
0
 def __init__(self, conv=common.default_conv, **kwargs):
     super(FusionModel, self).__init__()
     num_resblocks = kwargs['num_fusion_resblocks']
     intermediate_channels = kwargs['intermediate_channels']
     res_block = [
             common.ResBlock(conv, intermediate_channels, 3, bn=True) for _ in range(num_resblocks)
             ]
     fusion_model = nn.ModuleList([
             common.BasicBlock(conv, 2*intermediate_channels, intermediate_channels, 3),
             nn.Sequential(*res_block),
             common.BasicBlock(conv, intermediate_channels, 1, 3, act=nn.Sigmoid())
             ])
     self.fusion_model = nn.Sequential(*fusion_model)
Ejemplo n.º 6
0
    def __init__(self, args, conv=common.default_conv):
        super(MSSR, self).__init__()

        #n_convblock = args.n_convblocks
        n_feats = args.n_feats
        self.depth = args.depth
        kernel_size = 3
        scale = args.scale[0]

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)

        # define head module
        m_head = [
            common.BasicBlock(conv,
                              args.n_colors,
                              n_feats,
                              kernel_size,
                              stride=1,
                              bias=True,
                              bn=False,
                              act=nn.PReLU()),
            common.BasicBlock(conv,
                              n_feats,
                              n_feats,
                              kernel_size,
                              stride=1,
                              bias=True,
                              bn=False,
                              act=nn.PReLU())
        ]

        # define multiple reconstruction module

        self.body = RecurrentProjection(n_feats)

        # define tail module
        m_tail = [
            nn.Conv2d(n_feats * self.depth,
                      args.n_colors,
                      kernel_size,
                      padding=(kernel_size // 2))
        ]

        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)

        self.head = nn.Sequential(*m_head)
        self.tail = nn.Sequential(*m_tail)
Ejemplo n.º 7
0
 def __init__(self, in_channel, kernel_size=3, conv=common.default_conv):
     super(RecurrentProjection, self).__init__()
     self.multi_source_projection_1 = MultisourceProjection(
         in_channel, kernel_size=kernel_size, conv=conv)
     self.multi_source_projection_2 = MultisourceProjection(
         in_channel, kernel_size=kernel_size, conv=conv)
     self.down_sample_1 = nn.Sequential(*[
         nn.Conv2d(in_channel, in_channel, 6, stride=2, padding=2),
         nn.PReLU()
     ])
     #self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
     self.down_sample_3 = nn.Sequential(*[
         nn.Conv2d(in_channel, in_channel, 8, stride=4, padding=2),
         nn.PReLU()
     ])
     self.down_sample_4 = nn.Sequential(*[
         nn.Conv2d(in_channel, in_channel, 8, stride=4, padding=2),
         nn.PReLU()
     ])
     self.error_encode_1 = nn.Sequential(*[
         nn.ConvTranspose2d(in_channel, in_channel, 6, stride=2, padding=2),
         nn.PReLU()
     ])
     self.error_encode_2 = nn.Sequential(*[
         nn.ConvTranspose2d(in_channel, in_channel, 8, stride=4, padding=2),
         nn.PReLU()
     ])
     self.post_conv = common.BasicBlock(conv,
                                        in_channel,
                                        in_channel,
                                        kernel_size,
                                        stride=1,
                                        bias=True,
                                        act=nn.PReLU())
Ejemplo n.º 8
0
    def __init__(self, args, conv3x3=conv_factor, conv1x1=None):
        super(VGG, self).__init__()
        args = args[0]
        # we use batch noramlization for VGG
        norm = common.default_norm
        bias = not args.no_bias

        configs = {
            'A':
            [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512,
                512, 'M'
            ],
            '16': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512,
                'M', 512, 512, 512, 'M'
            ],
            '19': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512,
                512, 512, 'M', 512, 512, 512, 512, 'M'
            ],
            'ef': [
                32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256,
                'M', 256, 256, 256, 'M'
            ]
        }

        body_list = []
        in_channels = args.n_colors
        for i, v in enumerate(configs[args.vgg_type]):
            if v == 'M':
                body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                t = 3 if args.vgg_decom_type == 'all' else 8
                if i <= t:
                    body_list.append(
                        common.BasicBlock(in_channels,
                                          v,
                                          args.kernel_size,
                                          bias=bias,
                                          conv3x3=common.default_conv,
                                          norm=norm))
                else:
                    body_list.append(
                        BasicBlock(in_channels,
                                   v,
                                   args.kernel_size,
                                   bias=bias,
                                   conv=conv3x3,
                                   norm=norm))
                in_channels = v

        # for CIFAR10 and CIFAR100 only
        assert (args.data_train.find('CIFAR') >= 0)
        n_classes = int(args.data_train[5:])

        self.features = nn.Sequential(*body_list)
        self.classifier = nn.Linear(in_channels, n_classes)
Ejemplo n.º 9
0
    def __init__(self, conv=common.default_conv, **kwargs):
        super(AttentionModel, self).__init__()
        kernel_size = kwargs['default_kernel_size']
        intermediate_channels = kwargs['intermediate_channels']
        attention_input_channels = kwargs['attention_input_channels']
        dense_growth_rate = kwargs['dense_growth_rate']
        # -------------- Define attention model here ----------------------
        self.attention_conv_head = nn.Sequential(
            common.BasicBlock(conv, attention_input_channels, intermediate_channels, kernel_size),
            common.BasicBlock(conv, intermediate_channels, intermediate_channels, kernel_size),
        )
        self.attention_maxpool = nn.MaxPool2d(kernel_size=2, return_indices=True)
        self.down_dense_1 = common.DenseBlock(3, intermediate_channels, bn_size=4, 
                                        growth_rate=dense_growth_rate, drop_rate=0)
        down_dense_1_output_channels = intermediate_channels + 3*dense_growth_rate
        self.down_dense_1_trans = common.Transition(down_dense_1_output_channels, 
                                                down_dense_1_output_channels // 2)
        self.down_dense_2 = common.DenseBlock(3, down_dense_1_output_channels // 2, 
                                        bn_size=4, growth_rate=dense_growth_rate, drop_rate=0)
        down_dense_2_output_channels = down_dense_1_output_channels // 2 + 3*dense_growth_rate
        self.down_dense_2_trans = common.Transition(down_dense_2_output_channels, 
                                            down_dense_2_output_channels // 2)
        self.bottom_dense = common.DenseBlock(3, down_dense_2_output_channels // 2, bn_size=4, 
                                        growth_rate=dense_growth_rate, drop_rate=0)
        bottom_dense_output_channels = down_dense_2_output_channels // 2 + 3*dense_growth_rate
        self.bottom_dense_upsample = common.Upsampler(conv, bottom_dense_output_channels)
        self.up_dense_2 = common.DenseBlock(3,
                                            bottom_dense_output_channels+down_dense_2_output_channels,
                                            bn_size=4, growth_rate=dense_growth_rate, drop_rate=0)
        
        up_dense_2_output_channels = bottom_dense_output_channels \
                        + down_dense_2_output_channels \
                            + 3*dense_growth_rate
        self.up_dense_2_upsample = common.Upsampler(conv, up_dense_2_output_channels)

        self.up_dense_1 = common.DenseBlock(3, up_dense_2_output_channels+down_dense_1_output_channels,
                                            bn_size=4, growth_rate=dense_growth_rate, drop_rate=0)
        up_dense_1_output_channels = up_dense_2_output_channels \
                        + down_dense_1_output_channels \
                            + 3*dense_growth_rate
        self.up_dense_1_upsample = common.Upsampler(conv, up_dense_1_output_channels)
        self.attention_conv_tail = nn.Sequential(
            common.BasicBlock(conv, up_dense_1_output_channels+intermediate_channels,
                            intermediate_channels, kernel_size),
            common.BasicBlock(conv, intermediate_channels, intermediate_channels, kernel_size, act=nn.Sigmoid())
        )
Ejemplo n.º 10
0
 def basic_block(in_channels, out_channels, act):
     return common.BasicBlock(conv,
                              in_channels,
                              out_channels,
                              kernel_size,
                              bias=True,
                              bn=False,
                              act=act)
Ejemplo n.º 11
0
 def __init__(self,
              level=5,
              res_scale=1,
              channel=64,
              reduction=2,
              ksize=3,
              stride=1,
              softmax_scale=10,
              average=True,
              conv=common.default_conv):
     super(PyramidAttention, self).__init__()
     self.ksize = ksize
     self.stride = stride
     self.res_scale = res_scale
     self.softmax_scale = softmax_scale
     self.scale = [1 - i / 10 for i in range(level)]
     self.average = average
     escape_NaN = torch.FloatTensor([1e-4])
     self.register_buffer('escape_NaN', escape_NaN)
     self.conv_match_L_base = common.BasicBlock(conv,
                                                channel,
                                                channel // reduction,
                                                1,
                                                bn=False,
                                                act=nn.PReLU())
     self.conv_match = common.BasicBlock(conv,
                                         channel,
                                         channel // reduction,
                                         1,
                                         bn=False,
                                         act=nn.PReLU())
     self.conv_assembly = common.BasicBlock(conv,
                                            channel,
                                            channel,
                                            1,
                                            bn=False,
                                            act=nn.PReLU())
    def __init__(self,
                 channel=128,
                 reduction=2,
                 ksize=3,
                 scale=3,
                 stride=1,
                 softmax_scale=10,
                 average=True,
                 conv=common.default_conv):
        super(CrossScaleAttention, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.softmax_scale = softmax_scale

        self.scale = scale
        self.average = average
        escape_NaN = torch.FloatTensor([1e-4])
        self.register_buffer('escape_NaN', escape_NaN)
        self.conv_match_1 = common.BasicBlock(conv,
                                              channel,
                                              channel // reduction,
                                              1,
                                              bn=False,
                                              act=nn.PReLU())
        self.conv_match_2 = common.BasicBlock(conv,
                                              channel,
                                              channel // reduction,
                                              1,
                                              bn=False,
                                              act=nn.PReLU())
        self.conv_assembly = common.BasicBlock(conv,
                                               channel,
                                               channel,
                                               1,
                                               bn=False,
                                               act=nn.PReLU())
Ejemplo n.º 13
0
    def __init__(self,
                 args,
                 conv3x3=common.default_conv,
                 conv1x1=common.default_conv):
        super(ResNet, self).__init__()

        n_classes = int(
            args.data_train[5:]) if args.data_train.find('CIFAR') >= 0 else 200
        kernel_size = args.kernel_size
        if args.depth == 50:
            self.expansion = 4
            self.block = BottleNeck
            self.n_blocks = (args.depth - 2) // 9
        # elif args.depth <= 56:
        else:
            self.expansion = 1
            self.block = ResBlock
            self.n_blocks = (args.depth - 2) // 6
        # else:
        #     self.expansion = 4
        #     self.block = BottleNeck
        #     self.n_blocks = (args.depth - 2) // 9
        self.in_channels = 16
        self.downsample_type = args.downsample_type
        bias = not args.no_bias

        kwargs = {'conv3x3': conv3x3, 'conv1x1': conv1x1, 'args': args}
        stride = 1 if args.data_train.find('CIFAR') >= 0 else 2
        m = [
            common.BasicBlock(args.n_colors,
                              16,
                              kernel_size=kernel_size,
                              stride=stride,
                              bias=bias,
                              conv3x3=conv3x3,
                              args=args),
            self.make_layer(self.n_blocks, 16, kernel_size, **kwargs),
            self.make_layer(self.n_blocks, 32, kernel_size, stride=2,
                            **kwargs),
            self.make_layer(self.n_blocks, 64, kernel_size, stride=2,
                            **kwargs),
            nn.AvgPool2d(8)
        ]
        fc = nn.Linear(64 * self.expansion, n_classes)

        self.features = nn.Sequential(*m)
        self.classifier = fc
Ejemplo n.º 14
0
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(VGG, self).__init__()
        # we use batch noramlization for VGG
        norm = None
        self.norm = norm
        bias = not args.no_bias

        configs = {
            'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
            '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512,
                   'M'],
            'ef': [32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M']
        }

        body_list = []
        in_channels = args.n_colors
        if args.data_train.find('CIFAR') >= 0 or args.data_train.find('Tiny') >= 0:

            for i, v in enumerate(configs[args.vgg_type]):
                if v == 'M':
                    body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
                else:
                    stride = 2 if i == 0 and args.data_train.find('Tiny') >= 0 else 1
                    body_list.append(common.BasicBlock(in_channels, v, args.kernel_size, stride=stride,
                                                       bias=bias, conv3x3=conv3x3))
                    in_channels = v
        else:
            for i, v in enumerate(configs[args.vgg_type]):
                if v == 'M':
                    body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
                else:
                    conv2d = conv3x3(in_channels, v, kernel_size=3)
                    if norm is not None:
                        body_list += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                    else:
                        body_list += [conv2d, nn.ReLU(inplace=True)]
                    in_channels = v

        # assert(args.data_train.find('CIFAR') >= 0)
        self.body_list = body_list
        self.features = nn.Sequential(*body_list)
        if args.data_train.find('CIFAR') >= 0:
            n_classes = int(args.data_train[5:])
            if args.template.find('linear3') >= 0:
                self.classifier = nn.Sequential(nn.Linear(in_channels, in_channels),
                                                nn.Linear(in_channels, in_channels),
                                                nn.Linear(in_channels, n_classes))
            else:
                self.classifier = nn.Linear(in_channels, n_classes)
        if args.data_train.find('Tiny') >= 0:
            n_classes = 200
            self.classifier = nn.Sequential(nn.Linear(in_channels, in_channels), nn.Linear(in_channels, in_channels),
                                            nn.Linear(in_channels, n_classes))
        elif args.data_train == 'ImageNet':
            n_classes = 1000
            self.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, n_classes),
            )
            # self.classifier = nn.Sequential(
            #     nn.Linear(512 * 7 * 7, n_classes)
            # )
        # print(conv3x3, conv3x3 == common.default_conv or conv3x3 == nn.Conv2d)

        # if conv3x3 == common.default_conv or conv3x3 == nn.Conv2d:
        #     self.load(args, strict=True)
        self.total_time = [0] * len(body_list)
        self.top1_err_list = []
        self.sum_list = []
        self.layer_num = -1  # Cluster Prunning实验层编号
        self.spec_list = []  # 第layer_num层推理时间
        self.current_ratio_list = []  # flops率
        self.parameter_ratio_list = []  # parameter率
        self.timer_test_list = []  # 整个网络推理时间


        self.block_dict = {0: 0, 1: 1, 3: 2, 4: 3, 6: 4, 7: 5, 8: 6,
                           10: 7, 11: 8, 12: 9, 14: 10, 15: 11, 16: 12}

        for i in range(len(body_list)):
            body_list[i] = nn.Sequential(body_list[i])
    def __init__(self,
                 in_channel,
                 kernel_size=3,
                 scale=2,
                 conv=common.default_conv):
        super(RecurrentProjection, self).__init__()
        self.scale = scale
        stride_conv_ksize, stride, padding = {
            2: (6, 2, 2),
            3: (9, 3, 3),
            4: (6, 2, 2)
        }[scale]

        self.multi_source_projection = MultisourceProjection(
            in_channel, kernel_size=kernel_size, scale=scale, conv=conv)
        self.down_sample_1 = nn.Sequential(*[
            nn.Conv2d(in_channel,
                      in_channel,
                      stride_conv_ksize,
                      stride=stride,
                      padding=padding),
            nn.PReLU()
        ])
        if scale != 4:
            self.down_sample_2 = nn.Sequential(*[
                nn.Conv2d(in_channel,
                          in_channel,
                          stride_conv_ksize,
                          stride=stride,
                          padding=padding),
                nn.PReLU()
            ])
        self.error_encode = nn.Sequential(*[
            nn.ConvTranspose2d(in_channel,
                               in_channel,
                               stride_conv_ksize,
                               stride=stride,
                               padding=padding),
            nn.PReLU()
        ])
        self.post_conv = common.BasicBlock(conv,
                                           in_channel,
                                           in_channel,
                                           kernel_size,
                                           stride=1,
                                           bias=True,
                                           act=nn.PReLU())
        if scale == 4:
            self.multi_source_projection_2 = MultisourceProjection(
                in_channel, kernel_size=kernel_size, scale=scale, conv=conv)
            self.down_sample_3 = nn.Sequential(*[
                nn.Conv2d(in_channel, in_channel, 8, stride=4, padding=2),
                nn.PReLU()
            ])
            self.down_sample_4 = nn.Sequential(*[
                nn.Conv2d(in_channel, in_channel, 8, stride=4, padding=2),
                nn.PReLU()
            ])
            self.error_encode_2 = nn.Sequential(*[
                nn.ConvTranspose2d(
                    in_channel, in_channel, 8, stride=4, padding=2),
                nn.PReLU()
            ])
Ejemplo n.º 16
0
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(VGG_Basis, self).__init__()

        # we use batch noramlization for VGG
        args = args[0]
        norm = common.default_norm
        bias = not args.no_bias
        n_basis = args.n_basis
        basis_size = args.basis_size

        configs = {
            'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
            '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
            'ef': [32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M']
        }

        body_list = []
        in_channels = args.n_colors
        for i, v in enumerate(configs[args.vgg_type]):
            if v == 'M':
                body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                t = 3 if args.vgg_decom_type == 'all' else 8
                if i <= t:
                    body_list.append(common.BasicBlock(in_channels, v, args.kernel_size, bias=bias,
                                                       conv3x3=conv3x3, norm=norm))
                else:
                    body_list.append(BasicBlock(in_channels, v, n_basis, basis_size, args.kernel_size, bias=bias,
                                                conv=conv3x3, norm=norm))
                in_channels = v

        # assert(args.data_train.find('CIFAR') >= 0)
        self.features = nn.Sequential(*body_list)
        if args.data_train.find('CIFAR') >= 0:
            n_classes = int(args.data_train[5:])
            self.classifier = nn.Linear(in_channels, n_classes)
        elif args.data_train == 'ImageNet':
            n_classes = 1000
            self.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, n_classes),
            )

        if conv3x3 == common.default_conv:
            model_dir = os.path.join('..', 'models')
            os.makedirs(model_dir, exist_ok=True)
            if args.data_train.find('CIFAR') >= 0:
                if args.pretrained == 'download' or args.extend == 'download':
                    url = (
                        'https://cv.snu.ac.kr/'
                        'research/clustering_kernels/models/vgg16-89711a85.pt'
                    )

                    state = model_zoo.load_url(url, model_dir=model_dir)
                elif args.extend:
                    state = torch.load(args.extend)
                else:
                    common.init_vgg(self)
                    return
            elif args.data_train == 'ImageNet':
                if args.pretrained == 'download':
                    url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
                    state = model_zoo.load_url(url, model_dir=model_dir)
                else:
                    common.init_vgg(self)
                    return
            else:
                raise NotImplementedError('Unavailable dataset {}'.format(args.data_train))
            # from IPython import embed; embed()
            self.load_state_dict(state, False)
Ejemplo n.º 17
0
    def __init__(self,
                 args,
                 conv3x3=common.default_conv,
                 conv1x1=common.default_conv):
        super(ResNet_Group, self).__init__()
        args = args[0]
        self.args = args
        m = []
        if args.data_train.find('CIFAR') >= 0:
            self.expansion = 1

            self.n_blocks = (args.depth - 2) // 6
            self.in_channels = 16
            self.downsample_type = 'A'
            n_classes = int(args.data_train[5:])

            kwargs = {
                'kernel_size': args.kernel_size,
                'conv3x3': conv3x3,
            }
            m.append(common.BasicBlock(args.n_colors, 16, **kwargs))
            kwargs['conv3x3'] = BasicBlock
            m.append(self.make_layer(16, self.n_blocks, **kwargs))
            m.append(self.make_layer(32, self.n_blocks, stride=2, **kwargs))
            m.append(self.make_layer(64, self.n_blocks, stride=2, **kwargs))
            m.append(nn.AvgPool2d(8))

            fc = nn.Linear(64 * self.expansion, n_classes)

        elif args.data_train == 'ImageNet':
            block_config = {
                18: ([2, 2, 2, 2], ResBlock, 1),
                34: ([3, 4, 6, 3], ResBlock, 1),
                50: ([3, 4, 6, 3], BottleNeck, 4),
                101: ([3, 4, 23, 3], BottleNeck, 4),
                152: ([3, 8, 36, 3], BottleNeck, 4)
            }
            n_blocks, self.block, self.expansion = block_config[args.depth]

            self.in_channels = 64
            self.downsample_type = 'C'
            n_classes = 1000
            kwargs = {
                'conv3x3': conv3x3,
                'conv1x1': conv1x1,
            }
            m.append(
                common.BasicBlock(args.n_colors,
                                  64,
                                  7,
                                  stride=2,
                                  conv3x3=conv3x3,
                                  bias=False))
            m.append(nn.MaxPool2d(3, 2, padding=1))
            m.append(self.make_layer(64, n_blocks[0], 3, **kwargs))
            m.append(self.make_layer(128, n_blocks[1], 3, stride=2, **kwargs))
            m.append(self.make_layer(256, n_blocks[2], 3, stride=2, **kwargs))
            m.append(self.make_layer(512, n_blocks[3], 3, stride=2, **kwargs))
            m.append(nn.AvgPool2d(7, 1))

            fc = nn.Linear(512 * self.expansion, n_classes)

        self.features = nn.Sequential(*m)
        self.classifier = fc

        # only if when it is child model
        if conv3x3 == common.default_conv:
            if args.pretrained == 'download' or args.extend == 'download':
                state = getattr(models,
                                'resnet{}'.format(args.depth))(pretrained=True)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_kaiming(self)
                return

            source = state.state_dict()
            target = self.state_dict()
            for s, t in zip(source.keys(), target.keys()):
                target[t].copy_(source[s])
Ejemplo n.º 18
0
    def __init__(self, args, conv3x3=common.default_conv):
        super(ResNet_Basis_Blockwise, self).__init__()
        args = args[0]
        self.args = args
        m = []
        if args.data_train.find('CIFAR') >= 0:
            self.expansion = 1
            self.n_basis1 = args.n_basis1
            self.n_basis2 = args.n_basis2
            self.n_basis3 = args.n_basis3
            self.basis_size1 = args.basis_size1
            self.basis_size2 = args.basis_size2
            self.basis_size3 = args.basis_size3
            self.n_blocks = (args.depth - 2) // 6
            self.k_size = args.kernel_size
            self.in_channels = 16
            self.downsample_type = 'A'
            self.basis1 = nn.Parameter(
                nn.init.kaiming_uniform_(
                    torch.Tensor(self.n_basis1, self.basis_size1, self.k_size,
                                 self.k_size)))
            self.basis2 = nn.Parameter(
                nn.init.kaiming_uniform_(
                    torch.Tensor(self.n_basis2, self.basis_size2, self.k_size,
                                 self.k_size)))
            self.basis3 = nn.Parameter(
                nn.init.kaiming_uniform_(
                    torch.Tensor(self.n_basis3, self.basis_size3, self.k_size,
                                 self.k_size)))
            self.block = ResBlockDecom

            n_classes = int(args.data_train[5:])
            kwargs = {'kernel_size': args.kernel_size, 'conv3x3': conv3x3}
            m.append(common.BasicBlock(args.n_colors, 16, **kwargs))

            kwargs['conv3x3'] = BasicBlock
            kwargs['kernel_size2'] = args.k_size2
            kwargs['n_basis'] = self.n_basis1
            kwargs['basis_size'] = self.basis_size1
            kwargs['basis_l1'] = self.basis1
            kwargs['basis_l2'] = self.basis1

            m.append(self.make_layer(16, self.n_blocks, **kwargs))
            kwargs['n_basis'] = self.n_basis2
            kwargs['basis_size'] = self.basis_size2
            kwargs['basis_l1'] = self.basis1
            kwargs['basis_l2'] = self.basis2

            m.append(self.make_layer(32, self.n_blocks, stride=2, **kwargs))
            kwargs['n_basis'] = self.n_basis3
            kwargs['basis_size'] = self.basis_size3
            kwargs['basis_l1'] = self.basis2
            kwargs['basis_l2'] = self.basis3
            m.append(self.make_layer(64, self.n_blocks, stride=2, **kwargs))
            m.append(nn.AvgPool2d(8))

            fc = nn.Linear(64 * self.expansion, n_classes)

        self.features = nn.Sequential(*m)
        self.classifier = fc

        # only if when it is child model
        if conv3x3 == common.default_conv:
            if args.pretrained == 'download' or args.extend == 'download':
                state = getattr(models,
                                'resnet{}'.format(args.depth))(pretrained=True)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_kaiming(self)
                return

            source = state.state_dict()
            target = self.state_dict()
            for s, t in zip(source.keys(), target.keys()):
                target[t].copy_(source[s])
Ejemplo n.º 19
0
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(VGG, self).__init__()
        # args = args[0]
        # we use batch noramlization for VGG
        norm = None
        self.norm = norm
        bias = not args.no_bias

        configs = {
            'A':
            [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512,
                512, 'M'
            ],
            '16': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512,
                'M', 512, 512, 512, 'M'
            ],
            '19': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512,
                512, 512, 'M', 512, 512, 512, 512, 'M'
            ],
            'ef': [
                32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256,
                'M', 256, 256, 256, 'M'
            ]
        }

        body_list = []
        in_channels = args.n_colors
        if args.data_train.find('CIFAR') >= 0:

            for v in configs[args.vgg_type]:
                if v == 'M':
                    body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
                else:
                    body_list.append(
                        common.BasicBlock(in_channels,
                                          v,
                                          args.kernel_size,
                                          bias=bias,
                                          conv3x3=conv3x3,
                                          norm=norm))
                    in_channels = v
        else:
            for i, v in enumerate(configs[args.vgg_type]):
                if v == 'M':
                    body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
                else:
                    conv2d = conv3x3(in_channels, v, kernel_size=3)
                    if norm is not None:
                        body_list += [
                            conv2d,
                            nn.BatchNorm2d(v),
                            nn.ReLU(inplace=True)
                        ]
                    else:
                        body_list += [conv2d, nn.ReLU(inplace=True)]
                    in_channels = v

        # assert(args.data_train.find('CIFAR') >= 0)
        self.features = nn.Sequential(*body_list)
        if args.data_train.find('CIFAR') >= 0:
            n_classes = int(args.data_train[5:])
            self.classifier = nn.Linear(in_channels, n_classes)
        elif args.data_train == 'ImageNet':
            n_classes = 1000
            self.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, n_classes),
            )
            # self.classifier = nn.Sequential(
            #     nn.Linear(512 * 7 * 7, n_classes)
            # )
        print(conv3x3, conv3x3 == common.default_conv or conv3x3 == nn.Conv2d)

        if conv3x3 == common.default_conv or conv3x3 == nn.Conv2d:
            self.load(args, strict=True)
Ejemplo n.º 20
0
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(VGG_GROUP, self).__init__()
        args = args[0]
        # we use batch noramlization for VGG
        norm = common.default_norm
        bias = not args.no_bias
        group_size = args.group_size

        configs = {
            'A':
            [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512,
                512, 'M'
            ],
            '16': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512,
                'M', 512, 512, 512, 'M'
            ],
            '19': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512,
                512, 512, 'M', 512, 512, 512, 512, 'M'
            ],
            'ef': [
                32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256,
                'M', 256, 256, 256, 'M'
            ]
        }

        body_list = []
        in_channels = args.n_colors
        for i, v in enumerate(configs[args.vgg_type]):
            if v == 'M':
                body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                t = 3 if args.vgg_decom_type == 'all' else 8
                if i <= t:
                    body_list.append(
                        common.BasicBlock(in_channels,
                                          v,
                                          args.kernel_size,
                                          bias=bias,
                                          conv3x3=conv3x3,
                                          norm=norm))
                else:
                    body_list.append(
                        BasicBlock(in_channels,
                                   v,
                                   group_size,
                                   args.kernel_size,
                                   bias=bias,
                                   conv=conv3x3,
                                   norm=norm))
                in_channels = v

        # for CIFAR10 and CIFAR100 only
        assert (args.data_train.find('CIFAR') >= 0)
        n_classes = int(args.data_train[5:])

        self.features = nn.Sequential(*body_list)
        self.classifier = nn.Linear(in_channels, n_classes)

        if conv3x3 == common.default_conv:
            if args.pretrained == 'download' or args.extend == 'download':
                url = ('https://cv.snu.ac.kr/'
                       'research/clustering_kernels/models/vgg16-89711a85.pt')
                model_dir = os.path.join('..', 'models')
                os.makedirs(model_dir, exist_ok=True)
                state = torch.utils.model_zoo.load_url(url,
                                                       model_dir=model_dir)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_vgg(self)
                return

            self.load_state_dict(state, strict=False)
Ejemplo n.º 21
0
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(VGG, self).__init__()
        # we use batch noramlization for VGG
        norm = None
        self.norm = norm
        bias = not args.no_bias

        if args.overparam_type == '1':
            conv = common_tmp.conv3x3_1
        elif args.overparam_type == '2':
            conv = common_tmp.conv3x3_2
        elif args.overparam_type == '3':
            conv = common_tmp.conv3x3_3
        elif args.overparam_type == '4':
            conv = common_tmp.conv3x3_4
        elif args.overparam_type == '4tensor1':
            conv = common_tmp.conv3x3_4tensor1
        elif args.overparam_type == '4tensor2':
            conv = common_tmp.conv3x3_4tensor2
        elif args.overparam_type == '4tensor3':
            conv = common_tmp.conv3x3_4tensor3
        elif args.overparam_type == '5':
            conv = common_tmp.conv3x3_5
        elif args.overparam_type == '5relu':
            conv = common_tmp.conv3x3_5relu
        elif args.overparam_type == '5bias1':
            conv = common_tmp.conv3x3_5bias1
        elif args.overparam_type == '5bias2':
            conv = common_tmp.conv3x3_5bias2
        elif args.overparam_type == 'R1':
            conv = common.ROPConv1
        elif args.overparam_type == 'R2':
            conv = common.ROPConv2
        else:
            raise NotImplementedError

        configs = {
            'A':
            [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512,
                512, 'M'
            ],
            '16': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512,
                'M', 512, 512, 512, 'M'
            ],
            '19': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512,
                512, 512, 'M', 512, 512, 512, 512, 'M'
            ],
            'ef': [
                32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256,
                'M', 256, 256, 256, 'M'
            ]
        }

        body_list = []
        in_channels = args.n_colors
        if args.data_train.find('CIFAR') >= 0 or args.data_train.find(
                'Tiny') >= 0:

            for i, v in enumerate(configs[args.vgg_type]):
                if v == 'M':
                    body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
                else:
                    stride = 2 if i == 0 and args.data_train.find(
                        'Tiny') >= 0 else 1
                    body_list.append(
                        common.BasicBlock(in_channels,
                                          v,
                                          args.kernel_size,
                                          stride=stride,
                                          bias=bias,
                                          conv3x3=conv))
                    in_channels = v
        else:
            for i, v in enumerate(configs[args.vgg_type]):
                if v == 'M':
                    body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
                else:
                    conv2d = conv3x3(in_channels, v, kernel_size=3)
                    if norm is not None:
                        body_list += [
                            conv2d,
                            nn.BatchNorm2d(v),
                            nn.ReLU(inplace=True)
                        ]
                    else:
                        body_list += [conv2d, nn.ReLU(inplace=True)]
                    in_channels = v

        # assert(args.data_train.find('CIFAR') >= 0)
        self.features = nn.Sequential(*body_list)
        if args.data_train.find('CIFAR') >= 0:
            n_classes = int(args.data_train[5:])
            if args.template.find('linear3') >= 0:
                self.classifier = nn.Sequential(
                    nn.Linear(in_channels, in_channels),
                    nn.Linear(in_channels, in_channels),
                    nn.Linear(in_channels, n_classes))
            else:
                self.classifier = nn.Linear(in_channels, n_classes)
        if args.data_train.find('Tiny') >= 0:
            n_classes = 200
            self.classifier = nn.Sequential(
                nn.Linear(in_channels, in_channels),
                nn.Linear(in_channels, in_channels),
                nn.Linear(in_channels, n_classes))
        elif args.data_train == 'ImageNet':
            n_classes = 1000
            self.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, n_classes),
            )
            # self.classifier = nn.Sequential(
            #     nn.Linear(512 * 7 * 7, n_classes)
            # )
        # print(conv3x3, conv3x3 == common.default_conv or conv3x3 == nn.Conv2d)

        if conv3x3 == common.default_conv or conv3x3 == nn.Conv2d:
            self.load(args, strict=True)
Ejemplo n.º 22
0
    def __init__(self,
                 args,
                 conv3x3=common.default_conv,
                 conv1x1=common.default_conv):

        super(ResNet, self).__init__()

        m = []
        if args.data_train.find('CIFAR') >= 0:
            self.block = ResBlock
            self.expansion = 1

            self.n_blocks = (args.depth - 2) // 6
            self.in_channels = 16
            self.downsample_type = 'A'
            n_classes = int(args.data_train[5:])

            kwargs = {
                'kernel_size': args.kernel_size,
                'conv3x3': conv3x3,
            }
            m.append(common.BasicBlock(args.n_colors, 16, **kwargs))
            m.append(self.make_layer(16, self.n_blocks, **kwargs))
            m.append(self.make_layer(32, self.n_blocks, stride=2, **kwargs))
            m.append(self.make_layer(64, self.n_blocks, stride=2, **kwargs))
            m.append(nn.AvgPool2d(8))

            fc = nn.Linear(64 * self.expansion, n_classes)

        elif args.data_train == 'ImageNet':
            block_config = {
                18: ([2, 2, 2, 2], ResBlock, 1),
                34: ([3, 4, 6, 3], ResBlock, 1),
                50: ([3, 4, 6, 3], BottleNeck, 4),
                101: ([3, 4, 23, 3], BottleNeck, 4),
                152: ([3, 8, 36, 3], BottleNeck, 4)
            }
            n_blocks, self.block, self.expansion = block_config[args.depth]

            self.in_channels = 64
            self.downsample_type = 'C'
            n_classes = 1000
            kwargs = {
                'conv3x3': conv3x3,
                'conv1x1': conv1x1,
            }
            m.append(
                common.BasicBlock(args.n_colors,
                                  64,
                                  7,
                                  stride=2,
                                  conv3x3=conv3x3,
                                  bias=False))
            m.append(nn.MaxPool2d(3, 2, padding=1))
            m.append(self.make_layer(64, n_blocks[0], 3, **kwargs))
            m.append(self.make_layer(128, n_blocks[1], 3, stride=2, **kwargs))
            m.append(self.make_layer(256, n_blocks[2], 3, stride=2, **kwargs))
            m.append(self.make_layer(512, n_blocks[3], 3, stride=2, **kwargs))
            m.append(nn.AvgPool2d(7, 1))

            fc = nn.Linear(512 * self.expansion, n_classes)

        self.features = nn.Sequential(*m)
        self.classifier = fc

        # only if when it is child model
        if conv3x3 == common.default_conv:
            self.load(args, strict=True)