def __init__( self, width_multiplier=1., num_classes=1000, torchvision_pretrained=False, pretrained_num_classes=1000, fix_bn=False, partial_bn=False, ): """ :param width_multiplier: 宽度乘法器 :param num_classes: 类别数 :param torchvision_pretrained: 预训练模型 :param pretrained_num_classes: 假定预训练模型类别数 :param fix_bn: 固定BN :param partial_bn: 仅训练第一层BN """ super(TorchvisionMNASNet, self).__init__() self.fix_bn = fix_bn self.partial_bn = partial_bn if width_multiplier == 0.5: self.model = mnasnet0_5(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) elif width_multiplier == 0.75: self.model = mnasnet0_75(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) elif width_multiplier == 1.0: self.model = mnasnet1_0(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) elif width_multiplier == 1.3: self.model = mnasnet1_3(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) else: raise ValueError('no such value') self.init_weights(num_classes, pretrained_num_classes)
def __init__( self, # 宽度乘法器 width_multiplier=1., # 类别数 num_classes=1000, # 预训练模型 torchvision_pretrained=False, # 假定预训练模型类别数 pretrained_num_classes=1000, # 固定BN fix_bn=False, # 仅训练第一层BN partial_bn=False, ): super(TorchvisionMNASNet, self).__init__() self.fix_bn = fix_bn self.partial_bn = partial_bn if width_multiplier == 0.5: self.model = mnasnet0_5(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) elif width_multiplier == 0.75: self.model = mnasnet0_75(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) elif width_multiplier == 1.0: self.model = mnasnet1_0(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) elif width_multiplier == 1.3: self.model = mnasnet1_3(pretrained=torchvision_pretrained, num_classes=pretrained_num_classes) else: raise ValueError('no such value') self.init_weights(num_classes, pretrained_num_classes)
def test_mnasnet(self): x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)) self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)
elif BACKBONE == 'resnet50': net = resnet50(num_classes=num_classes).to(device) elif BACKBONE == 'densenet121': net = densenet121(num_classes=num_classes).to(device) elif BACKBONE == 'mobilenet_v2': net = mobilenet_v2(num_classes=num_classes, width_mult=1.0, inverted_residual_setting=None, round_nearest=8).to(device) elif BACKBONE == 'shufflenet_v2_x1_5': net = shufflenet_v2_x1_5(num_classes=num_classes).to(device) elif BACKBONE == 'squeezenet1_0': net = squeezenet1_0(num_classes=num_classes).to(device) elif BACKBONE == 'squeezenet1_1': net = squeezenet1_1(num_classes=num_classes).to(device) elif BACKBONE == 'mnasnet0_5': net = mnasnet0_5(num_classes=num_classes).to(device) elif BACKBONE == 'mnasnet1_0': net = mnasnet1_0(num_classes=num_classes).to(device) elif BACKBONE == 'mobilenet_v1': net = MobileNetV1(num_classes=num_classes).to(device) else: raise Exception('unknow backbone: {}'.format(BACKBONE)) for m in net.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.normal_(m.weight.data, 0, 0.1)