def get_model(fn_args):
    """
    This function defines a Keras model and returns the model as a Keras object.
    """
    base_model, layers, layer_names = create_base_model(
        name=constants.BACKBONE_NAME,
        weights=constants.PRETRAINED_WEIGHTS,
        height=constants.HEIGHT,
        width=constants.WIDTH,
        include_top=False,
        pooling=None)

    model = DANet(n_classes=constants.N_MODEL_CLASSES,
                  base_model=base_model,
                  output_layers=layers,
                  backbone_trainable=constants.BACKBONE_TRAINABLE,
                  height=constants.HEIGHT,
                  width=constants.WIDTH).model()

    opt = tf.keras.optimizers.SGD(learning_rate=0.2, momentum=0.9)
    metrics = [IOUScore(threshold=0.5)]
    categorical_focal_dice_loss = CategoricalFocalLoss(alpha=0.25,
                                                       gamma=2.0) + DiceLoss()

    model.compile(
        optimizer=opt,
        loss=categorical_focal_dice_loss,
        metrics=metrics,
    )

    return model
Example #2
0
    def __define_criterion(self,
                           class_weights,
                           delta_var,
                           delta_dist,
                           norm=2,
                           optimize_bg=False,
                           criterion='CE'):
        assert criterion in ['CE', 'Dice', 'Multi', None]

        smooth = 1.0

        # Discriminative Loss
        if self.use_instance_segmentation:
            self.criterion_discriminative = DiscriminativeLoss(
                delta_var, delta_dist, norm, self.usegpu)
            if self.usegpu:
                self.criterion_discriminative = \
                    self.criterion_discriminative.cuda()

        # FG Segmentation Loss
        if class_weights is not None:
            class_weights = self.__define_variable(
                torch.FloatTensor(class_weights))
            if criterion in ['CE', 'Multi']:
                self.criterion_ce = torch.nn.CrossEntropyLoss(class_weights)
            if criterion in ['Dice', 'Multi']:
                self.criterion_dice = DiceLoss(optimize_bg=optimize_bg,
                                               weight=class_weights,
                                               smooth=smooth)
        else:
            if criterion in ['CE', 'Multi']:
                self.criterion_ce = torch.nn.CrossEntropyLoss()
            if criterion in ['Dice', 'Multi']:
                self.criterion_dice = DiceLoss(optimize_bg=optimize_bg,
                                               smooth=smooth)

        # MSE Loss
        self.criterion_mse = torch.nn.MSELoss()

        if self.usegpu:
            if criterion in ['CE', 'Multi']:
                self.criterion_ce = self.criterion_ce.cuda()
            if criterion in ['Dice', 'Multi']:
                self.criterion_dice = self.criterion_dice.cuda()

            self.criterion_mse = self.criterion_mse.cuda()
    def __init__(self,
                 model,
                 lr,
                 num_classes,
                 weight_ce,
                 weight_dice,
                 metrics=True):
        super().__init__()
        # model
        self.model = model

        # learning rate
        self.lr = lr

        # number of classes
        self.num_classes = num_classes

        # loss
        self.register_buffer('weight_ce', weight_ce)
        self.register_buffer('weight_dice', weight_dice)
        self.dice_loss = DiceLoss(weight=self.weight_dice)
        self.ce_loss = CrossEntropyLoss(weight=self.weight_ce)

        # save hyperparameters
        self.save_hyperparameters()

        # metrics
        self.metrics = metrics
        if self.metrics:
            self.f1_train = CustomMetric(metric=pl.metrics.functional.f1,
                                         metric_name='F1',
                                         num_classes=self.num_classes,
                                         average='none')
            self.f1_valid = CustomMetric(metric=pl.metrics.functional.f1,
                                         metric_name='F1',
                                         num_classes=self.num_classes,
                                         average='none')
            self.f1_test = CustomMetric(metric=pl.metrics.functional.f1,
                                        metric_name='F1',
                                        num_classes=self.num_classes,
                                        average='none')

            self.iou_train = CustomMetric(metric=pl.metrics.functional.iou,
                                          metric_name='IoU',
                                          num_classes=self.num_classes,
                                          reduction='none')

            self.iou_valid = CustomMetric(metric=pl.metrics.functional.iou,
                                          metric_name='IoU',
                                          num_classes=self.num_classes,
                                          reduction='none')

            self.iou_test = CustomMetric(metric=pl.metrics.functional.iou,
                                         metric_name='IoU',
                                         num_classes=self.num_classes,
                                         reduction='none')
