# %%
from feature_extractor import FeatureExtractor
from normalizer import Normalizer
from reader import PSBDataset, DataSet
from helper.config import FEATURE_DATA_FILE, DEBUG, DATA_PATH_NORMED_SUBSET, DATA_PATH_NORMED, CLASS_FILE, DATA_PATH_PSB, DATA_PATH_DEBUG
import time
import pyvista as pv
from pprint import pprint

# %%
if __name__ == "__main__":

    print("=" * 10 + "Testing full pipeline for mono pipeline" + "=" * 10)
    descriptor = DataSet._extract_descr("ant.off")
    data_item = DataSet.mono_run_pipeline(descriptor)
    normed_data_item = Normalizer.mono_run_pipeline(data_item)
    features_data_item = FeatureExtractor.mono_run_pipeline(normed_data_item)

    pprint(features_data_item)

    plotter = pv.Plotter(shape=(1, 2))
    plotter.subplot(0, 0)
    plotter.add_text("Unnormalized", font_size=30)
    plotter.add_mesh(pv.PolyData(data_item["data"]["vertices"],
                                 data_item["data"]["faces"]),
                     show_edges=True)
    plotter.show_bounds(all_edges=True)
    plotter.subplot(0, 1)
    plotter.add_text("Normalized", font_size=30)
    plotter.add_mesh(pv.PolyData(normed_data_item["data"]["vertices"],
                                 normed_data_item["data"]["faces"]),
示例#2
0
class MainWindow(Qt.QMainWindow):
    def __init__(self, parent=None, show=True):
        Qt.QMainWindow.__init__(self, parent)
        self.ds = reader.DataSet("")
        self.meshes = []
        self.plotter = BackgroundPlotter(shape=(1, 2),
                                         border_color='white',
                                         title="MMR Visualization")
        self.setWindowTitle('MMR UI')
        self.frame = Qt.QFrame()
        vlayout = Qt.QVBoxLayout()
        self.normalizer = Normalizer()
        self.frame.setLayout(vlayout)
        self.setCentralWidget(self.frame)
        mainMenu = self.menuBar()
        fileMenu = mainMenu.addMenu('File')
        exitButton = Qt.QAction('Exit', self)
        exitButton.setShortcut('Ctrl+Q')
        exitButton.triggered.connect(self.close)
        fileMenu.addAction(exitButton)
        meshMenu = mainMenu.addMenu('Mesh')

        self.load_mesh = Qt.QAction('Load mesh', self)
        self.load_mesh.triggered.connect(
            lambda: self.add_mesh(self.open_file_name_dialog()))
        meshMenu.addAction(self.load_mesh)

        self.show_norm_pipeline = Qt.QAction('Show norm pipeline', self)
        self.show_norm_pipeline.triggered.connect(
            lambda: self.show_processing(self.open_file_name_dialog()))
        meshMenu.addAction(self.show_norm_pipeline)

        self.extract_features = Qt.QAction('Extract features', self)
        self.extract_features.triggered.connect(lambda: print(
            FeatureExtractor.mono_run_pipeline(self.open_file_name_dialog())))
        meshMenu.addAction(self.extract_features)

        if show:
            self.show()

    def add_mesh(self, mesh):
        if not mesh:
            print(f"Can't render object of type {type(mesh)}")
            return None

        self.meshes.append(mesh["poly_data"])
        self.plotter.add_mesh(mesh["poly_data"])
        df = pd.DataFrame.from_dict(self.fe.mono_run_pipeline(mesh))
        self.tableWidget = TableWidget(df, self)
        self.frame.layout().addWidget(self.tableWidget)
        self.plotter.reset_camera()

    def open_file_name_dialog(self):
        fileName, _ = QFileDialog.getOpenFileName(
            self,
            caption="Choose shape to view.",
            filter="All Files (*);; Model Files (.obj, .off, .ply, .stl)")
        if fileName:
            mesh = DataSet._read(fileName)
            return mesh
        return None

    def show_processing(self, mesh):
        if not mesh:
            print(f"Can't render mesh of type {type(mesh)}")
            return None

        new_data = self.normalizer.mono_run_pipeline(mesh)
        history = new_data["history"]
        num_of_operations = len(history)
        plt = BackgroundPlotter(shape=(2, num_of_operations // 2))
        elements = history
        plt.show_axes_all()
        for idx in range(num_of_operations):
            plt.subplot(int(idx / 3), idx % 3)
            if elements[idx]["op"] == "Center":
                plt.add_mesh(pv.Cube().extract_all_edges())
            curr_mesh = pv.PolyData(elements[idx]["data"]["vertices"],
                                    elements[idx]["data"]["faces"])
            plt.add_mesh(curr_mesh, color='w', show_edges=True)
            plt.reset_camera()
            plt.view_isometric()
            plt.add_text(elements[idx]["op"] + "\nVertices: " +
                         str(len(curr_mesh.points)) + "\nFaces: " +
                         str(curr_mesh.n_faces))
            plt.show_grid()
示例#3
0
class MainWindow(Qt.QMainWindow):
    def __init__(self, parent=None, show=True):
        Qt.QMainWindow.__init__(self, parent)
        with open('config.json') as f:
            self.config_data = json.load(f)
        self.query_matcher = QueryMatcher(self.config_data["FEATURE_DATA_FILE"])
        self.supported_file_types = [".ply", ".off"]
        self.buttons = {}
        self.ds = reader.DataSet("")
        self.meshes = []
        self.normalizer = Normalizer()
        self.smlw = None
        self.setWindowTitle('Source Mesh Window')
        self.frame = Qt.QFrame()
        self.QTIplotter = None
        self.vlayout = Qt.QVBoxLayout()
        self.frame.setLayout(self.vlayout)
        self.setCentralWidget(self.frame)
        self.hist_dict = {}
        self.setAcceptDrops(True)
        # Create main menu
        mainMenu = self.menuBar()
        fileMenu = mainMenu.addMenu('File')
        exitButton = Qt.QAction('Exit', self)
        exitButton.setShortcut('Ctrl+Q')
        exitButton.triggered.connect(self.close)
        fileMenu.addAction(exitButton)

        viewMenu = mainMenu.addMenu('View')
        exitButton = Qt.QAction('Plot tSNE', self)
        exitButton.triggered.connect(self.plot_tsne)
        viewMenu.addAction(exitButton)

        # Create load button and init action
        self.load_button = QPushButton("Load or drop mesh to query")
        self.load_button.clicked.connect(lambda: self.load_and_prep_query_mesh(self.open_file_name_dialog()))
        self.load_button.setFont(QtGui.QFont("arial", 30))
        # Create Plots widget
        self.graphWidget = pg.PlotWidget()
        self.graphWidget.setBackground('w')

        # Create and add widgets to layout

        n_sing, n_hist, mapping_of_labels = get_sizes_features(features_file=self.config_data["FEATURE_DATA_FILE"],with_labels=True)

        # self.hist_labels = list({**FeatureExtractor.get_pipeline_functions()[1]}.values())
        self.hist_labels = [val for key, val in mapping_of_labels.items() if "hist_" in key]
        self.tableWidget = TableWidget({}, self, {})
        self.tableWidget.hide()
        self.vlayout.addWidget(self.load_button)

        # Position MainWindow
        screen_topleft = QDesktopWidget().availableGeometry().topLeft()
        screen_height = QDesktopWidget().availableGeometry().height()
        width = (QDesktopWidget().availableGeometry().width() * 0.4)
        self.move(screen_topleft)
        self.resize(width, screen_height - 50)

        if show:
            self.show()

    def dragEnterEvent(self, e):
        if e.mimeData().hasUrls():
            e.accept()
        else:
            e.ignore()

    def dropEvent(self, e):
        if len(e.mimeData().urls()) == 1:
            file = QUrl.toLocalFile(e.mimeData().urls()[0])
            self.load_and_prep_query_mesh(DataSet._read(file))
        else:
            error_dialog = QtWidgets.QErrorMessage(parent=self)
            error_dialog.showMessage("Please drag only one mesh at the time.")

    def check_file(self, fileName):
        if fileName[-4:] not in self.supported_file_types:
            error_dialog = QtWidgets.QErrorMessage(parent=self)
            error_dialog.showMessage(("Selected file not supported." f"\nPlease select mesh files of type: {self.supported_file_types}"))
            return False

    def open_file_name_dialog(self):
        fileName, _ = QFileDialog.getOpenFileName(self, caption="Choose shape to view.", filter="All Files (*);; Model Files (.obj, .off, .ply, .stl)")
        if not (fileName or self.check_file(fileName)):
            return False

        mesh = DataSet._read(fileName)
        return mesh

    def load_and_prep_query_mesh(self, data):
        if not data: return
        self.load_button.setFont(QtGui.QFont("arial", 10))

        # Normalize query mesh
        normed_data = self.normalizer.mono_run_pipeline(data)
        normed_mesh = pv.PolyData(normed_data["history"][-1]["data"]["vertices"], normed_data["history"][-1]["data"]["faces"])
        normed_data['poly_data'] = normed_mesh
        # Extract features
        n_singletons, n_distributionals, mapping_of_labels = get_sizes_features(features_file=self.config_data["FEATURE_DATA_FILE"],with_labels=True, drop_feat=["timestamp"])
        mapping_of_labels_reversed = {val: key for key, val in mapping_of_labels.items()}
        features_dict = FeatureExtractor.mono_run_pipeline_old(normed_data)
        features_dict_carefully_selected = OrderedDict(
            sorted({mapping_of_labels.get(key): val
                    for key, val in features_dict.items() if key in mapping_of_labels}.items(), key=lambda t: t[0]))
        features_df = pd.DataFrame([features_dict_carefully_selected]).T.reset_index()
        self.hist_labels = [val for key, val in mapping_of_labels.items() if "hist_" in key]
        self.skeleton_labels = [val for key, val in mapping_of_labels.items() if "skeleton_" in key]

        # feature_formatted_keys = sing_labels + dist_labels
        # features_df = pd.DataFrame({'key': list(feature_formatted_keys), 'value': list(
        #     [list(f) if isinstance(f, np.ndarray) else f for f in list(features_dict.values())[3:]])})

        # Update plotter & feature table
        # since unfortunately Qtinteractor which plots the mesh cannot be updated (remove and add new mesh)
        # it needs to be removed and newly generated each time a mesh gets loaded
        self.tableWidget.deleteLater()
        self.vlayout.removeWidget(self.QTIplotter)
        self.QTIplotter = QtInteractor(self.frame)
        self.vlayout.addWidget(self.QTIplotter)
        self.tableWidget = TableWidget(features_dict_carefully_selected, self, mapping_of_labels_reversed)
        self.tableWidget.show()
        self.tableWidget.horizontalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch)
        self.vlayout.addWidget(self.tableWidget)
        self.QTIplotter.add_mesh(normed_mesh, show_edges=True)
        self.QTIplotter.isometric_view()
        self.QTIplotter.show_bounds(grid='front', location='outer', all_edges=True)
        self.vlayout.addWidget(self.graphWidget)

        # Compare shapes
        if self.smlw:
            self.smlw.deleteLater()
            if len(self.smlw.smw_list) != 0: self.smlw.smw_list[0].deleteLater()
        self.smlw = SimilarMeshesListWindow(features_dict)

        self.buttons = self.tableWidget.get_buttons_in_table()
        self.hist_dict = features_df.set_index("index").tail(n=len(self.hist_labels)).to_dict()
        for key, value in self.buttons.items():
            value.clicked.connect(lambda state, x=key, y=features_dict_carefully_selected[key]: self.plot_selected_hist(x, y))
        self.smlw.show()

    def plot_selected_hist(self, hist_title, hist_data):
        self.graphWidget.clear()
        styles = {"color": "#f00", "font-size": "15px"}
        pen = pg.mkPen(color=(255, 0, 0), width=5, style=QtCore.SolidLine)
        self.graphWidget.setTitle(hist_title, color="b", size="15pt")
        self.graphWidget.setLabel("left", "Values", **styles)
        self.graphWidget.setLabel("bottom", "Bins", **styles)
        self.graphWidget.addLegend()
        self.graphWidget.showGrid(x=True, y=True)
        self.graphWidget.setXRange(1, len(hist_data))
        self.graphWidget.setYRange(min(hist_data), max(hist_data))
        self.graphWidget.plot(np.arange(0, len(hist_data)), hist_data, pen=pen)

    def plot_tsne(self):
        labels = [dic["label"].replace("_", " ").title() for dic in self.query_matcher.features_raw]
        filename = "tsne_visualizer"

        tsne_plotter = TsneVisualiser(
            self.query_matcher,
            labels,  # labels_coarse,
            filename,
            False,
            False)
        tsne_plotter.plot()
示例#4
0
def plot_comparison(sample_labels, distance):
    qm = QueryMatcher(FEATURE_DATA_FILE)
    labelled_occurences = tuple(
        zip(sample_labels, [
            Counter(pd.DataFrame(qm.features_flattened)["label"]).get(lbl)
            for lbl in sample_labels
        ]))
    names = [[f for f in qm.features_raw if f["label"] == lbl][0]["name"]
             for lbl in sample_labels]
    sampled_labelled = dict(zip(labelled_occurences, names))
    paths = []

    for path, subdirs, files in os.walk(DATA_PATH_PSB):
        for name in files:
            if ("off" or "ply") in name:
                paths.append(os.path.join(path, name))

    n_singletons, n_distributionals, mapping_of_labels = get_sizes_features(
        with_labels=True)

    n_hist = len(
        [key for key, val in mapping_of_labels.items() if "hist_" in key])
    n_skeleton = len(
        [key for key, val in mapping_of_labels.items() if "skeleton_" in key])

    if distance != "knn":
        # Custom
        weights = ([3]) + \
                  ([100] * n_hist) + \
                  ([1] * n_skeleton)

        function_pipeline = [cosine] + \
                            ([wasserstein_distance] * n_hist) + \
                            ([wasserstein_distance] * n_skeleton)
    else:
        # KNN
        weights = ([1]) + \
                  ([1] * n_hist) + \
                  ([1] * n_skeleton)

        function_pipeline = [QueryMatcher.perform_knn] + (
            [QueryMatcher.perform_knn] * n_distributionals)

    normalizer = Normalizer()
    out_dict = defaultdict(list)
    for info_tuple, mesh_idx in sampled_labelled.items():
        full_path = [p for p in paths if mesh_idx in p][0]
        print(f"Processing: {full_path}")
        mesh = DataSet._read(Path(full_path))
        normed_data = normalizer.mono_run_pipeline(mesh)
        normed_mesh = pv.PolyData(
            normed_data["history"][-1]["data"]["vertices"],
            normed_data["history"][-1]["data"]["faces"])
        normed_data['poly_data'] = normed_mesh

        features_dict = FeatureExtractor.mono_run_pipeline_old(normed_data)

        indices, distance_values, _ = qm.match_with_db(
            features_dict,
            k=10,
            distance_functions=function_pipeline,
            weights=weights)
        if mesh_idx in indices:
            idx_of_idx = indices.index(mesh_idx)
            indices.remove(mesh_idx)
            del distance_values[idx_of_idx]
            distance_values.insert(0, 0)

        indices = indices[4:]
        indices.insert(0, mesh_idx)
        distance_values = distance_values[5:]
        out_dict[info_tuple].append({mesh_idx: (indices, distance_values)})
        print(out_dict)

    class_idx = 0
    plt = pv.Plotter(off_screen=True, shape=(6, 5))
    for key, val in out_dict.items():
        print(class_idx)
        for v in val:
            el_idx = 0
            distances = list(list(v.values())[0][1])
            for name, dist in zip(list(v.values())[0][0], distances):
                print(el_idx)
                plt.subplot(class_idx, el_idx)

                full_path = [p for p in paths if name in p][0]

                mesh = DataSet._read(Path(full_path))
                curr_mesh = pv.PolyData(mesh["data"]["vertices"],
                                        mesh["data"]["faces"])
                plt.add_mesh(curr_mesh, color='r')
                plt.reset_camera()
                plt.view_isometric()
                if el_idx != 0:
                    plt.add_text(f"{el_idx} - Dist: {round(dist,4)}",
                                 font_size=20)
                elif el_idx == 0 and class_idx == 0:
                    plt.add_text(
                        f"             Query\nClass: {key[0].replace('_', ' ').title()}"
                        + f"\nInstances: {key[1]}",
                        font_size=20)
                else:
                    plt.add_text(f"Class: {key[0].replace('_', ' ').title()}" +
                                 f"\nInstances: {key[1]}",
                                 font_size=20)

                el_idx += 1
        class_idx += 1

    if distance == "knn":
        plt.screenshot(f"fig\\comparison_knn.jpg", window_size=(1920, 2160))
    else:
        plt.screenshot(f"figs\\comparison_custom_distance.jpg",
                       window_size=(1920, 2160))