Beispiel #1
0
def test_cache_is_faster(h5_directory, signals, events, records,
                         cache_directory):
    dataset_parameters = {
        "h5_directory": h5_directory,
        "signals": signals,
        "events": events,
        "window": 1,
        "downsampling_rate": 1,
        "records": None,
        "minimum_overlap": 0.5,
        "ratio_positive": 0.5,
    }

    shutil.rmtree(cache_directory, ignore_errors=True)
    t1 = time.time()
    BalancedEventDataset(cache_data=True, **dataset_parameters)
    t1 = time.time() - t1

    t2 = time.time()
    BalancedEventDataset(
        cache_data=True,
        **dataset_parameters,
    )
    t2 = time.time() - t2

    assert t2 < t1
Beispiel #2
0
def test_parallel_is_faster(h5_directory, signals, events, records,
                            cache_directory):

    dataset_parameters = {
        "h5_directory": h5_directory,
        "signals": signals,
        "events": events,
        "window": 1,
        "fs": 64,
        "records": None,
        "minimum_overlap": 0.5,
        "ratio_positive": 0.5,
        "cache_data": False,
    }

    shutil.rmtree(cache_directory, ignore_errors=True)
    t1 = time.time()
    BalancedEventDataset(n_jobs=-1, **dataset_parameters)
    t1 = time.time() - t1

    shutil.rmtree(cache_directory, ignore_errors=True)
    t2 = time.time()
    BalancedEventDataset(
        n_jobs=1,
        **dataset_parameters,
    )
    t2 = time.time() - t2

    assert t2 > t1
Beispiel #3
0
 def get_balanced_data(dataset, idx, number):
     data = []
     for i in range(number):
         data.append(
             dataset.extract_balanced_data(
                 record=dataset.index_to_record_event[idx]["record"],
                 max_index=dataset.index_to_record_event[idx]["max_index"],
                 events_indexes=dataset.index_to_record_event[idx]
                 ["events_indexes"],
                 no_events_indexes=dataset.index_to_record_event[idx]
                 ["no_events_indexes"]))
     return data
Beispiel #4
0
def test_cache_no_cache(h5_directory, signals, events, records,
                        cache_directory):
    dataset_parameters = {
        "h5_directory": h5_directory,
        "signals": signals,
        "events": events,
        "window": 1,
        "downsampling_rate": 1,
        "records": None,
        "minimum_overlap": 0.5,
        "ratio_positive": 0.5,
        "n_jobs": -1,
    }

    shutil.rmtree(cache_directory, ignore_errors=True)
    BalancedEventDataset(cache_data=False, **dataset_parameters)
    assert not os.path.isdir(cache_directory)

    BalancedEventDataset(
        cache_data=True,
        **dataset_parameters,
    )
    assert os.path.isdir(cache_directory)
Beispiel #5
0
def test_balanced_dataset_ratio_1(h5_directory, signals, events, records):

    dataset = BalancedEventDataset(
        h5_directory=h5_directory,
        signals=signals,
        events=events,
        window=1,
        downsampling_rate=1,
        records=None,
        minimum_overlap=0.5,
        transformations=lambda x: x,
        ratio_positive=1,
    )

    signal, events_data = dataset[0]

    assert tuple(signal.shape) == (2, 64)
    assert events_data.shape[1] == 3

    number_of_events = sum(
        [len(dataset.get_record_events(record)[0]) for record in records])
    assert number_of_events == len(dataset) == 103

    assert len(list(dataset.get_record_batch(records[0], 17))) == 22
Beispiel #6
0
def test_balanced_dataset_ratio_0(h5_directory, signals, events, records):
    dataset = BalancedEventDataset(
        h5_directory=h5_directory,
        signals=signals,
        events=events,
        window=1,
        downsampling_rate=1,
        records=None,
        minimum_overlap=0.5,
        transformations=lambda x: x,
        ratio_positive=0,
    )

    signal, events_data = dataset[0]

    assert len(events_data) == 0
    assert len(dataset) == 103
Beispiel #7
0
def test_full_training():
    h5_directory = "./tests/test_files/h5/"

    window = 1  # in seconds

    signals = [{
        'h5_path': '/eeg_0',
        'fs': 64,
        'processing': {
            "type": "clip_and_normalize",
            "args": {
                "min_value": -150,
                "max_value": 150,
            }
        }
    }, {
        'h5_path': '/eeg_1',
        'fs': 64,
        'processing': {
            "type": "clip_and_normalize",
            "args": {
                "min_value": -150,
                "max_value": 150,
            }
        }
    }]

    events = [
        {
            "name": "spindle",
            "h5_path": "spindle",
        },
    ]

    device = torch.device("cuda")

    dataset = BalancedEventDataset(
        h5_directory=h5_directory,
        signals=signals,
        events=events,
        window=window,
        fs=64,
        minimum_overlap=0.5,
        transformations=lambda x: x,
        ratio_positive=0.5,
        n_jobs=-1,
    )

    # default events
    default_event_sizes = [1 * dataset.fs, 0.5 * dataset.fs]

    net = DOSED3(
        input_shape=dataset.input_shape,
        number_of_classes=dataset.number_of_classes,
        detection_parameters={
            "overlap_non_maximum_suppression": 0.5,
            "classification_threshold": 0.5,
        },
        default_event_sizes=default_event_sizes,
    )

    optimizer_parameters = {
        "lr": 5e-3,
        "weight_decay": 1e-8,
    }
    loss_specs = {
        "type": "worst_negative_mining",
        "parameters": {
            "number_of_classes": dataset.number_of_classes,
            "device": device,
        }
    }

    trainer = trainers["adam"](
        net,
        optimizer_parameters=optimizer_parameters,
        loss_specs=loss_specs,
        epochs=2,
    )

    best_net_train, best_metrics_train, best_threshold_train = trainer.train(
        dataset,
        dataset,
        batch_size=12,
    )

    best_net_train.predict_dataset(dataset, best_threshold_train, batch_size=2)