コード例 #1
0
    def get_data_pred(self):
        train_server = TrainParamServer()
        if os.path.isdir(train_server['PredInputData']):
            dir_name = train_server['PredInputData']
            image_files = []
            for ext in util.for_image_extensions():
                image_files += glob.glob(dir_name + '/*.{}'.format(ext))
                image_files += glob.glob(dir_name + '/*/*.{}'.format(ext))
            if not image_files:
                raise Exception('No jpg file in {}'.format(dir_name))

            pred_label_file = os.path.join(train_server.get_work_dir(),
                                           'pred_label.txt')

        elif os.path.isfile(train_server['PredInputData']):
            image_files = (train_server['PredInputData'],)
            pred_label_file = os.path.join(train_server.get_work_dir(),
                                           'pred_label.txt')
        else:
            raise FileNotFoundError(train_server['PredInputData'] +
                                    ' is not found.')

        with open(pred_label_file, 'w') as fw:
            for image, label in zip(image_files, pred_label_file):
                fw.write(image + '\n')

        return pred_label_file
コード例 #2
0
    def get_data_train(self):
        train_server = TrainParamServer()
        train_images, train_labels = self.get_all_images(train_server['TrainData'])

        if train_server['UseSameData']:
            split_idx = int(len(train_images) * train_server['TestDataRatio'])

            indices = numpy.arange(len(train_images))
            if train_server['Shuffle']:
                numpy.random.shuffle(indices)

            train_idx = indices[:split_idx]
            test_idx = indices[split_idx:]

            test_images = train_images[test_idx]
            test_labels = train_labels[test_idx]
            train_images = train_images[train_idx]
            train_labels = train_labels[train_idx]
        else:
            test_images, test_labels = self.get_all_images(train_server['TrainData'])

        all_labels = numpy.hstack((train_labels, test_labels))
        all_labels = sorted(list(set(all_labels)))
        label_conversion_file = os.path.join(train_server.get_work_dir(),
                                             'label_conversion.txt')
        self.make_label_conversion_file(all_labels, label_conversion_file)

        train_label_file = os.path.join(train_server.get_work_dir(),
                                        'train_label.txt')
        self.make_image_list(train_images, train_labels, train_label_file)
        test_label_file = os.path.join(train_server.get_work_dir(),
                                       'test_label.txt')
        self.make_image_list(test_images, test_labels, test_label_file)

        self.compute_mean(train_images)
コード例 #3
0
    def compute_mean(self, images):
        print('compute mean image')
        sum_image = 0
        N = len(images)

        resize_width = TrainParamServer()['ResizeWidth']
        resize_height = TrainParamServer()['ResizeHeight']

        crop_edit = TrainParamServer()['Crop']
        crop_width = TrainParamServer()['CropWidth']
        crop_height = TrainParamServer()['CropHeight']

        use_random_x_flip = TrainParamServer()['UseRandomXFlip']
        use_random_y_flip = TrainParamServer()['UseRandomYFlip']
        use_random_rotate = TrainParamServer()['UseRandomRotation']
        pca_lighting = TrainParamServer()['PCAlighting']

        for i, image in enumerate(images):
            image = _read_image_as_array(image, numpy.float32)
            image = image.transpose(2, 0, 1).astype(numpy.float32)
            image = augment_data(image, resize_width, resize_height,
                                 use_random_x_flip, use_random_y_flip,
                                 use_random_rotate, pca_lighting, crop_edit,
                                 crop_width, crop_height)
            sum_image += image

        mean_file = os.path.join(TrainParamServer().get_work_dir(),
                                 'mean.npy')
        mean = sum_image / N
        numpy.save(mean_file, mean)
