Пример #1
0
def do_epoch(dm, model, optimizer, criterion, lr_sched, mode='train'):
    dc = []
    L = []
    # lr_sched.step()
    if mode == 'train':
        model.train()
        for idx, sample in enumerate(dm):
            X = Variable(sample['sat'].to(device))
            Y = Variable(sample['mask'].to(device))

            optimizer.zero_grad()
            y_pred = model(X)
            y_pred = torch.nn.functional.sigmoid(y_pred)
            loss = criterion(y_pred, Y)
            loss.backward()
            optimizer.step()

            l = loss.detach().cpu().numpy()

            dc.append(dice_coeff(y_pred, Y))
            L.append(l)

            print('\r', 'Epoch: ', epoch,
                  'step: ,' idx, '|', len(dm),
                  'Loss: %.4f' % np.mean(L),
                  'Dice: %.4f' % np.mean(dc),
                  end=' ')

    elif mode == 'valid':
        model.eval()
        with torch.no_grad():
            for idx, sample in enumerate(dm):
                X = Variable(sample['sat'].cuda())
                Y = Variable(sample['mask'].cuda())

                y_pred = model(X)
                y_pred = torch.nn.functional.sigmoid(y_pred)

                dc.append(dice_coeff(y_pred, Y))

                print('\r', idx, '|', len(dm),
                      'Dice: %.4f' % np.mean(dc),
                      end=' ')

    if model == 'valid':
       lr_sched.step(np.mean(dc))
    return np.mean(dc)
Пример #2
0
def make_train_step(idx, data, model, optimizer, criterion, meters):

    # get the inputs and wrap in Variable
    if torch.cuda.is_available():
        inputs = Variable(data['sat_img'].cuda())
        labels = Variable(data['map_img'].cuda())
    else:
        inputs = Variable(data['sat_img'])
        labels = Variable(data['map_img'])

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward
    # prob_map = model(inputs) # last activation was a sigmoid
    # outputs = (prob_map > 0.3).float()
    outputs = model(inputs)

    # pay attention to the weighted loss should input logits not probs
    if args.lovasz_loss:
        loss, BCE_loss, DICE_loss = criterion(outputs, labels)
        outputs = torch.nn.functional.sigmoid(outputs)
    else:
        outputs = torch.nn.functional.sigmoid(outputs)
        loss, BCE_loss, DICE_loss = criterion(outputs, labels)

    # backward
    loss.backward()
    # https://github.com/asanakoy/kaggle_carvana_segmentation/blob/master/albu/src/train.py
    # torch.nn.utils.clip_grad_norm(model.parameters(), 1.)
    optimizer.step()

    meters["train_acc"].update(metrics.dice_coeff(outputs, labels),
                               outputs.size(0))
    meters["train_loss"].update(loss.data[0], outputs.size(0))
    meters["train_IoU"].update(metrics.jaccard_index(outputs, labels),
                               outputs.size(0))
    meters["train_BCE"].update(BCE_loss.data[0], outputs.size(0))
    meters["train_DICE"].update(DICE_loss.data[0], outputs.size(0))
    meters["outputs"] = outputs
    return meters
Пример #3
0
def validate(net, loader, device):
    net.eval()

    loss = 0
    with tqdm(total=len(loader),
              dynamic_ncols=True,
              desc="validation",
              unit="batch",
              leave=False) as pbar:
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)

            with torch.no_grad():
                output, _ = net(images)

            output = torch.sigmoid(output)
            output = (output > 0.5).float()

            loss += dice_coeff(output, masks).item()
            pbar.update()

    net.train()
    return loss / len(loader)
