Exemple #1
0
 def load_model(path):
     ExpUtils.print_verbose(self.HP,
                            "Loading weights ... ({})".format(path))
     with np.load(
             path
     ) as f:  #if both pathes are absolute and beginning of pathes are the same, join will merge the beginning
         param_values = [f['arr_%d' % i] for i in range(len(f.files))]
     L.layers.set_all_param_values(output_layer_for_loss, param_values)
Exemple #2
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep126.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v2.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DAugAll", "best_weights_ep16.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=0.3)  # set lower for more sensitivity
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
    def create_network(self):

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None    #faster
            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
Exemple #4
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                probs = None  #faster

            return loss.data[0], probs, f1

        def test(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(
                    self.HP,
                    "current learning rate: {}".format(param_group['lr']))

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            self.HP.NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            self.HP.NR_OF_GRADIENTS = 33

        if self.HP.LOSS_FUNCTION == "soft_sample_dice":
            criterion = PytorchUtils.soft_sample_dice
            final_activation = "sigmoid"
        elif self.HP.LOSS_FUNCTION == "soft_batch_dice":
            criterion = PytorchUtils.soft_batch_dice
            final_activation = "sigmoid"
        else:
            # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            # weights[:, 5, :, :] *= 10     #CA
            # weights[:, 21, :, :] *= 10    #FX_left
            # weights[:, 22, :, :] *= 10    #FX_right
            # criterion = nn.BCEWithLogitsLoss(weight=weights)
            criterion = nn.BCEWithLogitsLoss()
            final_activation = None

        net = UNet(n_input_channels=NR_OF_GRADIENTS,
                   n_classes=self.HP.NR_OF_CLASSES,
                   n_filt=self.HP.UNET_NR_FILT,
                   batchnorm=self.HP.BATCH_NORM,
                   final_activation=final_activation)

        if torch.cuda.is_available():
            net = net.cuda()
        # else:
        #     net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT,
        #                batchnorm=self.HP.BATCH_NORM)

        # net = nn.DataParallel(net, device_ids=[0,1])

        # if self.HP.TRAIN:
        #     ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        if self.HP.OPTIMIZER == "Adamax":
            optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        elif self.HP.OPTIMIZER == "Adam":
            #todo important: change
            # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)
            optimizer = Adam(net.parameters(),
                             lr=self.HP.LEARNING_RATE,
                             weight_decay=self.HP.WEIGHT_DECAY)
        else:
            raise ValueError("Optimizer not defined")
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()

            outputs, intermediate = net(X)  # forward     # outputs: (bs, classes, x, y)

            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            else:
                probs = None    #faster

            return loss.data[0], probs, f1, intermediate

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        # net = nn.DataParallel(net, device_ids=[0,1])

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)  #very slow (half speed of Adamax) -> strange
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        #plot feature weights
        # weights = list(list(net.children())[0].children())[0].weight.cpu().data.numpy()   # sequential -> conv2d   # (64, 9, 3, 3)
        # weights = weights[:, 0:1, :, :]  # select one input channel to plot       # (64, 1, 3, 3)
        # weights = (weights*100).astype(np.uint8) # can not plot negative values (and if float only 0-1 allowed) -> not good: we remove negatives
        # plot_kernels(weights)

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
        # self.scheduler = scheduler
