コード例 #1
0
        def save_model(metrics, epoch_nr):
            max_f1_idx = np.argmax(metrics["f1_macro_validate"])
            max_f1 = np.max(metrics["f1_macro_validate"])
            if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
                print("  Saving weights...")
                for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")
                                    ):  # remove weights from previous epochs
                    os.remove(fl)
                try:
                    #Actually is a pkl not a npz
                    PytorchUtils.save_checkpoint(join(
                        self.HP.EXP_PATH,
                        "best_weights_ep" + str(epoch_nr) + ".npz"),
                                                 unet=net)
                except IOError:
                    print(
                        "\nERROR: Could not save weights because of IO Error\n"
                    )
                self.HP.BEST_EPOCH = epoch_nr

            #Saving Last Epoch:
            # print("  Saving weights last epoch...")
            # for fl in glob.glob(join(self.HP.EXP_PATH, "weights_ep*")):  # remove weights from previous epochs
            #     os.remove(fl)
            # try:
            #     # Actually is a pkl not a npz
            #     PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "weights_ep" + str(epoch_nr) + ".npz"), unet=net)
            # except IOError:
            #     print("\nERROR: Could not save weights because of IO Error\n")

            self.HP.BEST_EPOCH = epoch_nr
コード例 #2
0
        def train(X, y, weight_factor=10):
            X = torch.tensor(X, dtype=torch.float32).to(
                device)  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            y = torch.tensor(y, dtype=torch.float32).to(device)

            optimizer.zero_grad()
            net.train()
            outputs, outputs_sigmoid = net(
                X)  # forward     # outputs: (bs, classes, x, y)

            if weight_factor > 1:
                # weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
                weights = torch.ones(
                    (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, y.shape[2],
                     y.shape[3])).cuda()
                bundle_mask = y > 0
                weights[bundle_mask.data] *= weight_factor  # 10
                if self.HP.EXPERIMENT_TYPE == "peak_regression":
                    loss = criterion(outputs, y, weights)
                else:
                    loss = nn.BCEWithLogitsLoss(weight=weights)(outputs, y)
            else:
                if self.HP.LOSS_FUNCTION == "soft_sample_dice" or self.HP.LOSS_FUNCTION == "soft_batch_dice":
                    loss = criterion(outputs_sigmoid, y)
                    # loss = criterion(outputs_sigmoid, y) + nn.BCEWithLogitsLoss()(outputs, y)
                else:
                    loss = criterion(outputs, y)

            loss.backward()  # backward
            optimizer.step()  # optimise

            if self.HP.EXPERIMENT_TYPE == "peak_regression":
                # f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
                # f1_a = MetricUtils.calc_peak_dice_pytorch(self.HP, outputs.data, y.data, max_angle_error=self.HP.PEAK_DICE_THR)
                f1 = MetricUtils.calc_peak_length_dice_pytorch(
                    self.HP,
                    outputs.detach(),
                    y.detach(),
                    max_angle_error=self.HP.PEAK_DICE_THR,
                    max_length_error=self.HP.PEAK_DICE_LEN_THR)
                # f1 = (f1_a, f1_b)
            elif self.HP.EXPERIMENT_TYPE == "dm_regression":  #density map regression
                f1 = PytorchUtils.f1_score_macro(y.detach() > 0.5,
                                                 outputs.detach(),
                                                 per_class=True)
            else:
                f1 = PytorchUtils.f1_score_macro(y.detach(),
                                                 outputs_sigmoid.detach(),
                                                 per_class=True,
                                                 threshold=self.HP.THRESHOLD)

            if self.HP.USE_VISLOGGER:
                # probs = outputs_sigmoid.detach().cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
                probs = outputs_sigmoid
            else:
                probs = None  #faster

            return loss.item(), probs, f1
