def generate_train_batch(self): subjects = self._data[0] subject_idx = int( random.uniform(0, len(subjects)) ) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if np.random.random() < 0.5: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # rnd_choice = np.random.random() # if rnd_choice < 0.33: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # elif rnd_choice < 0.66: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() seg = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save( self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation slice_direction = int(round(random.uniform(0, 2))) if slice_direction == 0: y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose( 0, 3, 1, 2 ) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(2, 3, 0, 1) sw = 5 #slice_window (only odd numbers allowed) pad = int((sw - 1) / 2) data_pad = np.zeros( (data.shape[0] + sw - 1, data.shape[1] + sw - 1, data.shape[2] + sw - 1, data.shape[3])).astype(data.dtype) data_pad[ pad:-pad, pad:-pad, pad:-pad, :] = data #padded with two slices of zeros on all sides batch = [] for s_idx in slice_idxs: if slice_direction == 0: #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5 x = data_pad[s_idx:s_idx + sw:, pad:-pad, pad:-pad, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 0, 3, 1, 2 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 1: x = data_pad[pad:-pad, s_idx:s_idx + sw, pad:-pad, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 1, 3, 0, 2 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 2: x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx + sw, :].astype( np.float32) # (5, y, z, channels) x = np.array(x).transpose( 2, 3, 0, 1 ) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) data_dict = { "data": np.array(batch), # (batch_size, channels, x, y, [z]) "seg": y } # (batch_size, channels, x, y, [z]) return data_dict
def train(self, HP): if HP.USE_VISLOGGER: try: from trixi.logger.visdom import PytorchVisdomLogger except ImportError: pass trixi = PytorchVisdomLogger(port=8080, auto_start=True) ExpUtils.print_and_save(HP, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics = dict(list(metrics.items()) + list(metrics_new.items())) for epoch_nr in range(HP.NUM_EPOCHS): start_time = time.time() # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr) # current_lr = HP.LEARNING_RATE batch_gen_time = 0 data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = { "train": 0, "test": 0, "validate": 0 } if HP.LOSS_WEIGHT_LEN == -1: weight_factor = float(HP.LOSS_WEIGHT) else: if epoch_nr < HP.LOSS_WEIGHT_LEN: # weight_factor = -(9./100.) * epoch_nr + 10. #ep0: 10 -> linear decrease -> ep100: 1 weight_factor = -((HP.LOSS_WEIGHT-1)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) else: weight_factor = 1. # weight_factor = 5. for type in ["train", "test", "validate"]: print_loss = [] start_time_batch_gen = time.time() batch_generator = self.dataManager.get_batches(batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS")) batch_gen_time = time.time() - start_time_batch_gen # print("batch_gen_time: {}s".format(batch_gen_time)) print("Start looping batches...") start_time_batch_part = time.time() for batch in batch_generator: #getting next batch takes around 0.14s -> second largest Time part after mode! start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne # y = y.astype(HP.LABELS_TYPE) #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out data_preparation_time += time.time() - start_time_data_preparation # self.model.learning_rate.set_value(np.float32(current_lr)) start_time_network = time.time() if type == "train": nr_of_updates += 1 loss, probs, f1 = self.model.train(x, y, weight_factor=weight_factor) # probs: # (bs, x, y, nrClasses) # loss, probs, f1, intermediate = self.model.train(x, y) elif type == "validate": loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor) elif type == "test": loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if HP.CALC_F1: if HP.EXPERIMENT_TYPE == "peak_regression": #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines # y_flat = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1])) # (bs*x*y, nr_of_classes) # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]}) #Numpy # y_right_order = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order) # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean() #Pytorch peak_f1_mean = np.array([s for s in list(f1.values())]).mean() #if f1 for multiple bundles metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD) #Pytorch 2 F1 # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean() # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean() # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"LenF1": peak_f1_mean_b}) #Single Bundle # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]}) # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics_onlyLoss(metrics, loss, type=type) metrics_time += time.time() - start_time_metrics print_loss.append(loss) if batch_nr[type] % HP.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() ExpUtils.print_and_save(HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format(type, epoch_nr, batch_nr[type] * HP.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / HP.PRINT_FREQ, 3))) print_loss = [] if HP.USE_VISLOGGER: ExpUtils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr) ################################### # Post Training tasks (each epoch) ################################### #Adapt LR if HP.LR_SCHEDULE: self.model.scheduler.step() # self.model.scheduler.step(np.mean(f1)) self.model.print_current_lr() # Average loss per batch over entire epoch metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates)) # Save Weights start_time_saving = time.time() if HP.SAVE_WEIGHTS: self.model.save_model(metrics, epoch_nr) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump(metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb")) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True) plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) ExpUtils.print_and_save(HP, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) ExpUtils.print_and_save(HP, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) ExpUtils.print_and_save(HP, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) ExpUtils.print_and_save(HP, " Epoch {}, time saving files: {}s".format(epoch_nr, saving_time)) ExpUtils.print_and_save(HP, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < HP.NUM_EPOCHS-1: metrics = MetricUtils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times)))) return metrics
def generate_train_batch(self): subjects = self._data[0] subject_idx = int( random.uniform(0, len(subjects)) ) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if self.HP.FEATURES_FILENAME == "12g90g270g": # if np.random.random() < 0.5: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() rnd_choice = np.random.random() if rnd_choice < 0.33: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() elif self.HP.FEATURES_FILENAME == "T1_Peaks270g": peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g": rnd_choice = np.random.random() if rnd_choice < 0.33: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: peaks = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() t1 = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) else: data = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data() seg = nib.load( join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save( self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.LABELS_FILENAME not in [ "bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080" ]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation if self.HP.TRAINING_SLICE_DIRECTION == "xyz": slice_direction = int(round(random.uniform(0, 2))) else: slice_direction = 1 #always use Y if slice_direction == 0: x = data[slice_idxs, :, :].astype( np.float32) # (batch_size, y, z, channels) y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose( 0, 3, 1, 2 ) # depth-channel has to be before width and height for Unet (but after batches) y = np.array(y).transpose( 0, 3, 1, 2 ) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: x = data[:, slice_idxs, :].astype( np.float32) # (x, batch_size, z, channels) y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(1, 3, 0, 2) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: x = data[:, :, slice_idxs].astype( np.float32) # (x, y, batch_size, channels) y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(2, 3, 0, 1) y = np.array(y).transpose(2, 3, 0, 1) data_dict = { "data": x, # (batch_size, channels, x, y, [z]) "seg": y } # (batch_size, channels, x, y, [z]) return data_dict
def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save( self.HP, "current learning rate: {}".format(param_group['lr']))
def create_network(self): def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda()) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None #faster return loss.data[0], probs, f1 def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print("\nERROR: Could not save weights because of IO Error\n") self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model
def create_network(self): # torch.backends.cudnn.benchmark = True #not faster def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda()) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs, intermediate = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) if self.HP.USE_VISLOGGER: probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) else: probs = None #faster return loss.data[0], probs, f1, intermediate def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print("\nERROR: Could not save weights because of IO Error\n") self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr'])) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 # NR_OF_GRADIENTS = 9 * 5 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # net = nn.DataParallel(net, device_ids=[0,1]) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) criterion = nn.BCEWithLogitsLoss() optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE) #very slow (half speed of Adamax) -> strange # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max") if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) #plot feature weights # weights = list(list(net.children())[0].children())[0].weight.cpu().data.numpy() # sequential -> conv2d # (64, 9, 3, 3) # weights = weights[:, 0:1, :, :] # select one input channel to plot # (64, 1, 3, 3) # weights = (weights*100).astype(np.uint8) # can not plot negative values (and if float only 0-1 allowed) -> not good: we remove negatives # plot_kernels(weights) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model self.print_current_lr = print_current_lr # self.scheduler = scheduler
def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))
def create_network(self): def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda( )) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None #faster return loss.data[0], probs, f1 def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*") ): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join( self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print( "\nERROR: Could not save weights because of IO Error\n" ) self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose( self.HP, "Loading weights ... ({})".format( join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model
def train(self, HP): if HP.USE_VISLOGGER: nvl = Nvl(name="Training") ExpUtils.print_and_save(HP, socket.gethostname()) epoch_times = [] nr_of_updates = 0 metrics = {} for type in ["train", "test", "validate"]: metrics_new = { "loss_" + type: [0], "f1_macro_" + type: [0], } metrics = dict(list(metrics.items()) + list(metrics_new.items())) for epoch_nr in range(HP.NUM_EPOCHS): start_time = time.time() # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr) # current_lr = HP.LEARNING_RATE batch_gen_time = 0 data_preparation_time = 0 network_time = 0 metrics_time = 0 saving_time = 0 plotting_time = 0 batch_nr = {"train": 0, "test": 0, "validate": 0} if HP.LOSS_WEIGHT_LEN == -1: weight_factor = float(HP.LOSS_WEIGHT) else: if epoch_nr < HP.LOSS_WEIGHT_LEN: # weight_factor = -(9./100.) * epoch_nr + 10. #ep0: 10 -> linear decrease -> ep100: 1 weight_factor = -((HP.LOSS_WEIGHT - 1) / float( HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT) else: weight_factor = 1. # weight_factor = 5. for type in ["train", "test", "validate"]: print_loss = [] start_time_batch_gen = time.time() batch_generator = self.dataManager.get_batches( batch_size=HP.BATCH_SIZE, type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS")) batch_gen_time = time.time() - start_time_batch_gen # print("batch_gen_time: {}s".format(batch_gen_time)) print("Start looping batches...") start_time_batch_part = time.time() for batch in batch_generator: #getting next batch takes around 0.14s -> second largest Time part after UNet! start_time_data_preparation = time.time() batch_nr[type] += 1 x = batch["data"] # (bs, nr_of_channels, x, y) y = batch["seg"] # (bs, nr_of_classes, x, y) # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne # y = y.astype(HP.LABELS_TYPE) #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out data_preparation_time += time.time( ) - start_time_data_preparation # self.model.learning_rate.set_value(np.float32(current_lr)) start_time_network = time.time() if type == "train": nr_of_updates += 1 loss, probs, f1 = self.model.train( x, y, weight_factor=weight_factor ) # probs: # (bs, x, y, nrClasses) # loss, probs, f1, intermediate = self.model.train(x, y) elif type == "validate": loss, probs, f1 = self.model.predict( x, y, weight_factor=weight_factor) elif type == "test": loss, probs, f1 = self.model.predict( x, y, weight_factor=weight_factor) network_time += time.time() - start_time_network start_time_metrics = time.time() if HP.CALC_F1: if HP.LABELS_TYPE == np.int16: metrics = MetricUtils.calculate_metrics( metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD) else: #Regression #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines # y_flat = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1])) # (bs*x*y, nr_of_classes) # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]}) #Numpy # y_right_order = y.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes) # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order) # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean() #Pytorch peak_f1_mean = np.array([ s for s in list(f1.values()) ]).mean() #if f1 for multiple bundles metrics = MetricUtils.calculate_metrics( metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD) #Pytorch 2 F1 # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean() # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean() # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"LenF1": peak_f1_mean_b}) #Single Bundle # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD, # f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]}) # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD) else: metrics = MetricUtils.calculate_metrics_onlyLoss( metrics, loss, type=type) metrics_time += time.time() - start_time_metrics print_loss.append(loss) if batch_nr[type] % HP.PRINT_FREQ == 0: time_batch_part = time.time() - start_time_batch_part start_time_batch_part = time.time() ExpUtils.print_and_save( HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s" .format(type, epoch_nr, batch_nr[type] * HP.BATCH_SIZE, round(np.array(print_loss).mean(), 6), round(time_batch_part, 3), round(time_batch_part / HP.PRINT_FREQ, 3))) print_loss = [] if HP.USE_VISLOGGER: x_norm = (x - x.min()) / (x.max() - x.min()) nvl.show_images( x_norm[0:1, :, :, :].transpose((1, 0, 2, 3)), name="input batch", title="Input batch") #all channels of one batch probs_shaped = probs[:, :, :, 15:16].transpose( (0, 3, 1, 2)) # (bs, 1, x, y) probs_shaped_bin = (probs_shaped > 0.5).astype( np.int16) nvl.show_images(probs_shaped, name="predictions", title="Predictions Probmap") # nvl.show_images(probs_shaped_bin, name="predictions_binary", title="Predictions Binary") # Show GT and Prediction in one image (bundle: CST) # GREEN: GT; RED: prediction (FP); YELLOW: prediction (TP) combined = np.zeros( (y.shape[0], 3, y.shape[2], y.shape[3])) combined[:, 0:1, :, :] = probs_shaped_bin #Red combined[:, 1:2, :, :] = y[:, 15:16, :, :] #Green nvl.show_images(combined, name="predictions_combined", title="Combined") #Show feature activations contr_1_2 = intermediate[2].data.cpu().numpy( ) # (bs, nr_feature_channels=64, x, y) contr_1_2 = contr_1_2[0:1, :, :, :].transpose( (1, 0, 2, 3)) # (nr_feature_channels=64, 1, x, y) contr_1_2 = (contr_1_2 - contr_1_2.min()) / ( contr_1_2.max() - contr_1_2.min()) nvl.show_images(contr_1_2, name="contr_1_2", title="contr_1_2") # Show feature activations contr_3_2 = intermediate[1].data.cpu().numpy( ) # (bs, nr_feature_channels=64, x, y) contr_3_2 = contr_3_2[0:1, :, :, :].transpose( (1, 0, 2, 3)) # (nr_feature_channels=64, 1, x, y) contr_3_2 = (contr_3_2 - contr_3_2.min()) / ( contr_3_2.max() - contr_3_2.min()) nvl.show_images(contr_3_2, name="contr_3_2", title="contr_3_2") # Show feature activations deconv_2 = intermediate[0].data.cpu().numpy( ) # (bs, nr_feature_channels=64, x, y) deconv_2 = deconv_2[0:1, :, :, :].transpose( (1, 0, 2, 3)) # (nr_feature_channels=64, 1, x, y) deconv_2 = (deconv_2 - deconv_2.min()) / ( deconv_2.max() - deconv_2.min()) nvl.show_images(deconv_2, name="deconv_2", title="deconv_2") nvl.show_value(float(loss), name="loss") nvl.show_value(float(np.mean(f1)), name="f1") ################################### # Post Training tasks (each epoch) ################################### #Adapt LR # self.model.scheduler.step() # self.model.scheduler.step(np.mean(f1)) # self.model.print_current_lr() # Average loss per batch over entire epoch metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate") metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test") print(" Epoch {}, Average Epoch loss = {}".format( epoch_nr, metrics["loss_train"][-1])) print(" Epoch {}, nr_of_updates {}".format( epoch_nr, nr_of_updates)) # Save Weights start_time_saving = time.time() if HP.SAVE_WEIGHTS: self.model.save_model(metrics, epoch_nr) saving_time += time.time() - start_time_saving # Create Plots start_time_plotting = time.time() pickle.dump( metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb") ) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb")) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME) ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True) plotting_time += time.time() - start_time_plotting epoch_time = time.time() - start_time epoch_times.append(epoch_time) ExpUtils.print_and_save( HP, " Epoch {}, time total {}s".format(epoch_nr, epoch_time)) ExpUtils.print_and_save( HP, " Epoch {}, time UNet: {}s".format(epoch_nr, network_time)) ExpUtils.print_and_save( HP, " Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time)) ExpUtils.print_and_save( HP, " Epoch {}, time saving files: {}s".format( epoch_nr, saving_time)) ExpUtils.print_and_save(HP, str(datetime.datetime.now())) # Adding next Epoch if epoch_nr < HP.NUM_EPOCHS - 1: metrics = MetricUtils.add_empty_element(metrics) #################################### # After all epochs ################################### with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f: # a for append f.write("\n\n") f.write("Average Epoch time: {}s".format( sum(epoch_times) / float(len(epoch_times)))) return metrics
def generate_train_batch(self): subjects = self._data[0] subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if np.random.random() < 0.5: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # rnd_choice = np.random.random() # if rnd_choice < 0.33: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # elif rnd_choice < 0.66: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation slice_direction = int(round(random.uniform(0,2))) if slice_direction == 0: y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(0, 3, 1, 2) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) y = np.array(y).transpose(2, 3, 0, 1) sw = 5 #slice_window (only odd numbers allowed) pad = int((sw-1) / 2) data_pad = np.zeros((data.shape[0]+sw-1, data.shape[1]+sw-1, data.shape[2]+sw-1, data.shape[3])).astype(data.dtype) data_pad[pad:-pad, pad:-pad, pad:-pad, :] = data #padded with two slices of zeros on all sides batch=[] for s_idx in slice_idxs: if slice_direction == 0: #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5 x = data_pad[s_idx:s_idx+sw:, pad:-pad, pad:-pad, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(0, 3, 1, 2) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 1: x = data_pad[pad:-pad, s_idx:s_idx+sw, pad:-pad, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(1, 3, 0, 2) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) elif slice_direction == 2: x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx+sw, :].astype(np.float32) # (5, y, z, channels) x = np.array(x).transpose(2, 3, 0, 1) # channels dim has to be before width and height for Unet (but after batches) x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) # (5*channels, y, z) batch.append(x) data_dict = {"data": np.array(batch), # (batch_size, channels, x, y, [z]) "seg": y} # (batch_size, channels, x, y, [z]) return data_dict
def generate_train_batch(self): subjects = self._data[0] subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor for i in range(20): try: if self.HP.FEATURES_FILENAME == "12g90g270g": # if np.random.random() < 0.5: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() # else: # data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() rnd_choice = np.random.random() if rnd_choice < 0.33: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() elif self.HP.FEATURES_FILENAME == "T1_Peaks270g": peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g": rnd_choice = np.random.random() if rnd_choice < 0.33: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data() elif rnd_choice < 0.66: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data() else: peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data() t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data() data = np.concatenate((peaks, t1), axis=3) else: data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data() seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data() break except IOError: ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n") ExpUtils.print_and_save(self.HP, "Sleeping 20s") sleep(20) # ExpUtils.print_and_save(self.HP, "Successfully loaded input.") data = np.nan_to_num(data) # Needed otherwise not working seg = np.nan_to_num(seg) data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, channels) if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080"]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION) # (x, y, z, classes) slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None) # Randomly sample slice orientation if self.HP.TRAINING_SLICE_DIRECTION == "xyz": slice_direction = int(round(random.uniform(0,2))) else: slice_direction = 1 #always use Y if slice_direction == 0: x = data[slice_idxs, :, :].astype(np.float32) # (batch_size, y, z, channels) y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(0, 3, 1, 2) # depth-channel has to be before width and height for Unet (but after batches) y = np.array(y).transpose(0, 3, 1, 2) # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y) elif slice_direction == 1: x = data[:, slice_idxs, :].astype(np.float32) # (x, batch_size, z, channels) y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(1, 3, 0, 2) y = np.array(y).transpose(1, 3, 0, 2) elif slice_direction == 2: x = data[:, :, slice_idxs].astype(np.float32) # (x, y, batch_size, channels) y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE) x = np.array(x).transpose(2, 3, 0, 1) y = np.array(y).transpose(2, 3, 0, 1) data_dict = {"data": x, # (batch_size, channels, x, y, [z]) "seg": y} # (batch_size, channels, x, y, [z]) return data_dict
def create_network(self): # torch.backends.cudnn.benchmark = True #not faster def train(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda()), Variable(y.cuda( )) # X: (bs, features, x, y) y: (bs, classes, x, y) else: X, y = Variable(X), Variable(y) optimizer.zero_grad() net.train() outputs = net(X) # forward # outputs: (bs, classes, x, y) loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) loss.backward() # backward optimizer.step() # optimise f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) if self.HP.USE_VISLOGGER: probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) else: probs = None #faster return loss.data[0], probs, f1 def test(X, y): X = torch.from_numpy(X.astype(np.float32)) y = torch.from_numpy(y.astype(np.float32)) if torch.cuda.is_available(): X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True) else: X, y = Variable(X, volatile=True), Variable(y, volatile=True) net.train(False) outputs = net(X) # forward loss = criterion(outputs, y) # loss = PytorchUtils.soft_dice(outputs, y) f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True) # probs = outputs.data.cpu().numpy().transpose(0,2,3,1) # (bs, x, y, classes) probs = None # faster return loss.data[0], probs, f1 def predict(X): X = torch.from_numpy(X.astype(np.float32)) if torch.cuda.is_available(): X = Variable(X.cuda(), volatile=True) else: X = Variable(X, volatile=True) net.train(False) outputs = net(X) # forward probs = outputs.data.cpu().numpy().transpose( 0, 2, 3, 1) # (bs, x, y, classes) return probs def save_model(metrics, epoch_nr): max_f1_idx = np.argmax(metrics["f1_macro_validate"]) max_f1 = np.max(metrics["f1_macro_validate"]) if epoch_nr == max_f1_idx and max_f1 > 0.01: # saving to network drives takes 5s (to local only 0.5s) -> do not save so often print(" Saving weights...") for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*") ): # remove weights from previous epochs os.remove(fl) try: #Actually is a pkl not a npz PytorchUtils.save_checkpoint(join( self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net) except IOError: print( "\nERROR: Could not save weights because of IO Error\n" ) self.HP.BEST_EPOCH = epoch_nr def load_model(path): PytorchUtils.load_checkpoint(path, unet=net) def print_current_lr(): for param_group in optimizer.param_groups: ExpUtils.print_and_save( self.HP, "current learning rate: {}".format(param_group['lr'])) if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction": NR_OF_GRADIENTS = 9 # NR_OF_GRADIENTS = 9 * 5 # NR_OF_GRADIENTS = 9 * 9 # NR_OF_GRADIENTS = 33 elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined": NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES else: NR_OF_GRADIENTS = 33 if torch.cuda.is_available(): net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda() else: net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT) # net = nn.DataParallel(net, device_ids=[0,1]) if self.HP.TRAIN: ExpUtils.print_and_save(self.HP, str(net), only_log=True) # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda() # weights[:, 5, :, :] *= 10 #CA # weights[:, 21, :, :] *= 10 #FX_left # weights[:, 22, :, :] *= 10 #FX_right # criterion = nn.BCEWithLogitsLoss(weight=weights) criterion = nn.BCEWithLogitsLoss() optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE) # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE) #very slow (half speed of Adamax) -> strange # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max") if self.HP.LOAD_WEIGHTS: ExpUtils.print_verbose( self.HP, "Loading weights ... ({})".format( join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))) load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)) self.train = train self.predict = test self.get_probs = predict self.save_model = save_model self.load_model = load_model self.print_current_lr = print_current_lr