Пример #1
0
    def _extract_feature(self):
        with t.no_grad():
            self.photo_net.eval()
            self.sketch_net.eval()

            extractor = Extractor(e_model=self.photo_net,
                                  vis=False,
                                  dataloader=True)
            photo_data = extractor.extract(self.photo_test)

            extractor.reload_model(self.sketch_net)
            sketch_data = extractor.extract(self.sketch_test)

            photo_name = photo_data['name']
            photo_feature = photo_data['feature']

            sketch_name = sketch_data['name']
            sketch_feature = sketch_data['feature']

        return photo_name, photo_feature, sketch_name, sketch_feature
Пример #2
0
    def _extract_feature_embedding(self):
        with t.no_grad():
            self.photo_net.eval()
            self.sketch_net.eval()

            extractor = Extractor(e_model=self.photo_net,
                                  cat_info=False,
                                  vis=False,
                                  dataloader=True)
            photo_data = extractor.extract(self.photo_test,
                                           batch_size=self.test_bs)

            extractor.reload_model(self.sketch_net)
            sketch_data = extractor.extract(self.sketch_test,
                                            batch_size=self.test_bs)

            photo_name = photo_data['name']
            photo_feature = photo_data['feature']

            sketch_name = sketch_data['name']
            sketch_feature = sketch_data['feature']

        return photo_name, photo_feature, sketch_name, sketch_feature
Пример #3
0
# The trained model root for vgg
SKETCH_VGG = '/data1/zzl/model/caffe2torch/vgg_triplet_loss/sketch/sketch_vgg_190.pth'
PHOTO_VGG = '/data1/zzl/model/caffe2torch/vgg_triplet_loss/photo/photo_vgg_190.pth'

FINE_TUNE_RESNET = '/data1/zzl/model/caffe2torch/fine_tune/model_270.pth'

device = 'cuda:1'
'''vgg'''
vgg = vgg16(pretrained=False)
vgg.classifier[6] = nn.Linear(in_features=4096, out_features=125, bias=True)
vgg.load_state_dict(t.load(PHOTO_VGG, map_location=t.device('cpu')))
vgg.cuda()

ext = Extractor(pretrained=False)
ext.reload_model(vgg)

photo_feature = ext.extract_with_dataloader(test_photo_root,
                                            'photo-vgg-190epoch.pkl')

vgg.load_state_dict(t.load(SKETCH_VGG, map_location=t.device('cpu')))
ext.reload_model(vgg)

sketch_feature = ext.extract_with_dataloader(test_set_root,
                                             'sketch-vgg-190epoch.pkl')
'''resnet'''
resnet = resnet50()
resnet.fc = nn.Linear(in_features=2048, out_features=125)
resnet.load_state_dict(t.load(PHOTO_RESNET, map_location=t.device('cpu')))
resnet.cuda()