コード例 #1
0
    def load_model():
        # this model has a last conv feature map as 14x14
        model_file = 'wideresnet18_places365.pth.tar'
        if not os.access(model_file, os.W_OK):
            os.system('wget http://places2.csail.mit.edu/models_places365/' +
                      model_file)
            os.system(
                'wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
            )

        import wideresnet
        model = wideresnet.resnet18(num_classes=365)
        checkpoint = torch.load(model_file,
                                map_location=lambda storage, loc: storage)
        state_dict = {
            str.replace(k, 'module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        }
        model.load_state_dict(state_dict)
        model.eval()
        model.eval()
        # hook the feature extractor
        features_names = ['layer4', 'avgpool'
                          ]  # this is the last conv layer of the resnet
        for name in features_names:
            model._modules.get(name).register_forward_hook(hook_feature)
        return model
コード例 #2
0
ファイル: places365.py プロジェクト: mc261670164/OutdoorSent
 def __init__(this):
     this.features_blobs = []
     file_name_W = 'W_sceneattribute_wideresnet18.npy'
     if not os.access(file_name_W, os.W_OK):
         synset_url = 'http://places2.csail.mit.edu/models_places365/W_sceneattribute_wideresnet18.npy'
         os.system('wget ' + synset_url)
     this.W_attribute = np.load(file_name_W)
     model_file = 'wideresnet18_places365.pth.tar'
     if not os.access(model_file, os.W_OK):
         os.system('wget http://places2.csail.mit.edu/models_places365/' +
                   model_file)
         os.system(
             'wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
         )
     import wideresnet
     this.model = wideresnet.resnet18(num_classes=365)
     checkpoint = torch.load(model_file,
                             map_location=lambda storage, loc: storage)
     state_dict = {
         str.replace(k, 'module.', ''): v
         for k, v in checkpoint['state_dict'].items()
     }
     this.model.load_state_dict(state_dict)
     this.model.eval()
     features_names = ['layer4', 'avgpool']
     for name in features_names:
         this.model._modules.get(name).register_forward_hook(
             this.__hook_feature)
     this.tf = trn.Compose([
         trn.Resize((224, 224)),
         trn.ToTensor(),
         trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
     ])
コード例 #3
0
def load_model_wideresnet18():
    # this model has a last conv feature map as 14x14
    model_file = 'wideresnet18_places365.pth.tar'
    if not os.access(model_file, os.W_OK):
        wget.download("http://places2.csail.mit.edu/models_places365/" +
                      model_file)
        wget.download(
            "https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py"
        )

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)

    # hacky way to deal with the upgraded batchnorm2D and avgpool layers...
    for i, (name, module) in enumerate(model._modules.items()):
        module = recursion_change_bn(model)
    model.avgpool = torch.nn.AvgPool2d(kernel_size=14, stride=1, padding=0)

    model.eval()
    # hook the feature extractor
    features_names = ['layer4',
                      'avgpool']  # this is the last conv layer of the resnet
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model
コード例 #4
0
ファイル: test_scene_attn.py プロジェクト: leeorn/DEDUCE
        def load_model():
            # this model has a last conv feature map as 14x14

            model_file = 'models/wideresnet_best_attn.pth.tar'

            import wideresnet
            model = wideresnet.resnet18(num_classes=7)
            checkpoint = torch.load(model_file,
                                    map_location=lambda storage, loc: storage)
            state_dict = {
                str.replace(k, 'module.', ''): v
                for k, v in checkpoint['state_dict'].items()
            }
            model.load_state_dict(state_dict)
            model.eval()

            # the following is deprecated, everything is migrated to python36

            ## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux
            #from functools import partial
            #import pickle
            #pickle.load = partial(pickle.load, encoding="latin1")
            #pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
            #model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)

            # hook the feature extractor
            features_names = ['layer4', 'avgpool'
                              ]  # this is the last conv layer of the resnet
            for name in features_names:
                model._modules.get(name).register_forward_hook(hook_feature)
            return model
コード例 #5
0
def load_model():
    # this model has a last conv feature map as 14x14

    model_file = 'wideresnet18_places365.pth.tar'
    if not os.access(model_file, os.W_OK):
        os.system('wget http://places2.csail.mit.edu/models_places365/' +
                  model_file)
        os.system(
            'wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
        )

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)
    # hacky way to deal with the upgraded batchnorm2D and avgpool layers...
    for i, (name, module) in enumerate(model._modules.items()):
        module = recursion_change_bn(model)
    model.avgpool = torch.nn.AvgPool2d(kernel_size=14, stride=1, padding=0)
    model.eval()
    model.cuda()
    return model
