def split_dataset(dataset_path, output_path, val_size, test_size, nrows, random_seed): """ Runs the CLI. """ logging.info('Loading dataset from "%s"...', dataset_path) dataset = load_dataset(dataset_path=dataset_path, transform_data=False, nrows=nrows) train, val, test = train_val_test_split(dataset, val_size=val_size, test_size=test_size, stratify_col=md.COLUMN_LABEL_CAT, random_state=random_seed) train = remove_extra_labels(train) val = remove_extra_labels(val) test = remove_extra_labels(test) save_dataset(train, output_path, 'train') save_dataset(val, output_path, 'val') save_dataset(test, output_path, 'test') logging.info('Processing complete.')
def test_loaded_dataset_must_replace_invalid_value_with_nan(val_data): df = load_dataset(conf.TEST_DATA_DIR) inf_value_c = inf_value_count(val_data) neg_value_c = neg_value_count(val_data) assert (inf_value_c + neg_value_c) == nan_value_count(df)
def test_loaded_dataset_must_contain_label_is_attack(): df = load_dataset(conf.TEST_DATA_DIR) all_sample_count = len(df) benign_sample_count = len(df[df.label == 'Benign']) attack_sample_count = all_sample_count - benign_sample_count assert len(df[df.label_is_attack == 0]) == benign_sample_count assert len(df[df.label_is_attack == 1]) == attack_sample_count
def test_loaded_dataset_must_not_contain_negative_values_except_excluded_cols( ): df = load_dataset( conf.TEST_DATA_DIR, preserve_neg_value_cols=['init_fwd_win_byts', 'init_bwd_win_byts']) assert neg_value_count(df) != 0 assert set(negative_value_columns(df)) == { 'init_bwd_win_byts', 'init_fwd_win_byts' }
def test_loaded_dataset_must_omit_specified_columns(): df = load_dataset(conf.TEST_DATA_DIR, omit_cols=['dst_port']) assert 'dst_port' not in df.columns
def test_loaded_dataset_must_contain_only_specified_columns(): df = load_dataset(conf.TEST_DATA_DIR, use_cols=['dst_port']) assert df.columns == ['dst_port']
def test_loaded_dataset_must_contain_label_category(): df = load_dataset(conf.TEST_DATA_DIR) assert len(df.label_cat.value_counts()) == len(df.label.value_counts())
def test_loaded_dataset_must_not_contain_negative_values(): df = load_dataset(conf.TEST_DATA_DIR) assert neg_value_count(df) == 0