コード例 #4
0
 def load_graph(self, override=''):
     if not override:
         init_path = TrainParamServer().get_work_dir()
         file_name = QtWidgets.QFileDialog.getOpenFileName(
             self, 'Open File', init_path,
             filter='Chainer Wing Files (*.json);; Any (*.*)')[0]
     else:
         file_name = override
     if not file_name:
         return
     logger.debug('Attempting to load graph: {}'.format(file_name))
     self.drawer.clear_all_nodes()
     with open(file_name, 'r') as fp:
         try:
             proj_dict = json.load(fp)
         except json.decoder.JSONDecodeError:
             util.disp_error(file_name + ' is corrupted.')
             return
         # proj_dict = json.load(fp, object_hook=util.nethook)
         if 'graph' in proj_dict:
             self.drawer.graph.load_from_dict(proj_dict['graph'])
             self.statusBar.showMessage(
                 'Graph loaded from {}.'.format(file_name), 2000)
             logger.info('Successfully loaded graph: {}'.format(file_name))
         if 'train' in proj_dict:
             TrainParamServer().load_from_dict(proj_dict['train'])
     self.settings.setValue('graph_file', file_name)
     self.update_data_label()
     self.setupNodeLib()
     TrainParamServer()['ProjectName'] = file_name.split('/')[-1].replace('.json', '')
コード例 #5
0
    def __init__(self):
        train_server = TrainParamServer()
        module_file = machinery.SourceFileLoader('net_run',
                                                 train_server.get_net_name())
        self.module = module_file.load_module()

        # Progress bar should be initialized after loading module file.
        self.pbar = CWProgressBar(train_server['Epoch'])
        self.chainerui_server = None
コード例 #6
0
ファイル: data_config.py プロジェクト: EroData/ChainerWing
 def __init__(self, settings, parent):
     menu = ('Do Nothing', 'MinMax Scale')
     self.parent = parent
     self.settings = settings
     super(PreProcessorEdit, self).__init__()
     self.addItems(menu)
     if 'PreProcessor_idx' in TrainParamServer().__dict__:
         self.setCurrentIndex(TrainParamServer()['PreProcessor_idx'])
     else:
         self.setCurrentIndex(settings.value('PreProcessor', type=int))
     TrainParamServer()['PreProcessor'] = self.currentText()
コード例 #7
0
 def __init__(self, settings, parent):
     menu = ('Do Nothing', 'Center Crop', 'Random Crop')
     self.parent = parent
     self.settings = settings
     super(CropEdit, self).__init__()
     self.addItems(menu)
     if 'Crop_idx' in TrainParamServer().__dict__:
         self.setCurrentIndex(TrainParamServer()['Crop_idx'])
     else:
         self.setCurrentIndex(settings.value('Crop', type=int))
     TrainParamServer()['Crop'] = self.currentText()
コード例 #8
0
ファイル: data_config.py プロジェクト: EroData/ChainerWing
 def __init__(self, settings, parent, key):
     self.parent = parent
     self.settings = settings
     super(DataCheckBox, self).__init__()
     self.key = key
     v = settings.value(key, type=bool)
     if key in TrainParamServer().__dict__:
         v = TrainParamServer()[key]
     else:
         TrainParamServer()[key] = v
     self.setChecked(v)
     TrainParamServer()[key] = self.isChecked()
コード例 #9
0
ファイル: data_config.py プロジェクト: EroData/ChainerWing
    def __init__(self, settings, parent, key, data_type=float):
        super(DataLineEdit, self).__init__()

        self.parent = parent
        self.settings = settings
        self.data_type = data_type
        self.key = key
        v = settings.value(key, type=data_type)
        v = v if v else 100
        if key in TrainParamServer().__dict__:
            v = TrainParamServer()[key]
        else:
            TrainParamServer()[key] = v
        self.setText(str(v))
コード例 #10
0
 def __init__(self, settings, parent, key):
     self.parent = parent
     self.settings = settings
     super(DataDirEdit, self).__init__('Browse')
     v = settings.value(key, type=str)
     v = v if v else './'
     if key in TrainParamServer().__dict__:
         self.value = TrainParamServer()[key]
     else:
         self.value = v
         TrainParamServer()[key] = v
     self.key = key
     self.label = DataFileLabel(settings, parent, key)
     self.label.setText(self.value)
     self.clicked.connect(self.open_dialog)
コード例 #11
0
 def update_report(self):
     self.removeTab(0)
     self.removeTab(0)
     try:
         loss_image = TrainParamServer().get_result_dir() + "/loss.png"
     except KeyError:
         loss_image = "result/loss.png"
     self.loss_widget = GraphWidget(loss_image, parent=self)
     self.addTab(self.loss_widget, 'Loss')
     try:
         acc_image = TrainParamServer().get_result_dir() + "/accuracy.png"
     except KeyError:
         acc_image = "result/accuracy.png"
     self.acc_widget = GraphWidget(acc_image, parent=self)
     self.addTab(self.acc_widget, 'Accuracy')
