예제 #1
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(
            random.uniform(0, len(subjects))
        )  # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if np.random.random() < 0.5:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "270g_125mm_peaks.nii.gz")).get_data()
                else:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "90g_125mm_peaks.nii.gz")).get_data()

                # rnd_choice = np.random.random()
                # if rnd_choice < 0.33:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                # elif rnd_choice < 0.66:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                # else:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()

                seg = nib.load(
                    join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                         subjects[subject_idx],
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(
                    self.HP,
                    "\n\nWARNING: Could not load file. Trying again in 20s (Try number: "
                    + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)  # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(
            data, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, channels)
        if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
            # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, "HCP", self.HP.RESOLUTION)
        else:
            seg = DatasetUtils.scale_input_to_unet_shape(
                seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False,
                                      None)

        # Randomly sample slice orientation
        slice_direction = int(round(random.uniform(0, 2)))

        if slice_direction == 0:
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(
                0, 3, 1, 2
            )  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(2, 3, 0, 1)

        sw = 5  #slice_window (only odd numbers allowed)
        pad = int((sw - 1) / 2)

        data_pad = np.zeros(
            (data.shape[0] + sw - 1, data.shape[1] + sw - 1,
             data.shape[2] + sw - 1, data.shape[3])).astype(data.dtype)
        data_pad[
            pad:-pad, pad:-pad,
            pad:-pad, :] = data  #padded with two slices of zeros on all sides
        batch = []
        for s_idx in slice_idxs:
            if slice_direction == 0:
                #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5
                x = data_pad[s_idx:s_idx + sw:, pad:-pad, pad:-pad, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    0, 3, 1, 2
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 1:
                x = data_pad[pad:-pad, s_idx:s_idx + sw, pad:-pad, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    1, 3, 0, 2
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 2:
                x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx + sw, :].astype(
                    np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(
                    2, 3, 0, 1
                )  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2],
                                   x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
        data_dict = {
            "data": np.array(batch),  # (batch_size, channels, x, y, [z])
            "seg": y
        }  # (batch_size, channels, x, y, [z])

        return data_dict
예제 #2
0
    def train(self, HP):

        if HP.USE_VISLOGGER:
            try:
                from trixi.logger.visdom import PytorchVisdomLogger
            except ImportError:
                pass
            trixi = PytorchVisdomLogger(port=8080, auto_start=True)

        ExpUtils.print_and_save(HP, socket.gethostname())

        epoch_times = []
        nr_of_updates = 0

        metrics = {}
        for type in ["train", "test", "validate"]:
            metrics_new = {
                "loss_" + type: [0],
                "f1_macro_" + type: [0],
            }
            metrics = dict(list(metrics.items()) + list(metrics_new.items()))

        for epoch_nr in range(HP.NUM_EPOCHS):
            start_time = time.time()
            # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr)
            # current_lr = HP.LEARNING_RATE

            batch_gen_time = 0
            data_preparation_time = 0
            network_time = 0
            metrics_time = 0
            saving_time = 0
            plotting_time = 0

            batch_nr = {
                "train": 0,
                "test": 0,
                "validate": 0
            }

            if HP.LOSS_WEIGHT_LEN == -1:
                weight_factor = float(HP.LOSS_WEIGHT)
            else:
                if epoch_nr < HP.LOSS_WEIGHT_LEN:
                    # weight_factor = -(9./100.) * epoch_nr + 10.   #ep0: 10 -> linear decrease -> ep100: 1
                    weight_factor = -((HP.LOSS_WEIGHT-1)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                    # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                else:
                    weight_factor = 1.
                    # weight_factor = 5.

            for type in ["train", "test", "validate"]:
                print_loss = []
                start_time_batch_gen = time.time()

                batch_generator = self.dataManager.get_batches(batch_size=HP.BATCH_SIZE,
                                                               type=type, subjects=getattr(HP, type.upper() + "_SUBJECTS"))
                batch_gen_time = time.time() - start_time_batch_gen
                # print("batch_gen_time: {}s".format(batch_gen_time))

                print("Start looping batches...")
                start_time_batch_part = time.time()
                for batch in batch_generator:                   #getting next batch takes around 0.14s -> second largest Time part after mode!

                    start_time_data_preparation = time.time()
                    batch_nr[type] += 1

                    x = batch["data"] # (bs, nr_of_channels, x, y)
                    y = batch["seg"]  # (bs, nr_of_classes, x, y)
                    # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne
                    # y = y.astype(HP.LABELS_TYPE)  #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out

                    data_preparation_time += time.time() - start_time_data_preparation
                    # self.model.learning_rate.set_value(np.float32(current_lr))
                    start_time_network = time.time()
                    if type == "train":
                        nr_of_updates += 1
                        loss, probs, f1 = self.model.train(x, y, weight_factor=weight_factor)    # probs: # (bs, x, y, nrClasses)
                        # loss, probs, f1, intermediate = self.model.train(x, y)
                    elif type == "validate":
                        loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor)
                    elif type == "test":
                        loss, probs, f1 = self.model.predict(x, y, weight_factor=weight_factor)
                    network_time += time.time() - start_time_network

                    start_time_metrics = time.time()

                    if HP.CALC_F1:
                        if HP.EXPERIMENT_TYPE == "peak_regression":
                            #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines
                            # y_flat = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1]))  # (bs*x*y, nr_of_classes)
                            # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]})

                            #Numpy
                            # y_right_order = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order)
                            # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean()

                            #Pytorch
                            peak_f1_mean = np.array([s for s in list(f1.values())]).mean()  #if f1 for multiple bundles
                            metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean, type=type, threshold=HP.THRESHOLD)

                            #Pytorch 2 F1
                            # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean()
                            # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean()
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"LenF1": peak_f1_mean_b})

                            #Single Bundle
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]})
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD)
                        else:
                            metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD)

                    else:
                        metrics = MetricUtils.calculate_metrics_onlyLoss(metrics, loss, type=type)

                    metrics_time += time.time() - start_time_metrics

                    print_loss.append(loss)
                    if batch_nr[type] % HP.PRINT_FREQ == 0:
                        time_batch_part = time.time() - start_time_batch_part
                        start_time_batch_part = time.time()
                        ExpUtils.print_and_save(HP, "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s".format(type, epoch_nr,
                                                                batch_nr[type] * HP.BATCH_SIZE,
                                                                round(np.array(print_loss).mean(), 6), round(time_batch_part, 3),
                                                                round(time_batch_part / HP.PRINT_FREQ, 3)))
                        print_loss = []

                    if HP.USE_VISLOGGER:
                        ExpUtils.plot_result_trixi(trixi, x, y, probs, loss, f1, epoch_nr)


            ###################################
            # Post Training tasks (each epoch)
            ###################################

            #Adapt LR
            if HP.LR_SCHEDULE:
                self.model.scheduler.step()
                # self.model.scheduler.step(np.mean(f1))
                self.model.print_current_lr()

            # Average loss per batch over entire epoch
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["train"], type="train")
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["validate"], type="validate")
            metrics = MetricUtils.normalize_last_element(metrics, batch_nr["test"], type="test")

            print("  Epoch {}, Average Epoch loss = {}".format(epoch_nr, metrics["loss_train"][-1]))
            print("  Epoch {}, nr_of_updates {}".format(epoch_nr, nr_of_updates))

            # Save Weights
            start_time_saving = time.time()
            if HP.SAVE_WEIGHTS:
                self.model.save_model(metrics, epoch_nr)
            saving_time += time.time() - start_time_saving

            # Create Plots
            start_time_plotting = time.time()
            pickle.dump(metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")) # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb"))
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME)
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME, without_first_epochs=True)
            plotting_time += time.time() - start_time_plotting

            epoch_time = time.time() - start_time
            epoch_times.append(epoch_time)

            ExpUtils.print_and_save(HP, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time UNet: {}s".format(epoch_nr, network_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time))
            ExpUtils.print_and_save(HP, "  Epoch {}, time saving files: {}s".format(epoch_nr, saving_time))
            ExpUtils.print_and_save(HP, str(datetime.datetime.now()))

            # Adding next Epoch
            if epoch_nr < HP.NUM_EPOCHS-1:
                metrics = MetricUtils.add_empty_element(metrics)


        ####################################
        # After all epochs
        ###################################
        with open(join(HP.EXP_PATH, "Hyperparameters.txt"), "a") as f:  # a for append
            f.write("\n\n")
            f.write("Average Epoch time: {}s".format(sum(epoch_times) / float(len(epoch_times))))

        return metrics
예제 #3
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(
            random.uniform(0, len(subjects))
        )  # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if self.HP.FEATURES_FILENAME == "12g90g270g":
                    # if np.random.random() < 0.5:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    # else:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        data = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "12g_125mm_peaks.nii.gz")).get_data()
                elif self.HP.FEATURES_FILENAME == "T1_Peaks270g":
                    peaks = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx],
                             "270g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g":
                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        peaks = nib.load(
                            join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                                 subjects[subject_idx],
                                 "12g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                else:
                    data = nib.load(
                        join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                             subjects[subject_idx], self.HP.FEATURES_FILENAME +
                             ".nii.gz")).get_data()

                seg = nib.load(
                    join(C.DATA_PATH, self.HP.DATASET_FOLDER,
                         subjects[subject_idx],
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(
                    self.HP,
                    "\n\nWARNING: Could not load file. Trying again in 20s (Try number: "
                    + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)  # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(
            data, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, channels)

        if self.HP.LABELS_FILENAME not in [
                "bundle_peaks_11_808080", "bundle_peaks_20_808080",
                "bundle_peaks_808080", "bundle_masks_20_808080",
                "bundle_masks_72_808080"
        ]:
            if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
                seg = DatasetUtils.scale_input_to_unet_shape(
                    seg, "HCP", self.HP.RESOLUTION)
            else:
                seg = DatasetUtils.scale_input_to_unet_shape(
                    seg, self.HP.DATASET,
                    self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False,
                                      None)

        # Randomly sample slice orientation
        if self.HP.TRAINING_SLICE_DIRECTION == "xyz":
            slice_direction = int(round(random.uniform(0, 2)))
        else:
            slice_direction = 1  #always use Y

        if slice_direction == 0:
            x = data[slice_idxs, :, :].astype(
                np.float32)  # (batch_size, y, z, channels)
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(
                0, 3, 1, 2
            )  # depth-channel has to be before width and height for Unet (but after batches)
            y = np.array(y).transpose(
                0, 3, 1, 2
            )  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            x = data[:, slice_idxs, :].astype(
                np.float32)  # (x, batch_size, z, channels)
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(1, 3, 0, 2)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            x = data[:, :, slice_idxs].astype(
                np.float32)  # (x, y, batch_size, channels)
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(2, 3, 0, 1)
            y = np.array(y).transpose(2, 3, 0, 1)

        data_dict = {
            "data": x,  # (batch_size, channels, x, y, [z])
            "seg": y
        }  # (batch_size, channels, x, y, [z])
        return data_dict
예제 #4
0
 def print_current_lr():
     for param_group in optimizer.param_groups:
         ExpUtils.print_and_save(
             self.HP,
             "current learning rate: {}".format(param_group['lr']))
예제 #5
0
    def create_network(self):

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None    #faster
            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
예제 #6
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()

            outputs, intermediate = net(X)  # forward     # outputs: (bs, classes, x, y)

            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            else:
                probs = None    #faster

            return loss.data[0], probs, f1, intermediate

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        # net = nn.DataParallel(net, device_ids=[0,1])

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)  #very slow (half speed of Adamax) -> strange
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        #plot feature weights
        # weights = list(list(net.children())[0].children())[0].weight.cpu().data.numpy()   # sequential -> conv2d   # (64, 9, 3, 3)
        # weights = weights[:, 0:1, :, :]  # select one input channel to plot       # (64, 1, 3, 3)
        # weights = (weights*100).astype(np.uint8) # can not plot negative values (and if float only 0-1 allowed) -> not good: we remove negatives
        # plot_kernels(weights)

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
        # self.scheduler = scheduler
