class SBIRWarpper: def __init__(self, model_path, dataset): self.net_ = None self.dataset_ = None self.image_feats_ = None self.config_ = load_model_config(dataset) self.load_dataset(dataset) self.set_cnn_model(model_path, self.config_.deploy_file) self.load_cached_features() def load_dataset(self, dataset_name): self.dataset_ = SMTSApi(dataset_root=self.config_.DATASET_ROOT, name=dataset_name) def load_cached_features(self): self.image_feats_ = cache_features(self.net_, self.dataset_.name) def set_cnn_model(self, weight_path, deploy_file): caffe.set_device(self.config_.gpu_id) caffe.set_mode_gpu() self.net_ = caffe.Net(deploy_file, weight_path, caffe.TEST) def run_retrieval(self, im): ranklist = sketch_retrieval(self.net_, im, self.image_feats_) return self.dataset_.get_image_pathes(ranklist, 'test')
def load_dataset(self, dataset_name): self.dataset_ = SMTSApi(dataset_root=self.config_.DATASET_ROOT, name=dataset_name)
def load_triplets(triplet_path, subset): smts_api = SMTSApi(triplet_path) triplets = smts_api.get_triplets(subset) return triplets, make_negative_list(triplets)
def load_triplets_bbox(triplet_path, subset): smts_api = SMTSApi(triplet_path) triplets, bbox = smts_api.get_triplets_bbox(subset) return triplets, make_negative_list_bbox(triplets, subset), bbox