def getBatches(*args): q, p_number, config = args import random random.seed(p_number) atlas, itk_atlas = loadAtlas(config) data = loadOASISData() train, test = data[:int(len(data) * config['split'] )], data[int(len(data) * config['split']):] volume_shape = config['resolution'] data_train = train[int(len(train) * config['validation']):] while True: minibatch = np.empty(shape=(config['batchsize'], *volume_shape, 2)) for i in range(config['batchsize']): idx_volume = random.choice(list(range(len(data_train)))) vol = readNormalizedVolumeByPath(data_train[idx_volume]['img'], itk_atlas) minibatch[i, :, :, :, 0] = atlas.reshape(volume_shape).astype("float32") minibatch[i, :, :, :, 1] = vol.reshape(volume_shape).astype("float32") q.put(minibatch)
def getTestData(config): atlas, itk_atlas = loadAtlas(config) data = loadOASISData() data_test = data[int(len(data) * config['split']):] volume_shape = config['resolution'] l = len(data_test) test = np.empty(shape=(l, *volume_shape, 2)) for i in range(l): vol = readNormalizedVolumeByPath(data_test[i]['img'], itk_atlas) test[i, :, :, :, 0] = atlas.reshape(volume_shape).astype("float32") test[i, :, :, :, 1] = vol.reshape(volume_shape).astype("float32") return test
def getValidationData(config): atlas, itk_atlas = loadAtlas(config) data = loadOASISData() train, test = data[:int(len(data) * config['split'] )], data[int(len(data) * config['split']):] volume_shape = config['resolution'] data_val = train[:int(len(train) * config['validation'])] l = len(data_val) val = np.empty(shape=(l, *volume_shape, 2)) for i in range(l): vol = readNormalizedVolumeByPath(data_val[i]['img'], itk_atlas) val[i, :, :, :, 0] = atlas.reshape(volume_shape).astype("float32") val[i, :, :, :, 1] = vol.reshape(volume_shape).astype("float32") return val
import multiprocessing as mp #mp.set_start_method("spawn") train_config = { 'batchsize': 1, 'split': 0.9, 'validation': 0.1, 'half_res': True, 'epochs': 200, 'atlas': 'atlas.nii.gz', 'model_output': 'model.pkl', 'exponentialSteps': 7, } training_elements = int( len(loadOASISData()) * train_config['split'] * (1 - train_config['validation'])) data_queue, processes = DataGenerator.stream(2, 1, train_config) validation_data = DataGenerator.getValidationData(train_config) validation_data_y = DataGenerator.inferYFromBatch(validation_data, train_config) def train_generator(): while True: minibatch = data_queue.get() yield minibatch, DataGenerator.inferYFromBatch(minibatch, train_config)