Example #4
0
def select_loss(loss_function):
    if loss_function == 'bce':
        criterion = nn.BCELoss()
    elif loss_function == 'bce_logit':
        criterion = nn.BCEWithLogitsLoss()
    elif loss_function == 'dice':
        criterion = DiceLoss()
    elif loss_function == 'mse':
        criterion = nn.MSELoss()
    elif loss_function == 'l1':
        criterion = nn.L1Loss()
    elif loss_function == 'kl' or loss_function == 'jsd':
        criterion = nn.KLDivLoss()
    elif loss_function == 'Cldice':
        bce = nn.BCEWithLogitsLoss().cuda()
        dice = DiceLoss().cuda()
        criterion = ClDice(bce, dice, alpha=1, beta=1)
    else:
        raise ValueError('Not supported loss.')
    return criterion.cuda()
Example #5
0
def get_loss(mask_size):
    if loss == 'bce-with-logits':
        return WeightedBCEWithLogitsLoss(mask_size)
    elif loss == 'dice':
        return DiceLoss()
    elif loss == 'cross-entropy':
        l = torch.nn.CrossEntropyLoss()
        l.__name__ = 'Cross Entropy Loss'
        return l
    else:
        NotImplementedError("Unknown loss: {}".format(loss))
Example #6
0
def loss_function_select(loss_function):
    if loss_function == 'bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([args.bce_weight])).cuda()
    elif loss_function == 'dice':
        criterion = DiceLoss().cuda()
    elif loss_function == 'mse':
        criterion = nn.MSELoss().cuda()
    elif loss_function == 'l1':
        criterion = nn.L1Loss().cuda()
    elif loss_function == 'kl' or loss_function == 'jsd':
        criterion = nn.KLDivLoss().cuda()
    elif loss_function == 'Cldice':
        bce = nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([args.bce_weight])).cuda()
        dice = DiceLoss().cuda()
        criterion = ClDice(bce, dice, alpha=1, beta=1)
    else:
        raise ValueError('Not supported loss.')
    return criterion
Example #7
0
    def case_test(self, model, device, test_loader, case_id):
        self.model.eval()
        testNum = len(test_loader.dataset)
        diceloss = 0.
        Criterion = DiceLoss()
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data).view_as(target)
                diceloss += Criterion(output, target)

        diceloss /= testNum
        dice = 1 - diceloss
        return dice.item()
Example #8
0
def train(train_loader, student, teacher, optimizer, epoch):
    global global_step
    global best_dice

    # set criterion
    segmentation_criterion = DiceLoss(ignore=True)
    consistency_criterion = nn.MSELoss()

    #switch to train mode
    student.train()
    teacher.train()

    for i, (data, label) in enumerate(train_loader):
        # if we don't use cosine_annealing, adjust learning rate
        if args.cosine_annealing == False:
            adjust_learning_rate(optimizer, epoch, i, len(train_loader))

        data = data.to(device)
        label = label.to(device)

        # get the result of the two models
        student_pred = student(data)
        teacher_pred = teacher(data)
        # We don't want gradient descent in teacher model.
        teacher_pred = torch.autograd.Variable(teacher_pred.detach().data,
                                               requires_grad=False)

        # calculate consistency criterion
        consistency_weight = get_current_consistency_weight(epoch)
        consistency_loss = consistency_weight * consistency_criterion(
            student_pred, teacher_pred)
        # calculate segmentation loss
        segmentation_loss = segmentation_criterion(student_pred, label)
        # combine them to get the final loss
        loss = consistency_loss + segmentation_loss

        # compute gradient and do optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        update_ema_variables(student, teacher, args.ema_decay, global_step)

        # print training information
        if i % args.print_freq == 0:
            LOG.info('Epoch: [{}][{}/{}]\tLoss: {:.3f}\t'.format(
                epoch, i, len(train_loader), loss))
            LOG.info('Conssistency Loss: {:.3f}\t'.format(consistency_loss))
