def test():
    device = torch.device(conf.cuda if torch.cuda.is_available() else "cpu")
    test_dataset = Testinging_Dataset(conf.data_path_test,
                                      conf.test_noise_param,
                                      conf.crop_img_size)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    print('Loading model from: {}'.format(conf.model_path_test))
    model = UNet(in_channels=conf.img_channel, out_channels=conf.img_channel)
    print('loading model')
    model.load_state_dict(torch.load(conf.model_path_test))
    model.eval()
    model.to(device)
    result_dir = conf.denoised_dir
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)
    for batch_idx, (source, img_cropped) in enumerate(test_loader):
        source_img = tvF.to_pil_image(source.squeeze(0))
        img_truth = img_cropped.squeeze(0).numpy().astype(np.uint8)
        source = source.to(device)
        denoised_img = model(source).detach().cpu()

        img_name = test_loader.dataset.image_list[batch_idx]

        denoised_result = tvF.to_pil_image(
            torch.clamp(denoised_img.squeeze(0), 0, 1))
        fname = os.path.splitext(img_name)[0]

        source_img.save(os.path.join(result_dir, f'{fname}-noisy.png'))
        denoised_result.save(os.path.join(result_dir, f'{fname}-denoised.png'))
        io.imsave(os.path.join(result_dir, f'{fname}-ground_truth.png'),
                  img_truth)
Esempio n. 2
0
def main(args):
    if args.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")
    model = UNet(n_channels=args.colordim, n_classes=args.num_class)
    model2 = UNet(n_channels=args.colordim2, n_classes=args.num_class2)
    if args.cuda:
        model = model.cuda()
        model2 = model2.cuda()
    model.load_state_dict(torch.load(args.pretrain_net))
    model2.load_state_dict(torch.load(args.pretrain_net2))
    model.eval()
    model2.eval()
    predDataset = generateDataset(args.pre_root_dir,
                                  args.img_size,
                                  args.colordim,
                                  isTrain=False)
    predLoader = DataLoader(dataset=predDataset,
                            batch_size=args.predictbatchsize,
                            num_workers=args.threads)
    with torch.no_grad():
        cm_w = np.zeros((2, 2))
        for batch_idx, (batch_x, batch_name) in enumerate(predLoader):
            batch_x = batch_x
            if args.cuda:
                batch_x = batch_x.float().cuda()

            out1 = model(batch_x)
            prediction2 = torch.cat((batch_x, out1), 1)
            out = model2(prediction2)
            pred_prop, pred_label = torch.max(out, 1)
            pred_label_np = pred_label.cpu().numpy()
            for id in range(len(batch_name)):
                predLabel_filename = args.preDir + '/' + batch_name[id] + '.png'

                pred_label_single = pred_label_np[id, :, :]
                label_filename = args.label_root_dir + batch_name[id] + '.png'
                label = io.imread(label_filename)
                cm = confusion_matrix(label.ravel(), pred_label_single.ravel())
                pred_label_single = np.where(pred_label_single > 0, 255, 0)
                print(np.max(pred_label_single))
                print(batch_name[id])
                if (np.max(pred_label_single) > 0):
                    io.imsave(predLabel_filename,
                              pred_label_single.astype(np.uint8))
                    #else:
                    #io.imsave(predLabel_filename, pred_label_single.astype(np.int32))
                    cm_w = cm_w + cm
                #OA_s, F1_s, IoU_s = evaluate(cm)
                #print('OA_s = ' + str(OA_s) + ', F1_s = ' + str(F1_s) + ', IoU = ' + str(IoU_s))

        print(cm_w)
        OA_w, F1_w, IoU_w = evaluate(cm_w)
        print('OA_w = ' + str(OA_w) + ', F1_w = ' + str(F1_w) + ', IoU = ' +
              str(IoU_w))
Esempio n. 3
0
        img_type='tif',
        in_size=256)
    test_dataloader = torch.utils.data.DataLoader(
        SRRFDATASET, batch_size=batch_size, shuffle=True,
        pin_memory=True)  # better than for loop

    model = UNet(n_channels=3, n_classes=1)

    print("{} paramerters in total".format(
        sum(x.numel() for x in model.parameters())))
    model.cuda(cuda)
    model.load_state_dict(
        torch.load(
            "/home/star/0_code_lhj/DL-SIM-github/MODELS/UNet_SIM3_microtubule.pkl"
        ))
    model.eval()

    for batch_idx, items in enumerate(test_dataloader):

        image = items['image_in']
        image_name = items['image_name']
        print(image_name[0])
        model.train()

        image = np.swapaxes(image, 1, 3)
        image = np.swapaxes(image, 2, 3)
        image = image.float()
        image = image.cuda(cuda)

        pred = model(image)
        max_out = 15383.0
