예제 #1
0
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(DenseNet_Basis, self).__init__()
        # from IPython import embed; embed()
        args = args[0]
        n_basis = args.n_basis
        n_group = args.n_group
        kernel_size = args.k_size1

        n_blocks = (args.depth - 4) // 3
        if args.bottleneck: n_blocks //= 2

        k = args.k
        c_in = 2 * k

        basis_size = k * (n_blocks * 3 + 1) // n_group
        transition_group = args.transition_group
        def _dense_block(basis, in_channels):
            module_list = []
            for _ in range(n_blocks):
                if args.bottleneck:
                    module_list.append(BottleNeck(in_channels, k, conv=conv3x3))
                else:
                    module_list.append(Dense(basis, in_channels, k, conv=BasicBlock, args=args))
                in_channels += k

            return nn.Sequential(*module_list)

        self.basis = nn.Parameter(nn.init.kaiming_uniform_(torch.Tensor(n_basis, basis_size, kernel_size, kernel_size)))#.to(self.device))
        module_list = []
        module_list.append(conv3x3(args.n_colors, c_in, 3, bias=False))

        for i in range(3):
            module_list.append(_dense_block(self.basis, c_in))
            c_in += k * n_blocks
            if i < 2:
                c_out = int(math.floor(args.reduction * c_in))
                module_list.append(Transition(c_in, transition_group, c_out))
                c_in = c_out

        module_list.append(common.default_norm(c_in))
        module_list.append(common.default_act())
        module_list.append(nn.AvgPool2d(8))
        self.features = nn.Sequential(*module_list)

        if args.data_train == 'ImageNet':
            n_classes = 1000
        else:
            if args.data_train.find('CIFAR') >= 0:
                n_classes = int(args.data_train[5:])

        self.classifier = nn.Linear(c_in, n_classes)

        common.init_kaiming(self)
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(DenseNet_Group, self).__init__()
        args = args[0]
        n_blocks = (args.depth - 4) // 3
        if args.bottleneck: n_blocks //= 2

        k = args.k
        c_in = 2 * k

        def _dense_block(in_channels):
            module_list = []
            for _ in range(n_blocks):
                if args.bottleneck:
                    module_list.append(BottleNeck(in_channels, k,
                                                  conv=conv3x3))
                else:
                    group_size = 3 if in_channels <= 252 else args.group_size
                    module_list.append(
                        Dense(in_channels,
                              k,
                              group_size=group_size,
                              conv=conv3x3))
                in_channels += k

            return nn.Sequential(*module_list)

        module_list = []
        module_list.append(conv3x3(args.n_colors, c_in, 3, bias=False))

        for i in range(3):
            module_list.append(_dense_block(c_in))
            c_in += k * n_blocks
            if i < 2:
                c_out = int(math.floor(args.reduction * c_in))
                module_list.append(Transition(c_in, c_out))
                c_in = c_out

        module_list.append(common.default_norm(c_in))
        module_list.append(common.default_act())
        module_list.append(nn.AvgPool2d(8))
        self.features = nn.Sequential(*module_list)

        if args.data_train == 'ImageNet':
            n_classes = 1000
        else:
            if args.data_train.find('CIFAR') >= 0:
                n_classes = int(args.data_train[5:])

        self.classifier = nn.Linear(c_in, n_classes)

        common.init_kaiming(self)
예제 #3
0
    def load(self, args, strict=True):
        if args.data_train == 'ImageNet':
            if args.pretrain == 'download' or args.extend == 'download':
                state = getattr(models,
                                'resnet{}'.format(args.depth))(pretrain=True)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_kaiming(self)
                return

            source = state.state_dict()
            target = self.state_dict()
            for s, t in zip(source.keys(), target.keys()):
                target[t].copy_(source[s])
        else:
            if args.pretrain:
                self.load_state_dict(torch.load(args.pretrain), strict=strict)
예제 #4
0
    def __init__(self,
                 args,
                 conv3x3=common.default_conv,
                 conv1x1=common.default_conv):
        super(ResNet_Group, self).__init__()
        args = args[0]
        self.args = args
        m = []
        if args.data_train.find('CIFAR') >= 0:
            self.expansion = 1

            self.n_blocks = (args.depth - 2) // 6
            self.in_channels = 16
            self.downsample_type = 'A'
            n_classes = int(args.data_train[5:])

            kwargs = {
                'kernel_size': args.kernel_size,
                'conv3x3': conv3x3,
            }
            m.append(common.BasicBlock(args.n_colors, 16, **kwargs))
            kwargs['conv3x3'] = BasicBlock
            m.append(self.make_layer(16, self.n_blocks, **kwargs))
            m.append(self.make_layer(32, self.n_blocks, stride=2, **kwargs))
            m.append(self.make_layer(64, self.n_blocks, stride=2, **kwargs))
            m.append(nn.AvgPool2d(8))

            fc = nn.Linear(64 * self.expansion, n_classes)

        elif args.data_train == 'ImageNet':
            block_config = {
                18: ([2, 2, 2, 2], ResBlock, 1),
                34: ([3, 4, 6, 3], ResBlock, 1),
                50: ([3, 4, 6, 3], BottleNeck, 4),
                101: ([3, 4, 23, 3], BottleNeck, 4),
                152: ([3, 8, 36, 3], BottleNeck, 4)
            }
            n_blocks, self.block, self.expansion = block_config[args.depth]

            self.in_channels = 64
            self.downsample_type = 'C'
            n_classes = 1000
            kwargs = {
                'conv3x3': conv3x3,
                'conv1x1': conv1x1,
            }
            m.append(
                common.BasicBlock(args.n_colors,
                                  64,
                                  7,
                                  stride=2,
                                  conv3x3=conv3x3,
                                  bias=False))
            m.append(nn.MaxPool2d(3, 2, padding=1))
            m.append(self.make_layer(64, n_blocks[0], 3, **kwargs))
            m.append(self.make_layer(128, n_blocks[1], 3, stride=2, **kwargs))
            m.append(self.make_layer(256, n_blocks[2], 3, stride=2, **kwargs))
            m.append(self.make_layer(512, n_blocks[3], 3, stride=2, **kwargs))
            m.append(nn.AvgPool2d(7, 1))

            fc = nn.Linear(512 * self.expansion, n_classes)

        self.features = nn.Sequential(*m)
        self.classifier = fc

        # only if when it is child model
        if conv3x3 == common.default_conv:
            if args.pretrained == 'download' or args.extend == 'download':
                state = getattr(models,
                                'resnet{}'.format(args.depth))(pretrained=True)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_kaiming(self)
                return

            source = state.state_dict()
            target = self.state_dict()
            for s, t in zip(source.keys(), target.keys()):
                target[t].copy_(source[s])
예제 #5
0
    def __init__(self, args, conv3x3=common.default_conv):
        super(ResNet_Basis_Blockwise, self).__init__()
        args = args[0]
        self.args = args
        m = []
        if args.data_train.find('CIFAR') >= 0:
            self.expansion = 1
            self.n_basis1 = args.n_basis1
            self.n_basis2 = args.n_basis2
            self.n_basis3 = args.n_basis3
            self.basis_size1 = args.basis_size1
            self.basis_size2 = args.basis_size2
            self.basis_size3 = args.basis_size3
            self.n_blocks = (args.depth - 2) // 6
            self.k_size = args.kernel_size
            self.in_channels = 16
            self.downsample_type = 'A'
            self.basis1 = nn.Parameter(
                nn.init.kaiming_uniform_(
                    torch.Tensor(self.n_basis1, self.basis_size1, self.k_size,
                                 self.k_size)))
            self.basis2 = nn.Parameter(
                nn.init.kaiming_uniform_(
                    torch.Tensor(self.n_basis2, self.basis_size2, self.k_size,
                                 self.k_size)))
            self.basis3 = nn.Parameter(
                nn.init.kaiming_uniform_(
                    torch.Tensor(self.n_basis3, self.basis_size3, self.k_size,
                                 self.k_size)))
            self.block = ResBlockDecom

            n_classes = int(args.data_train[5:])
            kwargs = {'kernel_size': args.kernel_size, 'conv3x3': conv3x3}
            m.append(common.BasicBlock(args.n_colors, 16, **kwargs))

            kwargs['conv3x3'] = BasicBlock
            kwargs['kernel_size2'] = args.k_size2
            kwargs['n_basis'] = self.n_basis1
            kwargs['basis_size'] = self.basis_size1
            kwargs['basis_l1'] = self.basis1
            kwargs['basis_l2'] = self.basis1

            m.append(self.make_layer(16, self.n_blocks, **kwargs))
            kwargs['n_basis'] = self.n_basis2
            kwargs['basis_size'] = self.basis_size2
            kwargs['basis_l1'] = self.basis1
            kwargs['basis_l2'] = self.basis2

            m.append(self.make_layer(32, self.n_blocks, stride=2, **kwargs))
            kwargs['n_basis'] = self.n_basis3
            kwargs['basis_size'] = self.basis_size3
            kwargs['basis_l1'] = self.basis2
            kwargs['basis_l2'] = self.basis3
            m.append(self.make_layer(64, self.n_blocks, stride=2, **kwargs))
            m.append(nn.AvgPool2d(8))

            fc = nn.Linear(64 * self.expansion, n_classes)

        self.features = nn.Sequential(*m)
        self.classifier = fc

        # only if when it is child model
        if conv3x3 == common.default_conv:
            if args.pretrained == 'download' or args.extend == 'download':
                state = getattr(models,
                                'resnet{}'.format(args.depth))(pretrained=True)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_kaiming(self)
                return

            source = state.state_dict()
            target = self.state_dict()
            for s, t in zip(source.keys(), target.keys()):
                target[t].copy_(source[s])