Example #9
0
 def __init__(self, model, args, train_dataset, eval_dataset,
              compute_metrics, loss_type, loss_gamma):
     Trainer.__init__(self,
                      model=model,
                      args=args,
                      train_dataset=train_dataset,
                      eval_dataset=eval_dataset,
                      compute_metrics=compute_metrics)
     if loss_type == 'DiceLoss':
         self.loss_fct = DiceLoss()
     elif loss_type == 'FocalLoss':
         self.loss_fct = FocalLoss(gamma=loss_gamma)
     elif loss_type == 'LabelSmoothingCrossEntropy':
         self.loss_fct = LabelSmoothingCrossEntropy()
     elif loss_type == 'CrossEntropyLoss':
         self.loss_fct = CrossEntropyLoss()
     elif loss_type == 'CourageLoss':
         self.loss_fct = CourageLoss(gamma=loss_gamma)
     else:
         raise ValueError("Doesn't support such loss type")
Example #10
0
def main():
    args = arg_parse()
    mode = args.mode
    lossf = args.lossf.lower()
    n_epochs = args.epochs
    opt = args.opt.lower()
    lr = args.lr
    hash_code = '_'.join(
        list(map(str, [args.hashcode, args.opt, args.lr, args.lossf])))
    device = 'cuda:' + args.gpu
    withlen = args.withlen.lower() == 'true'
    normed = args.normed.lower() == 'true'
    mu = args.mu
    alpha = args.alpha
    beta = args.beta
    modelpath = args.modelpath
    DEFAULT_BATCHSIZE = args.batchsize
    modeltype = args.model
    if mode == 'test' and not modelpath:
        raise ValueError("Need model path for testing!!")

    if mode == 'train':
        DRIVE_train = DRIVEDataset("train",
                                   DRIVE_train_imgs_original,
                                   DRIVE_train_groudTruth,
                                   DRIVE_train_narrowBand,
                                   patch_height,
                                   patch_width,
                                   N_subimgs,
                                   val_size=0)
        print(len(DRIVE_train))

        # DRIVE_valid = DRIVE_train.get_validation_dataset()
        DRIVE_train_load = \
            torch.utils.data.DataLoader(dataset=DRIVE_train,
                                        batch_size=DEFAULT_BATCHSIZE, shuffle=True)

        # DRIVE_val_load = \
        #     torch.utils.data.DataLoader(dataset=DRIVE_valid,
        #                                 batch_size=128, shuffle=True)
        DRIVE_valid = DRIVEDataset("test", DRIVE_test_imgs_original,
                                   DRIVE_test_groundTruth,
                                   DRIVE_test_narrowBand, patch_height,
                                   patch_width, None)
        DRIVE_val_load = \
            torch.utils.data.DataLoader(dataset=DRIVE_valid,
                                        batch_size=DEFAULT_BATCHSIZE, shuffle=False)

    else:
        DRIVE_test = DRIVEDataset("test", DRIVE_test_imgs_original,
                                  DRIVE_test_groundTruth,
                                  DRIVE_test_narrowBand, patch_height,
                                  patch_width, None)
        DRIVE_test_load = \
            torch.utils.data.DataLoader(dataset=DRIVE_test,
                                        batch_size=DEFAULT_BATCHSIZE, shuffle=False)

    shape = (1, 48, 48)
    if torch.cuda.is_available():
        if modeltype.lower() == 'fcn':
            model = FCN8s(n_class=1).to(device)
        else:
            model = UNet1024(shape).to(device)
    else:
        if modeltype.lower() == 'fcn':
            model = FCN8s(n_class=1)
        else:
            model = UNet1024(shape)

    if lossf == 'bce':
        criterion = BinaryCrossEntropyLoss2d()
    elif lossf == 'dice':
        criterion = DiceLoss()
    elif lossf == 'contour':
        criterion = ContourLoss(device=device,
                                mu=mu,
                                alpha=alpha,
                                beta=beta,
                                normed=normed,
                                withlen=withlen)
    elif lossf == 'contour-v3':
        criterion = ContourLossV3(device=device,
                                  mu=mu,
                                  alpha=alpha,
                                  beta=beta,
                                  normed=normed,
                                  withlen=withlen)
    elif lossf == 'contour-v2':
        criterion = ContourLossV2(device=device,
                                  mu=mu,
                                  alpha=alpha,
                                  beta=beta,
                                  normed=normed,
                                  withlen=withlen)
    elif lossf == 'contour-v4':
        criterion = ContourLossV4(device=device,
                                  mu=mu,
                                  normed=normed,
                                  withlen=withlen)
    elif lossf == 'focal':
        criterion = FocalLoss()
    else:
        raise ValueError('Undefined loss type')

    optimizer = None
    if opt == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    if opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Saving History to csv
    header = [
        'epoch', 'train loss', 'train auc', 'train accuracy', 'val loss',
        'val auc', 'val accuracy', 'val jaccard', 'val sensitivity',
        'val specitivity', 'val precision', 'val f1', 'val pr auc'
    ]
    save_file_name = f"../history/{hash_code}/history.csv"
    save_dir = f"../history/{hash_code}"

    # Saving images and models directories
    model_save_dir = f"../history/{hash_code}/saved_models"
    image_save_path = f"../history/{hash_code}/result_images"

    # Train
    if mode == 'train':
        print("Initializing Training!")
        min_loss = 1e9
        best_epoch = 0
        early_stop_count = 0
        max_count = 5
        for i in range(n_epochs):
            train_loss, train_acc = train_model(i, model, DRIVE_train_load,
                                                criterion, optimizer, device)
            print('Epoch', str(i + 1), 'Train loss:', train_loss, 'Train acc:',
                  train_acc)

            if (i + 1) % 5 == 0:
                # val_loss, val_acc = validate_model(model, DRIVE_val_load, criterion, device)
                # print('Epoch', str(i+1),
                #       'Val loss:', val_loss,
                #       'Val acc:', val_acc,
                #       )

                val_loss, val_auc, val_accuracy, val_jaccard, val_sensitivity,\
                    val_specitivity, val_precision, val_f1, val_pr_auc, val_iou = test_model(model, DRIVE_val_load, criterion,
                                                                                             test_border_masks,
                                                                                             f'{image_save_path}/{i+1}/', device)

                print(
                    'Epoch',
                    str(i + 1),
                    'Val loss:',
                    val_loss,
                    'Val acc:',
                    val_accuracy,
                )

                values = [
                    i + 1, train_loss, train_acc, val_loss, val_auc,
                    val_accuracy, val_jaccard, val_sensitivity,
                    val_specitivity, val_precision, val_f1, val_pr_auc, val_iou
                ]
                export_history(header, values, save_dir, save_file_name)

                if val_loss < min_loss:
                    early_stop_count = 0
                    min_loss = val_loss
                    best_epoch = i
                    save_model(model, model_save_dir, i + 1)
                else:
                    early_stop_count += 1
                    if early_stop_count > max_count:
                        print(
                            'Traning can not improve from epoch {}\tBest loss: {}'
                            .format(best_epoch, min_loss))
                        break
    else:
        print("Initializing Testing!")
        model.load_state_dict(torch.load(modelpath))
        model.eval()
        test_auc, test_accuracy, test_jaccard, test_sensitivity, test_specitivity, \
            test_precision, test_f1, test_pr_auc \
            = test_model_img(model,
                             DRIVE_test_load,
                             test_border_masks,
                             f'{image_save_path}/',
                             device)
        print('Test auc: ', test_auc)
        print('Test accuracy: ', test_accuracy)
        print('Test jaccard: ', test_jaccard)
        print('Test sensitivity: ', test_sensitivity)
        print('Test specitivity: ', test_specitivity)
        print('Test precision: ', test_precision)
        print('Test f1: ', test_f1)
        print('Test PR-AUC: ', test_pr_auc)
        print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
        print('AUC mean: {} - std: {}'.format(*mean_std(test_auc)))
        print('ACCURACY mean: {} - std: {}'.format(*mean_std(test_accuracy)))
        print('JACCARD mean: {} - std: {}'.format(*mean_std(test_jaccard)))
        print('SENS mean: {} - std: {}'.format(*mean_std(test_sensitivity)))
        print('SPEC mean: {} - std: {}'.format(*mean_std(test_specitivity)))
        print('PRECISION mean: {} - std: {}'.format(*mean_std(test_precision)))
        print('F1 mean: {} - std: {}'.format(*mean_std(test_f1)))
        print('PR AUC: {} - std: {}'.format(*mean_std(test_pr_auc)))
        print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
