示例#1
0
def TestNewData(NewDataCsv, model_folder, result_save_path=''):
    '''

    :param NewDataCsv: New radiomics feature matrix csv file path
    :param model_folder:The trained model path
    :return:classification result
    '''
    train_info = LoadTrainInfo(model_folder)
    new_data_container = DataContainer()

    #Normlization

    new_data_container.Load(NewDataCsv)

    # feature_selector = FeatureSelector()
    # feature_selector.SelectFeatureByName(new_data_container, train_info['selected_features'], is_replace=True)

    new_data_container = train_info['normalizer'].Transform(new_data_container)

    # data_frame = new_data_container.GetFrame()
    # data_frame = data_frame[train_info['selected_features']]
    # new_data_container.SetFrame(data_frame)
    # new_data_container.UpdateDataByFrame()

    ##Model
    train_info['classifier'].SetDataContainer(new_data_container)
    model = train_info['classifier'].GetModel()
    predict = model.predict_proba(new_data_container.GetArray())[:, 1]

    label = new_data_container.GetLabel()
    case_name = new_data_container.GetCaseName()

    test_result_info = [['CaseName', 'Pred', 'Label']]
    for index in range(len(label)):
        test_result_info.append(
            [case_name[index], predict[index], label[index]])

    metric = EstimateMetirc(predict, label)
    info = {}
    info.update(metric)
    cv = CrossValidation()

    print(metric)
    print('\t')

    if result_save_path:
        cv.SaveResult(info, result_save_path)
        np.save(os.path.join(result_save_path, 'test_predict.npy'), predict)
        np.save(os.path.join(result_save_path, 'test_label.npy'), label)
        with open(os.path.join(result_save_path, 'test_info.csv'),
                  'w',
                  newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerows(test_result_info)

    return metric
示例#2
0
class PrepareConnection(QWidget, Ui_Prepare):
    def __init__(self, parent=None):
        super(PrepareConnection, self).__init__(parent)
        self.setupUi(self)
        self.data_container = DataContainer()

        self.buttonLoad.clicked.connect(self.LoadData)
        self.buttonRemove.clicked.connect(self.RemoveNonValidValue)
        self.checkSeparate.clicked.connect(self.SetSeparateStatus)
        self.spinBoxSeparate.setEnabled(False)

        self.buttonSave.clicked.connect(self.CheckAndSave)

    def UpdateTable(self):
        if self.data_container.GetArray().size == 0:
            return

        self.tableFeature.setRowCount(len(self.data_container.GetCaseName()))
        header_name = deepcopy(self.data_container.GetFeatureName())
        header_name.insert(0, 'Label')
        self.tableFeature.setColumnCount(len(header_name))
        self.tableFeature.setHorizontalHeaderLabels(header_name)
        self.tableFeature.setVerticalHeaderLabels(
            list(map(str, self.data_container.GetCaseName())))

        for row_index in range(len(self.data_container.GetCaseName())):
            for col_index in range(len(header_name)):
                if col_index == 0:
                    self.tableFeature.setItem(
                        row_index, col_index,
                        QTableWidgetItem(
                            str(self.data_container.GetLabel()[row_index])))
                else:
                    self.tableFeature.setItem(
                        row_index, col_index,
                        QTableWidgetItem(
                            str(self.data_container.GetArray()[row_index,
                                                               col_index -
                                                               1])))

        text = "The number of cases: {:d}\n".format(
            len(self.data_container.GetCaseName()))
        text += "The number of features: {:d}\n".format(
            len(self.data_container.GetFeatureName()))
        if len(np.unique(self.data_container.GetLabel())) == 2:
            positive_number = len(
                np.where(self.data_container.GetLabel() == np.max(
                    self.data_container.GetLabel()))[0])
            negative_number = len(
                self.data_container.GetLabel()) - positive_number
            assert (positive_number + negative_number == len(
                self.data_container.GetLabel()))
            text += "The number of positive samples: {:d}\n".format(
                positive_number)
            text += "The number of negative samples: {:d}\n".format(
                negative_number)
        self.textInformation.setText(text)

    def LoadData(self):
        dlg = QFileDialog()
        file_name, _ = dlg.getOpenFileName(self,
                                           'Open SCV file',
                                           filter="csv files (*.csv)")
        try:
            self.data_container.Load(file_name)
        except:
            print('Error')

        self.UpdateTable()

    def RemoveNonValidValue(self):
        if self.radioRemoveNonvalidCases.isChecked():
            self.data_container.RemoveUneffectiveCases()
        elif self.radioRemoveNonvalidFeatures.isChecked():
            self.data_container.RemoveUneffectiveFeatures()

        self.UpdateTable()

    def SetSeparateStatus(self):
        if self.checkSeparate.isChecked():
            self.spinBoxSeparate.setEnabled(True)
        else:
            self.spinBoxSeparate.setEnabled(False)

    def CheckAndSave(self):
        if self.data_container.IsEmpty():
            QMessageBox.warning(self, "Warning", "There is no data",
                                QMessageBox.Ok)
        elif self.data_container.HasNonValidNumber():
            QMessageBox.warning(self, "Warning", "There are nan items",
                                QMessageBox.Ok)
            non_valid_number_Index = self.data_container.FindNonValidNumberIndex(
            )
            old_edit_triggers = self.tableFeature.editTriggers()
            self.tableFeature.setEditTriggers(QAbstractItemView.CurrentChanged)
            self.tableFeature.setCurrentCell(non_valid_number_Index[0],
                                             non_valid_number_Index[1] + 1)
            self.tableFeature.setEditTriggers(old_edit_triggers)
        elif self.checkSeparate.isChecked():
            percentage_testing_data = self.spinBoxSeparate.value()
            folder_name = QFileDialog.getExistingDirectory(self, "Save data")
            if folder_name != '':
                data_seperate = DataSeparate.DataSeparate(
                    percentage_testing_data)
                data_seperate.Run(self.data_container, folder_name)
        else:
            file_name, _ = QFileDialog.getSaveFileName(
                self, "Save data", filter="csv files (*.csv)")
            if file_name != '':
                self.data_container.Save(file_name)
示例#3
0
class PrepareConnection(QWidget, Ui_Prepare):
    def __init__(self, parent=None):
        super(PrepareConnection, self).__init__(parent)
        self.setupUi(self)
        self.data_container = DataContainer()

        self.buttonLoad.clicked.connect(self.LoadData)
        self.buttonRemove.clicked.connect(self.RemoveNonValidValue)
        self.loadTestingReference.clicked.connect(
            self.LoadTestingReferenceDataContainer)
        self.clearTestingReference.clicked.connect(
            self.ClearTestingReferenceDataContainer)
        self.__testing_ref_data_container = DataContainer()
        self.checkSeparate.clicked.connect(self.SetSeparateStatus)

        self.spinBoxSeparate.setEnabled(False)
        self.logger = eclog(os.path.split(__file__)[-1]).GetLogger()

        self.loadTestingReference.setEnabled(False)
        self.clearTestingReference.setEnabled(False)

        self.buttonSave.clicked.connect(self.CheckAndSave)

    def UpdateTable(self):
        if self.data_container.GetArray().size == 0:
            return

        self.tableFeature.setRowCount(len(self.data_container.GetCaseName()))
        header_name = deepcopy(self.data_container.GetFeatureName())
        header_name.insert(0, 'Label')
        self.tableFeature.setColumnCount(len(header_name))
        self.tableFeature.setHorizontalHeaderLabels(header_name)
        self.tableFeature.setVerticalHeaderLabels(
            list(map(str, self.data_container.GetCaseName())))

        for row_index in range(len(self.data_container.GetCaseName())):
            for col_index in range(len(header_name)):
                if col_index == 0:
                    self.tableFeature.setItem(
                        row_index, col_index,
                        QTableWidgetItem(
                            str(self.data_container.GetLabel()[row_index])))
                else:
                    self.tableFeature.setItem(
                        row_index, col_index,
                        QTableWidgetItem(
                            str(self.data_container.GetArray()[row_index,
                                                               col_index -
                                                               1])))

        text = "The number of cases: {:d}\n".format(
            len(self.data_container.GetCaseName()))
        text += "The number of features: {:d}\n".format(
            len(self.data_container.GetFeatureName()))
        if len(np.unique(self.data_container.GetLabel())) == 2:
            positive_number = len(
                np.where(self.data_container.GetLabel() == np.max(
                    self.data_container.GetLabel()))[0])
            negative_number = len(
                self.data_container.GetLabel()) - positive_number
            assert (positive_number + negative_number == len(
                self.data_container.GetLabel()))
            text += "The number of positive samples: {:d}\n".format(
                positive_number)
            text += "The number of negative samples: {:d}\n".format(
                negative_number)
        self.textInformation.setText(text)

    def LoadData(self):
        dlg = QFileDialog()
        file_name, _ = dlg.getOpenFileName(self,
                                           'Open SCV file',
                                           filter="csv files (*.csv)")
        try:
            self.data_container.Load(file_name)
            self.logger.info('Open the file ' + file_name + ' Succeed.')
        except OSError as reason:
            self.logger.log('Open SCV file Error, The reason is ' +
                            str(reason))
            QMessageBox.about(self, 'Load data Error', reason.__str__())
            print('Error!' + str(reason))
        except ValueError:
            self.logger.error('Open SCV file ' + file_name +
                              ' Failed. because of value error.')
            QMessageBox.information(self, 'Error',
                                    'The selected data file mismatch.')
        self.UpdateTable()

        self.buttonRemove.setEnabled(True)
        self.buttonSave.setEnabled(True)

    def LoadTestingReferenceDataContainer(self):
        dlg = QFileDialog()
        file_name, _ = dlg.getOpenFileName(self,
                                           'Open SCV file',
                                           filter="csv files (*.csv)")
        try:
            self.__testing_ref_data_container.Load(file_name)
            self.loadTestingReference.setEnabled(False)
            self.clearTestingReference.setEnabled(True)
            self.spinBoxSeparate.setEnabled(False)
        except OSError as reason:
            self.logger.log('Load Testing Reference Error: ' + str(reason))
            print('Error!' + str(reason))
        except ValueError:
            self.logger.error('Open SCV file ' + file_name +
                              ' Failed. because of value error.')
            QMessageBox.information(self, 'Error',
                                    'The selected data file mismatch.')

    def ClearTestingReferenceDataContainer(self):
        del self.__testing_ref_data_container
        self.__testing_ref_data_container = DataContainer()
        self.loadTestingReference.setEnabled(True)
        self.clearTestingReference.setEnabled(False)
        self.spinBoxSeparate.setEnabled(False)

    def RemoveNonValidValue(self):
        if self.radioRemoveNonvalidCases.isChecked():
            self.data_container.RemoveUneffectiveCases()
        elif self.radioRemoveNonvalidFeatures.isChecked():
            self.data_container.RemoveUneffectiveFeatures()

        self.UpdateTable()

    def SetSeparateStatus(self):
        if self.checkSeparate.isChecked():
            self.spinBoxSeparate.setEnabled(True)
            self.loadTestingReference.setEnabled(True)
            self.clearTestingReference.setEnabled(False)
        else:
            self.spinBoxSeparate.setEnabled(False)
            self.loadTestingReference.setEnabled(False)
            self.clearTestingReference.setEnabled(False)

    def CheckAndSave(self):
        if self.data_container.IsEmpty():
            QMessageBox.warning(self, "Warning", "There is no data",
                                QMessageBox.Ok)
        elif not self.data_container.IsBinaryLabel():
            QMessageBox.warning(self, "Warning", "There are not 2 Labels",
                                QMessageBox.Ok)
            non_valid_number_Index = self.data_container.FindNonValidLabelIndex(
            )
            old_edit_triggers = self.tableFeature.editTriggers()
            self.tableFeature.setEditTriggers(QAbstractItemView.CurrentChanged)
            self.tableFeature.setCurrentCell(non_valid_number_Index, 0)
            self.tableFeature.setEditTriggers(old_edit_triggers)
        elif self.data_container.HasNonValidNumber():
            QMessageBox.warning(self, "Warning", "There are nan items",
                                QMessageBox.Ok)
            non_valid_number_Index = self.data_container.FindNonValidNumberIndex(
            )
            old_edit_triggers = self.tableFeature.editTriggers()
            self.tableFeature.setEditTriggers(QAbstractItemView.CurrentChanged)
            self.tableFeature.setCurrentCell(non_valid_number_Index[0],
                                             non_valid_number_Index[1] + 1)
            self.tableFeature.setEditTriggers(old_edit_triggers)
        else:
            remove_features_with_same_value = RemoveSameFeatures()
            self.data_container = remove_features_with_same_value.Run(
                self.data_container)

            data_balance = DataBalance()
            if self.radioDownSampling.isChecked():
                data_balance = DownSampling()
            elif self.radioUpSampling.isChecked():
                data_balance = UpSampling()
            elif self.radioSmote.isChecked():
                data_balance = SmoteSampling()

            if self.checkSeparate.isChecked():
                folder_name = QFileDialog.getExistingDirectory(
                    self, "Save data")
                if folder_name != '':
                    data_separate = DataSeparate.DataSeparate()
                    try:
                        if self.__testing_ref_data_container.IsEmpty():
                            testing_data_percentage = self.spinBoxSeparate.value(
                            )
                            training_data_container, _, = data_separate.RunByTestingPercentage(
                                self.data_container, testing_data_percentage,
                                folder_name)
                        else:
                            training_data_container, _, = data_separate.RunByTestingReference(
                                self.data_container,
                                self.__testing_ref_data_container, folder_name)
                            if training_data_container.IsEmpty():
                                QMessageBox.information(
                                    self, 'Error',
                                    'The testing data does not mismatch, please check the testing data '
                                    'really exists in current data')
                                return None
                        data_balance.Run(training_data_container,
                                         store_path=folder_name)
                    except Exception as e:
                        content = 'PrepareConnection, splitting failed: '
                        self.logger.error('{}{}'.format(content, str(e)))
                        QMessageBox.about(self, content, e.__str__())

            else:
                file_name, _ = QFileDialog.getSaveFileName(
                    self, "Save data", filter="csv files (*.csv)")
                if file_name != '':
                    data_balance.Run(self.data_container, store_path=file_name)
示例#4
0
if __name__ == '__main__':
    from FAE.DataContainer.DataContainer import DataContainer
    from FAE.FeatureAnalysis.Normalizer import NormalizerZeroCenter
    from FAE.FeatureAnalysis.Classifier import SVM, LR, LDA, LRLasso, GaussianProcess, NaiveBayes, DecisionTree, RandomForest, AE, AdaBoost
    import numpy as np

    train_data_container = DataContainer()
    train_data_container.Load(
        r'C:\MyCode\FAEGitHub\FAE\Example\withoutshape\non_balance_features.csv'
    )

    normalizer = NormalizerZeroCenter()
    train_data_container = normalizer.Run(train_data_container)

    data = train_data_container.GetArray()
    label = np.asarray(train_data_container.GetLabel())

    #     param_list = [
    # {"hidden_layer_sizes": [(30,), (100,)],
    # "solver": ["adam"],
    # "alpha": [0.0001, 0.001],
    # "learning_rate_init": [0.001, 0.01]}
    # ]
    #     from sklearn.model_selection import ParameterGrid
    #     pl = ParameterGrid(param_list)

    cv = CrossValidation5Folder()
    cv.SetClassifier(LR())
    # cv.SetClassifierParameterList(pl)
    train_metric, val_metric, test_metric = cv.Run(
        train_data_container,
示例#5
0
        return "To Remove the unbalance of the training data set, we applied an Tomek link after the " \
               "Synthetic Minority Oversampling TEchnique (SMOTE) to make positive/negative samples balance. "

    def Run(self, data_container, store_path=''):
        data, label, feature_name, label_name = data_container.GetData()
        data_resampled, label_resampled = self._model.fit_sample(data, label)

        new_case_name = [
            'Generate' + str(index) for index in range(data_resampled.shape[0])
        ]
        new_data_container = DataContainer(data_resampled, label_resampled,
                                           data_container.GetFeatureName(),
                                           new_case_name)
        if store_path != '':
            if os.path.isdir(store_path):
                new_data_container.Save(
                    os.path.join(store_path,
                                 '{}_features.csv'.format(self._name)))
            else:
                new_data_container.Save(store_path)
        return new_data_container


if __name__ == '__main__':
    dc = DataContainer()
    dc.Load(r'..\..\Example\numeric_feature.csv')
    print(dc.GetArray().shape, np.sum(dc.GetLabel()))
    b = SmoteTomekSampling()
    new = b.Run(dc)
    print(new.GetArray().shape, np.sum(new.GetLabel()))