Exemple #1
0
    def test_validate_bag(self):

        bag_generator = LoadMIL.bag_data_generator(pipeline='whole',
                                                   verbose=False)
        bag, label = next(bag_generator)
        self.assertTrue(LoadMIL.validate_bag(bag))

        return None
Exemple #2
0
    def test_bag_data_generator(self):

        bag_generator = LoadMIL.bag_data_generator(pipeline='whole',
                                                   verbose=False)
        bag = next(bag_generator)
        self.assertIsInstance(bag[0], csr_matrix)

        return None
    def setUp(self):
        # Instance of BasePredictor
        self.predictor = SVMCRBFMILESPredictor(
            classifier_filename=SVMC_rbf_classifier_filename)
        # Generate some data from dataclass / pydantic for FastAPI
        self._generate_data()
        # Load raw data from file or create raw data
        (self.dfraw_load, self.bag_load,
         self.bag_label_load) = LoadMIL.get_single_mil_bag(pipeline='whole')

        return None
Exemple #4
0
    def setUp(self):
        self.predictor = CompNBPredictor(
            classifier_filename=CompNB_classifier_filename,
            pipeline_type='categorical')
        # Generate some data from dataclass / pydantic for FastAPI
        self._generate_data()
        # Load raw data from file or create raw data
        (self.dfraw_load, self.bag_load,
         self.bag_label_load) = LoadMIL.get_single_mil_bag(
             pipeline='categorical')

        return None
Exemple #5
0
    def setUp(self):
        # Construct raw data input
        # This is intended to test input gathered from a web form. Not all
        # Attributes that are present in a SQL database are present
        # Generate some data from dataclass / pydantic for FastAPI
        self.input_data, self.input_data_list = generate_input_data()
        self.input_data_pydantic, self.input_data_list_pydantic = generate_input_data_pydantic(
        )
        # Convert raw data to dataframe
        self.dfraw_input = pd.DataFrame(data=[self.input_data.__dict__])
        self.dfraw_input_pydantic = pd.DataFrame(
            data=[self.input_data_pydantic.__dict__])

        # Load raw data from file or create raw data
        (self.dfraw_load, self.bag_load,
         self.bag_label_load) = LoadMIL.get_single_mil_bag(pipeline='whole')

        return None
Exemple #6
0
from transform_mil import Transform
from dataclass_serving import RawInputData, RawInputDataPydantic

# Global declarations
config = configparser.ConfigParser()
config.read(r'../extract/sql_config.ini')
server_name = config['sql_server']['DEFAULT_SQL_SERVER_NAME']
driver_name = config['sql_server']['DEFAULT_SQL_DRIVER_NAME']
database_name = config['sql_server']['DEFAULT_DATABASE_NAME']
CompNB_classifier_filename = r"./compNB_si.clf"
KNN_classifier_filename = r"./knn_si.clf"
MultiNB_classifier_filename = r"./multiNB_si.clf"
SVMCL1SI_classifier_filename = r"./svmc_l1_si.clf"
SVMCRBFSI_classifier_filename = r"./svmc_rbf_si.clf"

LoadMIL = LoadMIL(server_name, driver_name, database_name)

#%% Class definitions


def generate_input_data():
    # Construct raw data input
    # This is intended to test input gathered from a web form. Not all
    # Attributes that are present in a SQL database are present
    input_data = RawInputData(
        # Required numeric attributes
        DEVICEHI=122.0,
        DEVICELO=32.0,
        SIGNALHI=10,
        SIGNALLO=0,
        SLOPE=1.2104,
Exemple #7
0
    if _PROJECT_DIR not in sys.path:
        sys.path.insert(0, _PROJECT_DIR)
from mil_load import (load_mil_dataset_from_file, LoadMIL, bags_2_si,
                      bags_2_si_generator)

# Global declarations
config = configparser.ConfigParser()
config.read(r'../extract/sql_config.ini')
server_name = config['sql_server']['DEFAULT_SQL_SERVER_NAME']
driver_name = config['sql_server']['DEFAULT_SQL_DRIVER_NAME']
database_name = config['sql_server']['DEFAULT_DATABASE_NAME']
numeric_feature_file = config['sql_server']['DEFAULT_NUMERIC_FILE_NAME']
categorical_feature_file = config['sql_server'][
    'DEFAULT_CATEGORICAL_FILE_NAME']

loadMIL = LoadMIL(server_name, driver_name, database_name)

#%%


def _filter_bags_by_size(X: Union[np.ndarray, csr_matrix, list],
                         min_instances: int, max_instances: int) -> np.ndarray:
    """Filter a set of bags by number of instances within the bag. If the 
    bag contains less than n_instances, then do not include that bag in the 
    returned index
    inputs
    -------
    X: (np.ndarray or iterable) of bags
    outputs
    --------
    index: (np.ndarray) index indicating where the number of instances per bag