Example #11
0
def train_net(net, options):

    data_path = options.data_path + '240dataset/'
    csv_file = options.data_path + 'new_train.csv'
    origin_spacing_data_path = options.data_path + 'origin_spacing_croped/'

    # z_size is the random crop size along z-axis, you can set it larger if have enough gpu memory
    trainset = BrainDataset(csv_file,
                            data_path,
                            data_path,
                            mode='train',
                            z_size=40)
    trainLoader = data.DataLoader(trainset,
                                  batch_size=options.batch_size,
                                  shuffle=True,
                                  num_workers=0)

    test_data_list, test_label_list = load_test_data(origin_spacing_data_path)

    writer = SummaryWriter(options.log_path + options.unique_name)

    optimizer = optim.SGD(net.parameters(),
                          lr=options.lr,
                          momentum=0.9,
                          weight_decay=0.0005)

    org_weight = torch.FloatTensor(options.org_weight).unsqueeze(1).cuda()
    criterion_fl = FocalLoss(10, alpha=org_weight)
    criterion_dl = DiceLoss()

    best_dice = 0
    for epoch in range(options.epochs):
        print('Starting epoch {}/{}'.format(epoch + 1, options.epochs))
        epoch_loss = 0

        multistep_scheduler = multistep_lr_scheduler_with_warmup(
            optimizer,
            init_lr=options.lr,
            epoch=epoch,
            warmup_epoch=5,
            lr_decay_epoch=[200, 400],
            max_epoch=options.epochs,
            gamma=0.1)
        print('current lr:', multistep_scheduler)

        net.train()
        for i, (img, label, weight) in enumerate(trainLoader, 0):

            img = img.cuda()
            label = label.cuda()
            weight = weight.cuda()

            end = time.time()

            optimizer.zero_grad()

            result = net(img)

            if options.rlt > 0:
                loss = criterion_fl(result, label,
                                    weight) + options.rlt * criterion_dl(
                                        result, label, weight)
            else:
                loss = criterion_dl(result, label, weight)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            batch_time = time.time() - end
            print('batch loss: %.5f, batch_time:%.5f' %
                  (loss.item(), batch_time))
        print('[epoch %d] epoch loss: %.5f' % (epoch + 1, epoch_loss /
                                               (i + 1)))

        writer.add_scalar('Train/Loss', epoch_loss / (i + 1), epoch + 1)
        writer.add_scalar('LR', multistep_scheduler, epoch + 1)

        if os.path.isdir('%s%s/' % (options.cp_path, options.unique_name)):
            pass
        else:
            os.mkdir('%s%s/' % (options.cp_path, options.unique_name))

        if (epoch + 1) % 10 == 0:
            torch.save(
                net.state_dict(), '%s%s/CP%d.pth' %
                (options.cp_path, options.unique_name, epoch))

        avg_dice, dice_list = validation(net, test_data_list, test_label_list)
        writer.add_scalar('Test/AVG_Dice', avg_dice, epoch + 1)
        for idx in range(9):
            writer.add_scalar('Test/Dice%d' % (idx + 1), dice_list[idx],
                              epoch + 1)

        if avg_dice >= best_dice:
            best_dice = avg_dice
            torch.save(
                net.state_dict(),
                '%s%s/best.pth' % (options.cp_path, options.unique_name))

        print('save done')
        print('dice: %.5f/best dice: %.5f' % (avg_dice, best_dice))
