Example #1
0
 def load_graph(self, override=''):
     if not override:
         init_path = TrainParamServer().get_work_dir()
         file_name = QtWidgets.QFileDialog.getOpenFileName(
             self, 'Open File', init_path,
             filter='Chainer Wing Files (*.json);; Any (*.*)')[0]
     else:
         file_name = override
     if not file_name:
         return
     logger.debug('Attempting to load graph: {}'.format(file_name))
     self.drawer.clear_all_nodes()
     with open(file_name, 'r') as fp:
         try:
             proj_dict = json.load(fp)
         except json.decoder.JSONDecodeError:
             util.disp_error(file_name + ' is corrupted.')
             return
         # proj_dict = json.load(fp, object_hook=util.nethook)
         if 'graph' in proj_dict:
             self.drawer.graph.load_from_dict(proj_dict['graph'])
             self.statusBar.showMessage(
                 'Graph loaded from {}.'.format(file_name), 2000)
             logger.info('Successfully loaded graph: {}'.format(file_name))
         if 'train' in proj_dict:
             TrainParamServer().load_from_dict(proj_dict['train'])
     self.settings.setValue('graph_file', file_name)
     self.update_data_label()
     self.setupNodeLib()
     TrainParamServer()['ProjectName'] = file_name.split('/')[-1].replace('.json', '')
Example #2
0
 def compile_init(self, nodes):
     links = []
     for node in nodes.values():
         if issubclass(type(node), Link):
             try:
                 links.append('            {0}={1}'.format(
                     node.get_name(), node.call_init()))
             except:
                 util.disp_error(
                     'Unset parameter was found in {0}'.format(node))
                 return ''
     return '\n'.join(links)
Example #3
0
 def open_data_config(self):
     if 'Image' in TrainParamServer()['Task']:
         try:
             import chainercv
             data_dialog = ImageDataDialog(self, settings=self.settings)
         except ImportError:
             util.disp_error('Failed to import chainercv.'
                             'See https://github.com/chainer/chainercv#installation')
             return
     else:
         data_dialog = DataDialog(self, settings=self.settings)
     data_dialog.show()
     self.update_data_label()
Example #4
0
 def compile(self):
     """
     Compile the Graph as chainer code.
     :return: If compilation was succeeded, return True.
     """
     try:
         result = compiler.Compiler()(self.nodes)
     except util.ExistsInvalidParameter as error:
         util.disp_error('{0} is not set @{1}'.format(
             error.args[1][1:], error.args[0]))
         self.nodes[error.args[0]].runtime_error_happened = True
         return False
     except compiler.NoLossError:
         util.disp_error('Please place loss function.')
         return False
     return result
Example #5
0
 def __call__(self, nodes, **kwargs):
     if not nodes:
         util.disp_error('Please place nodes and connect them'
                         ' before compilation.')
         return False
     init_impl = self.compile_init(nodes)
     if not init_impl:
         return False
     call_impl, pred_impl, lossID = self.compile_call(nodes)
     if not (call_impl and pred_impl):
         return False
     classification = 'Class' in TrainParamServer()['Task']
     net_file = open(TrainParamServer().get_net_name(), 'w')
     net_file.write(TEMPLATES['NetTemplate']()(
         TrainParamServer()['NetName'], init_impl, call_impl, pred_impl,
         lossID, classification))
     net_file.write(TEMPLATES['OptimizerTemplate']()(TrainParamServer()))
     net_file.write(TEMPLATES['TrainerTemplate']()(TrainParamServer()))
     return True
Example #6
0
    def mouseReleaseEvent(self, event):
        super(Painter2D, self).mouseReleaseEvent(event)
        if event.button() == Qt.LeftButton and self.looseConnection and self.clickedPin:
            valid = True
            if ':I' in self.clickedPin:
                input_nodeID, _, input_name = self.clickedPin.partition(':I')
                try:
                    output_nodeID, _, output_name = self.getOutputPinAt(
                        event.pos()).partition(':O')
                except AttributeError:
                    valid = False
            else:
                output_nodeID, _, output_name = self.clickedPin.partition(':O')
                try:
                    input_nodeID, _, input_name = self.getInputPinAt(
                        event.pos()).partition(':I')
                except AttributeError:
                    valid = False
            if valid:
                try:
                    self.graph.connect(output_nodeID, output_name, input_nodeID,
                                       input_name)
                    # self.update_graph_stack()
                except TypeError:
                    util.disp_error('Cannot connect pins of different type')
            self.looseConnection = False
            self.clickedPin = None
        self.drag = False
        self.downOverNode = False

        if self.selectFrame and self.selectFrame_End:
            x1, x2 = self._selectFrame.x(), self._selectFrame_End.x()
            if x1 > x2:
                x2, x1 = x1, x2
            y1, y2 = self._selectFrame.y(), self._selectFrame_End.y()
            if y1 > y2:
                y2, y1 = y1, y2
            self.groupSelection = self.massNodeCollide(x1, y1,
                                                       x2, y2)
        self.selectFrame = None
        self.selectFrame_End = None
        self.repaint()
        self.update()
