Пример #1
0
    def __init__(self,
                 label_loss,
                 density_map_loss,
                 count_loss,
                 output_folder=None):
        super(CountingDiceLoss, self).__init__()
        self.loss = SoftDiceLoss(
            softmax_helper, **{
                'batch_dice': False,
                'smooth': 1e-5,
                'do_bg': False
            })
        self.loss_density_map = torch.nn.MSELoss(
        )  # WeightedRobustCrossEntropyLoss([0.001, 0.999])

        self.loss_n_ma = torch.nn.MSELoss()
        self.n = 0
        self.output_folder = output_folder

        self.label_loss = label_loss
        self.density_map_loss = density_map_loss
        self.count_loss = count_loss

        self.l_ = []
        self.l_dm = []
        self.l_n = []
        self.l_total = []
        self.sizes = []
Пример #2
0
 def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
              unpack_data=True, deterministic=True, fp16=False):
     super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
                      deterministic, fp16, deep_supervision=False)
     soft_dice_kwargs = {'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}
     self.loss = SoftDiceLoss(**soft_dice_kwargs)
     self.max_num_epochs = 500
Пример #3
0
 def __init__(self, alpha=.25, gamma=2):
     super(SAWLoss, self).__init__()
     self.dice = SoftDiceLoss(
         softmax_helper, **{
             'batch_dice': False,
             'smooth': 1e-5,
             'do_bg': False
         })
     self.l2 = MSELoss()
Пример #4
0
 def __init__(self,
              plans_file,
              fold,
              output_folder=None,
              dataset_directory=None,
              batch_dice=True,
              stage=None,
              unpack_data=True,
              deterministic=True,
              fp16=False):
     super().__init__(plans_file, fold, output_folder, dataset_directory,
                      batch_dice, stage, unpack_data, deterministic, fp16)
     self.apply_nonlin = softmax_helper
     self.loss = SoftDiceLoss(apply_nonlin=self.apply_nonlin,
                              batch_dice=self.batch_dice,
                              smooth=1e-5,
                              do_bg=False)
Пример #5
0
 def __init__(self,
              plans_file,
              fold,
              output_folder=None,
              dataset_directory=None,
              batch_dice=True,
              stage=None,
              unpack_data=True,
              deterministic=True,
              fp16=False):
     super().__init__(plans_file, fold, output_folder, dataset_directory,
                      batch_dice, stage, unpack_data, deterministic, fp16)
     self.loss = SoftDiceLoss(apply_nonlin=torch.sigmoid,
                              **{
                                  'batch_dice': False,
                                  'do_bg': True,
                                  'smooth': 0
                              })
Пример #6
0
 def __init__(self,
              plans_file,
              fold,
              output_folder=None,
              dataset_directory=None,
              batch_dice=True,
              stage=None,
              unpack_data=True,
              deterministic=True,
              fp16=False):
     super().__init__(plans_file, fold, output_folder, dataset_directory,
                      batch_dice, stage, unpack_data, deterministic, fp16)
     self.loss = SoftDiceLoss(
         **{
             'apply_nonlin': softmax_helper,
             'batch_dice': self.batch_dice,
             'smooth': 1e-5,
             'do_bg': True
         })
Пример #7
0
 def initialize_optimizer_and_scheduler(self):
     self.opt_loss.append((torch.optim.Adam(self.network.parameters(), self.initial_lrs[0]),
                           SoftDiceLoss(softmax_helper, **{'batch_dice': False, 'smooth': 1e-5, 'do_bg': False})))
     self.opt_loss.append((torch.optim.Adam(self.network.parameters(), self.initial_lrs[1]),
                           torch.nn.MSELoss()))