Example #12
0
    valid_idx = fold[1]
    train_set = Dataset([folder_paths[i] for i in train_idx],
                        [folder_ids[i] for i in train_idx])
    valid_set = Dataset([folder_paths[i] for i in valid_idx],
                        [folder_ids[i] for i in valid_idx])
    train_loader = data.DataLoader(train_set, **params)
    valid_loader = data.DataLoader(valid_set, **params)

    # Model
    model = Modified3DUNet(in_channels, n_classes, base_n_filter)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    # Loss and optimizer
    #criterion = torch.nn.CrossEntropyLoss().to(device) # For cross entropy
    criterion = DiceLoss()  # For dice loss
    optimizer = torch.optim.Adam(model.parameters())

    # Load model and optimizer parameters if the training was interrupted and must be continued - need to also change epoch range in for loop
    # checkpoint = torch.load("/content/drive/My Drive/Brats2019/Model_Saves_KFold/Fold_1_Epoch_30_Train_Loss_0.0140_Valid_Loss_0.0137.tar")
    # model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #model.train()
    for epoch in range(1, max_epochs + 1):
        start_time = time.time()
        train_losses = []
        for batch, labels in train_loader:
            # Data Augment
            #augmenter = DataAugment(batch,labels)
            #batch,labels = augmenter.augment()