Exemple #6
0
def run_tractseg(data, output_type="tract_segmentation", input_type="peaks",
                 single_orientation=False, verbose=False, dropout_sampling=False, threshold=0.5,
                 bundle_specific_threshold=False, get_probs=False, peak_threshold=0.1):
    '''
    Run TractSeg

    :param data: input peaks (4D numpy array with shape [x,y,z,9])
    :param output_type: "tract_segmentation" | "endings_segmentation" | "TOM" | "dm_regression"
    :param input_type: "peaks"
    :param verbose: show debugging infos
    :param dropout_sampling: create uncertainty map by monte carlo dropout (https://arxiv.org/abs/1506.02142)
    :param threshold: Threshold for converting probability map to binary map
    :param bundle_specific_threshold: Threshold is lower for some bundles which need more sensitivity (CA, CST, FX)
    :param get_probs: Output raw probability map instead of binary map
    :param peak_threshold: all peaks shorter than peak_threshold will be set to zero
    :return: 4D numpy array with the output of tractseg
        for tract_segmentation:     [x,y,z,nr_of_bundles]
        for endings_segmentation:   [x,y,z,2*nr_of_bundles]
        for TOM:                    [x,y,z,3*nr_of_bundles]
    '''
    start_time = time.time()

    config = get_config_name(input_type, output_type)
    HP = getattr(importlib.import_module("tractseg.config.PretrainedModels." + config), "HP")()
    HP.VERBOSE = verbose
    HP.TRAIN = False
    HP.TEST = False
    HP.SEGMENT = False
    HP.GET_PROBS = get_probs
    HP.LOAD_WEIGHTS = True
    HP.DROPOUT_SAMPLING = dropout_sampling
    HP.THRESHOLD = threshold

    if bundle_specific_threshold:
        HP.GET_PROBS = True

    if input_type == "peaks":
        if HP.EXPERIMENT_TYPE == "tract_segmentation" and HP.DROPOUT_SAMPLING:
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_dropout_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DAugAll_Dropout", "best_weights_ep114.npz")
        elif HP.EXPERIMENT_TYPE == "tract_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/TractSeg_T1_12g90g270g_125mm_DAugAll", "best_weights_ep392.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888", "best_weights_ep247.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg72_888_SchizoFineT_lr001", "best_weights_ep186.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "TractSeg_12g90g270g_125mm_DS_DAugAll_RotMir", "best_weights_ep200.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v3.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "EndingsSeg_12g90g270g_125mm_DS_DAugAll", "best_weights_ep234.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "x_Pretrained_TractSeg_Models/Peaks20_12g90g270g_125mm_DAugSimp_constW5", "best_weights_ep441.npz")  #more oversegmentation with DAug
        elif HP.EXPERIMENT_TYPE == "dm_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_dm_regression_v1.npz")
            # HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes", "DmReg_12g90g270g_125mm_DAugAll_Ubuntu", "best_weights_ep80.npz")
    elif input_type == "T1":
        if HP.EXPERIMENT_TYPE == "tract_segmentation":
            # HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_tract_segmentation_v1.npz")
            HP.WEIGHTS_PATH = join(C.NETWORK_DRIVE, "hcp_exp_nodes/x_Pretrained_TractSeg_Models", "TractSeg_T1_125mm_DAugAll", "best_weights_ep142.npz")
        elif HP.EXPERIMENT_TYPE == "endings_segmentation":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_endings_segmentation_v1.npz")
        elif HP.EXPERIMENT_TYPE == "peak_regression":
            HP.WEIGHTS_PATH = join(C.TRACT_SEG_HOME, "pretrained_weights_peak_regression_v1.npz")
    print("Loading weights from: {}".format(HP.WEIGHTS_PATH))

    if HP.EXPERIMENT_TYPE == "peak_regression":
        HP.NR_OF_CLASSES = 3*len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])
    else:
        HP.NR_OF_CLASSES = len(ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    if HP.VERBOSE:
        print("Hyperparameters:")
        ExpUtils.print_HPs(HP)

    Utils.download_pretrained_weights(experiment_type=HP.EXPERIMENT_TYPE, dropout_sampling=HP.DROPOUT_SAMPLING)

    data = np.nan_to_num(data)
    # brain_mask = ImgUtils.simple_brain_mask(data)
    # if HP.VERBOSE:
    #     nib.save(nib.Nifti1Image(brain_mask, np.eye(4)), "otsu_brain_mask_DEBUG.nii.gz")

    if input_type == "T1":
        data = np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], 1))
    data, seg_None, bbox, original_shape = DatasetUtils.crop_to_nonzero(data)
    data, transformation = DatasetUtils.pad_and_scale_img_to_square_img(data, target_size=HP.INPUT_DIM[0])

    model = BaseModel(HP)

    if HP.EXPERIMENT_TYPE == "tract_segmentation" or HP.EXPERIMENT_TYPE == "endings_segmentation" or HP.EXPERIMENT_TYPE == "dm_regression":
        if single_orientation:     # mainly needed for testing because of less RAM requirements
            dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
            trainerSingle = Trainer(model, dataManagerSingle)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
            else:
                seg, img_y = trainerSingle.get_seg_single_img(HP, probs=False, scale_to_world_shape=False, only_prediction=True)
        else:
            seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
            if HP.DROPOUT_SAMPLING or HP.EXPERIMENT_TYPE == "dm_regression" or HP.GET_PROBS:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)
            else:
                seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=False)

    elif HP.EXPERIMENT_TYPE == "peak_regression":
        dataManagerSingle = DataManagerSingleSubjectByFile(HP, data=data)
        trainerSingle = Trainer(model, dataManagerSingle)
        seg, img_y = trainerSingle.get_seg_single_img(HP, probs=True, scale_to_world_shape=False, only_prediction=True)
        if bundle_specific_threshold:
            seg = ImgUtils.remove_small_peaks_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:], len_thr=0.3)
        else:
            seg = ImgUtils.remove_small_peaks(seg, len_thr=peak_threshold)
        #3 dir for Peaks -> not working (?)
        # seg_xyz, gt = DirectionMerger.get_seg_single_img_3_directions(HP, model, data=data, scale_to_world_shape=False, only_prediction=True)
        # seg = DirectionMerger.mean_fusion(HP.THRESHOLD, seg_xyz, probs=True)

    if bundle_specific_threshold and HP.EXPERIMENT_TYPE == "tract_segmentation":
        seg = ImgUtils.probs_to_binary_bundle_specific(seg, ExpUtils.get_bundle_names(HP.CLASSES)[1:])

    #remove following two lines to keep super resolution
    seg = DatasetUtils.cut_and_scale_img_back_to_original_img(seg, transformation)
    seg = DatasetUtils.add_original_zero_padding_again(seg, bbox, original_shape, HP.NR_OF_CLASSES)
    ExpUtils.print_verbose(HP, "Took {}s".format(round(time.time() - start_time, 2)))

    return seg
