Exemple #1
0
def load_netvlad(checkpoint_path):
    encoder_dim = 512
    encoder = models.vgg16(pretrained=False)
    layers = list(encoder.features.children())[:-2]
    encoder = nn.Sequential(*layers)
    model = nn.Module()
    model.add_module('encoder', encoder)
    vlad_layer = netvlad.NetVLAD(num_clusters=64,
                                 dim=encoder_dim,
                                 vladv2=False)
    model.add_module('pool', vlad_layer)

    checkpoint = torch.load(checkpoint_path,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    return model
Exemple #2
0
    def initialize(self):
        self.init_param()
        opt = self.parser.parse_args()
        self.verbose = opt.verbose
        self.ipaddr = opt.ipaddr
        restore_var = [
            'lr', 'lrStep', 'lrGamma', 'weightDecay', 'momentum', 'runsPath',
            'savePath', 'arch', 'num_clusters', 'pooling', 'optim', 'margin',
            'seed', 'patience'
        ]
        if opt.resume:
            flag_file = join(opt.resume, 'checkpoints', 'flags.json')
            if exists(flag_file):
                with open(flag_file, 'r') as f:
                    stored_flags = {
                        '--' + k: str(v)
                        for k, v in json.load(f).items() if k in restore_var
                    }
                    to_del = []
                    for flag, val in stored_flags.items():
                        for act in self.parser._actions:
                            if act.dest == flag[2:]:
                                # store_true / store_false args don't accept arguments, filter these
                                if type(act.const) == type(True):
                                    if val == str(act.default):
                                        to_del.append(flag)
                                    else:
                                        stored_flags[flag] = ''
                    for flag in to_del:
                        del stored_flags[flag]

                    train_flags = [
                        x for x in list(sum(stored_flags.items(), tuple()))
                        if len(x) > 0
                    ]
                    if self.verbose:
                        print('Restored flags:', train_flags)
                    opt = self.parser.parse_args(train_flags, namespace=opt)

        if self.verbose:
            print(opt)

        cuda = not opt.nocuda
        if cuda and not torch.cuda.is_available():
            raise Exception("No GPU found, please run with --nocuda")

        device = torch.device("cuda" if cuda else "cpu")

        random.seed(opt.seed)
        np.random.seed(opt.seed)
        torch.manual_seed(opt.seed)
        if cuda:
            torch.cuda.manual_seed(opt.seed)

        if self.verbose:
            print('===> Building model begin')

        pretrained = not opt.fromscratch
        if opt.arch.lower() == 'alexnet':
            self.encoder_dim = 256
            encoder = models.alexnet(pretrained=pretrained)
            # capture only features and remove last relu and maxpool
            layers = list(encoder.features.children())[:-2]

            if pretrained:
                # if using pretrained only train conv5
                for l in layers[:-1]:
                    for p in l.parameters():
                        p.requires_grad = False

        elif opt.arch.lower() == 'vgg16':
            self.encoder_dim = 512
            encoder = models.vgg16(pretrained=pretrained)
            # capture only feature part and remove last relu and maxpool

            layers = list(encoder.features.children())[:-2]

            if pretrained:
                # if using pretrained then only train conv5_1, conv5_2, and conv5_3
                for l in layers[:-5]:
                    for p in l.parameters():
                        p.requires_grad = False

        if opt.mode.lower() == 'cluster' and not opt.vladv2:
            layers.append(L2Norm())

        encoder = nn.Sequential(*layers)
        model = nn.Module()
        model.add_module('encoder', encoder)

        if opt.mode.lower() != 'cluster':
            if opt.pooling.lower() == 'netvlad':
                net_vlad = netvlad.NetVLAD(num_clusters=opt.num_clusters,
                                           dim=self.encoder_dim,
                                           vladv2=opt.vladv2)
                if not opt.resume:
                    if opt.mode.lower() == 'train':
                        initcache = join(
                            opt.dataPath, 'centroids',
                            opt.arch + '_' + train_set.dataset + '_' +
                            str(opt.num_clusters) + '_desc_cen.hdf5')
                    else:
                        initcache = join(
                            opt.dataPath, 'centroids',
                            opt.arch + '_' + whole_test_set.dataset + '_' +
                            str(opt.num_clusters) + '_desc_cen.hdf5')

                    if not exists(initcache):
                        raise FileNotFoundError(
                            'Could not find clusters, please run with --mode=cluster before proceeding'
                        )

                    with h5py.File(initcache, mode='r') as h5:
                        clsts = h5.get("centroids")[...]
                        traindescs = h5.get("descriptors")[...]
                        net_vlad.init_params(clsts, traindescs)
                        del clsts, traindescs

                model.add_module('pool', net_vlad)
            elif opt.pooling.lower() == 'max':
                global_pool = nn.AdaptiveMaxPool2d((1, 1))
                model.add_module(
                    'pool', nn.Sequential(
                        *[global_pool, Flatten(),
                          L2Norm()]))
            elif opt.pooling.lower() == 'avg':
                global_pool = nn.AdaptiveAvgPool2d((1, 1))
                model.add_module(
                    'pool', nn.Sequential(
                        *[global_pool, Flatten(),
                          L2Norm()]))
            else:
                raise ValueError('Unknown pooling type: ' + opt.pooling)

        isParallel = False
        if opt.nGPU > 1 and torch.cuda.device_count() > 1:
            model.encoder = nn.DataParallel(model.encoder)
            if opt.mode.lower() != 'cluster':
                model.pool = nn.DataParallel(model.pool)
            isParallel = True

        if not opt.resume:
            model = model.to(device)

        if opt.mode.lower() == 'train':
            if opt.optim.upper() == 'ADAM':
                optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                              model.parameters()),
                                       lr=opt.lr)  #, betas=(0,0.9))
            elif opt.optim.upper() == 'SGD':
                optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=opt.lr,
                                      momentum=opt.momentum,
                                      weight_decay=opt.weightDecay)

                scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                      step_size=opt.lrStep,
                                                      gamma=opt.lrGamma)
            else:
                raise ValueError('Unknown optimizer: ' + opt.optim)

            # original paper/code doesn't sqrt() the distances, we do, so sqrt() the margin, I think :D
            criterion = nn.TripletMarginLoss(margin=opt.margin**0.5,
                                             p=2,
                                             reduction='sum').to(device)

        if opt.resume:
            if opt.ckpt.lower() == 'latest':
                resume_ckpt = join(opt.resume, 'checkpoints',
                                   'checkpoint.pth.tar')
            elif opt.ckpt.lower() == 'best':
                resume_ckpt = join(opt.resume, 'checkpoints',
                                   'model_best.pth.tar')

            if isfile(resume_ckpt):
                if self.verbose:
                    print("=> loading checkpoint '{}'".format(resume_ckpt))
                checkpoint = torch.load(
                    resume_ckpt, map_location=lambda storage, loc: storage)
                opt.start_epoch = checkpoint['epoch']
                best_metric = checkpoint['best_score']
                model.load_state_dict(checkpoint['state_dict'])
                model = model.to(device)
                if opt.mode == 'train':
                    optimizer.load_state_dict(checkpoint['optimizer'])
                if self.verbose:
                    print("=> loaded checkpoint '{}' (epoch {})".format(
                        resume_ckpt, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(resume_ckpt))

        self.model = model
        if self.verbose:
            print('===> Building model end(vps.py)')

        if opt.dataset.lower() == 'pittsburgh':
            from netvlad import pittsburgh as dataset
            return 0  # Failed

        elif opt.dataset.lower() == 'deepguider':
            from netvlad import etri_dbloader as dataset
            self.dataset_root_dir = dataset.root_dir
            self.dataset_struct_dir = dataset.struct_dir
            self.dataset_queries_dir = dataset.queries_dir
            return 1  # Non-zero means success return