Example #13
0
        'lr': 5e-3,
        'weight_decay': 0.00003
    },
    {
        'params': model.encoder.parameters(),
        'lr': 5e-3 / 10,
        'weight_decay': 0.00003
    },
])

optimizer = contrib.nn.Lookahead(base_optimizer)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.2,
                                                       patience=5)
diceloss = DiceLoss()
loss_list = []
dice_list = []
for epoch in tqdm(range(EPOCHES)):
    ###Train
    train_loss = 0
    for data in d_train:
        optimizer.zero_grad()
        img, mask = data
        img = img.to(DEVICE)
        mask = mask.to(DEVICE)

        outputs = model(img)
        dice = dice_coef(outputs, mask)
        loss = diceloss(outputs, mask)
        loss.backward()
Example #14
0
def main():

    work_dir = os.path.join(args.work_dir, args.exp)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    #train
    trn_image_root = os.path.join(args.trn_root, 'images')
    exam_ids = os.listdir(trn_image_root)
    random.shuffle(exam_ids)
    train_exam_ids = exam_ids

    #train_exam_ids = exam_ids[:int(len(exam_ids)*0.8)]
    #val_exam_ids = exam_ids[int(len(exam_ids) * 0.8):]

    # train_dataset
    trn_dataset = DatasetTrain(args.trn_root,
                               train_exam_ids,
                               options=args,
                               input_stats=[0.5, 0.5])
    trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    # save input stats for later use
    np.save(os.path.join(work_dir, 'input_stats.npy'), trn_dataset.input_stats)

    #val
    val_image_root = os.path.join(args.val_root, 'images')
    val_exam = os.listdir(val_image_root)
    random.shuffle(val_exam)
    val_exam_ids = val_exam

    # val_dataset
    val_dataset = DatasetVal(args.val_root,
                             val_exam_ids,
                             options=args,
                             input_stats=trn_dataset.input_stats)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.num_workers)

    # make logger
    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # model_select
    if args.model == 'unet':
        net = UNet3D(1,
                     1,
                     f_maps=args.f_maps,
                     depth_stride=args.depth_stride,
                     conv_layer_order=args.conv_layer_order,
                     num_groups=args.num_groups)

    else:
        raise ValueError('Not supported network.')

    # loss_select
    if args.loss_function == 'bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([args.bce_weight])).cuda()
    elif args.loss_function == 'dice':
        criterion = DiceLoss().cuda()
    elif args.loss_function == 'weight_bce':
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.FloatTensor([5])).cuda()
    else:
        raise ValueError('{} loss is not supported yet.'.format(
            args.loss_function))

    # optim_select
    if args.optim == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay,
                              nesterov=False)

    elif args.optim == 'adam':
        optimizer = optim.Adam(net.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    else:
        raise ValueError('{} optim is not supported yet.'.format(args.optim))

    net = nn.DataParallel(net).cuda()
    cudnn.benchmark = True

    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=0.1)

    best_iou = 0
    for epoch in range(lr_schedule[-1]):

        train(trn_loader, net, criterion, optimizer, epoch, trn_logger,
              trn_raw_logger)
        iou = validate(val_loader, net, criterion, epoch, val_logger)

        lr_scheduler.step()

        # save model parameter
        is_best = iou > best_iou
        best_iou = max(iou, best_iou)
        checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(epoch + 1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }, is_best, work_dir, checkpoint_filename)

    # visualize curve
    draw_curve(work_dir, trn_logger, val_logger)

    if args.inplace_test:
        # calc overall performance and save figures
        print('Test mode ...')
        main_test(model=net, args=args)
