def load_graph(self, override=''): if not override: init_path = TrainParamServer().get_work_dir() file_name = QtWidgets.QFileDialog.getOpenFileName( self, 'Open File', init_path, filter='Chainer Wing Files (*.json);; Any (*.*)')[0] else: file_name = override if not file_name: return logger.debug('Attempting to load graph: {}'.format(file_name)) self.drawer.clear_all_nodes() with open(file_name, 'r') as fp: try: proj_dict = json.load(fp) except json.decoder.JSONDecodeError: util.disp_error(file_name + ' is corrupted.') return # proj_dict = json.load(fp, object_hook=util.nethook) if 'graph' in proj_dict: self.drawer.graph.load_from_dict(proj_dict['graph']) self.statusBar.showMessage( 'Graph loaded from {}.'.format(file_name), 2000) logger.info('Successfully loaded graph: {}'.format(file_name)) if 'train' in proj_dict: TrainParamServer().load_from_dict(proj_dict['train']) self.settings.setValue('graph_file', file_name) self.update_data_label() self.setupNodeLib() TrainParamServer()['ProjectName'] = file_name.split('/')[-1].replace('.json', '')
def compile_init(self, nodes): links = [] for node in nodes.values(): if issubclass(type(node), Link): try: links.append(' {0}={1}'.format( node.get_name(), node.call_init())) except: util.disp_error( 'Unset parameter was found in {0}'.format(node)) return '' return '\n'.join(links)
def open_data_config(self): if 'Image' in TrainParamServer()['Task']: try: import chainercv data_dialog = ImageDataDialog(self, settings=self.settings) except ImportError: util.disp_error('Failed to import chainercv.' 'See https://github.com/chainer/chainercv#installation') return else: data_dialog = DataDialog(self, settings=self.settings) data_dialog.show() self.update_data_label()
def compile(self): """ Compile the Graph as chainer code. :return: If compilation was succeeded, return True. """ try: result = compiler.Compiler()(self.nodes) except util.ExistsInvalidParameter as error: util.disp_error('{0} is not set @{1}'.format( error.args[1][1:], error.args[0])) self.nodes[error.args[0]].runtime_error_happened = True return False except compiler.NoLossError: util.disp_error('Please place loss function.') return False return result
def __call__(self, nodes, **kwargs): if not nodes: util.disp_error('Please place nodes and connect them' ' before compilation.') return False init_impl = self.compile_init(nodes) if not init_impl: return False call_impl, pred_impl, lossID = self.compile_call(nodes) if not (call_impl and pred_impl): return False classification = 'Class' in TrainParamServer()['Task'] net_file = open(TrainParamServer().get_net_name(), 'w') net_file.write(TEMPLATES['NetTemplate']()( TrainParamServer()['NetName'], init_impl, call_impl, pred_impl, lossID, classification)) net_file.write(TEMPLATES['OptimizerTemplate']()(TrainParamServer())) net_file.write(TEMPLATES['TrainerTemplate']()(TrainParamServer())) return True
def mouseReleaseEvent(self, event): super(Painter2D, self).mouseReleaseEvent(event) if event.button() == Qt.LeftButton and self.looseConnection and self.clickedPin: valid = True if ':I' in self.clickedPin: input_nodeID, _, input_name = self.clickedPin.partition(':I') try: output_nodeID, _, output_name = self.getOutputPinAt( event.pos()).partition(':O') except AttributeError: valid = False else: output_nodeID, _, output_name = self.clickedPin.partition(':O') try: input_nodeID, _, input_name = self.getInputPinAt( event.pos()).partition(':I') except AttributeError: valid = False if valid: try: self.graph.connect(output_nodeID, output_name, input_nodeID, input_name) # self.update_graph_stack() except TypeError: util.disp_error('Cannot connect pins of different type') self.looseConnection = False self.clickedPin = None self.drag = False self.downOverNode = False if self.selectFrame and self.selectFrame_End: x1, x2 = self._selectFrame.x(), self._selectFrame_End.x() if x1 > x2: x2, x1 = x1, x2 y1, y2 = self._selectFrame.y(), self._selectFrame_End.y() if y1 > y2: y2, y1 = y1, y2 self.groupSelection = self.massNodeCollide(x1, y1, x2, y2) self.selectFrame = None self.selectFrame_End = None self.repaint() self.update()
def load_from_dict(self, graph_state): """ Reconstruct a Graph instance from a JSON string representation created by the Graph.to_json() method. :param graph_state: :return: Dictionary mapping the saved nodeIDs to the newly created nodes's IDs. """ idMap = {} for id, nodeData in graph_state: try: restoredNode = self.spawnNode(NODECLASSES[nodeData['class']], position=nodeData['position'], id=id, name=nodeData['name']) except KeyError: util.disp_error('Unknown Node class **{}**'.format( nodeData['class'])) continue else: try: restoredNode.subgraph = nodeData['subgraph'] except KeyError: restoredNode.subgraph = 'main' idMap[id] = restoredNode.ID inputs = nodeData['inputs'] for input in inputs: if input[1] in ('bool', 'int', 'float'): restoredNode.inputs[input[0]].set_value(input[-1]) for id, nodeData in graph_state: for input_name, outputID in nodeData['inputConnections'].items(): output_node, output_name = outputID.split(':O') try: output_node = idMap[output_node] self.connect(str(output_node), output_name, str(idMap[id]), input_name) except KeyError: print('Warning: Could not create connection ' 'due to missing node.') self.update() return idMap
def update_preview(self): self.commit_all() image_files = glob.glob(TrainParamServer()['TrainData'] + '/*/*.jpg') if not image_files: self.image_file = None util.disp_error('No image was found in {}.'.format( TrainParamServer()['TrainData'])) return self.image_idx = self.image_idx % len(image_files) self.image_file = image_files[self.image_idx] image_array = chainercv.utils.read_image_as_array(self.image_file) resize_width = TrainParamServer()['ResizeWidth'] resize_height = TrainParamServer()['ResizeHeight'] crop_edit = TrainParamServer()['Crop'] crop_width = TrainParamServer()['CropWidth'] crop_height = TrainParamServer()['CropHeight'] use_random_x_flip = TrainParamServer()['UseRandomXFlip'] use_random_y_flip = TrainParamServer()['UseRandomYFlip'] use_random_rotate = TrainParamServer()['UseRandomRotation'] pca_lighting = TrainParamServer()['PCAlighting'] image_array = image_array.transpose(2, 0, 1).astype(numpy.float32) image_array = augment_data(image_array, resize_width, resize_height, use_random_x_flip, use_random_y_flip, use_random_rotate, pca_lighting, crop_edit, crop_width, crop_height) image_array = image_array.transpose(1, 2, 0).astype(numpy.uint8) im = PIL.Image.fromarray(image_array) im = im.resize( (300, int(300 * image_array.shape[1] / image_array.shape[0]))) im.save('preview_temp.jpg') pixmap = QtGui.QPixmap('preview_temp.jpg') self.preview.setPixmap(pixmap) self.image_idx += 1
def compile_call(self, nodes): call_all_loss = [] call_all_pred = [] loss = None for node in nodes.values(): if not issubclass(type(node), Loss): continue chains = self.compile_node(node, nodes, []) loss = chains[0] funcs = chains[1:] compiled_pred = [] previous_ID = '' try: for func in reversed(funcs): if previous_ID: pred_call = ''.join( (' ' * 8, func.call(), previous_ID, ')')) else: pred_call = ''.join((' ' * 8, func.call(), 'x)')) compiled_pred.append(pred_call) previous_ID = func.get_name() except InputNotAvailable: util.disp_error( 'Unset parameter was found in {0}'.format(node)) return '' compiled_pred.append(' return ' + previous_ID) compiled_pred = '\n'.join(compiled_pred) call_all_pred.append(compiled_pred) call_all_loss.append(loss.call()) if loss is None: raise NoLossError('Please plase loss function.') return ', '.join(call_all_loss), ', '.join( call_all_pred), loss.get_name()
def compile_and_exe(self): if self.compile_runner(): self.exe_runner() else: util.disp_error('Compile is failured')
def close(self): try: self.drawer.graph.killRunner() except: util.disp_error('No runner to kill.') QtWidgets.qApp.quit()
def exe_prediction(self): if TrainParamServer()['GPU'] and not util.check_cuda_available(): return self.pred_progress.setText('Processing...') try: if 'Image' in TrainParamServer()['Task']: runner = ImagePredictionRunner() else: runner = PredictionRunner() result, label = runner.run(self.classification.isChecked(), self.including_label.isChecked()) if 'PredOutputData' in TrainParamServer().__dict__: numpy.savetxt(TrainParamServer()['PredOutputData'], result, delimiter=",") result = result[:self.max_disp_rows.value(), :] if label is not None: label = label[:self.max_disp_rows.value(), :] result = numpy.hstack((result, label)) self.result_table.setModel(ResultTableModel(result)) self.pred_progress.setText('Prediction Finished!') except KeyError as ke: if ke.args[0] == 'PredInputData': util.disp_error('Input Data for prediction is not set.') elif ke.args[0] == 'PredModel': util.disp_error('Model for prediction is not set.') else: util.disp_error(ke.args[0][0]) except util.AbnormalDataCode as ac: if not os.path.isfile(TrainParamServer()['PredInputData']): util.disp_error('{} is not found'.format( TrainParamServer()['PredInputData'])) return if not os.path.isfile(TrainParamServer()['PredModel']): util.disp_error('{} is not found'.format( TrainParamServer()['PredModel'])) return util.disp_error(ac.args[0][0] + ' @' + TrainParamServer()['PredInputData']) except ValueError: util.disp_error('Irregal data was found @' + TrainParamServer()['PredInputData']) except type_check.InvalidType as error: last_node = util.get_executed_last_node() util.disp_error(str(error.args) + ' @node: ' + last_node) except FileNotFoundError as error: util.disp_error(error.filename + ': ' + str(error.args[1]))
def commit(self): try: TrainParamServer()[self.key] = float(self.text()) except ValueError: util.disp_error('Optimizer parameter should be float.')
def run(self): """ Run compiled chainer code. :return: """ self.clear_error() if TrainParamServer()['GPU'] and not util.check_cuda_available(): util.disp_error('GPU option is selected but available cuda device' 'is not found.') return try: self.runner = runner.TrainRunner() except SyntaxError: util.disp_error( 'Generated chainer script ({}) is not valid.'.format( TrainParamServer().get_net_name())) return try: self.runner.run() except util.AbnormalDataCode as error: util.disp_error( str(error.args[0][0]) + ' @' + TrainParamServer()['TrainData']) except ValueError as error: util.disp_error('{0}\n'.format(error.args[0]) + 'Irregal data was found @' + TrainParamServer()['TrainData']) except FileNotFoundError as error: util.disp_error('{} is not found.'.format(error.filename)) except util.UnexpectedFileExtension: util.disp_error('Unexpected file extension was found.' 'data should be ".csv", ".npz" or ".py"') except type_check.InvalidType as error: last_nodeID = util.get_executed_last_node() util.disp_error(str(error.args) + ' @node: ' + last_nodeID) self.nodes[last_nodeID].runtime_error_happened = True