コード例 #12
0
    def exe_prediction(self):
        if TrainParamServer()['GPU'] and not util.check_cuda_available():
            return

        self.pred_progress.setText('Processing...')
        try:
            if 'Image' in TrainParamServer()['Task']:
                runner = ImagePredictionRunner()
            else:
                runner = PredictionRunner()
            result, label = runner.run(self.classification.isChecked(),
                                       self.including_label.isChecked())
            if 'PredOutputData' in TrainParamServer().__dict__:
                numpy.savetxt(TrainParamServer()['PredOutputData'],
                              result,
                              delimiter=",")
            result = result[:self.max_disp_rows.value(), :]
            if label is not None:
                label = label[:self.max_disp_rows.value(), :]
                result = numpy.hstack((result, label))
            self.result_table.setModel(ResultTableModel(result))
            self.pred_progress.setText('Prediction Finished!')
        except KeyError as ke:
            if ke.args[0] == 'PredInputData':
                util.disp_error('Input Data for prediction is not set.')
            elif ke.args[0] == 'PredModel':
                util.disp_error('Model for prediction is not set.')
            else:
                util.disp_error(ke.args[0][0])
        except util.AbnormalDataCode as ac:
            if not os.path.isfile(TrainParamServer()['PredInputData']):
                util.disp_error('{} is not found'.format(
                    TrainParamServer()['PredInputData']))
                return
            if not os.path.isfile(TrainParamServer()['PredModel']):
                util.disp_error('{} is not found'.format(
                    TrainParamServer()['PredModel']))
                return
            util.disp_error(ac.args[0][0] + ' @' +
                            TrainParamServer()['PredInputData'])
        except ValueError:
            util.disp_error('Irregal data was found @' +
                            TrainParamServer()['PredInputData'])
        except type_check.InvalidType as error:
            last_node = util.get_executed_last_node()
            util.disp_error(str(error.args) + ' @node: ' + last_node)
        except FileNotFoundError as error:
            util.disp_error(error.filename + ': ' + str(error.args[1]))
コード例 #13
0
ファイル: data_config.py プロジェクト: EroData/ChainerWing
 def commit(self):
     try:
         value = self.data_type(self.text())
         self.settings.setValue(self.key, value)
         TrainParamServer()[self.key] = value
     except ValueError:
         return
コード例 #14
0
 def __init__(self, label, window):
     super(PredInputDataConfig, self).__init__(label, window)
     self.direction = 'Input Data File is not selected.'
     if 'Image' in TrainParamServer()['Task']:
         self.filter = '(*.jpg *.png);; Any (*.*)'
     else:
         self.filter = '(*.csv *.npz *.py);; Any (*.*)'
コード例 #15
0
ファイル: data_config.py プロジェクト: EroData/ChainerWing
class DataFileEdit(QtWidgets.QPushButton):
    def __init__(self, settings, parent, key):
        self.parent = parent
        self.settings = settings
        super(DataFileEdit, self).__init__('Browse')
        v = settings.value(key, type=str)
        v = v if v else './'
        if key in TrainParamServer().__dict__:
            self.value = TrainParamServer()[key]
        else:
            self.value = v
            TrainParamServer()[key] = v
        self.key = key
        self.label = DataFileLabel(settings, parent, key)
        self.label.setText(self.value)
        self.clicked.connect(self.open_dialog)

    def commit(self):
        self.settings.setValue(self.key, self.value)
        TrainParamServer()[self.key] = self.value

    def open_dialog(self):
        init_path = TrainParamServer().get_work_dir()
        data_file = QtWidgets.QFileDialog.getOpenFileName(
            self,
            'Select Data File',
            init_path,
            filter='(*.csv *.npz *.py);; Any (*.*)')[0]
        if data_file:
            self.value = data_file
            self.label.setText(self.value)
            self.parent.state_changed(0)

    def python_selected(self):
        return self.value.endswith('.py')