コード例 #3
0
ファイル: UNet_Pytorch_SE.py プロジェクト: doctoryfx/TractSeg
 def save_model(metrics, epoch_nr):
     max_f1_idx = np.argmax(metrics["f1_macro_validate"])
     max_f1 = np.max(metrics["f1_macro_validate"])
     if epoch_nr == max_f1_idx and max_f1 > 0.01:  # saving to network drives takes 5s (to local only 0.5s) -> do not save so often
         print("  Saving weights...")
         for fl in glob.glob(join(self.HP.EXP_PATH, "best_weights_ep*")):  # remove weights from previous epochs
             os.remove(fl)
         try:
             #Actually is a pkl not a npz
             PytorchUtils.save_checkpoint(join(self.HP.EXP_PATH, "best_weights_ep" + str(epoch_nr) + ".npz"), unet=net)
         except IOError:
             print("\nERROR: Could not save weights because of IO Error\n")
         self.HP.BEST_EPOCH = epoch_nr
コード例 #4
0
        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)

            weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  # 10

            criterion = nn.BCEWithLogitsLoss(weight=weights)
            loss = criterion(outputs, y)

            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            else:
                probs = None    #faster

            return loss.data[0], probs, f1
コード例 #5
0
ファイル: UNet_Pytorch.py プロジェクト: shmp0722/TractSeg
        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                probs = None  #faster

            return loss.data[0], probs, f1
コード例 #6
0
        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
            loss = criterion(outputs, y)

            loss.backward()  # backward
            optimizer.step()  # optimise

            if self.HP.CALC_F1:
                f1 = PytorchUtils.f1_score_macro(y.data > self.HP.THRESHOLD, outputs.data, per_class=True, threshold=self.HP.THRESHOLD)
            else:
                f1 = np.ones(outputs.shape[3])

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            else:
                # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)  # (bs, x, y, classes)
                probs = None    #faster

            return loss.data[0], probs, f1
