def _extract_feature(self): with t.no_grad(): self.photo_net.eval() self.sketch_net.eval() extractor = Extractor(e_model=self.photo_net, vis=False, dataloader=True) photo_data = extractor.extract(self.photo_test) extractor.reload_model(self.sketch_net) sketch_data = extractor.extract(self.sketch_test) photo_name = photo_data['name'] photo_feature = photo_data['feature'] sketch_name = sketch_data['name'] sketch_feature = sketch_data['feature'] return photo_name, photo_feature, sketch_name, sketch_feature
def _extract_feature_embedding(self): with t.no_grad(): self.photo_net.eval() self.sketch_net.eval() extractor = Extractor(e_model=self.photo_net, cat_info=False, vis=False, dataloader=True) photo_data = extractor.extract(self.photo_test, batch_size=self.test_bs) extractor.reload_model(self.sketch_net) sketch_data = extractor.extract(self.sketch_test, batch_size=self.test_bs) photo_name = photo_data['name'] photo_feature = photo_data['feature'] sketch_name = sketch_data['name'] sketch_feature = sketch_data['feature'] return photo_name, photo_feature, sketch_name, sketch_feature
# The trained model root for vgg SKETCH_VGG = '/data1/zzl/model/caffe2torch/vgg_triplet_loss/sketch/sketch_vgg_190.pth' PHOTO_VGG = '/data1/zzl/model/caffe2torch/vgg_triplet_loss/photo/photo_vgg_190.pth' FINE_TUNE_RESNET = '/data1/zzl/model/caffe2torch/fine_tune/model_270.pth' device = 'cuda:1' '''vgg''' vgg = vgg16(pretrained=False) vgg.classifier[6] = nn.Linear(in_features=4096, out_features=125, bias=True) vgg.load_state_dict(t.load(PHOTO_VGG, map_location=t.device('cpu'))) vgg.cuda() ext = Extractor(pretrained=False) ext.reload_model(vgg) photo_feature = ext.extract_with_dataloader(test_photo_root, 'photo-vgg-190epoch.pkl') vgg.load_state_dict(t.load(SKETCH_VGG, map_location=t.device('cpu'))) ext.reload_model(vgg) sketch_feature = ext.extract_with_dataloader(test_set_root, 'sketch-vgg-190epoch.pkl') '''resnet''' resnet = resnet50() resnet.fc = nn.Linear(in_features=2048, out_features=125) resnet.load_state_dict(t.load(PHOTO_RESNET, map_location=t.device('cpu'))) resnet.cuda()