Esempio n. 4
0
class Train(object):
    def __init__(self, configs):
        self.batch_size = configs.get("batch_size", "16")
        self.epochs = configs.get("epochs", "100")
        self.lr = configs.get("lr", "0.0001")

        device_args = configs.get("device", "cuda")
        self.device = torch.device(
            "cpu" if not torch.cuda.is_available() else device_args)

        self.workers = configs.get("workers", "4")

        self.vis_images = configs.get("vis_images", "200")
        self.vis_freq = configs.get("vis_freq", "10")

        self.weights = configs.get("weights", "./weights")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.logs = configs.get("logs", "./logs")
        if not os.path.exists(self.weights):
            os.mkdir(self.weights)

        self.images_path = configs.get("images_path", "./data")

        self.is_resize = config.get("is_resize", False)
        self.image_short_side = config.get("image_short_side", 256)

        self.is_padding = config.get("is_padding", False)

        is_multi_gpu = config.get("DateParallel", False)

        pre_train = config.get("pre_train", False)
        model_path = config.get("model_path", './weights/unet_idcard_adam.pth')

        # self.image_size = configs.get("image_size", "256")
        # self.aug_scale = configs.get("aug_scale", "0.05")
        # self.aug_angle = configs.get("aug_angle", "15")

        self.step = 0

        self.dsc_loss = DiceLoss()
        self.model = UNet(in_channels=Dataset.in_channels,
                          out_channels=Dataset.out_channels)
        if pre_train:
            self.model.load_state_dict(torch.load(model_path,
                                                  map_location=self.device),
                                       strict=False)

        if is_multi_gpu:
            self.model = nn.DataParallel(self.model)

        self.model.to(self.device)

        self.best_validation_dsc = 0.0

        self.loader_train, self.loader_valid = self.data_loaders()

        self.params = [p for p in self.model.parameters() if p.requires_grad]

        self.optimizer = optim.Adam(self.params,
                                    lr=self.lr,
                                    weight_decay=0.0005)
        # self.optimizer = torch.optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=0.0005)
        self.scheduler = lr_scheduler.LR_Scheduler_Head(
            'poly', self.lr, self.epochs, len(self.loader_train))

    def datasets(self):
        train_datasets = Dataset(
            images_dir=self.images_path,
            # image_size=self.image_size,
            subset="train",  # train
            transform=get_transforms(train=True),
            is_resize=self.is_resize,
            image_short_side=self.image_short_side,
            is_padding=self.is_padding)
        # valid_datasets = train_datasets

        valid_datasets = Dataset(
            images_dir=self.images_path,
            # image_size=self.image_size,
            subset="validation",  # validation
            transform=get_transforms(train=False),
            is_resize=self.is_resize,
            image_short_side=self.image_short_side,
            is_padding=False)
        return train_datasets, valid_datasets

    def data_loaders(self):
        dataset_train, dataset_valid = self.datasets()

        loader_train = DataLoader(
            dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.workers,
        )
        loader_valid = DataLoader(
            dataset_valid,
            batch_size=1,
            drop_last=False,
            num_workers=self.workers,
        )

        return loader_train, loader_valid

    @staticmethod
    def dsc_per_volume(validation_pred, validation_true):
        assert len(validation_pred) == len(validation_true)
        dsc_list = []
        for p in range(len(validation_pred)):
            y_pred = np.array([validation_pred[p]])
            y_true = np.array([validation_true[p]])
            dsc_list.append(dsc(y_pred, y_true))
        return dsc_list

    @staticmethod
    def get_logger(filename, verbosity=1, name=None):
        level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
        formatter = logging.Formatter(
            "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
        )
        logger = logging.getLogger(name)
        logger.setLevel(level_dict[verbosity])

        fh = logging.FileHandler(filename, "w")
        fh.setFormatter(formatter)
        logger.addHandler(fh)

        sh = logging.StreamHandler()
        sh.setFormatter(formatter)
        logger.addHandler(sh)

        return logger

    def train_one_epoch(self, epoch):

        self.model.train()
        loss_train = []
        for i, data in enumerate(self.loader_train):
            self.scheduler(self.optimizer, i, epoch, self.best_validation_dsc)
            x, y_true = data
            x, y_true = x.to(self.device), y_true.to(self.device)

            y_pred = self.model(x)
            # print('1111', y_pred.size())
            # print('2222', y_true.size())
            loss = self.dsc_loss(y_pred, y_true)

            loss_train.append(loss.item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # lr_scheduler.step()
            if self.step % 200 == 0:
                print('Epoch:[{}/{}]\t iter:[{}]\t loss={:.5f}\t '.format(
                    epoch, self.epochs, i, loss))

            self.step += 1

    def eval_model(self, patience):
        self.model.eval()
        loss_valid = []

        validation_pred = []
        validation_true = []
        # early_stopping = EarlyStopping(patience=patience, verbose=True)

        for i, data in enumerate(self.loader_valid):
            x, y_true = data
            x, y_true = x.to(self.device), y_true.to(self.device)

            # print(x.size())
            # print(333,x[0][2])
            with torch.no_grad():
                y_pred = self.model(x)
                loss = self.dsc_loss(y_pred, y_true)

            # print(y_pred.shape)
            mask = y_pred > 0.5
            mask = mask * 255
            mask = mask.cpu().numpy()[0][0]
            # print(mask)
            # print(mask.shape())
            cv2.imwrite('result.png', mask)

            loss_valid.append(loss.item())

            y_pred_np = y_pred.detach().cpu().numpy()

            validation_pred.extend(
                [y_pred_np[s] for s in range(y_pred_np.shape[0])])
            y_true_np = y_true.detach().cpu().numpy()
            validation_true.extend(
                [y_true_np[s] for s in range(y_true_np.shape[0])])

        # early_stopping(loss_valid, self.model)
        # if early_stopping.early_stop:
        #     print('Early stopping')
        #     import sys
        #     sys.exit(1)
        mean_dsc = np.mean(
            self.dsc_per_volume(
                validation_pred,
                validation_true,
            ))
        # print('mean_dsc:', mean_dsc)
        if mean_dsc > self.best_validation_dsc:
            self.best_validation_dsc = mean_dsc
            torch.save(self.model.state_dict(),
                       os.path.join(self.weights, "unet_xia_adam.pth"))
            print("Best validation mean DSC: {:4f}".format(
                self.best_validation_dsc))

    def main(self):
        # print('train is begin.....')
        # print('load data end.....')

        # loaders = {"train": loader_train, "valid": loader_valid}

        for epoch in tqdm(range(self.epochs), total=self.epochs):
            self.train_one_epoch(epoch)
            self.eval_model(patience=10)

        torch.save(self.model.state_dict(),
                   os.path.join(self.weights, "unet_final.pth"))
Esempio n. 5
0
    LE_512 = cropImage(LE_img, IMG_SHAPE[0],IMG_SHAPE[1])
    sample_le = {}
    for le_512 in LE_512:
        tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE)
        for n,img in enumerate(tiles):
            if n not in sample_le:
                sample_le[n] = []
            img = transform.resize(img,(IMG_SIZE*2, IMG_SIZE*2),preserve_range=True,order=3)
            sample_le[n].append(img)

	SNR_model = UNet(n_channels=15, n_classes=15)
	print("{} paramerters in total".format(sum(x.numel() for x in SNR_model.parameters())))
	SNR_model.cuda(cuda)
	SNR_model.load_state_dict(torch.load(SNR_model_path))
	# SNR_model.load_state_dict(torch.load(os.path.join(dir_path,"model","LE_HE_mito","LE_HE_0825.pkl")))
	SNR_model.eval()

	SIM_UNET = UNet(n_channels=15, n_classes=1)
	print("{} paramerters in total".format(sum(x.numel() for x in SIM_UNET.parameters())))
	SIM_UNET.cuda(cuda)
	SIM_UNET.load_state_dict(torch.load(SIM_UNET_model_path))
	# SIM_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl")))
	SIM_UNET.eval()

    SRRFDATASET = ReconsDataset(
    img_dict=sample_le,
    transform=ToTensor(),
    in_norm = LE_in_norm,
    img_type=".tif",
    in_size=256
    )