def compile_model(model, num_classes, metrics, loss, lr):
    from keras.losses import binary_crossentropy
    from keras.losses import categorical_crossentropy

    from keras.metrics import binary_accuracy
    from keras.metrics import categorical_accuracy

    from keras.optimizers import Adam

    from metrics import dice_coeff
    from metrics import jaccard_index
    from metrics import class_jaccard_index
    from metrics import pixelwise_precision
    from metrics import pixelwise_sensitivity
    from metrics import pixelwise_specificity
    from metrics import pixelwise_recall

    from losses import focal_loss

    if isinstance(loss, str):
        if loss in {'ce', 'crossentropy'}:
            if num_classes == 1:
                loss = binary_crossentropy
            else:
                loss = categorical_crossentropy
        elif loss in {'focal', 'focal_loss'}:
            loss = focal_loss(num_classes)
        else:
            raise ValueError('unknown loss %s' % loss)

    if isinstance(metrics, str):
        metrics = [metrics, ]

    for i, metric in enumerate(metrics):
        if not isinstance(metric, str):
            continue
        elif metric == 'acc':
            metrics[i] = binary_accuracy if num_classes == 1 else categorical_accuracy
        elif metric == 'jaccard_index':
            metrics[i] = jaccard_index(num_classes)
        elif metric == 'jaccard_index0':
            metrics[i] = class_jaccard_index(0)
        elif metric == 'jaccard_index1':
            metrics[i] = class_jaccard_index(1)
        elif metric == 'jaccard_index2':
            metrics[i] = class_jaccard_index(2)
        elif metric == 'jaccard_index3':
            metrics[i] = class_jaccard_index(3)
        elif metric == 'jaccard_index4':
            metrics[i] = class_jaccard_index(4)
        elif metric == 'jaccard_index5':
            metrics[i] = class_jaccard_index(5)
        elif metric == 'dice_coeff':
            metrics[i] = dice_coeff(num_classes)
        elif metric == 'pixelwise_precision':
            metrics[i] = pixelwise_precision(num_classes)
        elif metric == 'pixelwise_sensitivity':
            metrics[i] = pixelwise_sensitivity(num_classes)
        elif metric == 'pixelwise_specificity':
            metrics[i] = pixelwise_specificity(num_classes)
        elif metric == 'pixelwise_recall':
            metrics[i] = pixelwise_recall(num_classes)
        else:
            raise ValueError('metric %s not recognized' % metric)

    model.compile(optimizer=Adam(lr=lr),
                  loss=loss,
                  metrics=metrics)
Пример #5
0
def test_model():
    MAX = 0
    save = 0
    # for index in range(100):
    model = Insensee_3Dunet(1).to(device)
    model.load_state_dict(torch.load('./3dunet_model_save/weights_390.pth'))
    test_dataset = MRIdataset.LiverDataset(MRIdataset.test_imagepath, MRIdataset.test_labelpath,
                                           MRIdataset.testimg_ids, MRIdataset.testlabel_ids)

    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    train_dataset = MRIdataset.LiverDataset(MRIdataset.imagepath, MRIdataset.labelpath,
                                           MRIdataset.img_ids, MRIdataset.label_ids)

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=4)



    # train_iter = enumerate(train_loader)
    model.eval()
    LV_Dice = 0
    LV_Jac = 0
    RV_Dice = 0
    RV_Jac = 0
    Myo_Dice = 0
    Myo_Jac = 0
    i = 0

    # _, batch = train_iter.__next__()
    # _, batch = train_iter.__next__()
    # img, label= batch
    # i = 1
    # # print(img.shape)
    # # img =torch.squeeze(img)
    # # img = img.numpy().astype(float)
    # # img = img.transpose(1, 2, 0)
    # # new_image = nib.Nifti1Image(img, np.eye(4))
    # # nib.save(new_image, r'/home/peng/Desktop/CROP/train2test/trainimg_%d.nii.gz' % i)
    #
    # img = img.to(device)
    # pred = torch.argmax(model(img), 1)
    # pred = pred.cpu()
    # pred = torch.squeeze(pred)
    # pred = pred.numpy().astype(float)
    # pred = pred.transpose(1, 2, 0)
    #
    # label = torch.squeeze(label)
    # label = label.numpy().astype(float)
    # label = label.transpose(1, 2, 0)
    #
    #
    # new_image = nib.Nifti1Image(pred, np.eye(4))
    # nib.save(new_image, r'/home/peng/Desktop/CROP/train2test/pred_%d.nii.gz' % i)
    #
    # new_label = nib.Nifti1Image(label, np.eye(4))
    # nib.save(new_label, r'/home/peng/Desktop/CROP/train2test/label_%d.nii.gz' % i)

    for img, label in test_loader:
        img = img.to(device)
        pred = torch.argmax(model(img), 1)
        pred = pred.cpu()

        LV_dice, LV_jac, RV_dice, RV_jac, Myo_dice, Myo_jac = dice_coeff(pred, label)

        LV_Dice += LV_dice
        LV_Jac += LV_jac
        RV_Dice += RV_dice
        RV_Jac += RV_jac
        Myo_Dice += Myo_dice
        Myo_Jac += Myo_jac

        print('LV_Dice_%d:' % i, '%.6f' % LV_dice, '||', 'RV_Dice_%d:' % i, '%.6f' % RV_dice, '||'
              , 'Myo_Dice_%d:' % i, '%.6f' % Myo_dice)

        i += 1
    print('===============================================')
    print('LV_Dice_avg:', LV_Dice / i, 'RV_Dice_avg:', RV_Dice / i, 'Myo_Dice_avg:', Myo_Dice / i)