コード例 #16
0
ファイル: node_lib.py プロジェクト: EroData/ChainerWing
 def update_node_list(self, text=''):
     """
     Interpret the text in the LineEdit and send the filtered node list to
      the registered NodeList widget.
     :param text: string that is used for filtering the node list.
                  If '', display all Nodes.
     :return: None
     """
     text = text.lower()
     # nodes = [str(node) for node in nodeList if text in str(node).lower()]
     text = text[1:]
     if 'Image' not in TrainParamServer()['Task']:
         nodes = [
             nodeName for nodeName, node in NODECLASSES.items()
             if node.matchHint(text) and not node.is_image_node
         ]
     else:
         nodes = [
             nodeName for nodeName, node in NODECLASSES.items()
             if node.matchHint(text)
         ]
     model = QStandardItemModel()
     for node in sorted(nodes):
         item = QStandardItem()
         item.setText(node)
         item.setToolTip(NODECLASSES[node].doc())
         model.appendRow(item)
     self.listView.setModel(model)
コード例 #17
0
ファイル: util.py プロジェクト: EroData/ChainerWing
def get_executed_last_node():
    def get_last_lineno(stack):
        for frame in stack:
            if frame.f_code.co_filename != TrainParamServer().get_net_name():
                continue
            if frame.f_code.co_name == '__call__':
                last_lineno_candidate = frame.f_lineno
            if frame.f_code.co_name == '_predict':
                return frame.f_lineno
        return last_lineno_candidate

    tb = sys.exc_info()[2]
    while tb.tb_next:
        tb = tb.tb_next
    stack = []
    f = tb.tb_frame
    while f:
        stack.append(f)
        f = f.f_back
    stack.reverse()

    lineno = get_last_lineno(stack)
    with open(TrainParamServer().get_net_name(), 'r') as net_file:
        for i, line in enumerate(net_file):
            if i == lineno - 1:
                last_node = line.strip().split(' ')[0]
                last_node = last_node.replace('self.', '')
                break

    return last_node
コード例 #18
0
    def __init__(self, path, mean, dtype=numpy.float32):
        root = TrainParamServer().get_work_dir()
        self.base = chainer.datasets.ImageDataset(path, root)
        self.mean = mean.astype('f')
        self.dtype = dtype

        self.resize_width = TrainParamServer()['ResizeWidth']
        self.resize_height = TrainParamServer()['ResizeHeight']

        self.crop_edit = TrainParamServer()['Crop']
        self.crop_width = TrainParamServer()['CropWidth']
        self.crop_height = TrainParamServer()['CropHeight']

        self.use_random_x_flip = TrainParamServer()['UseRandomXFlip']
        self.use_random_y_flip = TrainParamServer()['UseRandomYFlip']
        self.use_random_rotate = TrainParamServer()['UseRandomRotation']
        self.pca_lighting = TrainParamServer()['PCAlighting']
コード例 #19
0
 def open_dialog(self):
     init_path = TrainParamServer().get_work_dir()
     data_dir = QtWidgets.QFileDialog.getExistingDirectory(
         self, 'Select Directory', init_path)
     if data_dir:
         self.value = data_dir
         self.label.setText(self.value)
         self.parent.state_changed(0)
コード例 #20
0
    def run(self):
        train_server = TrainParamServer()
        result_dir = train_server['WorkDir'] + '/result'
        if not os.path.isdir(result_dir):
            os.mkdir(result_dir)
        if _chainerui_available:
            subprocess.call('chainerui project create -d {0} -n {1}'.format(
                result_dir, train_server['ProjectName']),
                            shell=True)
            if self.chainerui_server is None:
                self.chainerui_server = subprocess.Popen('chainerui server',
                                                         shell=True)
            time.sleep(0.5)
            webbrowser.open('http://localhost:5000/')

        if 'Image' in TrainParamServer()['Task']:
            ImageDataManager().get_data_train()
            train_label_file = os.path.join(train_server.get_work_dir(),
                                            'train_label.txt')
            test_label_file = os.path.join(train_server.get_work_dir(),
                                           'test_label.txt')
            mean_file = os.path.join(TrainParamServer().get_work_dir(),
                                     'mean.npy')
            mean = numpy.load(mean_file)
            train_data = PreprocessedDataset(train_label_file, mean)
            test_data = PreprocessedDataset(test_label_file, mean)
        else:
            train_data, test_data = DataManager().get_data_train()
        self.module.training_main(train_data, test_data, self.pbar,
                                  cw_postprocess)
        util.disp_message('Training is finished. Model file is saved to ' +
                          train_server.get_model_name() + '.npz',
                          title='Training is finished')
コード例 #21
0
    def __init__(self, *args, settings=None):
        self.settings = settings
        super(PredictionWindow, self).__init__(*args)
        self.setupUi(self)

        self.input_sel_button.clicked.connect(self.set_input)
        self.input_config = PredInputDataConfig(self.input_data_name, self)
        self.output_sel_button.clicked.connect(self.set_output)
        self.output_config = PredOutputDataConfig(self.output_name, self)
        self.model_sel_button.clicked.connect(self.set_model)
        self.model_config = PredModelConfig(self.model_name, self)

        self.exe_button.clicked.connect(self.exe_prediction)
        self.including_label.stateChanged.connect(self.set_including_label)
        self.select_by_dir.stateChanged.connect(self.set_select_by_dir)

        if 'IncludingLabel' in TrainParamServer().__dict__:
            self.including_label.setChecked(
                TrainParamServer()['IncludingLabel'])
        if 'PredClass' in TrainParamServer().__dict__:
            self.classification.setChecked(TrainParamServer()['PredClass'])
        if 'Image' in TrainParamServer()['Task']:
            self.select_by_dir.setEnabled(True)
            if 'SelectByDir' in TrainParamServer()['Task']:
                self.select_by_dir.setChecked(
                    TrainParamServer()['SelectByDir'])
            else:
                self.select_by_dir.setChecked(True)
        else:
            self.select_by_dir.setEnabled(False)
            self.select_by_dir.setChecked(False)
コード例 #22
0
ファイル: util.py プロジェクト: EroData/ChainerWing
 def get_last_lineno(stack):
     for frame in stack:
         if frame.f_code.co_filename != TrainParamServer().get_net_name():
             continue
         if frame.f_code.co_name == '__call__':
             last_lineno_candidate = frame.f_lineno
         if frame.f_code.co_name == '_predict':
             return frame.f_lineno
     return last_lineno_candidate
コード例 #23
0
    def __init__(self, *args, **kwargs):
        super(ReportWidget, self).__init__(*args, **kwargs)
        self.setStyleSheet('''ReportWidget{background: rgb(55,55,55)}
        ''')
        try:
            loss_image = TrainParamServer().get_result_dir() + "/loss.png"
        except KeyError:
            loss_image = "result/loss.png"

        self.loss_widget = GraphWidget(loss_image, parent=self)
        self.addTab(self.loss_widget, 'Loss')
        try:
            acc_image = TrainParamServer().get_result_dir() + "/accuracy.png"
        except KeyError:
            acc_image = "result/accuracy.png"
        self.acc_widget = GraphWidget(acc_image, parent=self)
        self.addTab(self.acc_widget, 'Accuracy')
        self.resize(200, 200)
コード例 #24
0
ファイル: compiler.py プロジェクト: EroData/ChainerWing
 def __call__(self, nodes, **kwargs):
     if not nodes:
         util.disp_error('Please place nodes and connect them'
                         ' before compilation.')
         return False
     init_impl = self.compile_init(nodes)
     if not init_impl:
         return False
     call_impl, pred_impl, lossID = self.compile_call(nodes)
     if not (call_impl and pred_impl):
         return False
     classification = 'Class' in TrainParamServer()['Task']
     net_file = open(TrainParamServer().get_net_name(), 'w')
     net_file.write(TEMPLATES['NetTemplate']()(
         TrainParamServer()['NetName'], init_impl, call_impl, pred_impl,
         lossID, classification))
     net_file.write(TEMPLATES['OptimizerTemplate']()(TrainParamServer()))
     net_file.write(TEMPLATES['TrainerTemplate']()(TrainParamServer()))
     return True
コード例 #25
0
    def get_data_pred(self, including_label):
        train_server = TrainParamServer()
        if train_server['PredInputData'].endswith('.py'):
            module = machinery.SourceFileLoader('data_getter',
                                                train_server['PredInputData'])
            try:
                module = module.load_module()
                if including_label:
                    data, label = module.main()
                else:
                    data, label = module.main(), None
            except Exception as error:
                raise util.AbnormalDataCode(error.args)
        else:
            data_file = train_server['PredInputData']
            data, label = self.get_data_from_file(data_file, including_label)

        if TrainParamServer().use_minmax():
            data = self.minmax_scale(data)
        return data, label
コード例 #26
0
ファイル: util.py プロジェクト: EroData/ChainerWing
def deserialize_pred_label():
    image_files = []
    list_file = os.path.join(TrainParamServer().get_work_dir(),
                             'pred_label.txt')
    with open(list_file, 'r') as fr:
        for line in fr:
            line = line.strip()
            if line:
                file_name = line.split('/')[-1]
                image_files.append(file_name)
    return image_files
コード例 #27
0
ファイル: util.py プロジェクト: EroData/ChainerWing
def deserialize_label_conversion():
    label_to_class = {}
    label_conversion_file = os.path.join(TrainParamServer().get_work_dir(),
                                         'label_conversion.txt')
    with open(label_conversion_file, 'r') as fr:
        for line in fr:
            line = line.strip()
            if line:
                class_str, int_str = line.split(' ')
                label_to_class[int_str] = class_str
    return label_to_class
コード例 #28
0
ファイル: data_config.py プロジェクト: EroData/ChainerWing
 def open_dialog(self):
     init_path = TrainParamServer().get_work_dir()
     data_file = QtWidgets.QFileDialog.getOpenFileName(
         self,
         'Select Data File',
         init_path,
         filter='(*.csv *.npz *.py);; Any (*.*)')[0]
     if data_file:
         self.value = data_file
         self.label.setText(self.value)
         self.parent.state_changed(0)
コード例 #29
0
 def __init__(self, label, window, is_save=False, is_dir=False):
     self.param_name = self.__class__.__name__[:-6]  # remove 'Config'
     self.label = label
     train_server = TrainParamServer()
     if self.param_name in train_server.__dict__:
         self.label.setText(train_server[self.param_name])
     self.window = window
     self.direction = ''
     self.filter = ''
     self.is_save = is_save
     self.is_dir = is_dir
コード例 #30
0
    def __init__(self, parent=None, painter=None):
        super(MainWindow, self).__init__(parent)

        self.iconRoot = os.path.join(os.path.dirname(__file__), '../resources')
        self.settings = QtCore.QSettings('ChainerWing', 'ChainerWing')

        self.select_data_button = QtWidgets.QPushButton('')
        self.select_data_button.clicked.connect(self.open_data_config)
        self.select_data_button.setToolTip('Select training data')

        self.setupUi(self)

        self.setWindowIcon(
            QtGui.QIcon(os.path.join(self.iconRoot, 'appIcon.png')))

        try:
            self.resize(self.settings.value("size", (900, 700)))
            self.move(self.settings.value("pos", QtCore.QPoint(50, 50)))
            init_graph = self.settings.value("graph_file", '')
        except TypeError:
            pass
        self.setWindowTitle('ChainerWind')

        self.initActions()
        self.initMenus()

        painter.reportWidget = self.BottomWidget
        painter.set_settings(self.settings)

        painter.setAutoFillBackground(True)
        p = self.palette()
        p.setColor(painter.backgroundRole(), QtGui.QColor(70, 70, 70))
        painter.setPalette(p)
        l = QtWidgets.QGridLayout()
        l.addWidget(painter)
        self.DrawArea.setLayout(l)
        self.drawer = painter

        # to reflect initial configuration
        SettingsDialog(self, settings=self.settings).close()
        TrainDialog(self, settings=self.settings).close()
        ImageDataDialog(self, settings=self.settings).close()
        DataDialog(self, settings=self.settings).close()
        self.update_data_label()
        self.setupNodeLib()

        # Open Last Opened JSON if enable
        TrainParamServer()['ProjectName'] = 'New Project'
        try:
            if init_graph:
                self.load_graph(init_graph)
        except FileNotFoundError:
            pass