Esempio n. 6
0
def train(cont=False):

    # for tensorboard tracking
    logger = get_logger()
    logger.info("(1) Initiating Training ... ")
    logger.info("Training on device: {}".format(device))
    writer = SummaryWriter()

    # init model
    aux_layers = None
    if net == "SETR-PUP":
        aux_layers, model = get_SETR_PUP()
    elif net == "SETR-MLA":
        aux_layers, model = get_SETR_MLA()
    elif net == "TransUNet-Base":
        model = get_TransUNet_base()
    elif net == "TransUNet-Large":
        model = get_TransUNet_large()
    elif net == "UNet":
        model = UNet(CLASS_NUM)

    # prepare dataset
    cluster_model = get_clustering_model(logger)
    train_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="train",
                                     cluster_model=cluster_model)
    valid_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="val",
                                     cluster_model=cluster_model)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False)

    logger.info("(2) Dataset Initiated. ")

    # optimizer
    epochs = epoch_num if epoch_num > 0 else iteration_num // len(
        train_loader) + 1
    optim = SGD(model.parameters(),
                lr=lrate,
                momentum=momentum,
                weight_decay=wdecay)
    # optim = Adam(model.parameters(), lr=lrate)
    scheduler = lr_scheduler.MultiStepLR(
        optim, milestones=[int(epochs * fine_tune_ratio)], gamma=0.1)

    cur_epoch = 0
    best_loss = float('inf')
    epochs_since_improvement = 0

    # for continue training
    if cont:
        model, optim, cur_epoch, best_loss = load_ckpt_continue_training(
            best_ckpt_src, model, optim, logger)
        logger.info("Current best loss: {0}".format(best_loss))
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i in range(cur_epoch):
                scheduler.step()
    else:
        model = nn.DataParallel(model)
        model = model.to(device)

    logger.info("(3) Model Initiated ... ")
    logger.info("Training model: {}".format(net) + ". Training Started.")

    # loss
    ce_loss = CrossEntropyLoss()
    if use_dice_loss:
        dice_loss = DiceLoss(CLASS_NUM)

    # loop over epochs
    iter_count = 0
    epoch_bar = tqdm.tqdm(total=epochs,
                          desc="Epoch",
                          position=cur_epoch,
                          leave=True)
    logger.info("Total epochs: {0}. Starting from epoch {1}.".format(
        epochs, cur_epoch + 1))

    for e in range(epochs - cur_epoch):
        epoch = e + cur_epoch

        # Training.
        model.train()
        trainLossMeter = LossMeter()
        train_batch_bar = tqdm.tqdm(total=len(train_loader),
                                    desc="TrainBatch",
                                    position=0,
                                    leave=True)

        for batch_num, (orig_img, mask_img) in enumerate(train_loader):
            orig_img, mask_img = orig_img.float().to(
                device), mask_img.float().to(device)

            if net == "TransUNet-Base" or net == "TransUNet-Large":
                pred = model(orig_img)
            elif net == "SETR-PUP" or net == "SETR-MLA":
                if aux_layers is not None:
                    pred, _ = model(orig_img)
                else:
                    pred = model(orig_img)
            elif net == "UNet":
                pred = model(orig_img)

            loss_ce = ce_loss(pred, mask_img[:].long())
            if use_dice_loss:
                loss_dice = dice_loss(pred, mask_img, softmax=True)
                loss = 0.5 * (loss_ce + loss_dice)
            else:
                loss = loss_ce

            # Backward Propagation, Update weight and metrics
            optim.zero_grad()
            loss.backward()
            optim.step()

            # update learning rate
            for param_group in optim.param_groups:
                orig_lr = param_group['lr']
                param_group['lr'] = orig_lr * (1.0 -
                                               iter_count / iteration_num)**0.9
            iter_count += 1

            # Update loss
            trainLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Epoch: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(train_loader), loss=trainLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Train_Loss_{0}'.format(tensorboard_freq),
                    trainLossMeter.avg,
                    epoch * (len(train_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            train_batch_bar.update(1)

        writer.add_scalar('Train_Loss_epoch', trainLossMeter.avg, epoch)

        # Validation.
        model.eval()
        validLossMeter = LossMeter()
        valid_batch_bar = tqdm.tqdm(total=len(valid_loader),
                                    desc="ValidBatch",
                                    position=0,
                                    leave=True)
        with torch.no_grad():
            for batch_num, (orig_img, mask_img) in enumerate(valid_loader):
                orig_img, mask_img = orig_img.float().to(
                    device), mask_img.float().to(device)

                if net == "TransUNet-Base" or net == "TransUNet-Large":
                    pred = model(orig_img)
                elif net == "SETR-PUP" or net == "SETR-MLA":
                    if aux_layers is not None:
                        pred, _ = model(orig_img)
                    else:
                        pred = model(orig_img)
                elif net == "UNet":
                    pred = model(orig_img)

                loss_ce = ce_loss(pred, mask_img[:].long())
                if use_dice_loss:
                    loss_dice = dice_loss(pred, mask_img, softmax=True)
                    loss = 0.5 * (loss_ce + loss_dice)
                else:
                    loss = loss_ce

                # Update loss
                validLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Validation: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(valid_loader), loss=validLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Valid_Loss_{0}'.format(tensorboard_freq),
                    validLossMeter.avg,
                    epoch * (len(valid_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            valid_batch_bar.update(1)

        valid_loss = validLossMeter.avg
        writer.add_scalar('Valid_Loss_epoch', valid_loss, epoch)
        logger.info("Validation Loss of epoch [{0}/{1}]: {2}\n".format(
            epoch + 1, epochs, valid_loss))

        # update optim scheduler
        scheduler.step()

        # save checkpoint
        is_best = valid_loss < best_loss
        best_loss_tmp = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            logger.info("Epochs since last improvement: %d\n" %
                        (epochs_since_improvement, ))
            if epochs_since_improvement == early_stop_tolerance:
                break  # early stopping.
        else:
            epochs_since_improvement = 0
            state = {
                'epoch': epoch,
                'loss': best_loss_tmp,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
            }
            torch.save(state, ckpt_src)
            logger.info("Checkpoint updated.")
            best_loss = best_loss_tmp
        epoch_bar.update(1)
    writer.close()
Esempio n. 7
0
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid, save_image
from PIL import Image
import os
import numpy as np
from PIL import Image
from unet_model import UNet
import random

input_dir = "../Test_Data/"
output_dir = "../Generated_Test/"
model_path = "models/model_gen_latest"

generator = UNet(n_channels=3, n_classes=2)
generator.load_state_dict(torch.load(model_path))
generator.eval()

for filename in random.sample(os.listdir(input_dir),
                              len(os.listdir(input_dir))):

    img = Image.open(os.path.join(input_dir, filename))
    # img = normalize(img)
    img = torch.stack([
        transforms.Compose(
            [transforms.Resize((75, 210)),
             transforms.ToTensor()])(img)
    ])

    output_img = generator(img)
    save_image(output_img, output_dir + "/" + filename)
    print(filename)
    img = transform(img)
    img = img.unsqueeze(0)


    def get_layer_param(model):
        return sum([torch.numel(param) for param in model.parameters()])


    net = UNet(1, 3).to(device)
    print(net)
    print('parameters:', get_layer_param(net))

    print("Loading checkpoint...")
    checkpoint = torch.load(ckpt_path)
    net.load_state_dict(checkpoint['net_state_dict'])
    net.eval()

    print("Starting Test...")
    # -----------------------------------------------------------
    # Initial batch
    data_A = img.to(device)
    # -----------------------------------------------------------
    # Generate fake img:
    fake_B = net(data_A)
    # -----------------------------------------------------------
    # Output training stats
    # vutils.save_image(data_A, os.path.join(samples_path, 'result', '%s_data_A.jpg' % str(i).zfill(6)),
    #                   padding=2, nrow=2, normalize=True)
    vutils.save_image(fake_B, os.path.join('./', '%s_fake_B_leaky.jpg' % filename[0:-4]),
                      padding=0, nrow=1, normalize=True)
Esempio n. 9
0
class OnePredict(object):
    def __init__(self, params):
        self.params = params

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.model_path = params['model_path']

        self.model = UNet(in_channels=3, out_channels=1)

        self.threshold = 0.5

        self.resume()
        # self.model.eval()

        self.transform = get_transforms_3()

        self.is_resize = True
        self.image_short_side = 1024
        self.init_torch_tensor()
        self.model.eval()

    def init_torch_tensor(self):
        torch.set_default_tensor_type('torch.FloatTensor')
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.device = torch.device('cpu')
        # self.model.to(self.device)

    def resume(self):
        self.model.load_state_dict(torch.load(self.model_path, map_location=self.device), strict=False)
        self.model.to(self.device)

    def resize_img(self, img):
        '''输入PIL格式的图片'''
        width, height = img.size
        # print('111', img.size)
        if self.is_resize:
            if height < width:
                new_height = self.image_short_side
                new_width = int(math.ceil(new_height / height * width / 32) * 32)
            else:
                new_width = self.image_short_side
                new_height = int(math.ceil(new_width / width * height / 32) * 32)
        else:
            if height < width:
                scale = int(height / 32)
                new_image_short_side = scale * 32
                new_height = new_image_short_side
                new_width = int(math.ceil(new_height / height * width / 32) * 32)
            else:
                scale = int(width / 32)
                new_image_short_side = scale * 32
                new_width = new_image_short_side
                new_height = int(math.ceil(new_width / width * height / 32) * 32)
        # print('test1:', np.array(img))
        # print('new:', (new_width, new_height))
        resized_img = img.resize((new_width, new_height), Image.ANTIALIAS)
        # print(new_height, new_width)
        # print('test2:', np.array(resized_img))
        return resized_img

    def format_output(self):
        pass

    @staticmethod
    def pre_process(img):
        return img

    @staticmethod
    def pad_sample(img):
        a = img.size[0]
        b = img.size[1]
        if a == b:
            return img
        diff = (max(a, b) - min(a, b)) / 2.0
        if a > b:
            padding = (0, int(np.floor(diff)), 0, int(np.ceil(diff)))
        else:
            padding = (int(np.floor(diff)), 0, int(np.ceil(diff)), 0)

        img = ImageOps.expand(img, border=padding, fill=0)  ##left,top,right,bottom

        assert img.size[0] == img.size[1]
        return img

    def post_process(self, preds, img):
        mask = preds > self.threshold
        mask = mask * 255
        # print(mask.size())
        mask = mask.cpu().numpy()[0][0]
        # print(mask)
        # print(mask.shape())
        cv2.imwrite('mask.png', mask)

        mask = np.array(mask, np.uint8)

        contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        # print(contours)

        # img = img.cpu()
        img = np.array(img, np.uint8)

        cv2.drawContours(img, contours, -1, (0, 0, 255), 1)

        cv2.imwrite('result2.png', img)
        boxes = []

        return boxes

    @staticmethod
    def demo_visualize():
        pass

    def inference(self, img_path, is_visualize=True, is_format_output=False):
        img = cv2.imread(img_path, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img).convert("RGB")
        # img = Image.open(img_path).convert("RGB")
        # print('222', np.array(img))
        # img = self.pad_sample(img)
        img = self.resize_img(img)
        # print('333', img.size)
        # print('-----', np.array(img))
        ori_img = img
        img.save('img.png')
        # img = [img]
        print('111', np.array(img))
        img = self.transform(img)
        print('222', np.array(img))
        img = img.unsqueeze(0)
        img = img.to(self.device)
        # print('1111', img.size())
        # print(img)

        # print(img)
        with torch.no_grad():
            s1 = time.time()
            preds = self.model(img)
            print(preds)
            s2 = time.time()
            print(s2 - s1)
            # boxes, scores = SegDetectorRepresenter().represent(pred=preds, height=h, width=w, is_output_polygon=False)
            boxes = self.post_process(preds, ori_img)
Esempio n. 10
0
class UNetObjPrior(nn.Module):
    """ 
    Wrapper around UNet that takes object priors (gaussians) and images 
    as input.
    """
    def __init__(self, params, depth=5):
        super(UNetObjPrior, self).__init__()
        self.in_channels = 4
        self.model = UNet(1, self.in_channels, depth, cuda=params['cuda'])
        self.params = params
        self.device = torch.device('cuda' if params['cuda'] else 'cpu')

    def forward(self, im, obj_prior):
        x = torch.cat((im, obj_prior), dim=1)
        return self.model(x)

    def train(self, dataloader_train, dataloader_val):

        since = time.time()
        best_loss = float("inf")

        dataloader_train.mode = 'train'
        dataloader_val.mode = 'val'
        dataloaders = {'train': dataloader_train, 'val': dataloader_val}

        optimizer = optim.SGD(self.model.parameters(),
                              momentum=self.params['momentum'],
                              lr=self.params['lr'],
                              weight_decay=self.params['weight_decay'])

        train_logger = LossLogger('train', self.params['batch_size'],
                                  len(dataloader_train),
                                  self.params['out_dir'])

        val_logger = LossLogger('val', self.params['batch_size'],
                                len(dataloader_val), self.params['out_dir'])

        loggers = {'train': train_logger, 'val': val_logger}

        # self.criterion = WeightedMSE(dataloader_train.get_classes_weights(),
        #                              cuda=self.params['cuda'])
        self.criterion = nn.MSELoss()

        for epoch in range(self.params['num_epochs']):
            print('Epoch {}/{}'.format(epoch, self.params['num_epochs'] - 1))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    #scheduler.step()
                    self.model.train()
                else:
                    self.model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                samp = 1
                for i, data in enumerate(dataloaders[phase]):
                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        out = self.forward(data.image, data.obj_prior)
                        loss = self.criterion(out, data.truth)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    loggers[phase].update(epoch, samp, loss.item())

                    samp += 1

                loggers[phase].print_epoch(epoch)

                # Generate train prediction for check
                if phase == 'train':
                    path = os.path.join(self.params['out_dir'], 'previews',
                                        'epoch_{:04d}.jpg'.format(epoch))
                    data = dataloaders['val'].sample_uniform()
                    pred = self.forward(data.image, data.obj_prior)
                    im_ = data.image[0]
                    truth_ = data.truth[0]
                    pred_ = pred[0, ...]
                    utls.save_tensors(im_, pred_, truth_, path)

                if phase == 'val' and (loggers['val'].get_loss(epoch) <
                                       best_loss):
                    best_loss = loggers['val'].get_loss(epoch)

                loggers[phase].save('log_{}.csv'.format(phase))

                # save checkpoint
                if phase == 'val':
                    is_best = loggers['val'].get_loss(epoch) <= best_loss
                    path = os.path.join(self.params['out_dir'],
                                        'checkpoint.pth.tar')
                    utls.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': self.model.state_dict(),
                            'best_loss': best_loss,
                            'optimizer': optimizer.state_dict()
                        },
                        is_best,
                        path=path)

    def load_checkpoint(self, path, device='gpu'):

        if (device != 'gpu'):
            checkpoint = torch.load(path,
                                    map_location=lambda storage, loc: storage)
        else:
            checkpoint = torch.load(path)

        self.model.load_state_dict(checkpoint['state_dict'])
def test(experiment_path, test_epoch):
    # ========= CONFIG FILE TO READ FROM =======
    config = configparser.RawConfigParser()
    config.read('./' + experiment_path + '/' + experiment_path + '_config.txt')
    # ===========================================
    # run the training on invariant or local
    path_data = config.get('data paths', 'path_local')
    model = config.get('training settings', 'model')
    # original test images (for FOV selection)
    DRIVE_test_imgs_original = path_data + config.get('data paths', 'test_imgs_original')
    test_imgs_orig = load_hdf5(DRIVE_test_imgs_original)
    full_img_height = test_imgs_orig.shape[2]
    full_img_width = test_imgs_orig.shape[3]
    # the border masks provided by the DRIVE
    DRIVE_test_border_masks = path_data + config.get('data paths', 'test_border_masks')
    test_border_masks = load_hdf5(DRIVE_test_border_masks)
    # dimension of the patches
    patch_height = int(config.get('data attributes', 'patch_height'))
    patch_width = int(config.get('data attributes', 'patch_width'))
    # the stride in case output with average
    stride_height = int(config.get('testing settings', 'stride_height'))
    stride_width = int(config.get('testing settings', 'stride_width'))
    assert (stride_height < patch_height and stride_width < patch_width)
    # model name
    name_experiment = config.get('experiment name', 'name')
    path_experiment = './' + name_experiment + '/'
    # N full images to be predicted
    Imgs_to_test = int(config.get('testing settings', 'full_images_to_test'))
    # Grouping of the predicted images
    N_visual = int(config.get('testing settings', 'N_group_visual'))
    # ====== average mode ===========
    average_mode = config.getboolean('testing settings', 'average_mode')
    #N_subimgs = int(config.get('training settings', 'N_subimgs'))
    #batch_size = int(config.get('training settings', 'batch_size'))
    #epoch_size = N_subimgs // (batch_size)
    # #ground truth
    # gtruth= path_data + config.get('data paths', 'test_groundTruth')
    # img_truth= load_hdf5(gtruth)
    # visualize(group_images(test_imgs_orig[0:20,:,:,:],5),'original')#.show()
    # visualize(group_images(test_border_masks[0:20,:,:,:],5),'borders')#.show()
    # visualize(group_images(img_truth[0:20,:,:,:],5),'gtruth')#.show()

    # ============ Load the data and divide in patches
    patches_imgs_test = None
    new_height = None
    new_width = None
    masks_test = None
    patches_masks_test = None

    if average_mode == True:
        patches_imgs_test, new_height, new_width, masks_test= get_data_testing_overlap(
            DRIVE_test_imgs_original = DRIVE_test_imgs_original, #original'DRIVE_datasets_training_testing/test_hard_masks.npy'
            DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'),  #masks
            Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')),
            patch_height = patch_height,
            patch_width = patch_width,
            stride_height = stride_height,
            stride_width = stride_width)
    else:
        patches_imgs_test, patches_masks_test = get_data_testing_test(
            DRIVE_test_imgs_original = DRIVE_test_imgs_original,  #original
            DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'),  #masks
            Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')),
            patch_height = patch_height,
            patch_width = patch_width
        )
    #np.save(path_experiment + 'test_patches.npy', patches_imgs_test)
    #visualize(group_images(patches_imgs_test,100),'./'+name_experiment+'/'+"test_patches")

    # ================ Run the prediction of the patches ==================================
    best_last = config.get('testing settings', 'best_last')
    # Load the saved model
    if model == 'UNet':
        net = UNet(n_channels=1, n_classes=2)
    elif model == 'UNet_cat':
        net = UNet_cat(n_channels=1, n_classes=2)
    else:
        net = UNet_level4_our(n_channels=1, n_classes=2)
    # load data
    test_data = data.TensorDataset(torch.tensor(patches_imgs_test),torch.zeros(patches_imgs_test.shape[0]))
    test_loader = data.DataLoader(test_data, batch_size=1, pin_memory=True, shuffle=False)
    trained_model = path_experiment + 'DRIVE_' + str(test_epoch) + 'epoch.pth'
    print(trained_model)
    # trained_model= path_experiment+'DRIVE_unet2_B'+str(60*epoch_size)+'.pth'
    net.load_state_dict(torch.load(trained_model))
    net.eval()
    print('Finished loading model :' + trained_model)
    net = net.cuda()
    cudnn.benchmark = True
    # Calculate the predictions
    predictions_out = np.empty((patches_imgs_test.shape[0],patch_height*patch_width,2))
    for i_batch, (images, targets) in enumerate(test_loader):
        images = Variable(images.float().cuda())
        out1= net(images)

        pred = out1.permute(0,2,3,1)

        pred = F.softmax(pred, dim=-1)

        pred = pred.data.view(-1,patch_height*patch_width,2)

        predictions_out[i_batch] = pred

    # ===== Convert the prediction arrays in corresponding images
    pred_patches_out = pred_to_imgs(predictions_out, patch_height, patch_width, "original")
    #np.save(path_experiment + 'pred_patches_' + str(test_epoch) + "_epoch" + '.npy', pred_patches_out)
    #visualize(group_images(pred_patches_out,100),'./'+name_experiment+'/'+"pred_patches")


    #========== Elaborate and visualize the predicted images ====================
    pred_imgs_out = None
    orig_imgs = None
    gtruth_masks = None
    if average_mode == True:
        pred_imgs_out = recompone_overlap(pred_patches_out,new_height,new_width, stride_height, stride_width)
        orig_imgs = my_PreProc(test_imgs_orig[0:pred_imgs_out.shape[0],:,:,:])    #originals
        gtruth_masks = masks_test  #ground truth masks
    else:
        pred_imgs_out = recompone(pred_patches_out,10,9)       # predictions
        orig_imgs = recompone(patches_imgs_test,10,9)  # originals
        gtruth_masks = recompone(patches_masks_test,10,9)  #masks

    # apply the DRIVE masks on the repdictions #set everything outside the FOV to zero!!
    # DRIVE MASK  #only for visualization
    kill_border(pred_imgs_out, test_border_masks)
    # back to original dimensions
    orig_imgs = orig_imgs[:,:,0:full_img_height,0:full_img_width]
    pred_imgs_out = pred_imgs_out[:, :, 0:full_img_height, 0:full_img_width]
    gtruth_masks = gtruth_masks[:, :, 0:full_img_height, 0:full_img_width]

    print ("Orig imgs shape: "+str(orig_imgs.shape))
    print("pred imgs shape: " + str(pred_imgs_out.shape))
    print("Gtruth imgs shape: " + str(gtruth_masks.shape))
    np.save(path_experiment + 'pred_img_' + str(test_epoch) + "_epoch" + '.npy',pred_imgs_out)
    # visualize(group_images(orig_imgs,N_visual),path_experiment+"all_originals")#.show()
    if average_mode == True:
        visualize(group_images(pred_imgs_out, N_visual),
                  path_experiment + "all_predictions_" + str(test_epoch) + "thresh_epoch")
    else:
        visualize(group_images(pred_imgs_out, N_visual),
                  path_experiment + "all_predictions_" + str(test_epoch) + "epoch_no_average")
    visualize(group_images(gtruth_masks, N_visual), path_experiment + "all_groundTruths")

    # visualize results comparing mask and prediction:
    # assert (orig_imgs.shape[0] == pred_imgs_out.shape[0] and orig_imgs.shape[0] == gtruth_masks.shape[0])
    # N_predicted = orig_imgs.shape[0]
    # group = N_visual
    # assert (N_predicted%group == 0)
    

    # ====== Evaluate the results
    print("\n\n========  Evaluate the results =======================")
   
    # predictions only inside the FOV
    y_scores, y_true = pred_only_FOV(pred_imgs_out, gtruth_masks, test_border_masks)  # returns data only inside the FOV
    '''
    print("Calculating results only inside the FOV:")
    print("y scores pixels: " + str(
        y_scores.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str(
        pred_imgs_out.shape[0] * pred_imgs_out.shape[2] * pred_imgs_out.shape[3]) + " (584*565==329960)")
    print("y true pixels: " + str(
        y_true.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str(
        gtruth_masks.shape[2] * gtruth_masks.shape[3] * gtruth_masks.shape[0]) + " (584*565==329960)")
    '''
    # Area under the ROC curve
    fpr, tpr, thresholds = roc_curve((y_true), y_scores)
    AUC_ROC = roc_auc_score(y_true, y_scores)
    # test_integral = np.trapz(tpr,fpr) #trapz is numpy integration
    print("\nArea under the ROC curve: " + str(AUC_ROC))
    rOc_curve = plt.figure()
    plt.plot(fpr, tpr, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_ROC)
    plt.title('ROC curve')
    plt.xlabel("FPR (False Positive Rate)")
    plt.ylabel("TPR (True Positive Rate)")
    plt.legend(loc="lower right")
    plt.savefig(path_experiment + "ROC.png")

    # Precision-recall curve
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    precision = np.fliplr([precision])[0]  # so the array is increasing (you won't get negative AUC)
    recall = np.fliplr([recall])[0]  # so the array is increasing (you won't get negative AUC)
    AUC_prec_rec = np.trapz(precision, recall)
    print("\nArea under Precision-Recall curve: " + str(AUC_prec_rec))
    prec_rec_curve = plt.figure()
    plt.plot(recall, precision, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_prec_rec)
    plt.title('Precision - Recall curve')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="lower right")
    plt.savefig(path_experiment + "Precision_recall.png")

    # Confusion matrix
    threshold_confusion = 0.5
    print("\nConfusion matrix:  Custom threshold (for positive) of " + str(threshold_confusion))
    y_pred = np.empty((y_scores.shape[0]))
    for i in range(y_scores.shape[0]):
        if y_scores[i] >= threshold_confusion:
            y_pred[i] = 1
        else:
            y_pred[i] = 0
    confusion = confusion_matrix(y_true, y_pred)
    print(confusion)
    accuracy = 0
    if float(np.sum(confusion)) != 0:
        accuracy = float(confusion[0, 0] + confusion[1, 1]) / float(np.sum(confusion))
    print("Global Accuracy: " + str(accuracy))
    specificity = 0
    if float(confusion[0, 0] + confusion[0, 1]) != 0:
        specificity = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])
    print("Specificity: " + str(specificity))
    sensitivity = 0
    if float(confusion[1, 1] + confusion[1, 0]) != 0:
        sensitivity = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])
    print("Sensitivity: " + str(sensitivity))
    precision = 0
    if float(confusion[1, 1] + confusion[0, 1]) != 0:
        precision = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[0, 1])
    print("Precision: " + str(precision))

    # Jaccard similarity index
    jaccard_index = jaccard_similarity_score(y_true, y_pred, normalize=True)
    print("\nJaccard similarity score: " + str(jaccard_index))

    # F1 score
    F1_score = f1_score(y_true, y_pred, labels=None, average='binary', sample_weight=None)
    print("\nF1 score (F-measure): " + str(F1_score))
    ####evaluate the thin vessels
    thin_3pixel_recall_indivi = []
    thin_3pixel_auc_roc = []
    for j in range(pred_imgs_out.shape[0]):
        thick3=opening(gtruth_masks[j, 0, :, :], square(3))
        thin_gt = gtruth_masks[j, 0, :, :] - thick3
        
        thin_pred=pred_imgs_out[j, 0, :, :]
        
        thin_pred[thick3==1]=0
        thin_3pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4))
        thin_3pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4))
    thin_2pixel_recall_indivi = []
    thin_2pixel_auc_roc = []
    for j in range(pred_imgs_out.shape[0]):
        thick=opening(gtruth_masks[j, 0, :, :], square(2))
        thin_gt = gtruth_masks[j, 0, :, :] - thick
        #thin_gt_only=thin_gt[thin_gt==1]
        #print(thin_gt_only)
        thin_pred=pred_imgs_out[j, 0, :, :]
        #thin_pred=thin_pred[thin_gt==1]
        thin_pred[thick==1]=0
        thin_2pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4))
        thin_2pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4))
    
    #print("thin 2vessel recall:", thin_2pixel_recall_indivi)
    #print('thin 2vessel auc score', thin_2pixel_auc_roc)
    # Save the results
    with open(path_experiment + 'test_performances_all_epochs.txt', mode='a') as f:
        f.write("\n\n" + path_experiment + " test epoch:" + str(test_epoch)
                + '\naverage mode is:' + str(average_mode)
                + "\nArea under the ROC curve: %.4f" % (AUC_ROC)
                + "\nArea under Precision-Recall curve: %.4f" % (AUC_prec_rec)
                + "\nJaccard similarity score: %.4f" % (jaccard_index)
                + "\nF1 score (F-measure): %.4f" % (F1_score)
                + "\nConfusion matrix:"
                + str(confusion)
                + "\nACCURACY: %.4f" % (accuracy)
                + "\nSENSITIVITY: %.4f" % (sensitivity)
                + "\nSPECIFICITY: %.4f" % (specificity)
                + "\nPRECISION: %.4f" % (precision)
                + "\nthin 2vessels recall indivi:\n" + str(thin_2pixel_recall_indivi)
                + "\nthin 2vessels recall mean:%.4f" % (np.mean(thin_2pixel_recall_indivi))
                + "\nthin 2vessels auc indivi:\n" + str(thin_2pixel_auc_roc)
                + "\nthin 2vessels auc score mean:%.4f" % (np.mean(thin_2pixel_auc_roc))
                + "\nthin 3vessels recall indivi:\n" + str(thin_3pixel_recall_indivi)
                + "\nthin 3vessels recall mean:%.4f" % (np.mean(thin_3pixel_recall_indivi))
                + "\nthin 3vessels auc indivi:\n" + str(thin_3pixel_auc_roc)
                + "\nthin 3vessels auc score mean:%.4f" % (np.mean(thin_3pixel_auc_roc))
                )
Esempio n. 12
0
        tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE)
        for n, img in enumerate(tiles):
            if n not in sample_le:
                sample_le[n] = []
            img = transform.resize(img, (IMG_SIZE * 2, IMG_SIZE * 2),
                                   preserve_range=True,
                                   order=3)
            sample_le[n].append(img)

    SC_UNET = UNet(n_channels=15, n_classes=1)
    print("{} paramerters in total".format(
        sum(x.numel() for x in SC_UNET.parameters())))
    SC_UNET.cuda(cuda)
    SC_UNET.load_state_dict(torch.load(model_path))
    # SC_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl")))
    SC_UNET.eval()

    SRRFDATASET = ReconsDataset(img_dict=sample_he,
                                transform=ToTensor(),
                                in_norm=LE_in_norm,
                                img_type=".tif",
                                in_size=256)
    test_dataloader = torch.utils.data.DataLoader(
        SRRFDATASET, batch_size=1, shuffle=False,
        pin_memory=True)  # better than for loop
    result = np.zeros((256, 256, len(SRRFDATASET)))
    for batch_idx, items in enumerate(test_dataloader):
        image = items['image_in']
        image_idx = items['image_name']

        image = np.swapaxes(image, 1, 3)