Esempio n. 1
0
def eval_seizure_model(model, model_type):
    example = io.StringIO(example_seizure_file)
    dataset = load_seizure_dataset(example, model_type)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=3,
                                         shuffle=False,
                                         num_workers=0)

    return model_eval(model, loader)
Esempio n. 2
0
def test_load_seizure_dataset_mlp():
    example = io.StringIO(example_seizure_file)
    tensor_dataset = load_seizure_dataset(example, 'MLP')

    expect(
        type(tensor_dataset) == torch.utils.data.dataset.TensorDataset,
        "it should return TensorDataset")

    data_tensor, target_tensor = tensor_dataset.tensors

    expect(data_tensor.size() == (3, 178), "it does not have expected shapes")
    expect(target_tensor.size() == (3, ), "it does not have expected shapes")
    assert_expectations()
Esempio n. 3
0
PATH_TRAIN_FILE = "../data/seizure/seizure_train.csv"
PATH_VALID_FILE = "../data/seizure/seizure_validation.csv"
PATH_TEST_FILE = "../data/seizure/seizure_test.csv"

# Path for saving model
PATH_OUTPUT = "../output/seizure/"
os.makedirs(PATH_OUTPUT, exist_ok=True)

# Some parameters
MODEL_TYPE = 'MLP'  # TODO: Change this to 'MLP', 'CNN', or 'RNN' according to your task
NUM_EPOCHS = 3
BATCH_SIZE = 32
USE_CUDA = False  # Set 'True' if you want to use GPU
NUM_WORKERS = 0  # Number of threads used by DataLoader. You can adjust this according to your machine spec.

train_dataset = load_seizure_dataset(PATH_TRAIN_FILE, MODEL_TYPE)
valid_dataset = load_seizure_dataset(PATH_VALID_FILE, MODEL_TYPE)
test_dataset = load_seizure_dataset(PATH_TEST_FILE, MODEL_TYPE)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True,
                                           num_workers=NUM_WORKERS)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=False,
                                           num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=NUM_WORKERS)