예제 #7
0
 def print_current_lr():
     for param_group in optimizer.param_groups:
         ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))
예제 #8
0
    def create_network(self):
        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  #faster
            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
예제 #9
0
    def train(self, HP):

        if HP.USE_VISLOGGER:
            nvl = Nvl(name="Training")

        ExpUtils.print_and_save(HP, socket.gethostname())

        epoch_times = []
        nr_of_updates = 0

        metrics = {}
        for type in ["train", "test", "validate"]:
            metrics_new = {
                "loss_" + type: [0],
                "f1_macro_" + type: [0],
            }
            metrics = dict(list(metrics.items()) + list(metrics_new.items()))

        for epoch_nr in range(HP.NUM_EPOCHS):
            start_time = time.time()
            # current_lr = HP.LEARNING_RATE * (HP.LR_DECAY ** epoch_nr)
            # current_lr = HP.LEARNING_RATE

            batch_gen_time = 0
            data_preparation_time = 0
            network_time = 0
            metrics_time = 0
            saving_time = 0
            plotting_time = 0

            batch_nr = {"train": 0, "test": 0, "validate": 0}

            if HP.LOSS_WEIGHT_LEN == -1:
                weight_factor = float(HP.LOSS_WEIGHT)
            else:
                if epoch_nr < HP.LOSS_WEIGHT_LEN:
                    # weight_factor = -(9./100.) * epoch_nr + 10.   #ep0: 10 -> linear decrease -> ep100: 1
                    weight_factor = -((HP.LOSS_WEIGHT - 1) / float(
                        HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                    # weight_factor = -((HP.LOSS_WEIGHT-5)/float(HP.LOSS_WEIGHT_LEN)) * epoch_nr + float(HP.LOSS_WEIGHT)
                else:
                    weight_factor = 1.
                    # weight_factor = 5.

            for type in ["train", "test", "validate"]:
                print_loss = []
                start_time_batch_gen = time.time()

                batch_generator = self.dataManager.get_batches(
                    batch_size=HP.BATCH_SIZE,
                    type=type,
                    subjects=getattr(HP,
                                     type.upper() + "_SUBJECTS"))
                batch_gen_time = time.time() - start_time_batch_gen
                # print("batch_gen_time: {}s".format(batch_gen_time))

                print("Start looping batches...")
                start_time_batch_part = time.time()
                for batch in batch_generator:  #getting next batch takes around 0.14s -> second largest Time part after UNet!

                    start_time_data_preparation = time.time()
                    batch_nr[type] += 1

                    x = batch["data"]  # (bs, nr_of_channels, x, y)
                    y = batch["seg"]  # (bs, nr_of_classes, x, y)
                    # since using new BatchGenerator y is not int anymore but float -> would be good for Pytorch but not Lasagne
                    # y = y.astype(HP.LABELS_TYPE)  #for bundle_peaks regression: is already float -> saves 0.2s/batch if left out

                    data_preparation_time += time.time(
                    ) - start_time_data_preparation
                    # self.model.learning_rate.set_value(np.float32(current_lr))
                    start_time_network = time.time()
                    if type == "train":
                        nr_of_updates += 1
                        loss, probs, f1 = self.model.train(
                            x, y, weight_factor=weight_factor
                        )  # probs: # (bs, x, y, nrClasses)
                        # loss, probs, f1, intermediate = self.model.train(x, y)
                    elif type == "validate":
                        loss, probs, f1 = self.model.predict(
                            x, y, weight_factor=weight_factor)
                    elif type == "test":
                        loss, probs, f1 = self.model.predict(
                            x, y, weight_factor=weight_factor)
                    network_time += time.time() - start_time_network

                    start_time_metrics = time.time()

                    if HP.CALC_F1:
                        if HP.LABELS_TYPE == np.int16:
                            metrics = MetricUtils.calculate_metrics(
                                metrics,
                                None,
                                None,
                                loss,
                                f1=np.mean(f1),
                                type=type,
                                threshold=HP.THRESHOLD)

                        else:  #Regression
                            #Following two lines increase metrics_time by 30s (without < 1s); time per batch increases by 1.5s by these lines
                            # y_flat = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # y_flat = np.reshape(y_flat, (-1, y_flat.shape[-1]))  # (bs*x*y, nr_of_classes)
                            # metrics = MetricUtils.calculate_metrics(metrics, y_flat, probs, loss, f1=np.mean(f1), type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"CA": f1[5], "FX_left": f1[23], "FX_right": f1[24]})

                            #Numpy
                            # y_right_order = y.transpose(0, 2, 3, 1)  # (bs, x, y, nr_of_classes)
                            # peak_f1 = MetricUtils.calc_peak_dice(HP, probs, y_right_order)
                            # peak_f1_mean = np.array([s for s in peak_f1.values()]).mean()

                            #Pytorch
                            peak_f1_mean = np.array([
                                s for s in list(f1.values())
                            ]).mean()  #if f1 for multiple bundles
                            metrics = MetricUtils.calculate_metrics(
                                metrics,
                                None,
                                None,
                                loss,
                                f1=peak_f1_mean,
                                type=type,
                                threshold=HP.THRESHOLD)

                            #Pytorch 2 F1
                            # peak_f1_mean_a = np.array([s for s in f1[0].values()]).mean()
                            # peak_f1_mean_b = np.array([s for s in f1[1].values()]).mean()
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=peak_f1_mean_a, type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"LenF1": peak_f1_mean_b})

                            #Single Bundle
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"][0], type=type, threshold=HP.THRESHOLD,
                            #                                         f1_per_bundle={"Thr1": f1["CST_right"][1], "Thr2": f1["CST_right"][2]})
                            # metrics = MetricUtils.calculate_metrics(metrics, None, None, loss, f1=f1["CST_right"], type=type, threshold=HP.THRESHOLD)
                    else:
                        metrics = MetricUtils.calculate_metrics_onlyLoss(
                            metrics, loss, type=type)

                    metrics_time += time.time() - start_time_metrics

                    print_loss.append(loss)
                    if batch_nr[type] % HP.PRINT_FREQ == 0:
                        time_batch_part = time.time() - start_time_batch_part
                        start_time_batch_part = time.time()
                        ExpUtils.print_and_save(
                            HP,
                            "{} Ep {}, Sp {}, loss {}, t print {}s, t batch {}s"
                            .format(type, epoch_nr,
                                    batch_nr[type] * HP.BATCH_SIZE,
                                    round(np.array(print_loss).mean(), 6),
                                    round(time_batch_part, 3),
                                    round(time_batch_part / HP.PRINT_FREQ, 3)))
                        print_loss = []

                    if HP.USE_VISLOGGER:
                        x_norm = (x - x.min()) / (x.max() - x.min())
                        nvl.show_images(
                            x_norm[0:1, :, :, :].transpose((1, 0, 2, 3)),
                            name="input batch",
                            title="Input batch")  #all channels of one batch
                        probs_shaped = probs[:, :, :, 15:16].transpose(
                            (0, 3, 1, 2))  # (bs, 1, x, y)
                        probs_shaped_bin = (probs_shaped > 0.5).astype(
                            np.int16)
                        nvl.show_images(probs_shaped,
                                        name="predictions",
                                        title="Predictions Probmap")
                        # nvl.show_images(probs_shaped_bin, name="predictions_binary", title="Predictions Binary")

                        # Show GT and Prediction in one image  (bundle: CST)
                        # GREEN: GT; RED: prediction (FP); YELLOW: prediction (TP)
                        combined = np.zeros(
                            (y.shape[0], 3, y.shape[2], y.shape[3]))
                        combined[:, 0:1, :, :] = probs_shaped_bin  #Red
                        combined[:, 1:2, :, :] = y[:, 15:16, :, :]  #Green
                        nvl.show_images(combined,
                                        name="predictions_combined",
                                        title="Combined")

                        #Show feature activations
                        contr_1_2 = intermediate[2].data.cpu().numpy(
                        )  # (bs, nr_feature_channels=64, x, y)
                        contr_1_2 = contr_1_2[0:1, :, :, :].transpose(
                            (1, 0, 2, 3))  # (nr_feature_channels=64, 1, x, y)
                        contr_1_2 = (contr_1_2 - contr_1_2.min()) / (
                            contr_1_2.max() - contr_1_2.min())
                        nvl.show_images(contr_1_2,
                                        name="contr_1_2",
                                        title="contr_1_2")

                        # Show feature activations
                        contr_3_2 = intermediate[1].data.cpu().numpy(
                        )  # (bs, nr_feature_channels=64, x, y)
                        contr_3_2 = contr_3_2[0:1, :, :, :].transpose(
                            (1, 0, 2, 3))  # (nr_feature_channels=64, 1, x, y)
                        contr_3_2 = (contr_3_2 - contr_3_2.min()) / (
                            contr_3_2.max() - contr_3_2.min())
                        nvl.show_images(contr_3_2,
                                        name="contr_3_2",
                                        title="contr_3_2")

                        # Show feature activations
                        deconv_2 = intermediate[0].data.cpu().numpy(
                        )  # (bs, nr_feature_channels=64, x, y)
                        deconv_2 = deconv_2[0:1, :, :, :].transpose(
                            (1, 0, 2, 3))  # (nr_feature_channels=64, 1, x, y)
                        deconv_2 = (deconv_2 - deconv_2.min()) / (
                            deconv_2.max() - deconv_2.min())
                        nvl.show_images(deconv_2,
                                        name="deconv_2",
                                        title="deconv_2")

                        nvl.show_value(float(loss), name="loss")
                        nvl.show_value(float(np.mean(f1)), name="f1")

            ###################################
            # Post Training tasks (each epoch)
            ###################################

            #Adapt LR
            # self.model.scheduler.step()
            # self.model.scheduler.step(np.mean(f1))
            # self.model.print_current_lr()

            # Average loss per batch over entire epoch
            metrics = MetricUtils.normalize_last_element(metrics,
                                                         batch_nr["train"],
                                                         type="train")
            metrics = MetricUtils.normalize_last_element(metrics,
                                                         batch_nr["validate"],
                                                         type="validate")
            metrics = MetricUtils.normalize_last_element(metrics,
                                                         batch_nr["test"],
                                                         type="test")

            print("  Epoch {}, Average Epoch loss = {}".format(
                epoch_nr, metrics["loss_train"][-1]))
            print("  Epoch {}, nr_of_updates {}".format(
                epoch_nr, nr_of_updates))

            # Save Weights
            start_time_saving = time.time()
            if HP.SAVE_WEIGHTS:
                self.model.save_model(metrics, epoch_nr)
            saving_time += time.time() - start_time_saving

            # Create Plots
            start_time_plotting = time.time()
            pickle.dump(
                metrics, open(join(HP.EXP_PATH, "metrics.pkl"), "wb")
            )  # wb -> write (override) and binary (binary only needed on windows, on unix also works without) # for loading: pickle.load(open("metrics.pkl", "rb"))
            ExpUtils.create_exp_plot(metrics, HP.EXP_PATH, HP.EXP_NAME)
            ExpUtils.create_exp_plot(metrics,
                                     HP.EXP_PATH,
                                     HP.EXP_NAME,
                                     without_first_epochs=True)
            plotting_time += time.time() - start_time_plotting

            epoch_time = time.time() - start_time
            epoch_times.append(epoch_time)

            ExpUtils.print_and_save(
                HP, "  Epoch {}, time total {}s".format(epoch_nr, epoch_time))
            ExpUtils.print_and_save(
                HP,
                "  Epoch {}, time UNet: {}s".format(epoch_nr, network_time))
            ExpUtils.print_and_save(
                HP,
                "  Epoch {}, time metrics: {}s".format(epoch_nr, metrics_time))
            ExpUtils.print_and_save(
                HP, "  Epoch {}, time saving files: {}s".format(
                    epoch_nr, saving_time))
            ExpUtils.print_and_save(HP, str(datetime.datetime.now()))

            # Adding next Epoch
            if epoch_nr < HP.NUM_EPOCHS - 1:
                metrics = MetricUtils.add_empty_element(metrics)

        ####################################
        # After all epochs
        ###################################
        with open(join(HP.EXP_PATH, "Hyperparameters.txt"),
                  "a") as f:  # a for append
            f.write("\n\n")
            f.write("Average Epoch time: {}s".format(
                sum(epoch_times) / float(len(epoch_times))))

        return metrics
