def get_data_pred(self): train_server = TrainParamServer() if os.path.isdir(train_server['PredInputData']): dir_name = train_server['PredInputData'] image_files = [] for ext in util.for_image_extensions(): image_files += glob.glob(dir_name + '/*.{}'.format(ext)) image_files += glob.glob(dir_name + '/*/*.{}'.format(ext)) if not image_files: raise Exception('No jpg file in {}'.format(dir_name)) pred_label_file = os.path.join(train_server.get_work_dir(), 'pred_label.txt') elif os.path.isfile(train_server['PredInputData']): image_files = (train_server['PredInputData'],) pred_label_file = os.path.join(train_server.get_work_dir(), 'pred_label.txt') else: raise FileNotFoundError(train_server['PredInputData'] + ' is not found.') with open(pred_label_file, 'w') as fw: for image, label in zip(image_files, pred_label_file): fw.write(image + '\n') return pred_label_file
def get_data_train(self): train_server = TrainParamServer() train_images, train_labels = self.get_all_images(train_server['TrainData']) if train_server['UseSameData']: split_idx = int(len(train_images) * train_server['TestDataRatio']) indices = numpy.arange(len(train_images)) if train_server['Shuffle']: numpy.random.shuffle(indices) train_idx = indices[:split_idx] test_idx = indices[split_idx:] test_images = train_images[test_idx] test_labels = train_labels[test_idx] train_images = train_images[train_idx] train_labels = train_labels[train_idx] else: test_images, test_labels = self.get_all_images(train_server['TrainData']) all_labels = numpy.hstack((train_labels, test_labels)) all_labels = sorted(list(set(all_labels))) label_conversion_file = os.path.join(train_server.get_work_dir(), 'label_conversion.txt') self.make_label_conversion_file(all_labels, label_conversion_file) train_label_file = os.path.join(train_server.get_work_dir(), 'train_label.txt') self.make_image_list(train_images, train_labels, train_label_file) test_label_file = os.path.join(train_server.get_work_dir(), 'test_label.txt') self.make_image_list(test_images, test_labels, test_label_file) self.compute_mean(train_images)
def compute_mean(self, images): print('compute mean image') sum_image = 0 N = len(images) resize_width = TrainParamServer()['ResizeWidth'] resize_height = TrainParamServer()['ResizeHeight'] crop_edit = TrainParamServer()['Crop'] crop_width = TrainParamServer()['CropWidth'] crop_height = TrainParamServer()['CropHeight'] use_random_x_flip = TrainParamServer()['UseRandomXFlip'] use_random_y_flip = TrainParamServer()['UseRandomYFlip'] use_random_rotate = TrainParamServer()['UseRandomRotation'] pca_lighting = TrainParamServer()['PCAlighting'] for i, image in enumerate(images): image = _read_image_as_array(image, numpy.float32) image = image.transpose(2, 0, 1).astype(numpy.float32) image = augment_data(image, resize_width, resize_height, use_random_x_flip, use_random_y_flip, use_random_rotate, pca_lighting, crop_edit, crop_width, crop_height) sum_image += image mean_file = os.path.join(TrainParamServer().get_work_dir(), 'mean.npy') mean = sum_image / N numpy.save(mean_file, mean)
def load_graph(self, override=''): if not override: init_path = TrainParamServer().get_work_dir() file_name = QtWidgets.QFileDialog.getOpenFileName( self, 'Open File', init_path, filter='Chainer Wing Files (*.json);; Any (*.*)')[0] else: file_name = override if not file_name: return logger.debug('Attempting to load graph: {}'.format(file_name)) self.drawer.clear_all_nodes() with open(file_name, 'r') as fp: try: proj_dict = json.load(fp) except json.decoder.JSONDecodeError: util.disp_error(file_name + ' is corrupted.') return # proj_dict = json.load(fp, object_hook=util.nethook) if 'graph' in proj_dict: self.drawer.graph.load_from_dict(proj_dict['graph']) self.statusBar.showMessage( 'Graph loaded from {}.'.format(file_name), 2000) logger.info('Successfully loaded graph: {}'.format(file_name)) if 'train' in proj_dict: TrainParamServer().load_from_dict(proj_dict['train']) self.settings.setValue('graph_file', file_name) self.update_data_label() self.setupNodeLib() TrainParamServer()['ProjectName'] = file_name.split('/')[-1].replace('.json', '')
def __init__(self): train_server = TrainParamServer() module_file = machinery.SourceFileLoader('net_run', train_server.get_net_name()) self.module = module_file.load_module() # Progress bar should be initialized after loading module file. self.pbar = CWProgressBar(train_server['Epoch']) self.chainerui_server = None
def __init__(self, settings, parent): menu = ('Do Nothing', 'MinMax Scale') self.parent = parent self.settings = settings super(PreProcessorEdit, self).__init__() self.addItems(menu) if 'PreProcessor_idx' in TrainParamServer().__dict__: self.setCurrentIndex(TrainParamServer()['PreProcessor_idx']) else: self.setCurrentIndex(settings.value('PreProcessor', type=int)) TrainParamServer()['PreProcessor'] = self.currentText()
def __init__(self, settings, parent): menu = ('Do Nothing', 'Center Crop', 'Random Crop') self.parent = parent self.settings = settings super(CropEdit, self).__init__() self.addItems(menu) if 'Crop_idx' in TrainParamServer().__dict__: self.setCurrentIndex(TrainParamServer()['Crop_idx']) else: self.setCurrentIndex(settings.value('Crop', type=int)) TrainParamServer()['Crop'] = self.currentText()
def __init__(self, settings, parent, key): self.parent = parent self.settings = settings super(DataCheckBox, self).__init__() self.key = key v = settings.value(key, type=bool) if key in TrainParamServer().__dict__: v = TrainParamServer()[key] else: TrainParamServer()[key] = v self.setChecked(v) TrainParamServer()[key] = self.isChecked()
def __init__(self, settings, parent, key, data_type=float): super(DataLineEdit, self).__init__() self.parent = parent self.settings = settings self.data_type = data_type self.key = key v = settings.value(key, type=data_type) v = v if v else 100 if key in TrainParamServer().__dict__: v = TrainParamServer()[key] else: TrainParamServer()[key] = v self.setText(str(v))
def __init__(self, settings, parent, key): self.parent = parent self.settings = settings super(DataDirEdit, self).__init__('Browse') v = settings.value(key, type=str) v = v if v else './' if key in TrainParamServer().__dict__: self.value = TrainParamServer()[key] else: self.value = v TrainParamServer()[key] = v self.key = key self.label = DataFileLabel(settings, parent, key) self.label.setText(self.value) self.clicked.connect(self.open_dialog)
def update_report(self): self.removeTab(0) self.removeTab(0) try: loss_image = TrainParamServer().get_result_dir() + "/loss.png" except KeyError: loss_image = "result/loss.png" self.loss_widget = GraphWidget(loss_image, parent=self) self.addTab(self.loss_widget, 'Loss') try: acc_image = TrainParamServer().get_result_dir() + "/accuracy.png" except KeyError: acc_image = "result/accuracy.png" self.acc_widget = GraphWidget(acc_image, parent=self) self.addTab(self.acc_widget, 'Accuracy')
def exe_prediction(self): if TrainParamServer()['GPU'] and not util.check_cuda_available(): return self.pred_progress.setText('Processing...') try: if 'Image' in TrainParamServer()['Task']: runner = ImagePredictionRunner() else: runner = PredictionRunner() result, label = runner.run(self.classification.isChecked(), self.including_label.isChecked()) if 'PredOutputData' in TrainParamServer().__dict__: numpy.savetxt(TrainParamServer()['PredOutputData'], result, delimiter=",") result = result[:self.max_disp_rows.value(), :] if label is not None: label = label[:self.max_disp_rows.value(), :] result = numpy.hstack((result, label)) self.result_table.setModel(ResultTableModel(result)) self.pred_progress.setText('Prediction Finished!') except KeyError as ke: if ke.args[0] == 'PredInputData': util.disp_error('Input Data for prediction is not set.') elif ke.args[0] == 'PredModel': util.disp_error('Model for prediction is not set.') else: util.disp_error(ke.args[0][0]) except util.AbnormalDataCode as ac: if not os.path.isfile(TrainParamServer()['PredInputData']): util.disp_error('{} is not found'.format( TrainParamServer()['PredInputData'])) return if not os.path.isfile(TrainParamServer()['PredModel']): util.disp_error('{} is not found'.format( TrainParamServer()['PredModel'])) return util.disp_error(ac.args[0][0] + ' @' + TrainParamServer()['PredInputData']) except ValueError: util.disp_error('Irregal data was found @' + TrainParamServer()['PredInputData']) except type_check.InvalidType as error: last_node = util.get_executed_last_node() util.disp_error(str(error.args) + ' @node: ' + last_node) except FileNotFoundError as error: util.disp_error(error.filename + ': ' + str(error.args[1]))
def commit(self): try: value = self.data_type(self.text()) self.settings.setValue(self.key, value) TrainParamServer()[self.key] = value except ValueError: return
def __init__(self, label, window): super(PredInputDataConfig, self).__init__(label, window) self.direction = 'Input Data File is not selected.' if 'Image' in TrainParamServer()['Task']: self.filter = '(*.jpg *.png);; Any (*.*)' else: self.filter = '(*.csv *.npz *.py);; Any (*.*)'
class DataFileEdit(QtWidgets.QPushButton): def __init__(self, settings, parent, key): self.parent = parent self.settings = settings super(DataFileEdit, self).__init__('Browse') v = settings.value(key, type=str) v = v if v else './' if key in TrainParamServer().__dict__: self.value = TrainParamServer()[key] else: self.value = v TrainParamServer()[key] = v self.key = key self.label = DataFileLabel(settings, parent, key) self.label.setText(self.value) self.clicked.connect(self.open_dialog) def commit(self): self.settings.setValue(self.key, self.value) TrainParamServer()[self.key] = self.value def open_dialog(self): init_path = TrainParamServer().get_work_dir() data_file = QtWidgets.QFileDialog.getOpenFileName( self, 'Select Data File', init_path, filter='(*.csv *.npz *.py);; Any (*.*)')[0] if data_file: self.value = data_file self.label.setText(self.value) self.parent.state_changed(0) def python_selected(self): return self.value.endswith('.py')
def update_node_list(self, text=''): """ Interpret the text in the LineEdit and send the filtered node list to the registered NodeList widget. :param text: string that is used for filtering the node list. If '', display all Nodes. :return: None """ text = text.lower() # nodes = [str(node) for node in nodeList if text in str(node).lower()] text = text[1:] if 'Image' not in TrainParamServer()['Task']: nodes = [ nodeName for nodeName, node in NODECLASSES.items() if node.matchHint(text) and not node.is_image_node ] else: nodes = [ nodeName for nodeName, node in NODECLASSES.items() if node.matchHint(text) ] model = QStandardItemModel() for node in sorted(nodes): item = QStandardItem() item.setText(node) item.setToolTip(NODECLASSES[node].doc()) model.appendRow(item) self.listView.setModel(model)
def get_executed_last_node(): def get_last_lineno(stack): for frame in stack: if frame.f_code.co_filename != TrainParamServer().get_net_name(): continue if frame.f_code.co_name == '__call__': last_lineno_candidate = frame.f_lineno if frame.f_code.co_name == '_predict': return frame.f_lineno return last_lineno_candidate tb = sys.exc_info()[2] while tb.tb_next: tb = tb.tb_next stack = [] f = tb.tb_frame while f: stack.append(f) f = f.f_back stack.reverse() lineno = get_last_lineno(stack) with open(TrainParamServer().get_net_name(), 'r') as net_file: for i, line in enumerate(net_file): if i == lineno - 1: last_node = line.strip().split(' ')[0] last_node = last_node.replace('self.', '') break return last_node
def __init__(self, path, mean, dtype=numpy.float32): root = TrainParamServer().get_work_dir() self.base = chainer.datasets.ImageDataset(path, root) self.mean = mean.astype('f') self.dtype = dtype self.resize_width = TrainParamServer()['ResizeWidth'] self.resize_height = TrainParamServer()['ResizeHeight'] self.crop_edit = TrainParamServer()['Crop'] self.crop_width = TrainParamServer()['CropWidth'] self.crop_height = TrainParamServer()['CropHeight'] self.use_random_x_flip = TrainParamServer()['UseRandomXFlip'] self.use_random_y_flip = TrainParamServer()['UseRandomYFlip'] self.use_random_rotate = TrainParamServer()['UseRandomRotation'] self.pca_lighting = TrainParamServer()['PCAlighting']
def open_dialog(self): init_path = TrainParamServer().get_work_dir() data_dir = QtWidgets.QFileDialog.getExistingDirectory( self, 'Select Directory', init_path) if data_dir: self.value = data_dir self.label.setText(self.value) self.parent.state_changed(0)
def run(self): train_server = TrainParamServer() result_dir = train_server['WorkDir'] + '/result' if not os.path.isdir(result_dir): os.mkdir(result_dir) if _chainerui_available: subprocess.call('chainerui project create -d {0} -n {1}'.format( result_dir, train_server['ProjectName']), shell=True) if self.chainerui_server is None: self.chainerui_server = subprocess.Popen('chainerui server', shell=True) time.sleep(0.5) webbrowser.open('http://localhost:5000/') if 'Image' in TrainParamServer()['Task']: ImageDataManager().get_data_train() train_label_file = os.path.join(train_server.get_work_dir(), 'train_label.txt') test_label_file = os.path.join(train_server.get_work_dir(), 'test_label.txt') mean_file = os.path.join(TrainParamServer().get_work_dir(), 'mean.npy') mean = numpy.load(mean_file) train_data = PreprocessedDataset(train_label_file, mean) test_data = PreprocessedDataset(test_label_file, mean) else: train_data, test_data = DataManager().get_data_train() self.module.training_main(train_data, test_data, self.pbar, cw_postprocess) util.disp_message('Training is finished. Model file is saved to ' + train_server.get_model_name() + '.npz', title='Training is finished')
def __init__(self, *args, settings=None): self.settings = settings super(PredictionWindow, self).__init__(*args) self.setupUi(self) self.input_sel_button.clicked.connect(self.set_input) self.input_config = PredInputDataConfig(self.input_data_name, self) self.output_sel_button.clicked.connect(self.set_output) self.output_config = PredOutputDataConfig(self.output_name, self) self.model_sel_button.clicked.connect(self.set_model) self.model_config = PredModelConfig(self.model_name, self) self.exe_button.clicked.connect(self.exe_prediction) self.including_label.stateChanged.connect(self.set_including_label) self.select_by_dir.stateChanged.connect(self.set_select_by_dir) if 'IncludingLabel' in TrainParamServer().__dict__: self.including_label.setChecked( TrainParamServer()['IncludingLabel']) if 'PredClass' in TrainParamServer().__dict__: self.classification.setChecked(TrainParamServer()['PredClass']) if 'Image' in TrainParamServer()['Task']: self.select_by_dir.setEnabled(True) if 'SelectByDir' in TrainParamServer()['Task']: self.select_by_dir.setChecked( TrainParamServer()['SelectByDir']) else: self.select_by_dir.setChecked(True) else: self.select_by_dir.setEnabled(False) self.select_by_dir.setChecked(False)
def get_last_lineno(stack): for frame in stack: if frame.f_code.co_filename != TrainParamServer().get_net_name(): continue if frame.f_code.co_name == '__call__': last_lineno_candidate = frame.f_lineno if frame.f_code.co_name == '_predict': return frame.f_lineno return last_lineno_candidate
def __init__(self, *args, **kwargs): super(ReportWidget, self).__init__(*args, **kwargs) self.setStyleSheet('''ReportWidget{background: rgb(55,55,55)} ''') try: loss_image = TrainParamServer().get_result_dir() + "/loss.png" except KeyError: loss_image = "result/loss.png" self.loss_widget = GraphWidget(loss_image, parent=self) self.addTab(self.loss_widget, 'Loss') try: acc_image = TrainParamServer().get_result_dir() + "/accuracy.png" except KeyError: acc_image = "result/accuracy.png" self.acc_widget = GraphWidget(acc_image, parent=self) self.addTab(self.acc_widget, 'Accuracy') self.resize(200, 200)
def __call__(self, nodes, **kwargs): if not nodes: util.disp_error('Please place nodes and connect them' ' before compilation.') return False init_impl = self.compile_init(nodes) if not init_impl: return False call_impl, pred_impl, lossID = self.compile_call(nodes) if not (call_impl and pred_impl): return False classification = 'Class' in TrainParamServer()['Task'] net_file = open(TrainParamServer().get_net_name(), 'w') net_file.write(TEMPLATES['NetTemplate']()( TrainParamServer()['NetName'], init_impl, call_impl, pred_impl, lossID, classification)) net_file.write(TEMPLATES['OptimizerTemplate']()(TrainParamServer())) net_file.write(TEMPLATES['TrainerTemplate']()(TrainParamServer())) return True
def get_data_pred(self, including_label): train_server = TrainParamServer() if train_server['PredInputData'].endswith('.py'): module = machinery.SourceFileLoader('data_getter', train_server['PredInputData']) try: module = module.load_module() if including_label: data, label = module.main() else: data, label = module.main(), None except Exception as error: raise util.AbnormalDataCode(error.args) else: data_file = train_server['PredInputData'] data, label = self.get_data_from_file(data_file, including_label) if TrainParamServer().use_minmax(): data = self.minmax_scale(data) return data, label
def deserialize_pred_label(): image_files = [] list_file = os.path.join(TrainParamServer().get_work_dir(), 'pred_label.txt') with open(list_file, 'r') as fr: for line in fr: line = line.strip() if line: file_name = line.split('/')[-1] image_files.append(file_name) return image_files
def deserialize_label_conversion(): label_to_class = {} label_conversion_file = os.path.join(TrainParamServer().get_work_dir(), 'label_conversion.txt') with open(label_conversion_file, 'r') as fr: for line in fr: line = line.strip() if line: class_str, int_str = line.split(' ') label_to_class[int_str] = class_str return label_to_class
def open_dialog(self): init_path = TrainParamServer().get_work_dir() data_file = QtWidgets.QFileDialog.getOpenFileName( self, 'Select Data File', init_path, filter='(*.csv *.npz *.py);; Any (*.*)')[0] if data_file: self.value = data_file self.label.setText(self.value) self.parent.state_changed(0)
def __init__(self, label, window, is_save=False, is_dir=False): self.param_name = self.__class__.__name__[:-6] # remove 'Config' self.label = label train_server = TrainParamServer() if self.param_name in train_server.__dict__: self.label.setText(train_server[self.param_name]) self.window = window self.direction = '' self.filter = '' self.is_save = is_save self.is_dir = is_dir
def __init__(self, parent=None, painter=None): super(MainWindow, self).__init__(parent) self.iconRoot = os.path.join(os.path.dirname(__file__), '../resources') self.settings = QtCore.QSettings('ChainerWing', 'ChainerWing') self.select_data_button = QtWidgets.QPushButton('') self.select_data_button.clicked.connect(self.open_data_config) self.select_data_button.setToolTip('Select training data') self.setupUi(self) self.setWindowIcon( QtGui.QIcon(os.path.join(self.iconRoot, 'appIcon.png'))) try: self.resize(self.settings.value("size", (900, 700))) self.move(self.settings.value("pos", QtCore.QPoint(50, 50))) init_graph = self.settings.value("graph_file", '') except TypeError: pass self.setWindowTitle('ChainerWind') self.initActions() self.initMenus() painter.reportWidget = self.BottomWidget painter.set_settings(self.settings) painter.setAutoFillBackground(True) p = self.palette() p.setColor(painter.backgroundRole(), QtGui.QColor(70, 70, 70)) painter.setPalette(p) l = QtWidgets.QGridLayout() l.addWidget(painter) self.DrawArea.setLayout(l) self.drawer = painter # to reflect initial configuration SettingsDialog(self, settings=self.settings).close() TrainDialog(self, settings=self.settings).close() ImageDataDialog(self, settings=self.settings).close() DataDialog(self, settings=self.settings).close() self.update_data_label() self.setupNodeLib() # Open Last Opened JSON if enable TrainParamServer()['ProjectName'] = 'New Project' try: if init_graph: self.load_graph(init_graph) except FileNotFoundError: pass