Пример #1
0
    def __init__(self,
                 arch="shufflenet_v2_x2_0",
                 num_classes=1000,
                 torchvision_pretrained=False,
                 pretrained_num_classes=1000,
                 fix_bn=False,
                 partial_bn=False):
        super(TorchvisionShuffleNetV2, self).__init__()

        self.fix_bn = fix_bn
        self.partial_bn = partial_bn

        if arch == 'shufflenet_v2_x2_0':
            self.model = shufflenet_v2_x2_0(pretrained=torchvision_pretrained,
                                            num_classes=pretrained_num_classes)
        elif arch == 'shufflenet_v2_x1_5':
            self.model = shufflenet_v2_x1_5(pretrained=torchvision_pretrained,
                                            num_classes=pretrained_num_classes)
        elif arch == 'shufflenet_v2_x1_0':
            self.model = shufflenet_v2_x1_0(pretrained=torchvision_pretrained,
                                            num_classes=pretrained_num_classes)
        elif arch == 'shufflenet_v2_x0_5':
            self.model = shufflenet_v2_x0_5(pretrained=torchvision_pretrained,
                                            num_classes=pretrained_num_classes)
        else:
            raise ValueError('no such value')

        self.init_weights(num_classes, pretrained_num_classes)
Пример #2
0
 def __init__(self, num_classes):
     super(Model, self).__init__()
     self.net = shufflenet_v2_x1_0(pretrained=True)
     num_feature = self.net.fc.in_features
     self.net.classifier = nn.Sequential(
         nn.Linear(num_feature, 1),
         nn.Sigmoid()
     )
Пример #3
0
def get_shfflenet(pre_trained=True):
    import torchvision.models.shufflenetv2 as shufflenet
    model_shuffle_net = shufflenet.shufflenet_v2_x1_0(num_classes=365)
    if pre_trained:
        model_file = 'DenseDesc/shufflenet_v2_x1_0_best.pth.tar'
        checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
        state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()}
        model_shuffle_net.load_state_dict(state_dict)
    return model_shuffle_net
def get_shufflenet(load_ckpt=True):
    """
    loading ShuffleNet V2 of trained on places365 as basenet
    """
    model_shuffle_net = shufflenet.shufflenet_v2_x1_0(num_classes=365)
    if load_ckpt:
        model_file = 'Place365/shufflenet_v2_x1_0_best.pth.tar'
        try:
            checkpoint = torch.load(model_file,
                                    map_location=lambda storage, loc: storage)
        except IOError:
            print('Fail to load checkpoint, because no such file:', model_file)
        else:
            state_dict = {
                str.replace(k, 'module.', ''): v
                for k, v in checkpoint['state_dict'].items()
            }
            model_shuffle_net.load_state_dict(state_dict)
    return model_shuffle_net
Пример #5
0
validate_set = torchvision.datasets.ImageFolder(
    root=conf["dataset"]["validate"], transform=transform_list)
validate_loader = torch.utils.data.DataLoader(
    validate_set,
    batch_size=conf["parameters"]["batch_size"],
    shuffle=True,
    num_workers=4)

assert train_set.class_to_idx == validate_set.class_to_idx

# Model
model_name = args.model if args.model else conf["parameters"]["model_name"]
print(f">> Building {model_name} model...")
model_dict = {
    "shufflenet": shufflenetv2.shufflenet_v2_x1_0(pretrained=True),
    "mobilenet": mobilenet.mobilenet_v2(pretrained=True),
    "vgg": vgg11_bn(pretrained=True),
    "resnet": resnet18(pretrained=True),
    "crnn": crnn.CRNN()
}
model = model_dict[model_name]
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),
                      lr=conf["parameters"]["learning_rate"],
                      momentum=0.9,
                      weight_decay=5e-4)