예제 #10
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if np.random.random() < 0.5:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                else:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                # rnd_choice = np.random.random()
                # if rnd_choice < 0.33:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                # elif rnd_choice < 0.66:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                # else:
                #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()

                seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)    # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION)    # (x, y, z, channels)
        if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
            # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
            seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION)
        else:
            seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None)

        # Randomly sample slice orientation
        slice_direction = int(round(random.uniform(0,2)))

        if slice_direction == 0:
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(0, 3, 1, 2)  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            y = np.array(y).transpose(2, 3, 0, 1)


        sw = 5 #slice_window (only odd numbers allowed)
        pad = int((sw-1) / 2)

        data_pad = np.zeros((data.shape[0]+sw-1, data.shape[1]+sw-1, data.shape[2]+sw-1, data.shape[3])).astype(data.dtype)
        data_pad[pad:-pad, pad:-pad, pad:-pad, :] = data   #padded with two slices of zeros on all sides
        batch=[]
        for s_idx in slice_idxs:
            if slice_direction == 0:
                #(s_idx+2)-2:(s_idx+2)+3 = s_idx:s_idx+5
                x = data_pad[s_idx:s_idx+sw:, pad:-pad, pad:-pad, :].astype(np.float32)      # (5, y, z, channels)
                x = np.array(x).transpose(0, 3, 1, 2)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 1:
                x = data_pad[pad:-pad, s_idx:s_idx+sw, pad:-pad, :].astype(np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(1, 3, 0, 2)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
            elif slice_direction == 2:
                x = data_pad[pad:-pad, pad:-pad, s_idx:s_idx+sw, :].astype(np.float32)  # (5, y, z, channels)
                x = np.array(x).transpose(2, 3, 0, 1)  # channels dim has to be before width and height for Unet (but after batches)
                x = np.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))  # (5*channels, y, z)
                batch.append(x)
        data_dict = {"data": np.array(batch),     # (batch_size, channels, x, y, [z])
                     "seg": y}                    # (batch_size, channels, x, y, [z])

        return data_dict
