예제 #1
0
def Split_ResNet101(cfg, progress=True):
    model = ResNet(cfg, get_builder(cfg), Bottleneck, [3, 4, 23, 3])
    if cfg.pretrained == 'imagenet':
        arch = 'resnet101'
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        load_state_dict(model, state_dict, strict=False)
    return model
예제 #2
0
def Split_ResNet18(cfg, progress=True):
    model = ResNet(cfg, get_builder(cfg), BasicBlock, [2, 2, 2, 2])
    if cfg.pretrained == 'imagenet':
        arch = 'resnet18'
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        load_state_dict(model, state_dict, strict=False)
    return model
예제 #3
0
def Split_densenet201(cfg, pretrained=False, progress=True, **kwargs):
    r"""Densenet-201 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """
    return _densenet(cfg, get_builder(cfg), 'densenet201', 32, (6, 12, 48, 32),
                     64, pretrained, progress, **kwargs)
예제 #4
0
    def __init__(self,cfg, num_classes=1000, aux_logits=False, transform_input=False, init_weights=None,
                 blocks=None): # AT : I disabled the aux_logits
        super(GoogLeNet, self).__init__()
        if blocks is None:
            blocks = [BasicConv2d, Inception, InceptionAux]
        if init_weights is None:
            warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
            init_weights = True
        assert len(blocks) == 3

        builder = get_builder(cfg)
        slim_factor = cfg.slim_factor
        if slim_factor < 1:
            cfg.logger.info('WARNING: You are using a slim network')

        conv_block = blocks[0]
        inception_block = blocks[1]
        inception_aux_block = blocks[2]

        self.aux_logits = aux_logits
        self.transform_input = transform_input

        slim = lambda x: math.ceil(x * slim_factor)
        self.conv1 = conv_block(builder,3, slim(64), kernel_size=7, stride=2) # , padding=3
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = conv_block(builder,slim(64), slim(64), kernel_size=1)
        self.conv3 = conv_block(builder,slim(64), slim(192), kernel_size=3) # padding=1
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = inception_block(builder, slim(192), slim(64)
                                           , slim(96), slim(128), slim(16), slim(32), slim(32))
        prev_out_channels = self.inception3a.out_channels # This is 256
        concat_order = self.inception3a.concat_order

        self.inception3b = inception_block(builder,prev_out_channels, slim(128), slim(128), slim(192), slim(32), slim(96), slim(64),in_channels_order=concat_order)
        prev_out_channels = self.inception3b.out_channels # This is 480
        concat_order = self.inception3b.concat_order

        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = inception_block(builder,prev_out_channels, slim(192), slim(96), slim(208), slim(16), slim(48), slim(64),in_channels_order=concat_order)
        prev_out_channels = self.inception4a.out_channels # This is 512
        concat_order = self.inception4a.concat_order

        self.inception4b = inception_block(builder,prev_out_channels, slim(160), slim(112), slim(224), slim(24), slim(64), slim(64),in_channels_order=concat_order)
        prev_out_channels = self.inception4b.out_channels  # This is 512
        concat_order = self.inception4b.concat_order

        self.inception4c = inception_block(builder,prev_out_channels, slim(128), slim(128), slim(256), slim(24), slim(64), slim(64),in_channels_order=concat_order)
        prev_out_channels = self.inception4c.out_channels  # This is 512
        concat_order = self.inception4c.concat_order

        self.inception4d = inception_block(builder,prev_out_channels, slim(112), slim(144), slim(288), slim(32), slim(64), slim(64),in_channels_order=concat_order)
        prev_out_channels = self.inception4d.out_channels  # This is 528
        concat_order = self.inception4d.concat_order

        self.inception4e = inception_block(builder,prev_out_channels, slim(256), slim(160), slim(320), slim(32), slim(128), slim(128),in_channels_order=concat_order)
        prev_out_channels = self.inception4e.out_channels  # This is 832
        concat_order = self.inception4e.concat_order

        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = inception_block(builder,prev_out_channels, slim(256), slim(160), slim(320), slim(32), slim(128), slim(128),in_channels_order=concat_order)
        prev_out_channels = self.inception5a.out_channels  # This is 832
        concat_order = self.inception5a.concat_order

        self.inception5b = inception_block(builder,prev_out_channels, slim(384), slim(192), slim(384), slim(48), slim(128), slim(128),in_channels_order=concat_order)
        prev_out_channels = self.inception5b.out_channels  # This is 1024
        concat_order = self.inception5b.concat_order

        if aux_logits:
            self.aux1 = inception_aux_block(builder,512, num_classes)
            self.aux2 = inception_aux_block(builder,528, num_classes)
        else:
            self.aux1 = None
            self.aux2 = None

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.2)
        # self.fc = nn.Linear(1024, num_classes)
        self.fc = builder.linear(prev_out_channels, num_classes,last_layer=True,in_channels_order=concat_order)

        if init_weights:
            self._initialize_weights()