Exemple #1
0
    def __init__(self, mat_output_node: GMaterialOutputNode):
        super().__init__(parent=None)
        self._mat_out_node = mat_output_node

        # Declare widgets
        self._layout = QGridLayout()
        self._splitter = QSplitter(Qt.Horizontal)
        self._plot = PlotWidget3D()
        self._canvas = self._plot.get_canvas()
        self._fig_ax = self._plot.get_axis()
        self._table_widget = QTableWidget(self)
        self._p1_res = IntInput(0, 100)
        self._p2_res = IntInput(0, 100)
        self._plot_button = QPushButton("Plot Loss")
        self._start_gd_button = QPushButton("Start Gradient Descent")

        # Declare data
        self._item_queue = FIFOQueue(maxsize=2)
        self._bg_brush_selected = QBrush(QColor("#8bf9b0"))
        self._bg_brush_default = QBrush(QColor("#ffffff"))
        self._settings = None
        self._target_image = None
        self._target_matrix = None
        self._out_node = None
        self._thread = None
        self._gd = None
        self._p1 = None
        self._p2 = None
        self._hist_p1 = []
        self._hist_p2 = []
        self._plot_line3d = None
        self._progress_dialog = None

        self._init()
        self._list_parameters()
    def __init__(self, cc: ControlCenter, *args):
        super().__init__(*args)

        self.cc = cc

        # Define gui components
        self._image_plot = ImagePlotter()
        self._layout = QVBoxLayout()
        self._width_input = IntInput(1, 500)
        self._height_input = IntInput(1, 500)
        self._resize_button = QPushButton("Resize")

        # Define widget data
        self._width, self._height = 100, 100
        self._shader = None
        self._material = None

        self._init_widget()
Exemple #3
0
    def _add_input_module(self, node_socket: NodeSocket, shader_input: Shader):
        dtype = node_socket.dtype()
        socket = self._create_g_socket(node_socket)
        lim_min = shader_input.get_limits()[0]
        lim_max = shader_input.get_limits()[1]
        display_label = shader_input.get_display_label()
        z_value = 0

        if dtype == DataType.Float:
            # Create an widgets widget
            input_widget = FloatInput(min_=lim_min, max_=lim_max, dtype=dtype)
        elif dtype == DataType.Int:
            input_widget = IntInput(min_=lim_min, max_=lim_max, dtype=dtype)
        elif dtype == DataType.Vec3_RGB:
            input_widget = ColorInput(dtype)
        elif dtype == DataType.Shader:
            input_widget = ShaderInput(dtype)
        elif dtype == DataType.Vec3_Float:
            size = 3
            input_widget = ArrayInput(size,
                                      min_=lim_min,
                                      max_=lim_max,
                                      dtype=dtype)
        elif dtype == DataType.Int_Choice:
            input_widget = IntChoiceInput(values=shader_input.get_names())
            input_widget.raise_()
            z_value = 1
        else:
            raise TypeError("Data Type {} is not yet supported!".format(dtype))

        # Create a module and add to this node
        module = SocketModule(socket, display_label, input_widget)
        module.input_changed.connect(self._notify_change)
        module.set_label_palette(self._input_label_palette)
        module.set_value(socket.value())
        self._in_socket_modules.append((socket, module))

        module_item = self.node_scene.addWidget(module)
        module_item.setZValue(z_value)
        self._master_layout.addItem(socket, self._input_index, 0,
                                    Qt.AlignVCenter)
        self._master_layout.addItem(module_item, self._input_index, 1,
                                    Qt.AlignVCenter)
        self._input_index += 1
        self._height = self._master_layout.preferredHeight() + 10

        if not shader_input.is_connectable():
            socket.setVisible(False)
Exemple #4
0
 def _widget_from_type(self, type_, subtype):
     if type_ in [float, DataType.Float]:
         return FloatInput()
     elif type_ in [int, DataType.Int]:
         return IntInput()
     elif type_ in [str]:
         return StringInput()
     elif type_ in [typing.Iterable[int]]:
         return MultipleIntInput()
     elif type_ in [typing.Iterable[float]]:
         return MultipleFloatInput()
     elif type_ in [tuple, list]:
         if subtype == float:
             return MultipleFloatInput()
         elif subtype == int:
             return MultipleIntInput()
     else:
         return None
Exemple #5
0
    def __init__(self, cc: ControlCenter, *args):
        super().__init__(*args)
        # Define gui elements
        self.cc = cc
        self._layout = QVBoxLayout()
        self._load_texture_button = QPushButton("Load Texture")
        self._texture_label = QLabel("load texture...")
        self._match_button = QPushButton("Match Texture")
        self._loss_combo_box = QComboBox()
        self._loss_settings_group = FunctionSettingsGroup()
        self._optimizer_combo_box = QComboBox()
        self._optimizer_settings_group = FunctionSettingsGroup()
        self._width_input = IntInput(0, 1000)
        self._height_input = IntInput(0, 1000)
        self._max_iter_input = IntInput(0, 10000)
        self._early_stopping_loss_thresh = FloatInput(0, 1)
        self._save_data_button = QPushButton("Save Data")
        self._load_settings_button = QPushButton("Load Settings")
        self._restore_best_button = QPushButton("Restore Best Parameters")
        # self._learning_rate = FloatInput(0, 1)

        # Define data
        self._loss_func_map = {}
        ls = [losses.XSELoss, losses.SquaredBinLoss, losses.NeuralLoss]
        for l in ls:
            self._loss_func_map[l.__name__] = l

        self._optimizer_map = {}
        os = [optim.Adam, optim.AdamW, optim.Adagrad, optim.RMSprop]
        for o in os:
            self._optimizer_map[o.__name__] = o

        self.loaded_image = None
        self._max_iter = 100
        self._settings = GradientDescentSettings()
        self._is_running = False
        self._is_cleared = True
        self._last_load_path = None
        self._last_save_path = None

        self._init_widget()
Exemple #6
0
class SettingsPanel(QWidget):
    match_start = pyqtSignal()
    match_stop = pyqtSignal()
    reset_requested = pyqtSignal()
    texture_loaded = pyqtSignal(Image.Image)
    settings_changed = pyqtSignal(
        GradientDescentSettings)  # changed key as well as all settings
    save_data = pyqtSignal(str)
    restore_best = pyqtSignal()

    def __init__(self, cc: ControlCenter, *args):
        super().__init__(*args)
        # Define gui elements
        self.cc = cc
        self._layout = QVBoxLayout()
        self._load_texture_button = QPushButton("Load Texture")
        self._texture_label = QLabel("load texture...")
        self._match_button = QPushButton("Match Texture")
        self._loss_combo_box = QComboBox()
        self._loss_settings_group = FunctionSettingsGroup()
        self._optimizer_combo_box = QComboBox()
        self._optimizer_settings_group = FunctionSettingsGroup()
        self._width_input = IntInput(0, 1000)
        self._height_input = IntInput(0, 1000)
        self._max_iter_input = IntInput(0, 10000)
        self._early_stopping_loss_thresh = FloatInput(0, 1)
        self._save_data_button = QPushButton("Save Data")
        self._load_settings_button = QPushButton("Load Settings")
        self._restore_best_button = QPushButton("Restore Best Parameters")
        # self._learning_rate = FloatInput(0, 1)

        # Define data
        self._loss_func_map = {}
        ls = [losses.XSELoss, losses.SquaredBinLoss, losses.NeuralLoss]
        for l in ls:
            self._loss_func_map[l.__name__] = l

        self._optimizer_map = {}
        os = [optim.Adam, optim.AdamW, optim.Adagrad, optim.RMSprop]
        for o in os:
            self._optimizer_map[o.__name__] = o

        self.loaded_image = None
        self._max_iter = 100
        self._settings = GradientDescentSettings()
        self._is_running = False
        self._is_cleared = True
        self._last_load_path = None
        self._last_save_path = None

        self._init_widget()

    def _init_widget(self):
        self._match_button.setEnabled(False)
        self._match_button.clicked.connect(self._toggle_matching)
        self._load_texture_button.clicked.connect(self._load_texture)

        # --- Setup loss function combo box ---
        self._loss_combo_box.addItems(list(self._loss_func_map))
        self._loss_combo_box.currentIndexChanged.connect(self._set_loss_func)
        self._loss_combo_box.setCurrentIndex(0)
        self._settings.loss_func = self._loss_func_map[
            self._loss_combo_box.currentText()]

        # --- Setup loss function settings group ---
        self._loss_settings_group.set_function(self._settings.loss_func)

        # --- Setup optimizer combo box ---
        self._optimizer_combo_box.addItems(list(self._optimizer_map))
        self._optimizer_combo_box.currentIndexChanged.connect(
            self._set_optimizer)
        self._optimizer_combo_box.setCurrentIndex(0)
        self._settings.optimizer = self._optimizer_map[
            self._optimizer_combo_box.currentText()]

        # --- Setup optimizer function settings group ---
        self._optimizer_settings_group.set_function(self._settings.optimizer)

        # --- Setup render size input ---
        self._width_input.set_value(self._settings.render_width)
        self._height_input.set_value(self._settings.render_height)
        self._width_input.input_changed.connect(lambda: self._change_settings(
            "render_width", self._width_input.get_gl_value()))
        self._height_input.input_changed.connect(lambda: self._change_settings(
            "render_height", self._height_input.get_gl_value()))

        # --- Setup max iterations input ---
        self._max_iter_input.input_changed.connect(
            lambda: self._change_settings("max_iter",
                                          self._max_iter_input.get_gl_value()))
        self._max_iter_input.set_value(100)

        # --- Setup early stopping loss threshold input ---
        self._early_stopping_loss_thresh.input_changed.connect(
            lambda: self._change_settings(
                "early_stopping_thresh",
                self._early_stopping_loss_thresh.get_gl_value()))
        self._early_stopping_loss_thresh.set_value(0.01)

        # --- Setup save/load data button ---
        self._save_data_button.clicked.connect(self._save_data)
        self._load_settings_button.clicked.connect(self._load_settings)
        self._restore_best_button.clicked.connect(
            lambda: self.restore_best.emit())

        self._layout.addWidget(self._match_button)
        self._layout.addWidget(self._load_texture_button)
        self._layout.addWidget(
            LabelledInput("Loss function", self._loss_combo_box))
        self._layout.addWidget(self._loss_settings_group)
        self._layout.addWidget(
            LabelledInput("Optimizer", self._optimizer_combo_box))
        self._layout.addWidget(self._optimizer_settings_group)
        self._layout.addWidget(LabelledInput("Render width",
                                             self._width_input))
        self._layout.addWidget(
            LabelledInput("Render height", self._height_input))
        self._layout.addWidget(
            LabelledInput("Max iterations", self._max_iter_input))
        self._layout.addWidget(
            LabelledInput("Early stopping loss thresh",
                          self._early_stopping_loss_thresh))
        self._layout.addWidget(self._save_data_button)
        self._layout.addWidget(self._load_settings_button)
        self._layout.addWidget(self._restore_best_button)

        self._layout.setAlignment(Qt.AlignTop)
        self.setLayout(self._layout)

    def settings(self) -> GradientDescentSettings:
        # Instantiate loss function with settings from loss function widget
        loss_func = self._settings.loss_func
        if isinstance(loss_func, Loss):
            self._settings.loss_func = loss_func.__class__

        self._settings.loss_args = self._loss_settings_group.to_dict()
        self._settings.optimizer_args = self._optimizer_settings_group.to_dict(
        )
        return self._settings

    def _load_texture(self):
        filename, _ = QFileDialog.getOpenFileName(
            self,
            "Open Texture",
            filter="Image File (*.png *.jpg *.jpeg *.bmp)")

        if filename:
            self.loaded_image = Image.open(filename)
            self._match_button.setEnabled(True)
            self.texture_loaded.emit(self.loaded_image)

    def _set_loss_func(self, index: int):
        loss_func = self._loss_func_map[self._loss_combo_box.currentText()]

        if loss_func == losses.NeuralLoss or isinstance(
                loss_func, losses.NeuralLoss):
            self._width_input.set_value(224)
            self._height_input.set_value(224)

        self._change_settings("loss_func", loss_func)
        self._loss_settings_group.set_function(loss_func)

    def _set_optimizer(self, index: int):
        optimizer = self._optimizer_map[
            self._optimizer_combo_box.currentText()]
        self._change_settings("optimizer", optimizer)
        self._optimizer_settings_group.set_function(optimizer)

    def set_gd_finished(self):
        self._match_button.setText("Reset")
        self._is_running = False

    def set_gd_finishing(self):
        self._match_button.setText("Stopping...")

    def _change_settings(self, var: str, new_val: typing.Any):
        _logger.debug("New value for setting {} -> {}.".format(var, new_val))
        setattr(self._settings, var, new_val)
        self.settings_changed.emit(self._settings)

    def _toggle_matching(self):
        if not self._is_running:
            if not self._is_cleared:
                self._is_cleared = True
                self.reset_requested.emit()
                self._match_button.setText("Match Texture")
            else:
                self._is_running = True
                self._is_cleared = False
                self._match_button.setText("Stop")
                self.match_start.emit()
        else:
            self._is_running = False
            self._is_cleared = False
            self._match_button.setText("Reset")
            self.match_stop.emit()

    def _save_data(self):
        filename = to_filename(self.cc.active_material.name,
                               self._settings.loss_func.__name__,
                               self._settings.optimizer.__name__) + ".hdf5"

        if self._last_save_path:
            path, _ = QFileDialog.getSaveFileName(
                self,
                "Save Data",
                directory=str(self._last_save_path / filename),
                filter="HDF5 (*.hdf5)")
        else:
            path, _ = QFileDialog.getSaveFileName(self,
                                                  "Save Data",
                                                  directory=filename,
                                                  filter="HDF5 (*.hdf5)")

        if path:
            self._last_save_path = Path(path).parent
            self.save_data.emit(path)

    def _load_settings(self):
        if self._last_load_path:
            path, _ = QFileDialog.getOpenFileName(self,
                                                  "Load Settings",
                                                  directory=str(
                                                      self._last_load_path),
                                                  filter="HDF5 (*.hdf5)")
        else:
            path, _ = QFileDialog.getOpenFileName(self,
                                                  "Load Settings",
                                                  filter="HDF5 (*.hdf5)")

        if path:
            self._last_load_path = Path(path).parent
            f = h5py.File(path, "r")
            self._settings.load_from_hdf5(f)
            self._update_from_settings()
            _logger.info("Loaded settings from file {}".format(path))

    def _update_from_settings(self):
        self._width_input.set_value(self._settings.render_width)
        self._height_input.set_value(self._settings.render_height)
        self._early_stopping_loss_thresh.set_value(
            self._settings.early_stopping_thresh)
        self._max_iter_input.set_value(self._settings.max_iter)

        if self._settings.loss_func:
            self._loss_combo_box.setCurrentIndex(
                list(self._loss_func_map).index(
                    self._settings.loss_func.__name__))
            self._loss_settings_group.load_from_dict(self._settings.loss_args)

        if self._settings.optimizer:
            self._optimizer_combo_box.setCurrentIndex(
                list(self._optimizer_map).index(
                    self._settings.optimizer.__name__))
            self._optimizer_settings_group.load_from_dict(
                self._settings.optimizer_args)
class PythonRenderingWidget(QWidget):
    closed = pyqtSignal()

    def __init__(self, cc: ControlCenter, *args):
        super().__init__(*args)

        self.cc = cc

        # Define gui components
        self._image_plot = ImagePlotter()
        self._layout = QVBoxLayout()
        self._width_input = IntInput(1, 500)
        self._height_input = IntInput(1, 500)
        self._resize_button = QPushButton("Resize")

        # Define widget data
        self._width, self._height = 100, 100
        self._shader = None
        self._material = None

        self._init_widget()

    def _init_widget(self):
        self.setWindowTitle("DiPTeR - Python Renderer")

        # Setup settings controls
        settings_layout = QHBoxLayout()
        settings_layout.setAlignment(Qt.AlignLeft)
        self._width_input.set_value(self._width)
        self._height_input.set_value(self._height)
        self._resize_button.clicked.connect(self._handle_resize)
        settings_layout.addWidget(LabelledInput("Width", self._width_input))
        settings_layout.addWidget(LabelledInput("Height", self._height_input))
        settings_layout.addWidget(self._resize_button)
        self._layout.addLayout(settings_layout)

        # Add plotting widget
        self._layout.addWidget(self._image_plot)

        self.cc.active_material_changed.connect(self._material_changed)
        if self.cc.active_material:
            self._material_changed(self.cc.active_material)

        self.setLayout(self._layout)

    def _render(self):
        node = self._material.get_material_output_node()
        start = time.time()
        img, _ = node.render(self._width, self._height)
        total_time = time.time() - start
        _logger.debug("Rendering DONE in {:.4f}s.".format(total_time))

        self._image_plot.set_image(img)

    def _material_changed(self, mat: Material):
        # Disconnect signals from previous material
        if self._material:
            self._material.shader_ready.disconnect(self._render)
            self._material.changed.disconnect(self._handle_material_changed)

        self._material = mat
        self._material.shader_ready.connect(self._render)
        self._material.changed.connect(self._handle_material_changed)

        if self._material.shader:  # Handle the case where the shader is already available
            self._render()

    def _set_title(self):
        shader = self._material.shader
        self._axis.set_title("Python Render ({})".format(
            shader.__class__.__name__))

    def _handle_material_changed(self):
        self._render()

    def _handle_resize(self):
        self._width = self._width_input.get_gl_value()
        self._height = self._height_input.get_gl_value()
        self._image_plot.set_x_range(0, self._width)
        self._image_plot.set_y_range(0, self._height)
        self._render()

    def closeEvent(self, event: QCloseEvent):
        self.closed.emit()
        super().closeEvent(event)