コード例 #6
0
ファイル: unified.py プロジェクト: UMass-Rescue/scene_detect
def load_model():
    # fetch the pretrained weights of the model if not already present
    model_file = 'wideresnet18_places365.pth.tar'
    if not os.access(model_file, os.W_OK):
        os.system('wget http://places2.csail.mit.edu/models_places365/' +
                  model_file)
        os.system(
            'wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
        )

    # imported here since this model file may not be present in the beginning
    # and has been downloaded right before this
    # makes the model ready to run the forward pass
    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)
    model.eval()

    # hook the feature extractor
    features_names = ['layer4',
                      'avgpool']  # this is the last conv layer of the resnet
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model
コード例 #7
0
def load_model():
    # this model has a last conv feature map as 14x14
    file = 'wideresnet18_places365.pth.tar'
    model_file = './utils/wideresnet18_places365.pth.tar'
    prefix_path = ' -P ./utils/'
    if not os.access(model_file, os.W_OK):
        os.system('wget http://places2.csail.mit.edu/models_places365/' + file + prefix_path)
        os.system('wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py' + prefix_path)

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict)
    model.eval()



    # the following is deprecated, everything is migrated to python36

    ## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux
    #from functools import partial
    #import pickle
    #pickle.load = partial(pickle.load, encoding="latin1")
    #pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
    #model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)

    model.eval()
    # hook the feature extractor
    features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model
コード例 #8
0
def load_model(model_file):
    if not os.access(model_file, os.W_OK):
        os.system('wget http://places2.csail.mit.edu/models_places365/' +
                  model_file)
    os.system(
        'wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
        + ' -P ./utils')
    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    params = list(model.parameters())
    weight_softmax = params[-2].clone().numpy()
    weight_softmax[weight_softmax < 0] = 0

    conv_model = nn.Sequential(*list(model.children())[:-2])
    linear_layer = model.fc
    conv_model.eval()
    linear_layer.eval()
    return linear_layer, conv_model, weight_softmax
コード例 #9
0
def load_model():
    # this model has a last conv feature map as 14x14

    model_file = 'wideresnet18_places365.pth.tar'
    if not os.access(model_file, os.W_OK):
        os.system('wget http://places2.csail.mit.edu/models_places365/' + model_file)
        os.system('wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py')

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict)
    model.eval()



    # the following is deprecated, everything is migrated to python36

    ## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux
    #from functools import partial
    #import pickle
    #pickle.load = partial(pickle.load, encoding="latin1")
    #pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
    #model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)

    model.eval()
    # hook the feature extractor
    features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model
