예제 #1
0
                      activation=None, task=args.task)

    optimizer = get_optimizer(optimizer=args.optimizer, lookahead=args.lookahead, model=model,
                              separate_decoder=args.separate_decoder, lr=args.lr, lr_e=args.lr_e)

    if args.scheduler == 'ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, factor=0.6, patience=3)
    else:
        scheduler = ReduceLROnPlateau(optimizer, factor=0.3, patience=3)

    if args.loss == 'BCEDiceLoss':
        criterion = smp.utils.losses.BCEDiceLoss(eps=1.)
    elif args.loss == 'BCEJaccardLoss':
        criterion = smp.utils.losses.BCEJaccardLoss(eps=1.)
    elif args.loss == 'FocalLoss':
        criterion = FocalLoss()
    # elif args.loss == 'lovasz_softmax':
    #     criterion = lovasz_softmax()
    elif args.loss == 'BCEMulticlassDiceLoss':
        criterion = BCEMulticlassDiceLoss()
    elif args.loss == 'MulticlassDiceMetricCallback':
        criterion = MulticlassDiceMetricCallback()
    elif args.loss == 'BCE':
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = smp.utils.losses.BCEDiceLoss(eps=1.)

    if args.multigpu:
        model = nn.DataParallel(model)

    if args.task == 'segmentation':
예제 #2
0
    def train(self, epoch):
        self.net.train(True)

        # Decay learning rate
        if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
            self.decayed_lr -= (self.lr / float(self.num_epochs_decay))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.decayed_lr
            print('epoch{}: Decay learning rate to lr: {}.'.format(
                epoch, self.decayed_lr))

        epoch_loss = 0

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0

        for i, (imgs, gts) in enumerate(tqdm(self.train_loader)):
            imgs = imgs.to(self.device)
            gts = gts.round().long().to(self.device)

            self.optimizer.zero_grad()

            outputs = self.net(imgs)

            # make sure shapes are the same by flattening them

            # weight = torch.tensor([1.,100.,100.,100.,50.,50.,80.,80.,50.,80.,80.,80.,50.,50.,70.,70.,70.,70.,
            #                        60.,60.,100.,100.,100.,]).to(self.device)

            #ce_loss = nn.CrossEntropyLoss(weight=weight,reduction='mean')(outputs, gts.reshape(-1,128,128,128))
            dice_loss = GeneralizedDiceLoss(sigmoid_normalization=False)(
                outputs, expand_as_one_hot(gts.reshape(-1, 128, 128, 128), 14))
            # bce_loss = torch.nn.BCEWithLogitsLoss()(outputs, gts)
            focal_loss = FocalLoss(num_class=14, alpha=None,
                                   gamma=1)(outputs,
                                            gts.reshape(-1, 128, 128, 128))

            loss = focal_loss
            #loss = focal_loss + dice_loss
            epoch_loss += loss.item() * imgs.size(
                0)  # because reduction = 'mean'
            loss.backward()
            self.optimizer.step()

            # DC += iou(outputs.detach().cpu().squeeze().argmax(dim=1),gts.detach().cpu(),n_classes=14)*imgs.size(0)
            length += imgs.size(0)

        # DC = DC / length
        epoch_loss = epoch_loss / length
        # # Print the log info
        # print(
        #     'Epoch [%d/%d], Loss: %.4f, \n[Training] DC: %.4f' % (
        #         epoch + 1, self.num_epochs,
        #         epoch_loss,
        #          DC))
        print('EPOCH{}, Loss{}'.format(epoch, epoch_loss))