コード例 #7
0
        def test(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(),
                                volatile=True), Variable(y.cuda(),
                                                         volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward

            weights = torch.ones(
                (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES,
                 self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  # 10

            criterion = nn.BCEWithLogitsLoss(weight=weights)
            loss = criterion(outputs, y)

            # loss = criterion(outputs, y)

            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1
コード例 #8
0
        def train(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda()), Variable(y.cuda(
                ))  # X: (bs, features, x, y)   y: (bs, classes, x, y)
            else:
                X, y = Variable(X), Variable(y)
            optimizer.zero_grad()
            net.train()
            outputs = net(X)  # forward     # outputs: (bs, classes, x, y)

            weights = torch.ones(
                (self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES,
                 self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  # 10

            criterion = nn.BCEWithLogitsLoss(weight=weights)
            loss = criterion(outputs, y)

            # loss = PytorchUtils.soft_dice(outputs, y)
            loss.backward()  # backward
            optimizer.step()  # optimise
            f1 = PytorchUtils.f1_score_macro(y.data,
                                             outputs.data,
                                             per_class=True)

            if self.HP.USE_VISLOGGER:
                probs = outputs.data.cpu().numpy().transpose(
                    0, 2, 3, 1)  # (bs, x, y, classes)
            else:
                probs = None  #faster

            return loss.data[0], probs, f1
コード例 #9
0
ファイル: MetricUtils.py プロジェクト: doctoryfx/TractSeg
    def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
コード例 #10
0
ファイル: UNet_Pytorch_SE.py プロジェクト: doctoryfx/TractSeg
 def test(X, y):
     X = torch.from_numpy(X.astype(np.float32))
     y = torch.from_numpy(y.astype(np.float32))
     if torch.cuda.is_available():
         X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
     else:
         X, y = Variable(X, volatile=True), Variable(y, volatile=True)
     net.train(False)
     outputs = net(X)  # forward
     loss = criterion(outputs, y)
     f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
     # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
     probs = None  # faster
     return loss.data[0], probs, f1
コード例 #11
0
ファイル: UNet_Pytorch_SE.py プロジェクト: doctoryfx/TractSeg
 def train(X, y):
     X = torch.from_numpy(X.astype(np.float32))
     y = torch.from_numpy(y.astype(np.float32))
     if torch.cuda.is_available():
         X, y = Variable(X.cuda()), Variable(y.cuda())  # X: (bs, features, x, y)   y: (bs, classes, x, y)
     else:
         X, y = Variable(X), Variable(y)
     optimizer.zero_grad()
     net.train()
     outputs = net(X)  # forward     # outputs: (bs, classes, x, y)
     loss = criterion(outputs, y)
     loss.backward()  # backward
     optimizer.step()  # optimise
     f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
     # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
     probs = None    #faster
     return loss.data[0], probs, f1
コード例 #12
0
ファイル: UNet_Pytorch.py プロジェクト: shmp0722/TractSeg
 def test(X, y, weight_factor=10):
     X = torch.from_numpy(X.astype(np.float32))
     y = torch.from_numpy(y.astype(np.float32))
     if torch.cuda.is_available():
         X, y = Variable(X.cuda(),
                         volatile=True), Variable(y.cuda(),
                                                  volatile=True)
     else:
         X, y = Variable(X, volatile=True), Variable(y, volatile=True)
     net.train(False)
     outputs = net(X)  # forward
     loss = criterion(outputs, y)
     f1 = PytorchUtils.f1_score_macro(y.data,
                                      outputs.data,
                                      per_class=True)
     # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
     probs = None  # faster
     return loss.data[0], probs, f1
コード例 #13
0
        def test(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward
            loss = criterion(outputs, y)

            if self.HP.CALC_F1:
                f1 = PytorchUtils.f1_score_macro(y.data > self.HP.THRESHOLD, outputs.data, per_class=True, threshold=self.HP.THRESHOLD)
            else:
                f1 = np.ones(outputs.shape[3])

            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1
コード例 #14
0
        def test(X, y, weight_factor=10):
            X = torch.from_numpy(X.astype(np.float32))
            y = torch.from_numpy(y.astype(np.float32))
            if torch.cuda.is_available():
                X, y = Variable(X.cuda(), volatile=True), Variable(y.cuda(), volatile=True)
            else:
                X, y = Variable(X, volatile=True), Variable(y, volatile=True)
            net.train(False)
            outputs = net(X)  # forward

            weights = torch.ones((self.HP.BATCH_SIZE, self.HP.NR_OF_CLASSES, self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[1])).cuda()
            bundle_mask = y > 0
            weights[bundle_mask.data] *= weight_factor  # 10

            criterion = nn.BCEWithLogitsLoss(weight=weights)
            loss = criterion(outputs, y)

            # loss = criterion(outputs, y)

            # loss = PytorchUtils.soft_dice(outputs, y)
            f1 = PytorchUtils.f1_score_macro(y.data, outputs.data, per_class=True)
            # probs = outputs.data.cpu().numpy().transpose(0,2,3,1)   # (bs, x, y, classes)
            probs = None  # faster
            return loss.data[0], probs, f1
コード例 #15
0
ファイル: MetricUtils.py プロジェクト: silongGG/TractSeg
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle
コード例 #16
0
ファイル: UNet_Pytorch.py プロジェクト: shmp0722/TractSeg
 def load_model(path):
     PytorchUtils.load_checkpoint(path, unet=net)
コード例 #17
0
ファイル: MetricUtils.py プロジェクト: doctoryfx/TractSeg
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle
コード例 #18
0
ファイル: UNet_Pytorch_SE.py プロジェクト: doctoryfx/TractSeg
 def load_model(path):
     PytorchUtils.load_checkpoint(path, unet=net)
コード例 #19
0
ファイル: MetricUtils.py プロジェクト: silongGG/TractSeg
    def calc_peak_length_dice_pytorch(HP,
                                      y_pred,
                                      y_true,
                                      max_angle_error=[0.9],
                                      max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred - lengths_true) < (
                max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle