def test_cache_is_faster(h5_directory, signals, events, records, cache_directory): dataset_parameters = { "h5_directory": h5_directory, "signals": signals, "events": events, "window": 1, "downsampling_rate": 1, "records": None, "minimum_overlap": 0.5, "ratio_positive": 0.5, } shutil.rmtree(cache_directory, ignore_errors=True) t1 = time.time() BalancedEventDataset(cache_data=True, **dataset_parameters) t1 = time.time() - t1 t2 = time.time() BalancedEventDataset( cache_data=True, **dataset_parameters, ) t2 = time.time() - t2 assert t2 < t1
def test_parallel_is_faster(h5_directory, signals, events, records, cache_directory): dataset_parameters = { "h5_directory": h5_directory, "signals": signals, "events": events, "window": 1, "fs": 64, "records": None, "minimum_overlap": 0.5, "ratio_positive": 0.5, "cache_data": False, } shutil.rmtree(cache_directory, ignore_errors=True) t1 = time.time() BalancedEventDataset(n_jobs=-1, **dataset_parameters) t1 = time.time() - t1 shutil.rmtree(cache_directory, ignore_errors=True) t2 = time.time() BalancedEventDataset( n_jobs=1, **dataset_parameters, ) t2 = time.time() - t2 assert t2 > t1
def get_balanced_data(dataset, idx, number): data = [] for i in range(number): data.append( dataset.extract_balanced_data( record=dataset.index_to_record_event[idx]["record"], max_index=dataset.index_to_record_event[idx]["max_index"], events_indexes=dataset.index_to_record_event[idx] ["events_indexes"], no_events_indexes=dataset.index_to_record_event[idx] ["no_events_indexes"])) return data
def test_cache_no_cache(h5_directory, signals, events, records, cache_directory): dataset_parameters = { "h5_directory": h5_directory, "signals": signals, "events": events, "window": 1, "downsampling_rate": 1, "records": None, "minimum_overlap": 0.5, "ratio_positive": 0.5, "n_jobs": -1, } shutil.rmtree(cache_directory, ignore_errors=True) BalancedEventDataset(cache_data=False, **dataset_parameters) assert not os.path.isdir(cache_directory) BalancedEventDataset( cache_data=True, **dataset_parameters, ) assert os.path.isdir(cache_directory)
def test_balanced_dataset_ratio_1(h5_directory, signals, events, records): dataset = BalancedEventDataset( h5_directory=h5_directory, signals=signals, events=events, window=1, downsampling_rate=1, records=None, minimum_overlap=0.5, transformations=lambda x: x, ratio_positive=1, ) signal, events_data = dataset[0] assert tuple(signal.shape) == (2, 64) assert events_data.shape[1] == 3 number_of_events = sum( [len(dataset.get_record_events(record)[0]) for record in records]) assert number_of_events == len(dataset) == 103 assert len(list(dataset.get_record_batch(records[0], 17))) == 22
def test_balanced_dataset_ratio_0(h5_directory, signals, events, records): dataset = BalancedEventDataset( h5_directory=h5_directory, signals=signals, events=events, window=1, downsampling_rate=1, records=None, minimum_overlap=0.5, transformations=lambda x: x, ratio_positive=0, ) signal, events_data = dataset[0] assert len(events_data) == 0 assert len(dataset) == 103
def test_full_training(): h5_directory = "./tests/test_files/h5/" window = 1 # in seconds signals = [{ 'h5_path': '/eeg_0', 'fs': 64, 'processing': { "type": "clip_and_normalize", "args": { "min_value": -150, "max_value": 150, } } }, { 'h5_path': '/eeg_1', 'fs': 64, 'processing': { "type": "clip_and_normalize", "args": { "min_value": -150, "max_value": 150, } } }] events = [ { "name": "spindle", "h5_path": "spindle", }, ] device = torch.device("cuda") dataset = BalancedEventDataset( h5_directory=h5_directory, signals=signals, events=events, window=window, fs=64, minimum_overlap=0.5, transformations=lambda x: x, ratio_positive=0.5, n_jobs=-1, ) # default events default_event_sizes = [1 * dataset.fs, 0.5 * dataset.fs] net = DOSED3( input_shape=dataset.input_shape, number_of_classes=dataset.number_of_classes, detection_parameters={ "overlap_non_maximum_suppression": 0.5, "classification_threshold": 0.5, }, default_event_sizes=default_event_sizes, ) optimizer_parameters = { "lr": 5e-3, "weight_decay": 1e-8, } loss_specs = { "type": "worst_negative_mining", "parameters": { "number_of_classes": dataset.number_of_classes, "device": device, } } trainer = trainers["adam"]( net, optimizer_parameters=optimizer_parameters, loss_specs=loss_specs, epochs=2, ) best_net_train, best_metrics_train, best_threshold_train = trainer.train( dataset, dataset, batch_size=12, ) best_net_train.predict_dataset(dataset, best_threshold_train, batch_size=2)