def _get_weights_for_this_epoch(Config, epoch_nr): if Config.LOSS_WEIGHT is None: weight_factor = None elif Config.LOSS_WEIGHT_LEN == -1: weight_factor = float(Config.LOSS_WEIGHT) else: # Linearly decrease from LOSS_WEIGHT to 1 over LOSS_WEIGHT_LEN epochs if epoch_nr < Config.LOSS_WEIGHT_LEN: weight_factor = -((Config.LOSS_WEIGHT - 1) / float(Config.LOSS_WEIGHT_LEN)) * epoch_nr + float(Config.LOSS_WEIGHT) else: weight_factor = 1. exp_utils.print_and_save(Config, "Current weight_factor: {}".format(weight_factor)) return weight_factor
def print_current_lr(self): for param_group in self.optimizer.param_groups: exp_utils.print_and_save( self.Config, "current learning rate: {}".format(param_group['lr']))
def train_model(Config, model, data_loader): if Config.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) exp_utils.print_and_save(Config, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: for metric in Config.METRIC_TYPES: metrics[metric + "_" + type] = [0] batch_gen_train = data_loader.get_batch_generator(batch_size=Config.BATCH_SIZE, type="train", subjects=getattr(Config, "TRAIN_SUBJECTS")) batch_gen_val = data_loader.get_batch_generator(batch_size=Config.BATCH_SIZE, type="validate", subjects=getattr(Config, "VALIDATE_SUBJECTS")) for epoch_nr in range(Config.NUM_EPOCHS): start_time = time.time() timings = defaultdict(lambda: 0) batch_nr = defaultdict(lambda: 0) weight_factor = _get_weights_for_this_epoch(Config, epoch_nr) types = ["validate"] if Config.ONLY_VAL else ["train", "validate"] for type in types: print_loss = [] if Config.DIM == "2D": nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0] else: nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net nr_batches = int(int(nr_of_samples / Config.BATCH_SIZE) * Config.EPOCH_MULTIPLIER) print("Start looping batches...") start_time_batch_part = time.time() for i in range(nr_batches): batch = next(batch_gen_train) if type == "train" else next(batch_gen_val) start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) timings["data_preparation_time"] += time.time() - start_time_data_preparation start_time_network = time.time() if type == "train": nr_of_updates += 1 probs, metr_batch = model.train(x, y, weight_factor=weight_factor) elif type == "validate": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) elif type == "test": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) timings["network_time"] += time.time() - start_time_network start_time_metrics = time.time() metrics = _update_metrics(Config, metrics, metr_batch, type) timings["metrics_time"] += time.time() - start_time_metrics print_loss.append(metr_batch["loss"]) if batch_nr[type] % Config.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() exp_utils.print_and_save(Config, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format( type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round( time_batch_part / Config.PRINT_FREQ, 3))) print_loss = [] if Config.USE_VISLOGGER: plot_utils.plot_result_trixi(trixi, x, y, probs, metr_batch["loss"], metr_batch["f1_macro"], epoch_nr) ################################### Post Training tasks (each epoch) ################################### if Config.ONLY_VAL: metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") print("f1 macro validate: {}".format(round(metrics["f1_macro_validate"][0], 4))) return model # Average loss per batch over entire epoch metrics = metric_utils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") print(" Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Adapt LR if Config.LR_SCHEDULE: if Config.LR_SCHEDULE_MODE == "min": model.scheduler.step(metrics["loss_validate"][-1]) else: model.scheduler.step(metrics["f1_macro_validate"][-1]) model.print_current_lr() # Save Weights start_time_saving = time.time() if Config.SAVE_WEIGHTS: model.save_model(metrics, epoch_nr, mode=Config.BEST_EPOCH_SELECTION) timings["saving_time"] += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb")) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_all.png") plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics.png") if "angle_err" in Config.METRIC_TYPES: plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "angle_err"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_angle.png") timings["plotting_time"] += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) exp_utils.print_and_save(Config, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) exp_utils.print_and_save(Config, " Epoch {}, time UNet: {}s".format(epoch_nr, timings["network_time"])) exp_utils.print_and_save(Config, " Epoch {}, time metrics: {}s".format(epoch_nr, timings["metrics_time"])) exp_utils.print_and_save(Config, " Epoch {}, time saving files: {}s".format(epoch_nr, timings["saving_time"])) exp_utils.print_and_save(Config, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < Config.NUM_EPOCHS-1: metrics = metric_utils.add_empty_element(metrics) with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "a") as f: f.write("\n\nAverage Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times))))
def train_model(Config, model, data_loader): if Config.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) exp_utils.print_and_save(Config, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics = dict(list(metrics.items()) + list(metrics_new.items())) for epoch_nr in range(Config.NUM_EPOCHS): start_time = time.time() # current_lr = Config.LEARNING_RATE * (Config.LR_DECAY ** epoch_nr) # current_lr = Config.LEARNING_RATE batch_gen_time = 0 data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = {"train": 0, "test": 0, "validate": 0} if Config.LOSS_WEIGHT_LEN == -1: weight_factor = float(Config.LOSS_WEIGHT) else: if epoch_nr < Config.LOSS_WEIGHT_LEN: weight_factor = -( (Config.LOSS_WEIGHT - 1) / float(Config.LOSS_WEIGHT_LEN)) * epoch_nr + float( Config.LOSS_WEIGHT) else: weight_factor = 1. for type in ["train", "test", "validate"]: print_loss = [] start_time_batch_gen = time.time() batch_gen = data_loader.get_batch_generator( batch_size=Config.BATCH_SIZE, type=type, subjects=getattr(Config, type.upper() + "_SUBJECTS")) batch_gen_time = time.time() - start_time_batch_gen # print("batch_gen_time: {}s".format(batch_gen_time)) if Config.DIM == "2D": nr_of_samples = len(getattr( Config, type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0] else: nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net nr_batches = int( int(nr_of_samples / Config.BATCH_SIZE) * Config.EPOCH_MULTIPLIER) print("Start looping batches...") start_time_batch_part = time.time() for i in range(nr_batches): batch = next(batch_gen) start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) data_preparation_time += time.time( ) - start_time_data_preparation start_time_network = time.time() if type == "train": nr_of_updates += 1 loss, probs, f1 = model.train(x, y, weight_factor=weight_factor) # loss, probs, f1, intermediate = model.train(x, y) elif type == "validate": loss, probs, f1 = model.test(x, y, weight_factor=weight_factor) elif type == "test": loss, probs, f1 = model.test(x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if Config.CALC_F1: if Config.EXPERIMENT_TYPE == "peak_regression": #Following two lines increase metrics_time by 30s (without < 1s); # time per batch increases by 1.5s by these lines # y_flat = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1])) # (bs*x*y, nr_of_classes) # metrics = metric_utils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), # type=type, threshold=Config.THRESHOLD, # f1_per_bundle={"CA": f1[5], "FX_left": f1[23], # "FX_right": f1[24]}) #Numpy # y_right_order = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # peak_f1 = metric_utils.calc_peak_dice(Config, probs, y_right_order) # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean() # import IPython # IPython.embed() #Pytorch peak_f1_mean = np.array([ s.to('cpu') for s in list(f1.values()) ]).mean() #if f1 for multiple bundles metrics = metric_utils.calculate_metrics( metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=Config.THRESHOLD) #Pytorch 2 F1 # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean() # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean() # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, # type=type, threshold=Config.THRESHOLD, # f1_per_bundle={"LenF1": peak_f1_mean_b}) #Single Bundle # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], # type=type, threshold=Config.THRESHOLD, # f1_per_bundle={"Thr1": f1["CST_right"][1], # "Thr2": f1["CST_right"][2]}) # metrics = metric_utils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], # type=type, threshold=Config.THRESHOLD) else: metrics = metric_utils.calculate_metrics( metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=Config.THRESHOLD) else: metrics = metric_utils.calculate_metrics_onlyLoss( metrics, loss, type=type) metrics_time += time.time() - start_time_metrics print_loss.append(loss) if batch_nr[type] % Config.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() exp_utils.print_and_save( Config, "{} Ep {}, Sp {}, loss {}, t print {}s, " "t batch {}s".format( type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / Config.PRINT_FREQ, 3))) print_loss = [] if Config.USE_VISLOGGER: plot_utils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr) ################################### # Post Training tasks (each epoch) ################################### # Average loss per batch over entire epoch metrics = metric_utils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") metrics = metric_utils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format( epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Adapt LR if Config.LR_SCHEDULE: if Config.LR_SCHEDULE_MODE == "min": model.scheduler.step(metrics["loss_validate"][-1]) else: model.scheduler.step(metrics["f1_macro_validate"][-1]) model.print_current_lr() # Save Weights start_time_saving = time.time() if Config.SAVE_WEIGHTS: model.save_model(metrics, epoch_nr) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb")) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True) plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) exp_utils.print_and_save( Config, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) exp_utils.print_and_save( Config, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) exp_utils.print_and_save( Config, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) exp_utils.print_and_save( Config, " Epoch {}, time saving files: {}s".format(epoch_nr, saving_time)) exp_utils.print_and_save(Config, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < Config.NUM_EPOCHS - 1: metrics = metric_utils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format( sum(epoch_times) / float(len(epoch_times)))) return model
def train_model(Config, model, data_loader): if Config.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) exp_utils.print_and_save(Config, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = {} for metric in Config.METRIC_TYPES: metrics_new[metric + "_" + type] = [0] metrics = dict(list(metrics.items()) + list(metrics_new.items())) batch_gen_train = data_loader.get_batch_generator( batch_size=Config.BATCH_SIZE, type="train", subjects=getattr(Config, "TRAIN_SUBJECTS")) batch_gen_val = data_loader.get_batch_generator( batch_size=Config.BATCH_SIZE, type="validate", subjects=getattr(Config, "VALIDATE_SUBJECTS")) for epoch_nr in range(Config.NUM_EPOCHS): start_time = time.time() # current_lr = Config.LEARNING_RATE * (Config.LR_DECAY ** epoch_nr) # current_lr = Config.LEARNING_RATE data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = {"train": 0, "test": 0, "validate": 0} if Config.LOSS_WEIGHT is None: weight_factor = None elif Config.LOSS_WEIGHT_LEN == -1: weight_factor = float(Config.LOSS_WEIGHT) else: # Linearly decrease from LOSS_WEIGHT to 1 over LOSS_WEIGHT_LEN epochs if epoch_nr < Config.LOSS_WEIGHT_LEN: weight_factor = -( (Config.LOSS_WEIGHT - 1) / float(Config.LOSS_WEIGHT_LEN)) * epoch_nr + float( Config.LOSS_WEIGHT) else: weight_factor = 1. exp_utils.print_and_save( Config, "Current weight_factor: {}".format(weight_factor)) if Config.ONLY_VAL: types = ["validate"] else: types = ["train", "validate"] for type in types: print_loss = [] if Config.DIM == "2D": nr_of_samples = len(getattr( Config, type.upper() + "_SUBJECTS")) * Config.INPUT_DIM[0] else: nr_of_samples = len(getattr(Config, type.upper() + "_SUBJECTS")) # *Config.EPOCH_MULTIPLIER needed to have roughly same number of updates/batches as with 2D U-Net nr_batches = int( int(nr_of_samples / Config.BATCH_SIZE) * Config.EPOCH_MULTIPLIER) print("Start looping batches...") start_time_batch_part = time.time() for i in range(nr_batches): if type == "train": batch = next(batch_gen_train) else: batch = next(batch_gen_val) start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) # print("x.shape: {}".format(x.shape)) # print("y.shape: {}".format(y.shape)) data_preparation_time += time.time( ) - start_time_data_preparation start_time_network = time.time() if type == "train": nr_of_updates += 1 probs, metr_batch = model.train( x, y, weight_factor=weight_factor) elif type == "validate": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) elif type == "test": probs, metr_batch = model.test(x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if Config.CALC_F1: if Config.EXPERIMENT_TYPE == "peak_regression": peak_f1_mean = np.array([ s.to('cpu') for s in list(metr_batch["f1_macro"].values()) ]).mean() metr_batch["f1_macro"] = peak_f1_mean metrics = metric_utils.add_to_metrics( metrics, metr_batch, type, Config.METRIC_TYPES) else: metr_batch["f1_macro"] = np.mean( metr_batch["f1_macro"]) metrics = metric_utils.add_to_metrics( metrics, metr_batch, type, Config.METRIC_TYPES) else: metrics = metric_utils.calculate_metrics_onlyLoss( metrics, metr_batch["loss"], type=type) metrics_time += time.time() - start_time_metrics print_loss.append(metr_batch["loss"]) if batch_nr[type] % Config.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() exp_utils.print_and_save( Config, "{} Ep {}, Sp {}, loss {}, t print {}s, " "t batch {}s".format( type, epoch_nr, batch_nr[type] * Config.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / Config.PRINT_FREQ, 3))) print_loss = [] if Config.USE_VISLOGGER: plot_utils.plot_result_trixi(trixi, x, y, probs, metr_batch["loss"], metr_batch["f1_macro"], epoch_nr) ################################### # Post Training tasks (each epoch) ################################### if Config.ONLY_VAL: metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") print("f1 macro validate: {}".format( round(metrics["f1_macro_validate"][0], 4))) return model # Average loss per batch over entire epoch metrics = metric_utils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = metric_utils.normalize_last_element(metrics, batch_nr["validate"], type="validate") # metrics = metric_utils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format( epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Adapt LR if Config.LR_SCHEDULE: if Config.LR_SCHEDULE_MODE == "min": model.scheduler.step(metrics["loss_validate"][-1]) else: model.scheduler.step(metrics["f1_macro_validate"][-1]) model.print_current_lr() # Save Weights start_time_saving = time.time() if Config.SAVE_WEIGHTS: model.save_model(metrics, epoch_nr, mode=Config.BEST_EPOCH_SELECTION) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(Config.EXP_PATH, "metrics.pkl"), "wb")) plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_all.png") plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "f1_macro"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics.png") if "angle_err" in Config.METRIC_TYPES: plot_utils.create_exp_plot(metrics, Config.EXP_PATH, Config.EXP_NAME, without_first_epochs=True, keys=["loss", "angle_err"], types=["train", "validate"], selected_ax=["loss", "f1"], fig_name="metrics_angle.png") plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) exp_utils.print_and_save( Config, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) exp_utils.print_and_save( Config, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) exp_utils.print_and_save( Config, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) exp_utils.print_and_save( Config, " Epoch {}, time saving files: {}s".format(epoch_nr, saving_time)) exp_utils.print_and_save(Config, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < Config.NUM_EPOCHS - 1: metrics = metric_utils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(Config.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format( sum(epoch_times) / float(len(epoch_times)))) return model
def load_training_data(Config, subject): """ Load data and labels for one subject from the training set. Cut and scale to make them have correct size. :param Config: config class :param subject: subject id (string) :return: """ for i in range(20): try: if Config.FEATURES_FILENAME == "12g90g270g": # if np.random.random() < 0.5: # data = nib.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() rnd_choice = np.random.random() if rnd_choice < 0.33: data = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: data = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "90g_125mm_peaks.nii.gz")).get_data() else: data = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "12g_125mm_peaks.nii.gz")).get_data() elif Config.FEATURES_FILENAME == "T1_Peaks270g": peaks = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "270g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) elif Config.FEATURES_FILENAME == "T1_Peaks12g90g270g": rnd_choice = np.random.random() if rnd_choice < 0.33: peaks = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: peaks = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "90g_125mm_peaks.nii.gz")).get_data() else: peaks = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "12g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) else: data = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, Config.FEATURES_FILENAME + ".nii.gz")).get_data() break except IOError: exp_utils.print_and_save( Config, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") exp_utils.print_and_save(Config, "Sleeping 20s") sleep(20) data = np.nan_to_num(data) # Needed otherwise not working data = dataset_utils.scale_input_to_unet_shape( data, Config.DATASET, Config.RESOLUTION) # (x, y, z, channels) seg = nib.load( join(C.DATA_PATH, Config.DATASET_FOLDER, subject, Config.LABELS_FILENAME + ".nii.gz")).get_data() seg = np.nan_to_num(seg) if Config.LABELS_FILENAME not in [ "bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080", "bundle_peaks_Part1_808080", "bundle_peaks_Part2_808080", "bundle_peaks_Part3_808080", "bundle_peaks_Part4_808080" ]: if Config.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = dataset_utils.scale_input_to_unet_shape( seg, "HCP", Config.RESOLUTION) else: seg = dataset_utils.scale_input_to_unet_shape( seg, Config.DATASET, Config.RESOLUTION) # (x, y, z, classes) return data, seg