# %% from feature_extractor import FeatureExtractor from normalizer import Normalizer from reader import PSBDataset, DataSet from helper.config import FEATURE_DATA_FILE, DEBUG, DATA_PATH_NORMED_SUBSET, DATA_PATH_NORMED, CLASS_FILE, DATA_PATH_PSB, DATA_PATH_DEBUG import time import pyvista as pv from pprint import pprint # %% if __name__ == "__main__": print("=" * 10 + "Testing full pipeline for mono pipeline" + "=" * 10) descriptor = DataSet._extract_descr("ant.off") data_item = DataSet.mono_run_pipeline(descriptor) normed_data_item = Normalizer.mono_run_pipeline(data_item) features_data_item = FeatureExtractor.mono_run_pipeline(normed_data_item) pprint(features_data_item) plotter = pv.Plotter(shape=(1, 2)) plotter.subplot(0, 0) plotter.add_text("Unnormalized", font_size=30) plotter.add_mesh(pv.PolyData(data_item["data"]["vertices"], data_item["data"]["faces"]), show_edges=True) plotter.show_bounds(all_edges=True) plotter.subplot(0, 1) plotter.add_text("Normalized", font_size=30) plotter.add_mesh(pv.PolyData(normed_data_item["data"]["vertices"], normed_data_item["data"]["faces"]),
class MainWindow(Qt.QMainWindow): def __init__(self, parent=None, show=True): Qt.QMainWindow.__init__(self, parent) self.ds = reader.DataSet("") self.meshes = [] self.plotter = BackgroundPlotter(shape=(1, 2), border_color='white', title="MMR Visualization") self.setWindowTitle('MMR UI') self.frame = Qt.QFrame() vlayout = Qt.QVBoxLayout() self.normalizer = Normalizer() self.frame.setLayout(vlayout) self.setCentralWidget(self.frame) mainMenu = self.menuBar() fileMenu = mainMenu.addMenu('File') exitButton = Qt.QAction('Exit', self) exitButton.setShortcut('Ctrl+Q') exitButton.triggered.connect(self.close) fileMenu.addAction(exitButton) meshMenu = mainMenu.addMenu('Mesh') self.load_mesh = Qt.QAction('Load mesh', self) self.load_mesh.triggered.connect( lambda: self.add_mesh(self.open_file_name_dialog())) meshMenu.addAction(self.load_mesh) self.show_norm_pipeline = Qt.QAction('Show norm pipeline', self) self.show_norm_pipeline.triggered.connect( lambda: self.show_processing(self.open_file_name_dialog())) meshMenu.addAction(self.show_norm_pipeline) self.extract_features = Qt.QAction('Extract features', self) self.extract_features.triggered.connect(lambda: print( FeatureExtractor.mono_run_pipeline(self.open_file_name_dialog()))) meshMenu.addAction(self.extract_features) if show: self.show() def add_mesh(self, mesh): if not mesh: print(f"Can't render object of type {type(mesh)}") return None self.meshes.append(mesh["poly_data"]) self.plotter.add_mesh(mesh["poly_data"]) df = pd.DataFrame.from_dict(self.fe.mono_run_pipeline(mesh)) self.tableWidget = TableWidget(df, self) self.frame.layout().addWidget(self.tableWidget) self.plotter.reset_camera() def open_file_name_dialog(self): fileName, _ = QFileDialog.getOpenFileName( self, caption="Choose shape to view.", filter="All Files (*);; Model Files (.obj, .off, .ply, .stl)") if fileName: mesh = DataSet._read(fileName) return mesh return None def show_processing(self, mesh): if not mesh: print(f"Can't render mesh of type {type(mesh)}") return None new_data = self.normalizer.mono_run_pipeline(mesh) history = new_data["history"] num_of_operations = len(history) plt = BackgroundPlotter(shape=(2, num_of_operations // 2)) elements = history plt.show_axes_all() for idx in range(num_of_operations): plt.subplot(int(idx / 3), idx % 3) if elements[idx]["op"] == "Center": plt.add_mesh(pv.Cube().extract_all_edges()) curr_mesh = pv.PolyData(elements[idx]["data"]["vertices"], elements[idx]["data"]["faces"]) plt.add_mesh(curr_mesh, color='w', show_edges=True) plt.reset_camera() plt.view_isometric() plt.add_text(elements[idx]["op"] + "\nVertices: " + str(len(curr_mesh.points)) + "\nFaces: " + str(curr_mesh.n_faces)) plt.show_grid()
class MainWindow(Qt.QMainWindow): def __init__(self, parent=None, show=True): Qt.QMainWindow.__init__(self, parent) with open('config.json') as f: self.config_data = json.load(f) self.query_matcher = QueryMatcher(self.config_data["FEATURE_DATA_FILE"]) self.supported_file_types = [".ply", ".off"] self.buttons = {} self.ds = reader.DataSet("") self.meshes = [] self.normalizer = Normalizer() self.smlw = None self.setWindowTitle('Source Mesh Window') self.frame = Qt.QFrame() self.QTIplotter = None self.vlayout = Qt.QVBoxLayout() self.frame.setLayout(self.vlayout) self.setCentralWidget(self.frame) self.hist_dict = {} self.setAcceptDrops(True) # Create main menu mainMenu = self.menuBar() fileMenu = mainMenu.addMenu('File') exitButton = Qt.QAction('Exit', self) exitButton.setShortcut('Ctrl+Q') exitButton.triggered.connect(self.close) fileMenu.addAction(exitButton) viewMenu = mainMenu.addMenu('View') exitButton = Qt.QAction('Plot tSNE', self) exitButton.triggered.connect(self.plot_tsne) viewMenu.addAction(exitButton) # Create load button and init action self.load_button = QPushButton("Load or drop mesh to query") self.load_button.clicked.connect(lambda: self.load_and_prep_query_mesh(self.open_file_name_dialog())) self.load_button.setFont(QtGui.QFont("arial", 30)) # Create Plots widget self.graphWidget = pg.PlotWidget() self.graphWidget.setBackground('w') # Create and add widgets to layout n_sing, n_hist, mapping_of_labels = get_sizes_features(features_file=self.config_data["FEATURE_DATA_FILE"],with_labels=True) # self.hist_labels = list({**FeatureExtractor.get_pipeline_functions()[1]}.values()) self.hist_labels = [val for key, val in mapping_of_labels.items() if "hist_" in key] self.tableWidget = TableWidget({}, self, {}) self.tableWidget.hide() self.vlayout.addWidget(self.load_button) # Position MainWindow screen_topleft = QDesktopWidget().availableGeometry().topLeft() screen_height = QDesktopWidget().availableGeometry().height() width = (QDesktopWidget().availableGeometry().width() * 0.4) self.move(screen_topleft) self.resize(width, screen_height - 50) if show: self.show() def dragEnterEvent(self, e): if e.mimeData().hasUrls(): e.accept() else: e.ignore() def dropEvent(self, e): if len(e.mimeData().urls()) == 1: file = QUrl.toLocalFile(e.mimeData().urls()[0]) self.load_and_prep_query_mesh(DataSet._read(file)) else: error_dialog = QtWidgets.QErrorMessage(parent=self) error_dialog.showMessage("Please drag only one mesh at the time.") def check_file(self, fileName): if fileName[-4:] not in self.supported_file_types: error_dialog = QtWidgets.QErrorMessage(parent=self) error_dialog.showMessage(("Selected file not supported." f"\nPlease select mesh files of type: {self.supported_file_types}")) return False def open_file_name_dialog(self): fileName, _ = QFileDialog.getOpenFileName(self, caption="Choose shape to view.", filter="All Files (*);; Model Files (.obj, .off, .ply, .stl)") if not (fileName or self.check_file(fileName)): return False mesh = DataSet._read(fileName) return mesh def load_and_prep_query_mesh(self, data): if not data: return self.load_button.setFont(QtGui.QFont("arial", 10)) # Normalize query mesh normed_data = self.normalizer.mono_run_pipeline(data) normed_mesh = pv.PolyData(normed_data["history"][-1]["data"]["vertices"], normed_data["history"][-1]["data"]["faces"]) normed_data['poly_data'] = normed_mesh # Extract features n_singletons, n_distributionals, mapping_of_labels = get_sizes_features(features_file=self.config_data["FEATURE_DATA_FILE"],with_labels=True, drop_feat=["timestamp"]) mapping_of_labels_reversed = {val: key for key, val in mapping_of_labels.items()} features_dict = FeatureExtractor.mono_run_pipeline_old(normed_data) features_dict_carefully_selected = OrderedDict( sorted({mapping_of_labels.get(key): val for key, val in features_dict.items() if key in mapping_of_labels}.items(), key=lambda t: t[0])) features_df = pd.DataFrame([features_dict_carefully_selected]).T.reset_index() self.hist_labels = [val for key, val in mapping_of_labels.items() if "hist_" in key] self.skeleton_labels = [val for key, val in mapping_of_labels.items() if "skeleton_" in key] # feature_formatted_keys = sing_labels + dist_labels # features_df = pd.DataFrame({'key': list(feature_formatted_keys), 'value': list( # [list(f) if isinstance(f, np.ndarray) else f for f in list(features_dict.values())[3:]])}) # Update plotter & feature table # since unfortunately Qtinteractor which plots the mesh cannot be updated (remove and add new mesh) # it needs to be removed and newly generated each time a mesh gets loaded self.tableWidget.deleteLater() self.vlayout.removeWidget(self.QTIplotter) self.QTIplotter = QtInteractor(self.frame) self.vlayout.addWidget(self.QTIplotter) self.tableWidget = TableWidget(features_dict_carefully_selected, self, mapping_of_labels_reversed) self.tableWidget.show() self.tableWidget.horizontalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch) self.vlayout.addWidget(self.tableWidget) self.QTIplotter.add_mesh(normed_mesh, show_edges=True) self.QTIplotter.isometric_view() self.QTIplotter.show_bounds(grid='front', location='outer', all_edges=True) self.vlayout.addWidget(self.graphWidget) # Compare shapes if self.smlw: self.smlw.deleteLater() if len(self.smlw.smw_list) != 0: self.smlw.smw_list[0].deleteLater() self.smlw = SimilarMeshesListWindow(features_dict) self.buttons = self.tableWidget.get_buttons_in_table() self.hist_dict = features_df.set_index("index").tail(n=len(self.hist_labels)).to_dict() for key, value in self.buttons.items(): value.clicked.connect(lambda state, x=key, y=features_dict_carefully_selected[key]: self.plot_selected_hist(x, y)) self.smlw.show() def plot_selected_hist(self, hist_title, hist_data): self.graphWidget.clear() styles = {"color": "#f00", "font-size": "15px"} pen = pg.mkPen(color=(255, 0, 0), width=5, style=QtCore.SolidLine) self.graphWidget.setTitle(hist_title, color="b", size="15pt") self.graphWidget.setLabel("left", "Values", **styles) self.graphWidget.setLabel("bottom", "Bins", **styles) self.graphWidget.addLegend() self.graphWidget.showGrid(x=True, y=True) self.graphWidget.setXRange(1, len(hist_data)) self.graphWidget.setYRange(min(hist_data), max(hist_data)) self.graphWidget.plot(np.arange(0, len(hist_data)), hist_data, pen=pen) def plot_tsne(self): labels = [dic["label"].replace("_", " ").title() for dic in self.query_matcher.features_raw] filename = "tsne_visualizer" tsne_plotter = TsneVisualiser( self.query_matcher, labels, # labels_coarse, filename, False, False) tsne_plotter.plot()
def plot_comparison(sample_labels, distance): qm = QueryMatcher(FEATURE_DATA_FILE) labelled_occurences = tuple( zip(sample_labels, [ Counter(pd.DataFrame(qm.features_flattened)["label"]).get(lbl) for lbl in sample_labels ])) names = [[f for f in qm.features_raw if f["label"] == lbl][0]["name"] for lbl in sample_labels] sampled_labelled = dict(zip(labelled_occurences, names)) paths = [] for path, subdirs, files in os.walk(DATA_PATH_PSB): for name in files: if ("off" or "ply") in name: paths.append(os.path.join(path, name)) n_singletons, n_distributionals, mapping_of_labels = get_sizes_features( with_labels=True) n_hist = len( [key for key, val in mapping_of_labels.items() if "hist_" in key]) n_skeleton = len( [key for key, val in mapping_of_labels.items() if "skeleton_" in key]) if distance != "knn": # Custom weights = ([3]) + \ ([100] * n_hist) + \ ([1] * n_skeleton) function_pipeline = [cosine] + \ ([wasserstein_distance] * n_hist) + \ ([wasserstein_distance] * n_skeleton) else: # KNN weights = ([1]) + \ ([1] * n_hist) + \ ([1] * n_skeleton) function_pipeline = [QueryMatcher.perform_knn] + ( [QueryMatcher.perform_knn] * n_distributionals) normalizer = Normalizer() out_dict = defaultdict(list) for info_tuple, mesh_idx in sampled_labelled.items(): full_path = [p for p in paths if mesh_idx in p][0] print(f"Processing: {full_path}") mesh = DataSet._read(Path(full_path)) normed_data = normalizer.mono_run_pipeline(mesh) normed_mesh = pv.PolyData( normed_data["history"][-1]["data"]["vertices"], normed_data["history"][-1]["data"]["faces"]) normed_data['poly_data'] = normed_mesh features_dict = FeatureExtractor.mono_run_pipeline_old(normed_data) indices, distance_values, _ = qm.match_with_db( features_dict, k=10, distance_functions=function_pipeline, weights=weights) if mesh_idx in indices: idx_of_idx = indices.index(mesh_idx) indices.remove(mesh_idx) del distance_values[idx_of_idx] distance_values.insert(0, 0) indices = indices[4:] indices.insert(0, mesh_idx) distance_values = distance_values[5:] out_dict[info_tuple].append({mesh_idx: (indices, distance_values)}) print(out_dict) class_idx = 0 plt = pv.Plotter(off_screen=True, shape=(6, 5)) for key, val in out_dict.items(): print(class_idx) for v in val: el_idx = 0 distances = list(list(v.values())[0][1]) for name, dist in zip(list(v.values())[0][0], distances): print(el_idx) plt.subplot(class_idx, el_idx) full_path = [p for p in paths if name in p][0] mesh = DataSet._read(Path(full_path)) curr_mesh = pv.PolyData(mesh["data"]["vertices"], mesh["data"]["faces"]) plt.add_mesh(curr_mesh, color='r') plt.reset_camera() plt.view_isometric() if el_idx != 0: plt.add_text(f"{el_idx} - Dist: {round(dist,4)}", font_size=20) elif el_idx == 0 and class_idx == 0: plt.add_text( f" Query\nClass: {key[0].replace('_', ' ').title()}" + f"\nInstances: {key[1]}", font_size=20) else: plt.add_text(f"Class: {key[0].replace('_', ' ').title()}" + f"\nInstances: {key[1]}", font_size=20) el_idx += 1 class_idx += 1 if distance == "knn": plt.screenshot(f"fig\\comparison_knn.jpg", window_size=(1920, 2160)) else: plt.screenshot(f"figs\\comparison_custom_distance.jpg", window_size=(1920, 2160))