Пример #6
0
def validation(valid_loader, model, criterion, logger, epoch_num):
    """

    Args:
        train_loader:
        model:
        criterion:
        optimizer:
        epoch:

    Returns:

    """
    # logging accuracy and loss
    valid_acc = metrics.MetricTracker()
    valid_loss = metrics.MetricTracker()
    valid_IoU = metrics.MetricTracker()
    valid_BCE = metrics.MetricTracker()
    valid_DICE = metrics.MetricTracker()

    log_iter = len(valid_loader) // logger.print_freq

    # switch to evaluate mode
    model.eval()

    # Iterate over data.
    for idx, data in enumerate(tqdm(valid_loader, desc='validation')):

        # get the inputs and wrap in Variable
        if torch.cuda.is_available():
            inputs = Variable(data['sat_img'].cuda(), volatile=True)
            labels = Variable(data['map_img'].cuda(), volatile=True)
        else:
            inputs = Variable(data['sat_img'], volatile=True)
            labels = Variable(data['map_img'], volatile=True)

        # forward
        # prob_map = model(inputs) # last activation was a sigmoid
        # outputs = (prob_map > 0.3).float()
        outputs = model(inputs)

        # pay attention to the weighted loss should input logits not probs
        if args.lovasz_loss:
            loss, BCE_loss, DICE_loss = criterion(outputs, labels)
            outputs = torch.nn.functional.sigmoid(outputs)
        else:
            outputs = torch.nn.functional.sigmoid(outputs)
            loss, BCE_loss, DICE_loss = criterion(outputs, labels)

        valid_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
        valid_loss.update(loss.data[0], outputs.size(0))
        valid_IoU.update(metrics.jaccard_index(outputs, labels),
                         outputs.size(0))
        valid_BCE.update(BCE_loss.data[0], outputs.size(0))
        valid_DICE.update(DICE_loss.data[0], outputs.size(0))

        # tensorboard logging
        if idx % log_iter == 0:

            step = (epoch_num * logger.print_freq) + (idx / log_iter)

            # log accuracy and loss
            info = {
                'loss': valid_loss.avg,
                'accuracy': valid_acc.avg,
                'IoU': valid_IoU.avg
            }

            for tag, value in info.items():
                logger.scalar_summary(tag, value, step)

            # log the sample images
            log_img = [
                data_utils.show_tensorboard_image(data['sat_img'],
                                                  data['map_img'],
                                                  outputs,
                                                  as_numpy=True),
            ]
            logger.image_summary('valid_images', log_img, step)

    print(
        'Validation Loss: {:.4f} BCE: {:.4f} DICE: {:.4f} Acc: {:.4f} IoU: {:.4f}'
        .format(valid_loss.avg, valid_BCE.avg, valid_DICE.avg, valid_acc.avg,
                valid_IoU.avg))
    print()

    return {
        'valid_loss': valid_loss.avg,
        'valid_acc': valid_acc.avg,
        'valid_IoU': valid_IoU.avg,
        'valid_BCE': valid_BCE.avg,
        'valid_DICE': valid_DICE.avg
    }
Пример #7
0
def val_model(model):
    test_dataset = MRIdataset.LiverDataset(MRIdataset.test_imagepath,
                                           MRIdataset.test_labelpath,
                                           MRIdataset.testimg_ids,
                                           MRIdataset.testlabel_ids, False)

    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=4)

    model.eval()
    LV_Dice = 0
    RV_Dice = 0
    Myo_Dice = 0
    i = 0

    for img, label, _, _ in test_loader:
        if i < 50:
            img = torch.squeeze(img)
            label = torch.squeeze(label)
            LV_dice = 0
            RV_dice = 0
            Myo_dice = 0
            for z in range(img.shape[0]):
                img_2d = img[z, :, :]
                label_2d = label[z, :, :]
                img_2d = torch.unsqueeze(img_2d, 0)
                img_2d = torch.unsqueeze(img_2d, 0)
                img_2d = img_2d.to(device)
                output = model(img_2d)
                # print(output.shape)
                pred = torch.argmax(output, 1)
                # print(pred.shape, label.shape)
                pred = pred.cpu()

                LV_dice_2d, LV_jac_2d, RV_dice_2d, RV_jac_2d, Myo_dice_2d, Myo_jac_2d = dice_coeff(
                    pred, label_2d)
                LV_dice += LV_dice_2d
                RV_dice += RV_dice_2d
                Myo_dice += Myo_dice_2d

            LV_dice /= img.shape[0]
            RV_dice /= img.shape[0]
            Myo_dice /= img.shape[0]

            LV_Dice += LV_dice
            RV_Dice += RV_dice
            Myo_Dice += Myo_dice

            # print('LV_Dice_%d:' % i, '%.6f' % LV_dice, '||', 'RV_Dice_%d:' % i, '%.6f' % RV_dice, '||'
            #       , 'Myo_Dice_%d:' % i, '%.6f' % Myo_dice)

            i += 1
    print('===============================================')
    LV_Dice_avg = LV_Dice / i
    RV_Dice_avg = RV_Dice / i
    Myo_Dice_avg = Myo_Dice / i
    Mean_metric = (LV_Dice_avg + RV_Dice_avg + Myo_Dice_avg) / 3
    print('Mean_metric:', Mean_metric)

    return Mean_metric