Example #7
0
    def load_from_dict(self, graph_state):
        """
        Reconstruct a Graph instance from a JSON string representation
        created by the Graph.to_json() method.
        :param graph_state:
        :return: Dictionary mapping the saved nodeIDs to the newly created
        nodes's IDs.
        """
        idMap = {}
        for id, nodeData in graph_state:
            try:
                restoredNode = self.spawnNode(NODECLASSES[nodeData['class']],
                                              position=nodeData['position'],
                                              id=id,
                                              name=nodeData['name'])
            except KeyError:
                util.disp_error('Unknown Node class **{}**'.format(
                    nodeData['class']))
                continue
            else:
                try:
                    restoredNode.subgraph = nodeData['subgraph']
                except KeyError:
                    restoredNode.subgraph = 'main'
            idMap[id] = restoredNode.ID
            inputs = nodeData['inputs']
            for input in inputs:
                if input[1] in ('bool', 'int', 'float'):
                    restoredNode.inputs[input[0]].set_value(input[-1])
        for id, nodeData in graph_state:
            for input_name, outputID in nodeData['inputConnections'].items():
                output_node, output_name = outputID.split(':O')
                try:
                    output_node = idMap[output_node]
                    self.connect(str(output_node), output_name, str(idMap[id]),
                                 input_name)
                except KeyError:
                    print('Warning: Could not create connection '
                          'due to missing node.')

        self.update()
        return idMap
Example #8
0
    def update_preview(self):
        self.commit_all()
        image_files = glob.glob(TrainParamServer()['TrainData'] + '/*/*.jpg')
        if not image_files:
            self.image_file = None
            util.disp_error('No image was found in {}.'.format(
                TrainParamServer()['TrainData']))
            return
        self.image_idx = self.image_idx % len(image_files)
        self.image_file = image_files[self.image_idx]
        image_array = chainercv.utils.read_image_as_array(self.image_file)

        resize_width = TrainParamServer()['ResizeWidth']
        resize_height = TrainParamServer()['ResizeHeight']

        crop_edit = TrainParamServer()['Crop']
        crop_width = TrainParamServer()['CropWidth']
        crop_height = TrainParamServer()['CropHeight']

        use_random_x_flip = TrainParamServer()['UseRandomXFlip']
        use_random_y_flip = TrainParamServer()['UseRandomYFlip']
        use_random_rotate = TrainParamServer()['UseRandomRotation']
        pca_lighting = TrainParamServer()['PCAlighting']

        image_array = image_array.transpose(2, 0, 1).astype(numpy.float32)
        image_array = augment_data(image_array, resize_width, resize_height,
                                   use_random_x_flip, use_random_y_flip,
                                   use_random_rotate, pca_lighting, crop_edit,
                                   crop_width, crop_height)
        image_array = image_array.transpose(1, 2, 0).astype(numpy.uint8)

        im = PIL.Image.fromarray(image_array)
        im = im.resize(
            (300, int(300 * image_array.shape[1] / image_array.shape[0])))
        im.save('preview_temp.jpg')

        pixmap = QtGui.QPixmap('preview_temp.jpg')
        self.preview.setPixmap(pixmap)

        self.image_idx += 1
Example #9
0
    def compile_call(self, nodes):
        call_all_loss = []
        call_all_pred = []
        loss = None
        for node in nodes.values():
            if not issubclass(type(node), Loss):
                continue

            chains = self.compile_node(node, nodes, [])
            loss = chains[0]
            funcs = chains[1:]
            compiled_pred = []
            previous_ID = ''
            try:
                for func in reversed(funcs):
                    if previous_ID:
                        pred_call = ''.join(
                            (' ' * 8, func.call(), previous_ID, ')'))
                    else:
                        pred_call = ''.join((' ' * 8, func.call(), 'x)'))
                    compiled_pred.append(pred_call)
                    previous_ID = func.get_name()
            except InputNotAvailable:
                util.disp_error(
                    'Unset parameter was found in {0}'.format(node))
                return ''

            compiled_pred.append('        return ' + previous_ID)
            compiled_pred = '\n'.join(compiled_pred)
            call_all_pred.append(compiled_pred)
            call_all_loss.append(loss.call())

        if loss is None:
            raise NoLossError('Please plase loss function.')
        return ', '.join(call_all_loss), ', '.join(
            call_all_pred), loss.get_name()
Example #10
0
 def compile_and_exe(self):
     if self.compile_runner():
         self.exe_runner()
     else:
         util.disp_error('Compile is failured')
Example #11
0
 def close(self):
     try:
         self.drawer.graph.killRunner()
     except:
         util.disp_error('No runner to kill.')
     QtWidgets.qApp.quit()
Example #12
0
    def exe_prediction(self):
        if TrainParamServer()['GPU'] and not util.check_cuda_available():
            return

        self.pred_progress.setText('Processing...')
        try:
            if 'Image' in TrainParamServer()['Task']:
                runner = ImagePredictionRunner()
            else:
                runner = PredictionRunner()
            result, label = runner.run(self.classification.isChecked(),
                                       self.including_label.isChecked())
            if 'PredOutputData' in TrainParamServer().__dict__:
                numpy.savetxt(TrainParamServer()['PredOutputData'],
                              result,
                              delimiter=",")
            result = result[:self.max_disp_rows.value(), :]
            if label is not None:
                label = label[:self.max_disp_rows.value(), :]
                result = numpy.hstack((result, label))
            self.result_table.setModel(ResultTableModel(result))
            self.pred_progress.setText('Prediction Finished!')
        except KeyError as ke:
            if ke.args[0] == 'PredInputData':
                util.disp_error('Input Data for prediction is not set.')
            elif ke.args[0] == 'PredModel':
                util.disp_error('Model for prediction is not set.')
            else:
                util.disp_error(ke.args[0][0])
        except util.AbnormalDataCode as ac:
            if not os.path.isfile(TrainParamServer()['PredInputData']):
                util.disp_error('{} is not found'.format(
                    TrainParamServer()['PredInputData']))
                return
            if not os.path.isfile(TrainParamServer()['PredModel']):
                util.disp_error('{} is not found'.format(
                    TrainParamServer()['PredModel']))
                return
            util.disp_error(ac.args[0][0] + ' @' +
                            TrainParamServer()['PredInputData'])
        except ValueError:
            util.disp_error('Irregal data was found @' +
                            TrainParamServer()['PredInputData'])
        except type_check.InvalidType as error:
            last_node = util.get_executed_last_node()
            util.disp_error(str(error.args) + ' @node: ' + last_node)
        except FileNotFoundError as error:
            util.disp_error(error.filename + ': ' + str(error.args[1]))
Example #13
0
 def commit(self):
     try:
         TrainParamServer()[self.key] = float(self.text())
     except ValueError:
         util.disp_error('Optimizer parameter should be float.')
Example #14
0
    def run(self):
        """
        Run compiled chainer code.
        :return:
        """
        self.clear_error()
        if TrainParamServer()['GPU'] and not util.check_cuda_available():
            util.disp_error('GPU option is selected but available cuda device'
                            'is not found.')
            return

        try:
            self.runner = runner.TrainRunner()
        except SyntaxError:
            util.disp_error(
                'Generated chainer script ({}) is not valid.'.format(
                    TrainParamServer().get_net_name()))
            return
        try:
            self.runner.run()
        except util.AbnormalDataCode as error:
            util.disp_error(
                str(error.args[0][0]) + ' @' + TrainParamServer()['TrainData'])
        except ValueError as error:
            util.disp_error('{0}\n'.format(error.args[0]) +
                            'Irregal data was found @' +
                            TrainParamServer()['TrainData'])
        except FileNotFoundError as error:
            util.disp_error('{} is not found.'.format(error.filename))
        except util.UnexpectedFileExtension:
            util.disp_error('Unexpected file extension was found.'
                            'data should be ".csv", ".npz" or ".py"')
        except type_check.InvalidType as error:
            last_nodeID = util.get_executed_last_node()
            util.disp_error(str(error.args) + ' @node: ' + last_nodeID)
            self.nodes[last_nodeID].runtime_error_happened = True