コード例 #10
0
def main():
    global args, best_prec1, globaliter_train, globaliter_val
    globaliter_train = 0
    globaliter_val = 0
    args = parser.parse_args()
    print(args)
    # create model
    num_classes = args.num_classes
    if args.transfer != "":
        num_classes = 365
    print("=> creating model '{}'".format(args.arch))
    if args.arch.lower().startswith('wideresnet'):
        # a customized resnet model with last feature map size as 14x14 for better class activation mapping
        if args.arch.lower() == 'wideresnet50': 
            model  = wideresnet.resnet50(num_classes=num_classes)
        if args.arch.lower() == 'wideresnet18': 
            model  = wideresnet.resnet18(num_classes=num_classes)
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)
    
    print(model)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    if args.transfer != "":
        if os.path.isfile(args.transfer):
            print("=> loading checkpoint '{}'".format(args.transfer))
            checkpoint = torch.load(args.transfer, map_location=lambda storage, loc: storage)
            state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
            model.load_state_dict(state_dict)
            # print("=> loaded checkpoint '{}' (epoch {})"
                #   .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.transfer))
        for param in model.parameters():
            param.requires_grad = False
        if freze == 2:
            for param in model.layer4.parameters():
                param.requires_grad = True
        args.arch.lower() == 'wideresnet50':
            model.fc = nn.Linear(2040 , 2)
        else:
            model.fc = nn.Linear(512 , 2)
        model.fc.reset_parameters()
def load_model():
    # this model has a last conv feature map as 14x14

    model_file = 'wideresnet18_places365.pth.tar'
    model_file_path = Path(model_file)
    weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file
    if not os.access(model_file, os.W_OK):
        # os.system('wget http://places2.csail.mit.edu/models_places365/' + model_file)
        print('Downloading...', end=' ')
        resp = requests.get(weight_url)
        with model_file_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')

        # os.system('wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py')
        widersnet_url = 'https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
        widersnet_name = 'wideresnet.py'
        widersnet_name_path = Path(widersnet_name)
        print('Downloading...', end=' ')
        resp = requests.get(widersnet_url)
        with widersnet_name_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict)
    
    # hacky way to deal with the upgraded batchnorm2D and avgpool layers...
    for i, (name, module) in enumerate(model._modules.items()):
        module = recursion_change_bn(model)
    model.avgpool = torch.nn.AvgPool2d(kernel_size=14, stride=1, padding=0)
    
    model.eval()



    # the following is deprecated, everything is migrated to python36

    ## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux
    #from functools import partial
    #import pickle
    #pickle.load = partial(pickle.load, encoding="latin1")
    #pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
    #model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)

    model.eval()
    # hook the feature extractor
    features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model
def load_model(model_file):
    if args.model == 'resnet18':
        model = wideresnet.resnet18(num_classes=365)
    if args.model == 'resnet50':
        model = wideresnet.resnet50(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)
    model.eval()
    return model
コード例 #13
0
ファイル: places365.py プロジェクト: ihowson/ownphotos
    def load_model():
        # this model has a last conv feature map as 14x14
        # model_file = os.path.join(dir_places365_model,'whole_wideresnet18_places365_python36.pth.tar')
        model_file = os.path.join(dir_places365_model,'wideresnet18_places365.pth.tar')

        model = wideresnet.resnet18(num_classes=365)
        checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
        state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
        model.load_state_dict(state_dict)
        model.eval()
        # hook the feature extractor
        features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet
        for name in features_names:
            model._modules.get(name).register_forward_hook(hook_feature)
        return model
