コード例 #1
0
ファイル: gui.py プロジェクト: itolab-hayashi-rafik/weather
class Window(QtGui.QDialog):
    def __init__(self, parent=None):
        super(Window, self).__init__(parent)

        # visualizer
        self.vis = Visualizer(xlim=30)

        # this is the Canvas Widget that displays the `figure`
        # it takes the `figure` instance as a parameter to __init__
        self.canvas = FigureCanvas(self.vis.getFigure())

        # this is the Navigation widget
        # it takes the Canvas widget and a parent
        # self.toolbar = NavigationToolbar(self.canvas, self)

        # Form
        self.window_size_line_edit = QtGui.QLineEdit("10")
        self.window_size_line_edit.textChanged.connect(self.dnnChanged)
        self.m_line_edit = QtGui.QLineEdit("1")
        self.m_line_edit.textChanged.connect(self.dnnChanged)
        self.r_line_edit = QtGui.QLineEdit("2")
        self.r_line_edit.textChanged.connect(self.dnnChanged)
        self.hidden_layer_sizes_line_edit = QtGui.QLineEdit("10,10,10")
        self.hidden_layer_sizes_line_edit.textChanged.connect(self.dnnChanged)

        self.input_form = QtGui.QFormLayout()
        self.input_form.addRow("Window SIze:", self.window_size_line_edit)
        self.input_form.addRow("m:", self.m_line_edit)
        self.input_form.addRow("r:", self.r_line_edit)
        self.input_form.addRow("Hidden Layer Sizes:", self.hidden_layer_sizes_line_edit)

        self.pretrian_epochs_line_edit = QtGui.QLineEdit("10")
        self.pretrian_epochs_line_edit.textChanged.connect(self.updateWorker)
        self.pretrain_lr_slider = QtGui.QSlider(QtCore.Qt.Horizontal)
        self.pretrain_lr_slider.setRange(1, 10)
        self.pretrain_lr_slider.setValue(1)
        self.pretrain_lr_slider.valueChanged.connect(self.updateWorker)
        self.finetune_epochs_line_edit = QtGui.QLineEdit("10")
        self.finetune_epochs_line_edit.textChanged.connect(self.updateWorker)
        self.finetune_lr_slider = QtGui.QSlider(QtCore.Qt.Horizontal)
        self.finetune_lr_slider.setRange(1, 10)
        self.finetune_lr_slider.setValue(1)
        self.finetune_lr_slider.valueChanged.connect(self.updateWorker)

        self.learn_form = QtGui.QFormLayout()
        self.learn_form.addRow("finetune_epoch", self.finetune_epochs_line_edit)
        self.learn_form.addRow("finetune_lr", self.finetune_lr_slider)
        self.learn_form.addRow("pretrain_epoch", self.pretrian_epochs_line_edit)
        self.learn_form.addRow("pretrain_lr", self.pretrain_lr_slider)

        # A slider to control the plot delay
        self.slider = QtGui.QSlider(QtCore.Qt.Horizontal)
        self.slider.setRange(0, 99)
        self.slider.setValue(25)
        self.slider.valueChanged.connect(self.updateWorker)

        # A slider to control K
        self.k_slider = QtGui.QSlider(QtCore.Qt.Vertical)
        self.k_slider.setRange(0, 100)
        self.k_slider.setValue(0)
        self.k_slider.valueChanged.connect(self.updateWorker)
        self.n_slider = QtGui.QSlider(QtCore.Qt.Vertical)
        self.n_slider.setRange(0, 100)
        self.n_slider.setValue(0)
        self.n_slider.valueChanged.connect(self.updateWorker)

        # Just some button connected to `plot` method
        self.start_stop_button = QtGui.QPushButton("Start")
        self.start_stop_button.clicked.connect(self.start)

        # set the layout
        layout = QtGui.QGridLayout()
        # layout.addWidget(self.toolbar)
        layout.addWidget(self.canvas, 0, 0, 1, 2)
        layout.addWidget(self.k_slider, 0, 2, 1, 1)
        layout.addWidget(self.n_slider, 0, 3, 1, 1)
        layout.addLayout(self.input_form, 1, 0, 1, 1)
        layout.addLayout(self.learn_form, 1, 1, 1, 1)
        layout.addWidget(self.slider, 2, 0)
        layout.addWidget(self.start_stop_button, 2, 1)
        self.setLayout(layout)

        # setup worker
        self.need_setup = True
        self.worker = Worker(self.vis)

        # setup event dispatchers
        self.worker.started.connect(self.workerStarted)
        self.worker.updated.connect(self.updateFigure)
        self.worker.stopped.connect(self.workerStopped)

    def start(self):
        self.start_stop_button.setText("Stop")
        self.start_stop_button.setEnabled(False)

        window_size = string.atoi(self.window_size_line_edit.text())
        m = string.atoi(self.m_line_edit.text())
        r = string.atoi(self.r_line_edit.text())
        hidden_layer_sizes = self.hidden_layer_sizes_line_edit.text().split(",")
        hidden_layer_sizes = [string.atoi(n) for n in hidden_layer_sizes]

        if self.need_setup:
            self.worker.setup(m=m, r=r, window_size=window_size, hidden_layer_sizes=hidden_layer_sizes, pretrain_step=1)
            self.need_setup = False
        self.updateWorker()
        self.worker.start()

    def stop(self):
        self.start_stop_button.setText("Start")
        self.start_stop_button.setEnabled(False)
        self.worker.stop()

    def dnnChanged(self):
        self.need_setup = True

    def updateWorker(self):
        self.worker.setGeneratorParams(self.getKValue(), self.getNValue())
        self.worker.setDelay(self.getDelayValue())
        self.worker.setLearningParams(self.getLearningParams())

    def workerStarted(self):
        self.start_stop_button.setEnabled(True)
        self.start_stop_button.clicked.connect(self.stop)
        self.window_size_line_edit.setReadOnly(True)
        self.m_line_edit.setReadOnly(True)
        self.r_line_edit.setReadOnly(True)
        self.hidden_layer_sizes_line_edit.setReadOnly(True)

    def updateFigure(self):
        # refresh canvas
        self.canvas.draw()

    def workerStopped(self):
        self.start_stop_button.setEnabled(True)
        self.start_stop_button.clicked.connect(self.start)
        self.window_size_line_edit.setReadOnly(False)
        self.m_line_edit.setReadOnly(False)
        self.r_line_edit.setReadOnly(False)
        self.hidden_layer_sizes_line_edit.setReadOnly(False)

    def getLearningParams(self):
        return {
            "pretrain_epochs": string.atoi(self.pretrian_epochs_line_edit.text()),
            "pretrain_lr": 1.0 / pow(10, self.pretrain_lr_slider.value()),
            "finetune_epochs": string.atoi(self.finetune_epochs_line_edit.text()),
            "finetune_lr": 1.0 / pow(10, self.finetune_lr_slider.value()),
        }

    def getNValue(self):
        return self.n_slider.value() / 100.0

    def getKValue(self):
        return self.k_slider.value() / 100.0

    def getDelayValue(self):
        return self.slider.value() / 100.0