Esempio n. 1
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        pix_w = loss_input_dict['pixel_weight']
        cls_w = loss_input_dict['class_weight']
        softmax = loss_input_dict['softmax']

        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)
        gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y

        if (self.enable_cls_weight):
            if (cls_w is None):
                raise ValueError("Class weight is enabled but not defined")
            gce = torch.sum(gce * cls_w, dim=1)
        else:
            gce = torch.sum(gce, dim=1)

        if (self.enable_pix_weight):
            if (pix_w is None):
                raise ValueError("Pixel weight is enabled but not defined")
            pix_w = reshape_tensor_to_2D(pix_w)
            gce = torch.sum(gce * pix_w) / torch.sum(pix_w)
        else:
            gce = torch.mean(gce)
        return gce
Esempio n. 2
0
File: mse.py Progetto: zz10001/PyMIC
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        pix_w = loss_input_dict['pixel_weight']
        cls_w = loss_input_dict['class_weight']
        softmax = loss_input_dict['softmax']

        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)
        se = self.get_prediction_error(predict, soft_y)
        if (self.enable_cls_weight):
            if (cls_w is None):
                raise ValueError("Class weight is enabled but not defined")
            mse = torch.sum(se * cls_w, dim=1) / torch.sum(cls_w)
        else:
            mse = torch.mean(se, dim=1)
        if (self.enable_pix_weight):
            if (pix_w is None):
                raise ValueError("Pixel weight is enabled but not defined")
            pix_w = reshape_tensor_to_2D(pix_w)
            mse = torch.sum(mse * pix_w) / torch.sum(pix_w)
        else:
            mse = torch.mean(mse)
        return mse
Esempio n. 3
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        pix_w = loss_input_dict['pixel_weight']
        cls_w = loss_input_dict['class_weight']
        softmax = loss_input_dict['softmax']

        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)

        ce = -soft_y * torch.log(predict)
        if (self.enable_cls_weight):
            if (cls_w is None):
                raise ValueError("Class weight is enabled but not defined")
            ce = torch.sum(ce * cls_w, dim=1)
        else:
            ce = torch.sum(ce, dim=1)  # shape is [N]
        if (self.enable_pix_weight):
            if (pix_w is None):
                raise ValueError("Pixel weight is enabled but not defined")
            pix_w = reshape_tensor_to_2D(pix_w)  # shape is [N, 1]
            pix_w = torch.squeeze(pix_w)  # squeeze to [N]
            ce = torch.sum(ce * pix_w) / torch.sum(pix_w)
        else:
            ce = torch.mean(ce)
        return ce
Esempio n. 4
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        pix_w = loss_input_dict['pixel_weight']
        cls_w = loss_input_dict['class_weight']
        softmax = loss_input_dict['softmax']

        if (isinstance(predict, (list, tuple))):
            predict = predict[0]
        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)

        numerator = torch.abs(predict - soft_y)
        numerator = torch.pow(numerator, self.gamma)
        denominator = predict + soft_y
        if (self.enable_pix_weight):
            if (pix_w is None):
                raise ValueError("Pixel weight is enabled but not defined")
            pix_w = reshape_tensor_to_2D(pix_w)
            numerator = numerator * pix_w
            denominator = denominator * pix_w
        numer_sum = torch.sum(numerator, dim=0)
        denom_sum = torch.sum(denominator, dim=0)
        loss_vector = numer_sum / (denom_sum + 1e-5)

        if (self.enable_cls_weight):
            if (cls_w is None):
                raise ValueError("Class weight is enabled but not defined")
            weighted_dice = loss_vector * cls_w
            loss = weighted_dice.sum() / cls_w.sum()
        else:
            loss = torch.mean(loss_vector)
        return loss
Esempio n. 5
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        img_w = loss_input_dict['image_weight']
        pix_w = loss_input_dict['pixel_weight']
        cls_w = loss_input_dict['class_weight']
        softmax = loss_input_dict['softmax']

        if (isinstance(predict, (list, tuple))):
            predict = predict[0]
        tensor_dim = len(predict.size())
        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)

        # combien pixel weight and image weight
        if (tensor_dim == 5):
            img_w = img_w[:, None, None, None, None]
        else:
            img_w = img_w[:, None, None, None]
        pix_w = pix_w * img_w
        pix_w = reshape_tensor_to_2D(pix_w)
        dice_score = get_classwise_dice(predict, soft_y, pix_w)

        weighted_dice = dice_score * cls_w
        average_dice = weighted_dice.sum() / cls_w.sum()
        dice_loss = 1.0 - average_dice
        return dice_loss
Esempio n. 6
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        softmax = loss_input_dict['softmax']

        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)

        dice_score = get_classwise_dice(predict, soft_y)
        dice_score = 0.01 + dice_score * 0.98
        exp_dice = -torch.log(dice_score)
        exp_dice = torch.pow(exp_dice, self.gamma)
        exp_dice = torch.mean(exp_dice)

        predict = 0.01 + predict * 0.98
        wc = torch.mean(soft_y, dim=0)
        wc = 1.0 / (wc + 0.1)
        wc = torch.pow(wc, 0.5)
        ce = -torch.log(predict)
        exp_ce = wc * torch.pow(ce, self.gamma)
        exp_ce = torch.sum(soft_y * exp_ce, dim=1)
        exp_ce = torch.mean(exp_ce)

        loss = exp_dice * self.w_dice + exp_ce * (1.0 - self.w_dice)
        return loss
Esempio n. 7
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y  = loss_input_dict['ground_truth']
        softmax = loss_input_dict['softmax']

        if(softmax):
            predict = nn.Softmax(dim = 1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y  = reshape_tensor_to_2D(soft_y) 

        dice_score = get_classwise_dice(predict, soft_y, None)
        dice_score = 0.01 + dice_score * 0.98
        dice_loss  = 1.0 - torch.pow(dice_score, 1.0 / self.beta)

        avg_loss = torch.mean(dice_loss)   
        return avg_loss