예제 #11
0
    def generate_train_batch(self):
        subjects = self._data[0]
        subject_idx = int(random.uniform(0, len(subjects)))     # len(subjects)-1 not needed because int always rounds to floor

        for i in range(20):
            try:
                if self.HP.FEATURES_FILENAME == "12g90g270g":
                    # if np.random.random() < 0.5:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    # else:
                    #     data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()

                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()
                elif self.HP.FEATURES_FILENAME == "T1_Peaks270g":
                    peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                elif self.HP.FEATURES_FILENAME == "T1_Peaks12g90g270g":
                    rnd_choice = np.random.random()
                    if rnd_choice < 0.33:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "270g_125mm_peaks.nii.gz")).get_data()
                    elif rnd_choice < 0.66:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "90g_125mm_peaks.nii.gz")).get_data()
                    else:
                        peaks = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "12g_125mm_peaks.nii.gz")).get_data()
                    t1 = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], "T1.nii.gz")).get_data()
                    data = np.concatenate((peaks, t1), axis=3)
                else:
                    data = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.FEATURES_FILENAME + ".nii.gz")).get_data()

                seg = nib.load(join(C.DATA_PATH, self.HP.DATASET_FOLDER, subjects[subject_idx], self.HP.LABELS_FILENAME + ".nii.gz")).get_data()
                break
            except IOError:
                ExpUtils.print_and_save(self.HP, "\n\nWARNING: Could not load file. Trying again in 20s (Try number: " + str(i) + ").\n\n")
            ExpUtils.print_and_save(self.HP, "Sleeping 20s")
            sleep(20)
        # ExpUtils.print_and_save(self.HP, "Successfully loaded input.")

        data = np.nan_to_num(data)    # Needed otherwise not working
        seg = np.nan_to_num(seg)

        data = DatasetUtils.scale_input_to_unet_shape(data, self.HP.DATASET, self.HP.RESOLUTION)    # (x, y, z, channels)

        if self.HP.LABELS_FILENAME not in ["bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080",
                                           "bundle_masks_20_808080", "bundle_masks_72_808080"]:
            if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask to the lower resolution
                seg = DatasetUtils.scale_input_to_unet_shape(seg, "HCP", self.HP.RESOLUTION)
            else:
                seg = DatasetUtils.scale_input_to_unet_shape(seg, self.HP.DATASET, self.HP.RESOLUTION)  # (x, y, z, classes)

        slice_idxs = np.random.choice(data.shape[0], self.BATCH_SIZE, False, None)

        # Randomly sample slice orientation
        if self.HP.TRAINING_SLICE_DIRECTION == "xyz":
            slice_direction = int(round(random.uniform(0,2)))
        else:
            slice_direction = 1 #always use Y

        if slice_direction == 0:
            x = data[slice_idxs, :, :].astype(np.float32)      # (batch_size, y, z, channels)
            y = seg[slice_idxs, :, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(0, 3, 1, 2)  # depth-channel has to be before width and height for Unet (but after batches)
            y = np.array(y).transpose(0, 3, 1, 2)  # nr_classes channel has to be before with and height for DataAugmentation (bs, nr_of_classes, x, y)
        elif slice_direction == 1:
            x = data[:, slice_idxs, :].astype(np.float32)      # (x, batch_size, z, channels)
            y = seg[:, slice_idxs, :].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(1, 3, 0, 2)
            y = np.array(y).transpose(1, 3, 0, 2)
        elif slice_direction == 2:
            x = data[:, :, slice_idxs].astype(np.float32)      # (x, y, batch_size, channels)
            y = seg[:, :, slice_idxs].astype(self.HP.LABELS_TYPE)
            x = np.array(x).transpose(2, 3, 0, 1)
            y = np.array(y).transpose(2, 3, 0, 1)

        data_dict = {"data": x,     # (batch_size, channels, x, y, [z])
                     "seg": y}      # (batch_size, channels, x, y, [z])
        return data_dict
예제 #12
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                probs = None  #faster

            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(
                    self.HP,
                    "current learning rate: {}".format(param_group['lr']))

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)

        # net = nn.DataParallel(net, device_ids=[0,1])

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
        # weights[:, 5, :, :] *= 10     #CA
        # weights[:, 21, :, :] *= 10    #FX_left
        # weights[:, 22, :, :] *= 10    #FX_right
        # criterion = nn.BCEWithLogitsLoss(weight=weights)
        criterion = nn.BCEWithLogitsLoss()

        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)  #very slow (half speed of Adamax) -> strange
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr