def predict(self, image=None, text=None): from data_processing import img_feature_get, text_feature_get if torch.cuda.is_available(): self.text_model.cuda() self.img_model.cuda() else: self.text_model.cpu() self.img_model.cpu() img_encode, img_decode1, img_decode2 = ([], [], []) text_encode, text_decode1, text_decode2 = ([], [], []) if image: a = img_feature_get.get_img_feature([image])[0] if torch.cuda.is_available(): a = a.cuda() img_encode, img_decode1, img_decode2 = self.img_model(a) if text: a = text_feature_get.get_text_feature([text])[0] if torch.cuda.is_available(): a = a.cuda() text_encode, text_decode1, text_decode2 = self.text_model(a) if image and text: return img_encode.detach().cpu(), text_encode.detach().cpu() if image: return img_encode.detach().cpu() if text: return text_encode.detach().cpu()
def united_search(self, get_img=True): for label in self.img_labels: label.close() for label in self.text_labels: label.close() global img_data, text_data, _img_data, _text_data if get_img and self.text_file: f = open(self.text_file, 'r') text = f.read() result = self.model.search_top3( mode=2, search_data=img_data, text=text_feature_get.get_text_feature([text])[0]) for i in range(3): self.img_labels[i].setPixmap( QtGui.QPixmap(_img_data[result[i]])) self.img_labels[i].show() elif not get_img and self.img_file: result = self.model.search_top3( mode=1, search_data=text_data, img=img_feature_get.get_img_feature([self.img_file], mode=3)[0]) for i in range(3): self.text_labels[i].setText(_text_data[result[i]]) self.text_labels[i].show()
def run(self): global _text_data, text_data text_data = text_feature_get.get_text_feature(texts=_text_data) mi = text_data.min().numpy() ma = text_data.max().numpy() text_data = (text_data - mi) / (ma - mi) #文本数据0-1归一化 global n n -= 1
mode = int( input( 'search mode:1.img2text 2.text2img 3.img_text2text 4.img_text2img 5.exit\n' )) if mode == 5: break data = input("input data:\n") if mode == 1: result = model.search_top3(mode=mode, search_data=text_data, img=img_feature_get.get_img_feature( [data], mode=y)[0]) elif mode == 2: result = model.search_top3(mode=mode, search_data=img_data, text=text_feature_get.get_text_feature( [data])[0]) elif mode == 3: _data = input() result = model.search_top3( mode=mode, search_data=text_data, img=img_feature_get.get_img_feature([data], mode=y)[0], text=text_feature_get.get_text_feature([_data])[0]) else: _data = input() result = model.search_top3( mode=mode, search_data=img_data, img=img_feature_get.get_img_feature([data], mode=y)[0], text=text_feature_get.get_text_feature([_data])[0]) print("result:")