Exemple #8
0
class LossVisualizer(QWidget):
    def __init__(self, mat_output_node: GMaterialOutputNode):
        super().__init__(parent=None)
        self._mat_out_node = mat_output_node

        # Declare widgets
        self._layout = QGridLayout()
        self._splitter = QSplitter(Qt.Horizontal)
        self._plot = PlotWidget3D()
        self._canvas = self._plot.get_canvas()
        self._fig_ax = self._plot.get_axis()
        self._table_widget = QTableWidget(self)
        self._p1_res = IntInput(0, 100)
        self._p2_res = IntInput(0, 100)
        self._plot_button = QPushButton("Plot Loss")
        self._start_gd_button = QPushButton("Start Gradient Descent")

        # Declare data
        self._item_queue = FIFOQueue(maxsize=2)
        self._bg_brush_selected = QBrush(QColor("#8bf9b0"))
        self._bg_brush_default = QBrush(QColor("#ffffff"))
        self._settings = None
        self._target_image = None
        self._target_matrix = None
        self._out_node = None
        self._thread = None
        self._gd = None
        self._p1 = None
        self._p2 = None
        self._hist_p1 = []
        self._hist_p2 = []
        self._plot_line3d = None
        self._progress_dialog = None

        self._init()
        self._list_parameters()

    def _init(self):
        self.setWindowTitle("Loss Visualizer")

        # Setup table widget
        self._table_widget.setColumnCount(3)
        self._table_widget.setHorizontalHeaderLabels(
            ["Parameter", "Min", "Max"])
        self._table_widget.setColumnWidth(1, 50)
        self._table_widget.setColumnWidth(2, 50)

        # Setup plot
        self._fig_ax.set_title("Loss Surface")
        self._fig_ax.set_xlabel("Parameter ?")
        self._fig_ax.set_ylabel("Parameter ?")
        self._fig_ax.set_zlabel("Loss Value")

        # Setup resolution input
        self._p1_res.set_value(20)
        self._p2_res.set_value(20)
        p1_module = Module("Param 1 Res.", self._p1_res)
        p2_module = Module("Param 2 Res.", self._p2_res)

        # Setup buttons
        self._plot_button.clicked.connect(self._plot_loss)
        self._start_gd_button.clicked.connect(self._start_gd)
        self._start_gd_button.setEnabled(False)

        # Add widgets to layout
        self._splitter.addWidget(self._table_widget)
        self._splitter.addWidget(self._plot)
        self._layout.addWidget(self._splitter, 0, 0, 1, 5)
        self._layout.addWidget(p1_module, 1, 1)
        self._layout.addWidget(p2_module, 1, 2)
        self._layout.addWidget(self._plot_button, 1, 3)
        self._layout.addWidget(self._start_gd_button, 1, 4)

        self.setLayout(self._layout)

    def _list_parameters(self):
        _, param_dict = self._mat_out_node.get_backend_node().render(
            10, 10, retain_graph=True)
        row = 0
        self._table_widget.setRowCount(len(param_dict))

        for key in param_dict:
            param = param_dict[key]
            param.set_modified_arg(key)
            limits = param.get_limits()
            diff = limits[1] - limits[0]
            min_item = FloatInput(limits[0] - diff, limits[1] + diff)
            min_item.set_value(limits[0])
            max_item = FloatInput(limits[0] - diff, limits[1] + diff)
            max_item.set_value(limits[1])
            item = CheckboxItem(key, content={"param": param, "index": -1})
            item.state_changed.connect(self._item_state_changed)
            self._table_widget.setCellWidget(row, 0, item)
            self._table_widget.setCellWidget(row, 1, min_item)
            self._table_widget.setCellWidget(row, 2, max_item)

            row += 1

            if param.is_vector():
                item.set_checkable(False)
                item.setEnabled(False)

                for i in range(param.shape()[1]):
                    self._table_widget.insertRow(row)
                    min_item = FloatInput(limits[0], limits[1])
                    min_item.set_value(limits[0])
                    max_item = FloatInput(limits[0], limits[1])
                    max_item.set_value(limits[1])
                    sub_item = CheckboxItem("  [{}]".format(i),
                                            content={
                                                "param": param,
                                                "index": i
                                            })
                    sub_item.state_changed.connect(self._item_state_changed)
                    self._table_widget.setCellWidget(row, 0, sub_item)
                    self._table_widget.setCellWidget(row, 1, min_item)
                    self._table_widget.setCellWidget(row, 2, max_item)
                    row += 1

        self._table_widget.resizeColumnToContents(0)
        self._table_widget.resizeRowsToContents()

    def _checked_items(
            self) -> typing.List[typing.Tuple[QWidget, QWidget, QWidget]]:
        checked = []

        for i in range(self._table_widget.rowCount()):
            item = self._table_widget.cellWidget(i, 0)
            min_item = self._table_widget.cellWidget(i, 1)
            max_item = self._table_widget.cellWidget(i, 2)
            if item.get_state() == Qt.Checked:
                checked.append((item, min_item, max_item))

        return checked

    def _item_state_changed(self, item: CheckboxItem):
        if item.get_state() == Qt.Checked:
            if self._item_queue.is_full():
                first_item = self._item_queue.pop()
                first_item.set_state(Qt.Unchecked)

            self._item_queue.put(item)
        elif item.get_state() == Qt.Unchecked:
            if item in self._item_queue:
                self._item_queue.remove(item)

    def _plot_loss(self):
        self._fig_ax.clear()

        W, H = self._settings.render_width, self._settings.render_height
        R1, R2 = self._p1_res.get_gl_value(), self._p2_res.get_gl_value()
        progress_dialog = QProgressDialog("Calculating loss surface...",
                                          "Cancel", 0, R1 - 1, self)
        progress_dialog.setWindowTitle("Calculating")
        progress_dialog.setWindowModality(Qt.WindowModal)
        progress_dialog.setMinimumDuration(1)

        self._target_matrix = image_funcs.image_to_tensor(
            self._target_image, (W, H))
        loss_surface = np.empty((R1, R2))
        loss_f = self._settings.loss_func(**self._settings.loss_args)
        checked_items = self._checked_items()

        item1 = checked_items[0][0]
        item1_min = checked_items[0][1].get_gl_value()
        item1_max = checked_items[0][2].get_gl_value()
        self._fig_ax.set_xlabel(item1.label)
        self._p1: Parameter = item1.content["param"]
        p1_index = item1.content["index"]
        self._p1.save_value()
        p1_values = torch.from_numpy(
            np.linspace(item1_min, item1_max, num=R1, endpoint=True))

        item2 = checked_items[1][0]
        item2_min = checked_items[1][1].get_gl_value()
        item2_max = checked_items[1][2].get_gl_value()
        self._fig_ax.set_ylabel(item2.label)
        self._p2: Parameter = item2.content["param"]
        p2_index = item2.content["index"]
        self._p2.save_value()
        p2_values = torch.from_numpy(
            np.linspace(item2_min, item2_max, num=R2, endpoint=True))

        min_loss = np.finfo(np.float32).max
        min_loss_p1 = None
        min_loss_p2 = None

        for i in range(R1):
            self._p1.set_value(p1_values[i], index=p1_index)
            progress_dialog.setValue(i)

            if progress_dialog.wasCanceled():
                return

            for j in range(R2):
                self._p2.set_value(p2_values[j], index=p2_index)
                r, _ = self._mat_out_node.get_backend_node().render(
                    W, H, retain_graph=True)
                loss = loss_f(
                    r, self._target_matrix).detach().clone().cpu().numpy()

                if loss < min_loss:
                    min_loss = loss
                    min_loss_p1 = self._p1.get_value(p1_index)
                    min_loss_p2 = self._p2.get_value(p2_index)

                loss_surface[i, j] = loss

            _logger.info("{:.2f}% complete...".format((i + 1) / R1 * 100))

        P1, P2 = torch.meshgrid([p1_values, p2_values])
        self._p1.restore_value()
        self._p2.restore_value()

        self._fig_ax.plot_surface(P1, P2, loss_surface, cmap=plt.cm.viridis)
        self._fig_ax.set_zlim(bottom=0)

        # Add min value marker
        self._fig_ax.plot([min_loss_p1], [min_loss_p2], [min_loss],
                          marker='+',
                          color="#ff00ff",
                          markersize=14,
                          markeredgewidth=2.5)
        # self._fig_ax.text(min_loss_p1, min_loss_p2, min_loss * 1.1, "Minimum Loss = {:.4f}".format(float(min_loss)), color='#ff00ff')

        self._canvas.draw()
        self._start_gd_button.setEnabled(True)

    def _start_gd(self):
        self._hist_p2 = np.empty(self._settings.max_iter)
        self._hist_p1 = np.empty(self._settings.max_iter)

        self._progress_dialog = QProgressDialog(
            "Performing Gradient Descent...", "Cancel", 0,
            self._settings.max_iter, self)
        self._progress_dialog.setWindowTitle("Calculating")
        self._progress_dialog.setWindowModality(Qt.WindowModal)
        self._progress_dialog.setMinimumDuration(1)

        checked_items = self._checked_items()
        params = {
            i[0].content["param"].get_modified_arg(): i[0].content["param"]
            for i in checked_items
        }

        self._gd = GradientDescent(self._target_image, self._out_node,
                                   self._settings)
        self._gd.set_active_parameters(params)
        self._thread = QThread()
        self._gd.iteration_done.connect(self._gd_callback)
        self._gd.moveToThread(self._thread)
        self._thread.started.connect(self._gd.run)
        self._gd.finished.connect(self._finish_gradient_descent)

        _logger.debug("Started Gradient Descent Thread...")

        self._thread.start()

    def _gd_callback(self, info: dict):
        params = info["params"]
        i = info["iter"]
        loss = info["loss"]
        self._progress_dialog.setValue(i)

        if self._progress_dialog.wasCanceled():
            self._gd.stop()
            self._thread.quit()

        x = params[self._p1.get_modified_arg()]
        y = params[self._p2.get_modified_arg()]
        self._hist_p1[i] = x
        self._hist_p2[i] = y
        _logger.debug("{}. Loss: {}, P1: {}, P2: {}".format(i, loss, x, y))

    def _finish_gradient_descent(self, params, loss_hist, _):

        if self._thread.isRunning():
            _logger.info("Stopping Gradient Descent Thread...")
            self._gd.stop()
            self._thread.quit()
            self._gd.restore_params()

        self._progress_dialog.setValue(self._progress_dialog.maximum())

        # Plot dis shit!
        num_iter = len(loss_hist)
        x = params[self._p1.get_modified_arg()]
        y = params[self._p2.get_modified_arg()]
        self._hist_p1[num_iter - 1] = x.get_value()
        self._hist_p2[num_iter - 1] = y.get_value()
        xs = self._hist_p1[0:num_iter]
        ys = self._hist_p2[0:num_iter]
        self._fig_ax.set_title("HSV Shader Loss Surface", fontsize=18)
        self._fig_ax.plot(xs,
                          ys,
                          loss_hist,
                          color="#ff656dff",
                          marker="o",
                          mfc="#c44e52ff",
                          mec="#ff656dff",
                          lw=2)
        self._fig_ax.set_xlabel("Hue")
        self._fig_ax.set_ylabel("Saturation")
        self._fig_ax.set_zlabel("Loss")
        self._canvas.draw()
        self._canvas.flush_events()

    def open(self, settings: GradientDescentSettings, target: Image,
             mat_out_node: GMaterialOutputNode):
        if target is None:
            msg = QMessageBox(
                QMessageBox.Warning, "Need to set Target Texture!",
                "Can not open loss visualizer because target texture is not set."
            )
            msg.exec()
        else:
            self._settings = settings
            self._target_image = target
            self._out_node = mat_out_node
            super().show()