Пример #8
0
def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss
Пример #9
0
def test_model():
    MAX = 0
    save = 0
    # for index in range(100):
    # model = DeepSupervision_U_Net(1, 4).to(device)
    model = ResNetUNet(4).to(device)
    # print(model)
    model.load_state_dict(
        torch.load(
            '/home/laisong/MMs/LS/3dunet_model_save/BEST_weights_150.pth'))
    test_dataset = MRIdataset.LiverDataset(MRIdataset.test_imagepath,
                                           MRIdataset.test_labelpath,
                                           MRIdataset.testimg_ids,
                                           MRIdataset.testlabel_ids, False)

    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=4)

    train_dataset = MRIdataset.LiverDataset(MRIdataset.imagepath,
                                            MRIdataset.labelpath,
                                            MRIdataset.img_ids,
                                            MRIdataset.label_ids, False)

    train_loader = DataLoader(train_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=4)

    B_dataset = MRIdataset.LiverDataset(MRIdataset.B_imagepath,
                                        MRIdataset.B_labelpath,
                                        MRIdataset.Bimg_ids,
                                        MRIdataset.Blabel_ids, False)

    B_loader = DataLoader(B_dataset,
                          batch_size=1,
                          shuffle=False,
                          num_workers=4)

    # train_iter = enumerate(train_loader)
    model.eval()
    LV_Dice = 0
    LV_Jac = 0
    RV_Dice = 0
    RV_Jac = 0
    Myo_Dice = 0
    Myo_Jac = 0
    i = 0

    # _, batch = train_iter.__next__()
    # _, batch = train_iter.__next__()
    # img, label= batch
    # i = 1
    # # print(img.shape)
    # # img =torch.squeeze(img)
    # # img = img.numpy().astype(float)
    # # img = img.transpose(1, 2, 0)
    # # new_image = nib.Nifti1Image(img, np.eye(4))
    # # nib.save(new_image, r'/home/peng/Desktop/CROP/train2test/trainimg_%d.nii.gz' % i)
    #
    # img = img.to(device)
    # pred = torch.argmax(model(img), 1)
    # pred = pred.cpu()
    # pred = torch.squeeze(pred)
    # pred = pred.numpy().astype(float)
    # pred = pred.transpose(1, 2, 0)
    #
    # label = torch.squeeze(label)
    # label = label.numpy().astype(float)
    # label = label.transpose(1, 2, 0)
    #
    #
    # new_image = nib.Nifti1Image(pred, np.eye(4))
    # nib.save(new_image, r'/home/peng/Desktop/CROP/train2test/pred_%d.nii.gz' % i)
    #
    # new_label = nib.Nifti1Image(label, np.eye(4))
    # nib.save(new_label, r'/home/peng/Desktop/CROP/train2test/label_%d.nii.gz' % i)

    for img, label, _, _ in test_loader:
        if i < 50:
            img = torch.squeeze(img)
            label = torch.squeeze(label)
            LV_dice = 0
            RV_dice = 0
            Myo_dice = 0
            for z in range(img.shape[0]):
                img_2d = img[z, :, :]
                label_2d = label[z, :, :]
                img_2d = torch.unsqueeze(img_2d, 0)
                img_2d = torch.unsqueeze(img_2d, 0)
                img_2d = img_2d.to(device)
                output = model(img_2d)
                # print(output.shape)
                pred = torch.argmax(output, 1)
                # print(pred.shape, label.shape)
                pred = pred.cpu()

                LV_dice_2d, LV_jac_2d, RV_dice_2d, RV_jac_2d, Myo_dice_2d, Myo_jac_2d = dice_coeff(
                    pred, label_2d)
                LV_dice += LV_dice_2d
                RV_dice += RV_dice_2d
                Myo_dice += Myo_dice_2d

            LV_dice /= img.shape[0]
            RV_dice /= img.shape[0]
            Myo_dice /= img.shape[0]

            LV_Dice += LV_dice
            RV_Dice += RV_dice
            Myo_Dice += Myo_dice

            print('LV_Dice_%d:' % i, '%.6f' % LV_dice, '||', 'RV_Dice_%d:' % i,
                  '%.6f' % RV_dice, '||', 'Myo_Dice_%d:' % i,
                  '%.6f' % Myo_dice)

            i += 1
    print('===============================================')
    print('LV_Dice_avg:', LV_Dice / i, 'RV_Dice_avg:', RV_Dice / i,
          'Myo_Dice_avg:', Myo_Dice / i)