コード例 #14
0
ファイル: scene_tagger.py プロジェクト: raineydavid/voyage
def load_model():
    model_file = 'wideresnet18_places365.pth.tar'

    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)
    model.eval()

    model.eval()
    features_names = ['layer4', 'avgpool']
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model
コード例 #15
0
def load_model(path, model):
    model_file = path
    if model == 'resnet18':
        model = wideresnet.resnet18(num_classes=365)
    if model == 'resnet50':
        model = wideresnet.resnet50(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    #model.load_state_dict(checkpoint['state_dict'])
    model.load_state_dict(state_dict)
    for i, (name, module) in enumerate(model._modules.items()):
        module = recursion_change_bn(model)
    model.eval()
    return model
コード例 #16
0
    def initialize(self, ctx):
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        self.model_dir = model_dir
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        model_file = os.path.join(model_dir, 'wideresnet18_places365.pth.tar')

        self.model = wideresnet.resnet18(num_classes=365)
        checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
        state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
        self.model.load_state_dict(state_dict)
        self.model.eval()

        self.tf = self.returnTF()
        self.classes, self.labels_IO, self.labels_attribute, self.W_attribute = self.load_labels()
        
        self.initialized = True
コード例 #17
0
    def load(self):
        model_file_dir = './pytorchserver/example_model/'
        model_file = os.path.join(model_file_dir, PYTORCH_FILE)
        py_files = []
        #for filename in os.listdir(model_file_dir):
        #    if filename.endswith('.py'):
        #        py_files.append(filename)
        #if len(py_files) == 1:
        #    model_class_file = os.path.join(model_file_dir, py_files[0])
        #elif len(py_files) == 0:
        #    raise Exception('Missing PyTorch Model Class File.')
        #else:
        #    raise Exception('More than one Python file is detected',
        #                    'Only one Python file is allowed within model_dir.')
        #model_class_name = self.model_class_name

        # Load the python class into memory
        #sys.path.append(os.path.dirname(model_class_file))
        #modulename = os.path.basename(model_class_file).split('.')[0].replace('-', '_')
        #model_class = getattr(importlib.import_module(modulename), model_class_name)

        # Make sure the model weight is transform with the right device in this machine
        #self.model = model_class().to(self.device)
        #self.model.load_state_dict(torch.load(model_file, map_location=self.device))
        #outfile = open('output.csv','a')
        import wideresnet 
        self.model = wideresnet.resnet18(num_classes=365)
        self.model.load_state_dict(torch.load(model_file))
        #model.load_state_dict(torch.load('./models/wideresnet18_places365.pth'))
        #params = list(model.parameters())
        #from UncertaintySampling import UncertaintySampling
        #device = torch.device('cuda:1')
        if torch.cuda.is_available():
            self.device = torch.device('cuda:1')
            print("Train on GPU.")
        else:
            print("No cuda available")
            self.device = torch.device('cpu')

        self.model.to(self.device)
        self.model.eval()
        self.ready = True
コード例 #18
0
ファイル: places365.py プロジェクト: lgezyxr/wqbackend
 def load_model(
 ):  # TODO Should the model be reloaded for every photo? Wouldn't it be better to do that once?
     # this model has a last conv feature map as 14x14
     model_file = os.path.join(dir_places365_model,
                               'wideresnet18_places365.pth.tar')
     model = wideresnet.resnet18(num_classes=365)
     checkpoint = torch.load(model_file,
                             map_location=lambda storage, loc: storage)
     state_dict = {
         str.replace(k, 'module.', ''): v
         for k, v in checkpoint['state_dict'].items()
     }
     model.load_state_dict(state_dict)
     model.eval()
     # hook the feature extractor
     features_names = ['layer4', 'avgpool'
                       ]  # this is the last conv layer of the resnet
     for name in features_names:
         model._modules.get(name).register_forward_hook(hook_feature)
     return model
コード例 #19
0
def load_model():
    # this model has a last conv feature map as 14x14

    model_file = 'wideresnet18_places365.pth.tar'
    if not os.access(model_file, os.W_OK):
        os.system('wget http://places2.csail.mit.edu/models_places365/' +
                  model_file)
        os.system(
            'wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
        )

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)

    # hacky way to deal with the upgraded batchnorm2D and avgpool layers...
    for i, (name, module) in enumerate(model._modules.items()):
        module = recursion_change_bn(model)
    model.avgpool = torch.nn.AvgPool2d(kernel_size=14, stride=1, padding=0)

    model.eval()

    # the following is deprecated, everything is migrated to python36

    ## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux
    #from functools import partial
    #import pickle
    #pickle.load = partial(pickle.load, encoding="latin1")
    #pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
    #model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)

    return model
コード例 #20
0
def load_model(model_file):
    # this model has a last conv feature map as 14x14

    #model_file = 'wideresnet18_places365.pth.tar'
    if not os.access(model_file, os.W_OK):
        os.system('wget http://places2.csail.mit.edu/models_places365/' +
                  model_file)
        os.system(
            'wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
        )

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    model.load_state_dict(state_dict)
    model.eval()

    return model
コード例 #21
0
parser.add_argument('--weights',
                    '-w',
                    metavar='weights',
                    default='resnet18_best.pth.tar',
                    help='model architecture:')
args = parser.parse_args()
# th architecture to use
arch = args.arch
# arch = 'resnet18'

# load the pre-trained weights
model_file = args.weights
if args.arch.lower().startswith('wideresnet'):
    # a customized resnet model with last feature map size as 14x14 for better class activation mapping
    model = wideresnet.resnet18(num_classes=args.num_classes)
else:
    model = models.__dict__[arch](num_classes=args.num_classes)
# model = models.__dict__[arch](num_classes=args.num_classes)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {
    str.replace(k, 'module.', ''): v
    for k, v in checkpoint['state_dict'].items()
}
model.load_state_dict(state_dict)

# model = torch.nn.DataParallel(model)
model.cpu()
model.eval()

# load the image transformer
コード例 #22
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print args
    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch.lower().startswith('wideresnet'):
        # a customized resnet model with last feature map size as 14x14 for better class activation mapping
        model = wideresnet.resnet18(num_classes=args.num_classes)
    else:
        # model = models.__dict__[args.arch](num_classes=args.num_classes)
        model = resnet.resnet18(num_classes=args.num_classes)

    if args.arch.lower().startswith('alexnet') or args.arch.lower().startswith(
            'vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()
    print(model)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    data_dir = places_dir + '/places365_standard_home'
    traindir = os.path.join(data_dir, 'train')
    valdir = os.path.join(data_dir, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    accuracies_list = []
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        accuracies_list.append("%.2f" % prec1.tolist())
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, args.arch.lower())
    print("The best accuracy obtained during training is = {}".format(
        best_prec1))
コード例 #23
0
ファイル: train_placesCNN.py プロジェクト: hafiz703/places365
def main():
    global args, best_prec1
    args = parser.parse_args()
    # print (args)
    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch.lower().startswith('wideresnet_freeze'):
        
        model_file = 'wideresnet18_places365_1.pth.tar'
        model = wideresnet.resnet18(num_classes=365)
        checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
        state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
        model.load_state_dict(state_dict)
        model.eval()
        for name,param in model.named_parameters():             
            # print(name)
            if "fc" not in name:     
                param.requires_grad = False
         
    else:
        if args.arch.lower().startswith('wideresnet'):
            # a customized resnet model with last feature map size as 14x14 for better class activation mapping
            
            model  = wideresnet.resnet50(num_classes=args.num_classes)
        else:
            model = models.__dict__[args.arch](num_classes=args.num_classes)

        if args.arch.lower().startswith('alexnet') or args.arch.lower().startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    # print (model)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    print("DATA",args.data)
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(traindir, transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best, args.arch.lower())
コード例 #24
0
                    help='num of class in the model')

args = parser.parse_args()
# th architecture to use
arch = args.arch
# arch = 'resnet18'

# load the pre-trained weights
model_file = '%s_places365.pth.tar' % arch
if not os.access(model_file, os.W_OK):
    weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file
    os.system('wget ' + weight_url)

if args.arch.lower().startswith('wideresnet'):
    # a customized resnet model with last feature map size as 14x14 for better class activation mapping
    model = wideresnet.resnet18(num_classes=365)
else:
    model = models.__dict__[arch](num_classes=args.num_classes)
# model.cuda()
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {
    str.replace(k, 'module.', ''): v
    for k, v in checkpoint['state_dict'].items()
}
model.load_state_dict(state_dict)
# model = torch.nn.DataParallel(model)
model.cpu()
model.eval()

# load the image transformer
centre_crop = trn.Compose([