コード例 #1
0
    def __init__(
        self,
        width_multiplier=1.,
        num_classes=1000,
        torchvision_pretrained=False,
        pretrained_num_classes=1000,
        fix_bn=False,
        partial_bn=False,
    ):
        """
        :param width_multiplier: 宽度乘法器
        :param num_classes: 类别数
        :param torchvision_pretrained: 预训练模型
        :param pretrained_num_classes: 假定预训练模型类别数
        :param fix_bn: 固定BN
        :param partial_bn: 仅训练第一层BN
        """
        super(TorchvisionMNASNet, self).__init__()

        self.fix_bn = fix_bn
        self.partial_bn = partial_bn

        if width_multiplier == 0.5:
            self.model = mnasnet0_5(pretrained=torchvision_pretrained,
                                    num_classes=pretrained_num_classes)
        elif width_multiplier == 0.75:
            self.model = mnasnet0_75(pretrained=torchvision_pretrained,
                                     num_classes=pretrained_num_classes)
        elif width_multiplier == 1.0:
            self.model = mnasnet1_0(pretrained=torchvision_pretrained,
                                    num_classes=pretrained_num_classes)
        elif width_multiplier == 1.3:
            self.model = mnasnet1_3(pretrained=torchvision_pretrained,
                                    num_classes=pretrained_num_classes)
        else:
            raise ValueError('no such value')

        self.init_weights(num_classes, pretrained_num_classes)
コード例 #2
0
ファイル: mnasnet_recognizer.py プロジェクト: lilujunai/ZCls
    def __init__(
        self,
        # 宽度乘法器
        width_multiplier=1.,
        # 类别数
        num_classes=1000,
        # 预训练模型
        torchvision_pretrained=False,
        # 假定预训练模型类别数
        pretrained_num_classes=1000,
        # 固定BN
        fix_bn=False,
        # 仅训练第一层BN
        partial_bn=False,
    ):
        super(TorchvisionMNASNet, self).__init__()

        self.fix_bn = fix_bn
        self.partial_bn = partial_bn

        if width_multiplier == 0.5:
            self.model = mnasnet0_5(pretrained=torchvision_pretrained,
                                    num_classes=pretrained_num_classes)
        elif width_multiplier == 0.75:
            self.model = mnasnet0_75(pretrained=torchvision_pretrained,
                                     num_classes=pretrained_num_classes)
        elif width_multiplier == 1.0:
            self.model = mnasnet1_0(pretrained=torchvision_pretrained,
                                    num_classes=pretrained_num_classes)
        elif width_multiplier == 1.3:
            self.model = mnasnet1_3(pretrained=torchvision_pretrained,
                                    num_classes=pretrained_num_classes)
        else:
            raise ValueError('no such value')

        self.init_weights(num_classes, pretrained_num_classes)
コード例 #3
0
 def test_mnasnet(self):
     x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
     self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)
コード例 #4
0
ファイル: train.py プロジェクト: JuneXia/proml
elif BACKBONE == 'resnet50':
    net = resnet50(num_classes=num_classes).to(device)
elif BACKBONE == 'densenet121':
    net = densenet121(num_classes=num_classes).to(device)
elif BACKBONE == 'mobilenet_v2':
    net = mobilenet_v2(num_classes=num_classes, width_mult=1.0, inverted_residual_setting=None, round_nearest=8).to(device)
elif BACKBONE == 'shufflenet_v2_x1_5':
    net = shufflenet_v2_x1_5(num_classes=num_classes).to(device)
elif BACKBONE == 'squeezenet1_0':
    net = squeezenet1_0(num_classes=num_classes).to(device)
elif BACKBONE == 'squeezenet1_1':
    net = squeezenet1_1(num_classes=num_classes).to(device)
elif BACKBONE == 'mnasnet0_5':
    net = mnasnet0_5(num_classes=num_classes).to(device)
elif BACKBONE == 'mnasnet1_0':
    net = mnasnet1_0(num_classes=num_classes).to(device)
elif BACKBONE == 'mobilenet_v1':
    net = MobileNetV1(num_classes=num_classes).to(device)
else:
    raise Exception('unknow backbone: {}'.format(BACKBONE))

for m in net.modules():
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight.data, 0, 0.1)