Example #15
0
def main():
    print(args.work_dir, args.exp)
    work_dir = os.path.join(args.work_dir, args.exp)

    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    # 1.dataset
    train_filename = args.trn_root
    test_filename = args.test_root

    trainset = Segmentation_2d_data(train_filename)
    valiset = Segmentation_2d_data(test_filename)

    train_loader = data.DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
    valid_loader = data.DataLoader(valiset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)

    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    if args.model == 'unet':
        net = Unet2D(in_shape=(1, 512, 512), padding=args.padding_size, momentum=args.batchnorm_momentum)
    elif args.model == 'unetcoord':
        net = Unet2D_coordconv(in_shape=(1, 512, 512), padding=args.padding_size,
                            momentum=args.batchnorm_momentum, coordnumber=args.coordconv_no, radius=False)
    elif args.model == 'unetmultiinput':
        net = Unet2D_multiinput(in_shape=(1, 512, 512), padding=args.padding_size,
                                momentum=args.batchnorm_momentum)
    elif args.model == 'scse_block':
        net = Unet_sae(in_shape=(1, 512, 512), padding=args.padding_size, momentum=args.batchnorm_momentum)
    else:
        raise ValueError('Not supported network.')

    # loss
    if args.loss_function == 'bce':
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([args.bce_weight])).cuda()
    elif args.loss_function == 'dice':
        criterion = DiceLoss().cuda()
    else:
        raise ValueError('{} loss is not supported yet.'.format(args.loss_function))

    # optim
    if args.optim_function == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), lr=args.initial_lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)
    elif args.optim_function == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.initial_lr, weight_decay=args.weight_decay)
    elif args.optim_function == 'radam':
        optimizer = RAdam(net.parameters(), lr=args.initial_lr, weight_decay = args.weight_decay)
    else:
        raise ValueError('{} loss is not supported yet.'.format(args.optim_function))

    net = nn.DataParallel(net).cuda()

    cudnn.benchmark = True

    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=0.1)
    best_iou = 0

    for epoch in range(lr_schedule[-1]):
        train(train_loader, net, criterion, optimizer, epoch, trn_logger, trn_raw_logger)

        iou = validate(valid_loader, net, criterion, epoch, val_logger)
        lr_scheduler.step()

        is_best = iou > best_iou
        best_iou = max(iou, best_iou)
        checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(epoch + 1)
        save_checkpoint({'epoch': epoch + 1,
                            'state_dict': net.state_dict(),
                            'optimizer': optimizer.state_dict()},
                        is_best,
                        work_dir,
                        checkpoint_filename)

    draw_curve(work_dir, trn_logger, val_logger)
Example #16
0
def get_lossfn():
    bce = nn.BCEWithLogitsLoss()
    dice = DiceLoss(mode='binary', log_loss=True, smooth=1e-7)
    focal = losses.BinaryFocalLoss(alpha=0.25, reduced_threshold=0.5)
    criterion = ImanipLoss(bce, seglossA=dice, seglossB=focal)
    return criterion