Exemple #7
0
    def create_network(self):
        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  #faster
            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)
            # net = UNet_Skip(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT)

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
Exemple #8
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y, weight_factor=10):
            X = torch.tensor(X, dtype=torch.float32).to(device)   # X: (bs, features, x, y)   y: (bs, classes, x, y)
            y = torch.tensor(y, dtype=torch.float32).to(device)

            optimizer.zero_grad()
            net.train()
            outputs, outputs_sigmoid = net(X)  # forward     # outputs: (bs, classes, x, y)

            if weight_factor > 1:
                # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
                weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, y.shape[2], y.shape[3])).cuda()
                bundle_mask = y > 0
                weights[bundle_mask.data] *= weight_factor  # 10
                if self.HP.EXPERIMENT_TYPE == "peak_regression":
                    loss = criterion(outputs, y, weights)
                else:
                    loss = nn.BCEWithLogitsLoss(weight=weights)(outputs, y)
            else:
                if self.HP.LOSS_FUNCTION == "soft_sample_dice" or self.HP.LOSS_FUNCTION == "soft_batch_dice":
                    loss = criterion(outputs_sigmoid, y)
                    # loss = criterion(outputs_sigmoid, y) + nn.BCEWithLogitsLoss()(outputs, y)
                else:
                    loss = criterion(outputs, y)

            loss.backward()  # backward
            optimizer.step()  # optimise

            if self.HP.EXPERIMENT_TYPE == "peak_regression":
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(self.HP, outputs.detach(), y.detach(),
                                                               max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            elif self.HP.EXPERIMENT_TYPE == "dm_regression":   #density map regression
                f1 = PytorchUtils.f1_score_macro(y.detach()>0.5, outputs.detach(), per_class=True)
            else:
                f1 = PytorchUtils.f1_score_macro(y.detach(), outputs_sigmoid.detach(), per_class=True, threshold=self.HP.THRESHOLD)

            if self.HP.USE_VISLOGGER:
                # probs = outputs_sigmoid.detach().cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
                probs = outputs_sigmoid
            else:
                probs = None    #faster

            return loss.item(), probs, f1


        def test(X, y, weight_factor=10):
            with torch.no_grad():
                X = torch.tensor(X, dtype=torch.float32).to(device)
                y = torch.tensor(y, dtype=torch.float32).to(device)

            if self.HP.DROPOUT_SAMPLING:
                net.train()
            else:
                net.train(False)
            outputs, outputs_sigmoid = net(X)  # forward

            if weight_factor > 1:
                # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
                weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, y.shape[2], y.shape[3])).cuda()
                bundle_mask = y > 0
                weights[bundle_mask.data] *= weight_factor  # 10
                if self.HP.EXPERIMENT_TYPE == "peak_regression":
                    loss = criterion(outputs, y, weights)
                else:
                    loss = nn.BCEWithLogitsLoss(weight=weights)(outputs, y)
            else:
                if self.HP.LOSS_FUNCTION == "soft_sample_dice" or self.HP.LOSS_FUNCTION == "soft_batch_dice":
                    loss = criterion(outputs_sigmoid, y)
                    # loss = criterion(outputs_sigmoid, y) + nn.BCEWithLogitsLoss()(outputs, y)
                else:
                    loss = criterion(outputs, y)

            if self.HP.EXPERIMENT_TYPE == "peak_regression":
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(self.HP, outputs.detach(), y.detach(),
                                                               max_angle_error=self.HP.PEAK_DICE_THR, max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            elif self.HP.EXPERIMENT_TYPE == "dm_regression":   #density map regression
                f1 = PytorchUtils.f1_score_macro(y.detach()>0.5, outputs.detach(), per_class=True)
            else:
                f1 = PytorchUtils.f1_score_macro(y.detach(), outputs_sigmoid.detach(), per_class=True, threshold=self.HP.THRESHOLD)

            if self.HP.USE_VISLOGGER:
                # probs = outputs_sigmoid.detach().cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
                probs = outputs_sigmoid
            else:
                probs = None  # faster

            return loss.item(), probs, f1


        def predict(X):
            with torch.no_grad():
                X = torch.tensor(X, dtype=torch.float32).to(device)

            if self.HP.DROPOUT_SAMPLING:
                net.train()
            else:
                net.train(False)
            outputs, outputs_sigmoid = net(X)  # forward
            if self.HP.EXPERIMENT_TYPE == "peak_regression" or self.HP.EXPERIMENT_TYPE == "dm_regression":
                probs = outputs.detach().cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            else:
                probs = outputs_sigmoid.detach().cpu().numpy().transpose(0, 2, 3, 1)  # (bs, x, y, classes)
            return probs


        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
                except IOError:
                    print("\nERROR: Could not save weights because of IO Error\n")
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(self.HP, "current learning rate: {}".format(param_group['lr']))


        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            self.HP.NR_OF_GRADIENTS = 3*self.HP.NR_OF_CLASSES
        else:
            self.HP.NR_OF_GRADIENTS = 33

        if self.HP.LOSS_FUNCTION == "soft_sample_dice":
            criterion = PytorchUtils.soft_sample_dice
        elif self.HP.LOSS_FUNCTION == "soft_batch_dice":
            criterion = PytorchUtils.soft_batch_dice
        elif self.HP.EXPERIMENT_TYPE == "peak_regression":
            criterion = PytorchUtils.angle_length_loss
        else:
            # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            # weights[:, 5, :, :] *= 10     #CA
            # weights[:, 21, :, :] *= 10    #FX_left
            # weights[:, 22, :, :] *= 10    #FX_right
            # criterion = nn.BCEWithLogitsLoss(weight=weights)
            criterion = nn.BCEWithLogitsLoss()

        NetworkClass = getattr(importlib.import_module("tractseg.models." + self.HP.MODEL), self.HP.MODEL)
        net = NetworkClass(n_input_channels=NR_OF_GRADIENTS, n_classes=self.HP.NR_OF_CLASSES, n_filt=self.HP.UNET_NR_FILT,
                   batchnorm=self.HP.BATCH_NORM, dropout=self.HP.USE_DROPOUT)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        net = net.to(device)

        # if self.HP.TRAIN:
        #     ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        if self.HP.OPTIMIZER == "Adamax":
            optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        elif self.HP.OPTIMIZER == "Adam":
            optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)
            # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE, weight_decay=self.HP.WEIGHT_DECAY)
        else:
            raise ValueError("Optimizer not defined")

        if self.HP.LR_SCHEDULE:
            scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
            # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")
            self.scheduler = scheduler

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(self.HP, "Loading weights ... ({})".format(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        if self.HP.RESET_LAST_LAYER:
            # net.conv_5 = conv2d(self.HP.UNET_NR_FILT, self.HP.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(device)
            net.conv_5 = nn.Conv2d(self.HP.UNET_NR_FILT, self.HP.NR_OF_CLASSES, kernel_size=1, stride=1, padding=0, bias=True).to(device)

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
Exemple #9
0
 def __init__(self, HP, data):
     self.data = data
     self.HP = HP
     ExpUtils.print_verbose(self.HP, "Loading data from PREDICT_IMG input file")
Exemple #10
0
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)

            weights = torch.ones(
                (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES,
                 self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  #10

            loss = criterion(outputs, y, Variable(weights))
            # loss = criterion1(outputs, y, Variable(weights)) + criterion2(outputs, y, Variable(weights))

            loss.backward()  # backward
            optimizer.step()  # optimise

            if self.HP.CALC_F1:
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(
                    self.HP,
                    outputs.data,
                    y.data,
                    max_angle_error=self.HP.PEAK_DICE_THR,
                    max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            else:
                f1 = np.ones(outputs.shape[3])

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)  # (bs, x, y, classes)
                probs = None  #faster

            return loss.data[0], probs, f1

        def test(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward

            weights = torch.ones(
                (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES,
                 self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  #10

            loss = criterion(outputs, y, Variable(weights))
            # loss = criterion1(outputs, y, Variable(weights)) + criterion2(outputs, y, Variable(weights))

            if self.HP.CALC_F1:
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(
                    self.HP,
                    outputs.data,
                    y.data,
                    max_angle_error=self.HP.PEAK_DICE_THR,
                    max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            else:
                f1 = np.ones(outputs.shape[3])

            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

            #Saving Last Epoch:
            # print("  Saving weights last epoch...")
            # for fl in glob.glob(join(self.HP.EXP_PATH, "weights_ep*")):  # remove weights from previous epochs
            #     os.remove(fl)
            # try:
            #     # Actually is a pkl not a npz
            #     PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "weights_ep" + str(epoch_nr) + ".npz"), unet=net)
            # except IOError:
            #     print("\nERROR: Could not save weights because of IO Error\n")
            # self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(
                    self.HP,
                    "current learning rate: {}".format(param_group['lr']))

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = self.HP.NR_OF_GRADIENTS
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)

        # if self.HP.TRAIN:
        #     ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        # criterion1 = PytorchUtils.MSE_weighted
        # criterion2 = PytorchUtils.angle_loss

        # criterion = PytorchUtils.MSE_weighted
        # criterion = PytorchUtils.angle_loss
        criterion = PytorchUtils.angle_length_loss

        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr
Exemple #11
0
 def __init__(self, HP, data):
     self.data = data
     self.HP = HP
     ExpUtils.print_verbose(self.HP,
                            "Loading data from PREDICT_IMG input file")
    def create_network(self):
        # torch.backends.cudnn.benchmark = True     #not faster

        def train(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                probs = None  #faster

            return loss.data[0], probs, f1

        def test(X, y):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)
            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1

        def predict(X):
            X = torch.from_numpy(X.astype(np.float32))
            if torch.cuda.is_available():
                X = Variable(X.cuda(), volatile=True)
            else:
                X = Variable(X, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            probs = outputs.data.cpu().numpy().transpose(
                0, 2, 3, 1)  # (bs, x, y, classes)
            return probs

        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

        def load_model(path):
            PytorchUtils.load_checkpoint(path, unet=net)

        def print_current_lr():
            for param_group in optimizer.param_groups:
                ExpUtils.print_and_save(
                    self.HP,
                    "current learning rate: {}".format(param_group['lr']))

        if self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "single_direction":
            NR_OF_GRADIENTS = 9
            # NR_OF_GRADIENTS = 9 * 5
            # NR_OF_GRADIENTS = 9 * 9
            # NR_OF_GRADIENTS = 33
        elif self.HP.SEG_INPUT == "Peaks" and self.HP.TYPE == "combined":
            NR_OF_GRADIENTS = 3 * self.HP.NR_OF_CLASSES
        else:
            NR_OF_GRADIENTS = 33

        if torch.cuda.is_available():
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT).cuda()
        else:
            net = UNet(n_input_channels=NR_OF_GRADIENTS,
                       n_classes=self.HP.NR_OF_CLASSES,
                       n_filt=self.HP.UNET_NR_FILT)

        #Initialisation from U-Net Paper
        def weights_init(m):
            classname = m.__class__.__name__
            # Do not use with batchnorm -> has to be adapted for batchnorm
            if classname.find('Conv') != -1:
                N = m.in_channels * m.kernel_size[0] * m.kernel_size[0]
                std = math.sqrt(2. / N)
                m.weight.data.normal_(0.0, std)

        net.apply(weights_init)

        # net = nn.DataParallel(net, device_ids=[0,1])

        if self.HP.TRAIN:
            ExpUtils.print_and_save(self.HP, str(net), only_log=True)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = Adamax(net.parameters(), lr=self.HP.LEARNING_RATE)
        # optimizer = Adam(net.parameters(), lr=self.HP.LEARNING_RATE)  #very slow (half speed of Adamax) -> strange
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
        # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max")

        if self.HP.LOAD_WEIGHTS:
            ExpUtils.print_verbose(
                self.HP, "Loading weights ... ({})".format(
                    join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH)))
            load_model(join(self.HP.EXP_PATH, self.HP.WEIGHTS_PATH))

        self.train = train
        self.predict = test
        self.get_probs = predict
        self.save_model = save_model
        self.load_model = load_model
        self.print_current_lr = print_current_lr