def mobilenet_v2(pretrained='scratch', **kwargs): model = MobileNetV2(**kwargs) if pretrained == 'imagenet': print('loading pre-trained imagenet model') model_full = mobilenet_v2_imagenet.mobilenet_v2(pretrained=True) model.load_pretrained(model_full) elif pretrained == 'coco': print('loading pre-trained COCO model') # Load checkpoint checkpoint = torch.load(os.path.join(Path.models_dir(), 'mobilenet_v2_coco_80.pth'), map_location=lambda storage, loc: storage) # handle dataparallel if 'module.' in list(checkpoint.keys())[0]: new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k.replace('module.', '') # remove `module.` new_state_dict[name] = v else: new_state_dict = checkpoint # Load pre-trained IN model model.load_state_dict(new_state_dict) elif pretrained == 'scratch': print('using imagenet initialized from scratch') else: raise NotImplementedError( 'select either imagenet or scratch for pre-training') return model
def Res_Deeplab(n_classes=21, pretrained=None): model = MS_Deeplab(Bottleneck, n_classes) if pretrained is not None: if pretrained == 'voc': pth_model = 'MS_DeepLab_resnet_trained_VOC.pth' elif pretrained == 'ms_coco': pth_model = 'MS_DeepLab_resnet_pretrained_COCO_init.pth' saved_state_dict = torch.load( os.path.join(Path.models_dir(), pth_model), map_location=lambda storage, loc: storage) if n_classes != 21: for i in saved_state_dict: i_parts = i.split('.') if i_parts[1] == 'layer5': saved_state_dict[i] = model.state_dict()[i] model.load_state_dict(saved_state_dict) return model
def resnet26(pretrained=False, remote=True): """Constructs a ResNet-26 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet_C5(Bottleneck, [2, 2, 2, 2]) if pretrained: # Define ResNet26 ImageNet model_IN = resnet.ResNet(block=Bottleneck, layers=[2, 2, 2, 2], num_classes=1000) # Load checkpoint if remote: checkpoint = load_state_dict_from_url(model_urls['resnet26'], map_location='cpu', progress=True) else: checkpoint = torch.load(os.path.join(Path.models_dir(), 'resnet26.pth'), map_location=lambda storage, loc: storage) checkpoint = checkpoint['model_state'] # Handle DataParallel if 'module.' in list(checkpoint.keys())[0]: new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k.replace('module.', '') # remove `module.` new_state_dict[name] = v else: new_state_dict = checkpoint # Load pre-trained IN model model_IN.load_state_dict(new_state_dict) # Load weights to dense-labelling network model.load_pretrained(model_IN) return model
def se_mobilenet_v2(pretrained=False, features=False, n_class=1000, last_channel=1280, remote=True): if not features: model = SEMobileNetV2(n_class=n_class, last_channel=last_channel) else: model = SEMobileNetV2Features(n_class=n_class, last_channel=last_channel) if pretrained: print('Loading Imagenet pre-trained SE-MobileNet-v2') # Load checkpoint if remote: checkpoint = load_state_dict_from_url( model_urls['se_mobilenet_v2_1280'], map_location='cpu', progress=True) else: checkpoint = torch.load(os.path.join(Path.models_dir(), 'se_mobilenet_v2_1280.pth'), map_location=lambda storage, loc: storage) checkpoint = checkpoint['model_state'] # Handle DataParallel if 'module.' in list(checkpoint.keys())[0]: new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k.replace('module.', '') # remove `module.` new_state_dict[name] = v else: new_state_dict = checkpoint # Load pre-trained IN model model.load_state_dict(new_state_dict) return model
def get_state_dict_se(model_name, remote=True): # Load checkpoint if remote: checkpoint = load_state_dict_from_url(model_urls[model_name], map_location='cpu', progress=True) else: checkpoint = torch.load(os.path.join(Path.models_dir(), model_name + '.pth'), map_location=lambda storage, loc: storage) checkpoint = checkpoint['model_state'] # Handle DataParallel if 'module.' in list(checkpoint.keys())[0]: new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k.replace('module.', '') # remove `module.` new_state_dict[name] = v else: new_state_dict = checkpoint return new_state_dict
def resnet26(pretrained=False, features=False, remote=True, **kwargs): """Constructs a ResNet-26 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ if not features: model = ResNet(Bottleneck, [2, 2, 2, 2], **kwargs) else: model = ResNetFeatures(Bottleneck, [2, 2, 2, 2], **kwargs) if pretrained: print('Loading resnet26 Imagenet') # Load checkpoint if remote: checkpoint = load_state_dict_from_url(model_urls['resnet26'], map_location='cpu', progress=True) else: checkpoint = torch.load(os.path.join(Path.models_dir(), 'resnet26.pth'), map_location=lambda storage, loc: storage) checkpoint = checkpoint['model_state'] # Handle DataParallel if 'module.' in list(checkpoint.keys())[0]: new_state_dict = OrderedDict() for k, v in checkpoint.items(): name = k.replace('module.', '') # remove `module.` new_state_dict[name] = v else: new_state_dict = checkpoint # Load pre-trained IN model model.load_state_dict(new_state_dict) return model
def mobilenet_v2(pretrained=False, features=False, n_class=1000, last_channel=1280, remote=True): if not features: model = MobileNetV2(n_class=n_class, last_channel=last_channel) else: model = MobileNetV2Features(n_class=n_class, last_channel=last_channel) if pretrained: if remote: checkpoint = load_state_dict_from_url( model_urls['mobilenet_v2_1280'], map_location='cpu', progress=True) else: checkpoint = torch.load(os.path.join(Path.models_dir(), 'mobilenet_v2.pth'), map_location='cpu') model.load_state_dict(checkpoint) return model
def resnext_101_32x4d(): model = resnext_101_32x4d_model model.load_state_dict(torch.load(os.path.join(Path.models_dir(), 'resnext_101_32x4d.pth'))) return model